SGLang 的 Expert Parallel 特性解读

0x0. 前言

最近在 SGlang 仓库下做了一段时间的开发和学习,对 SGLang 的一些比较新的 Feature 也开始有一些了解。这篇文章就是尝试来梳理一下 SGLang 中 Expert Parallel 的实现,据我所知 SGlang 应该是开源推理框架中率先实现 Expert Parallel 的。我们可以学习一下它是如何实现的,以及它相比于普通的 EP 主要优化点在哪。SGLang 在 https://github.com/sgl-project/sglang/pull/2371 中实现了 Expert Parallel,我们从这里看就行。如果对 MoE EP 不熟悉可以参考 https://zhuanlan.zhihu.com/p/681154742 这篇文章或者阅读 DeepSeek 相关的资料。

0x1. 上层的接口

image.png

Image

首先我们可以看到 server_args.py 中的改动,Expert Parallel 接管了 Tensor Parallel 的位置,以 Deepseek V3 为例子,有 256 个 Expert,现在打开 Expert Parallel 并且把expert_parallel_size设置为 8 的话,那么每张卡上分得完整的 32 个 Expert。另外可以看到在初始化参数的时候,如果开启了 Expert Paralle 会先把expert_parallel_size设置为 TP 的大小。

Image

接着看一下 Mixtral 模型实现上的修改,值得注意的是在调用 EPMoE 接口的时候没有reduce_results=True,这个参数了,但是在 EPMoE 计算完成之后对结果调用了tensor_model_parallel_all_reduce 。去掉reduce_results=True,参数比较好理解,在 EP 中我们没有对 Expert 的参数做切分,只需要把 token 分到对应的 expert 上,做的矩阵乘都是完整的,所以获得的结果也是完整的。为什么要对结果使用tensor_model_parallel_all_reduce?继续读一下代码寻找答案,我在之后的 0x4 节给出了原因。

上层接口差不多看到这里就可以了,核心实现分成两部分,一部分是 EP MoE Layer,一部分是 EP MoE 的 kernel。需要耐心点看这两个。

0x2. SGLang EP MoE Layer 实现

文件位置:https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/ep_moe/layer.py

0x2.1 GroupedGemmRunner

首先看到一个用于执行 Group GEMM 的工具类,首先简单解析一下这个类,降低后续理解的负担。我先添加一下注释:

#  用于执行分组矩阵乘法的 Runner 类  
class GroupedGemmRunner(torch.nn.Module):  
    # flashinfer 的 gemm 包装器,用于加速计算  
    flashinfer_gemm_warpper = None

def init(self, device, use_flashinfer: bool = False):  
        """  
         初始化 GroupedGemmRunner  
        Args:  
            device:  运行设备  
            use_flashinfer:  是否使用 flashinfer 加速  
        """  
        super().init()  
        self.device = device  
        self.use_flashinfer = use_flashinfer  
        if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:  
            GroupedGemmRunner._init_flashinfer_wrapper(device)

@classmethod  
    def _init_flashinfer_wrapper(cls, device):  
        """  
         初始化 flashinfer 的 gemm 包装器  
        Args:  
            device:  运行设备  
        """  
        from flashinfer import SegmentGEMMWrapper

#  创建工作空间缓冲区  
        workspace_buffer = torch.empty(  
            128 * 1024 * 1024, dtype=torch.int8, device=device  
        )  
        cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)

# c = a * b  
    def forward(  
        self,  
        a: torch.Tensor,  #  输入矩阵 a  
        b: torch.Tensor,  #  输入矩阵 b  
        c: torch.Tensor,  #  输出矩阵 c  
        batch_size: int,  # batch 大小  
        weight_column_major: bool,  #  权重是否为列主序  
        seg_indptr: Optional[torch.Tensor] = None,  #  分段指针  
        weight_indices: Optional[torch.Tensor] = None,  #  权重索引  
        use_fp8_w8a8: bool = False,  #  是否使用 fp8 量化  
        scale_a: torch.Tensor = None,  # a 的缩放因子  
        scale_b: torch.Tensor = None,  # b 的缩放因子  
    ):  
        """执行分组矩阵乘法"""  
        if self.use_flashinfer:  
            # TODO: flashinfer  
            assert False  
            assert GroupedGemmRunner.flashinfer_gemm_warpper is not None  
            c = GroupedGemmRunner.flashinfer_gemm_warpper.run(  
                x=a,  
                weights=b,  
                batch_size=batch_size,  
                weight_column_major=weight_column_major,  
                seg_indptr=seg_indptr,  
                weight_indices=weight_indices,  
            )  
        else:  
            #  使用 triton 实现的分组矩阵乘法  
            assert weight_column_major == True  
            c = grouped_gemm_triton(  
                a,  
                b,  
                c,  
                batch_size,  
                weight_column_major,  
                seg_indptr,  
                weight_indices,  
                use_fp8_w8a8,  
                scale_a,  
                scale_b,  
            )  
        return c  

总的来说,这个类把两种做 Group GEMM 的方法抽象了一下,我们可以选择使用 CUDA 实现的 FlashInfer,也可以选择 Triton 的实现。

0x2.2 EPMoE 类

这个类是连接上层的模型实现和底层的 EPMoE Kernel 的关键组件,我们需要先理解一下这个类的实现。

EPMoE 类的定义
class EPMoE(torch.nn.Module):  
    """  
    MoE 专家并行实现

  
    Args:  
        num_experts:  专家总数  
        top_k:  每个 token 选择的专家数量  
        hidden_size:  隐藏层大小  
        intermediate_size:  中间层大小  
        params_dtype:  参数数据类型,默认为 None 使用系统默认类型  
        renormalize:  是否重新归一化,默认 True  
        use_grouped_topk:  是否使用分组 topk,默认 False  
        num_expert_group:  专家组数量,仅在 use_grouped_topk=True 时使用  
        topk_group:  每组选择的专家数量,仅在 use_grouped_topk=True 时使用  
        quant_config:  量化配置,默认 None  
        tp_size:  张量并行大小,默认 None  
        prefix:  前缀,默认空字符串  
        correction_bias:  修正偏置,默认 None  
    """

def init(  
        self,  
        num_experts: int,  
        top_k: int,  
        hidden_size: int,  
        intermediate_size: int,  
        params_dtype: Optional[torch.dtype] = None,  
        renormalize: bool = True,  
        use_grouped_topk: bool = False,  
        num_expert_group: Optional[int] = None,  
        topk_group: Optional[int] = None,  
        quant_config: Optional[QuantizationConfig] = None,  
        tp_size: Optional[int] = None,  
        prefix: str = "",  
        correction_bias: Optional[torch.Tensor] = None,  
    ):  
        super().init()

#  如果未指定参数类型,使用系统默认类型  
        if params_dtype is None:  
            params_dtype = torch.get_default_dtype()

#  设置张量并行相关参数  
        self.tp_size = (  
            tp_size if tp_size is not None else get_tensor_model_parallel_world_size()  
        )  
        self.tp_rank = get_tensor_model_parallel_rank()

#  设置专家相关参数  
        self.num_experts = num_experts  
        assert self.num_experts % self.tp_size == 0  #  确保专家数可以被 tp_size 整除  
        self.num_experts_per_partition = self.num_experts // self.tp_size  #  每个分区的专家数  
        self.start_expert_id = self.tp_rank * self.num_experts_per_partition  #  当前分区起始专家 ID  
        self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1  #  当前分区结束专家 ID

#  设置其他参数  
        self.top_k = top_k  
        self.intermediate_size = intermediate_size  
        self.renormalize = renormalize  
        self.use_grouped_topk = use_grouped_topk  
        if self.use_grouped_topk:  
            assert num_expert_group is not None and topk_group is not None  
        self.num_expert_group = num_expert_group  
        self.topk_group = topk_group  
        self.correction_bias = correction_bias

#  设置量化方法  
        if quant_config is None:  
            self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()  
            self.use_fp8_w8a8 = False  
            self.activation_scheme = None  
        else:  
            self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(  
                quant_config  
            )  
            self.use_fp8_w8a8 = True  
            self.fp8_dtype = torch.float8_e4m3fn  
            self.activation_scheme = quant_config.activation_scheme

#  创建权重  
        self.quant_method.create_weights(  
            layer=self,  
            num_experts_per_partition=self.num_experts_per_partition,  
            hidden_size=hidden_size,  
            intermediate_size=self.intermediate_size,  
            params_dtype=params_dtype,  
            weight_loader=self.weight_loader,  
        )

#  初始化分组矩阵乘法运行器  
        self.grouped_gemm_runner = None  

这个类定义中我们可以看到它主要是做一些准备工作,同时 EPMoE 复用了 Tensor Parallel 的进程组,所以也是直接在 Tensor Parallel 进程组上获取当前 Rank 需要处理的是哪些 Expert ID。

EPMoE 类的 Forward

简单添加几行注释:

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):  
        """前向传播函数  
        Args:  
            hidden_states:  输入的隐藏状态张量  
            router_logits:  路由器输出的 logits 张量  
        Returns:  
            output:  经过 MoE 层处理后的输出张量  
        """  
        assert self.quant_method is not None

#  初始化分组矩阵乘法运行器  
        if self.grouped_gemm_runner is None:  
            self.grouped_gemm_runner = GroupedGemmRunner(  
                hidden_states.device, use_flashinfer=False  # TODO: use flashinfer  
            )

#  选择专家,获取 topk 权重和 ID  
        topk_weights, topk_ids = select_experts(  
            hidden_states=hidden_states,  
            router_logits=router_logits,  
            top_k=self.top_k,  
            use_grouped_topk=self.use_grouped_topk,  
            renormalize=self.renormalize,  
            topk_group=self.topk_group,  
            num_expert_group=self.num_expert_group,  
            correction_bias=self.correction_bias,  
        )

#  预处理 topk ID,获取重排序信息  
        reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(  
            topk_ids, self.num_experts  
        )

#  初始化门控输入张量  
        gateup_input = torch.empty(  
            (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),  
            device=hidden_states.device,  
            dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,  
        )

  
        #  动态量化时计算输入缩放因子  
        if self.activation_scheme == "dynamic":  
            max_value = (  
                torch.max(hidden_states)  
                .repeat(self.num_experts_per_partition)  
                .to(torch.float32)  
            )  
            self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max

#  预重排序,重新排列输入数据  
        pre_reorder_triton_kernel[(hidden_states.shape[0],)](  
            hidden_states,  
            gateup_input,  
            src2dst,  
            topk_ids,  
            self.w13_input_scale,  
            self.start_expert_id,  
            self.end_expert_id,  
            self.top_k,  
            hidden_states.shape[1],  
            BLOCK_SIZE=512,  
        )

#  获取当前 rank 的分段指针和权重索引  
        seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]  
        weight_indices_cur_rank = torch.arange(  
            0,  
            self.num_experts_per_partition,  
            device=hidden_states.device,  
            dtype=torch.int64,  
        )

  
        #  第一次分组矩阵乘法  
        gateup_output = torch.empty(  
            gateup_input.shape[0],  
            self.w13_weight.shape[1],  
            device=hidden_states.device,  
            dtype=hidden_states.dtype,  
        )  
        gateup_output = self.grouped_gemm_runner(  
            a=gateup_input,  
            b=self.w13_weight,  
            c=gateup_output,  
            batch_size=self.num_experts_per_partition,  
            weight_column_major=True,  
            seg_indptr=seg_indptr_cur_rank,  
            weight_indices=weight_indices_cur_rank,  
            use_fp8_w8a8=self.use_fp8_w8a8,  
            scale_a=self.w13_input_scale,  
            scale_b=self.w13_weight_scale,  
        )

#  激活函数处理  
        down_input = torch.empty(  
            gateup_output.shape[0],  
            gateup_output.shape[1] // 2,  
            device=gateup_output.device,  
            dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,  
        )  
        if self.w2_input_scale is None:  
            self.w2_input_scale = torch.ones(  
                self.num_experts_per_partition,  
                dtype=torch.float32,  
                device=hidden_states.device,  
            )  
        silu_and_mul_triton_kernel[(gateup_output.shape[0],)](  
            gateup_output,  
            down_input,  
            gateup_output.shape[1],  
            reorder_topk_ids,  
            self.w2_input_scale,  
            self.start_expert_id,  
            self.end_expert_id,  
            BLOCK_SIZE=512,  
        )

#  第二次分组矩阵乘法  
        down_output = torch.empty(  
            down_input.shape[0],  
            self.w2_weight.shape[1],  
            device=hidden_states.device,  
            dtype=hidden_states.dtype,  
        )  
        down_output = self.grouped_gemm_runner(  
            a=down_input,  
            b=self.w2_weight,  
            c=down_output,  
            batch_size=self.num_experts_per_partition,  
            weight_column_major=True,  
            seg_indptr=seg_indptr_cur_rank,  
            weight_indices=weight_indices_cur_rank,  
            use_fp8_w8a8=self.use_fp8_w8a8,  
            scale_a=self.w2_input_scale,  
            scale_b=self.w2_weight_scale,  
        )

#  后重排序,生成最终输出  
        output = torch.empty_like(hidden_states)  
        post_reorder_triton_kernel[(hidden_states.size(0),)](  
            down_output,  
            output,  
            src2dst,  
            topk_ids,  
            topk_weights,  
            self.start_expert_id,  
            self.end_expert_id,  
            self.top_k,  
            hidden_states.size(1),  
            BLOCK_SIZE=512,  
        )  
        return output  

这个 forward 函数的流程还是比较清晰的:

  • 首先根据 router_logits 选择每个 token 要使用的 top-k 个专家及其权重
  • 对输入数据进行预处理和重排序,将相同专家的数据分组在一起以便后续批量计算
  • 执行第一次分组矩阵乘法(grouped gemm),将输入与 gate 和 up 投影权重(w13_weight)相乘
  • 对第一次矩阵乘法的结果应用 SiLU 激活函数并进行处理
  • 执行第二次分组矩阵乘法,将激活后的结果与 down 投影权重(w2_weight)相乘
  • 最后进行后重排序,将各个专家的输出按原始 token 顺序重组,并根据专家权重进行加权组合得到最终输出

这个过程基本上和 EP MoE 训练时的步骤一致,其中第二步和最后一步就对应了 EP 中的两次 All2All。

权重加载逻辑

笔者注:对于本篇文章的主题来说,可以不用在意这几个工具函数。

EPMoE 类中还有 3 个和权重加载相关的函数,这里也顺便添加了注释。

    @classmethod  
    def make_expert_params_mapping(  
        cls,  
        ckpt_gate_proj_name: str,  
        ckpt_down_proj_name: str,  
        ckpt_up_proj_name: str,  
        num_experts: int,  
    ) -> List[Tuple[str, str, int, str]]:  
        """生成专家参数映射关系

  
        Args:  
            ckpt_gate_proj_name:  检查点中 gate 投影层的名称  
            ckpt_down_proj_name:  检查点中 down 投影层的名称    
            ckpt_up_proj_name:  检查点中 up 投影层的名称  
            num_experts:  专家总数

  
        Returns:  
            List[Tuple[str, str, int, str]]:  返回参数映射列表,每个元素为元组:  
                - param*name:  参数名称前缀(w13 或 w2)  
                - weight_name:  权重完整名称  
                - expert_id:  专家 ID  
                - shard_id:  分片 ID(w1/w2/w3)  
        """  
        return [  
            # (param_name, weight_name, expert_id, shard_id)  
            (  
                (  
                    "experts.w13*"  
                    if weight*name in [ckpt_gate_proj_name, ckpt_up_proj_name]  
                    else "experts.w2*"  
                ),  
                f"experts.{expert_id}.{weight_name}.",  
                expert_id,  
                shard_id,  
            )  
            for expert_id in range(num_experts)  
            for shard_id, weight_name in [
                ("w1", ckpt_gate_proj_name),
                ("w2", ckpt_down_proj_name),
                ("w3", ckpt_up_proj_name),
            ]  
        ]

def weight_loader(  
        self,  
        param: torch.nn.Parameter,  
        loaded_weight: torch.Tensor,  
        weight_name: str,  
        shard_id: str,  
        expert_id: int,  
    ) -> None:  
        """加载权重参数

  
        Args:  
            param:  目标参数  
            loaded_weight:  加载的权重张量  
            weight_name:  权重名称  
            shard_id:  分片 ID(w1/w2/w3)  
            expert_id:  专家 ID

  
        Raises:  
            ValueError:  当 shard_id 不合法时抛出异常  
        """  
        if expert_id < self.start_expert_id or expert_id > self.end_expert_id:  
            return  
        expert_id = expert_id - self.start_expert_id

if shard_id not in ("w1", "w2", "w3"):  
            raise ValueError(  
                f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."  
            )

#  处理 FP8 缩放因子的特殊情况  
        if "scale" in weight_name:  
            self._load_fp8_scale(  
                param.data, loaded_weight, weight_name, shard_id, expert_id  
            )  
            return

expert_data = param.data[expert_id]  
        if shard_id == "w2":  
            param.data[expert_id] = loaded_weight  
        elif shard_id == "w1":  
            param.data[expert_id][: self.intermediate_size, :] = loaded_weight  
        elif shard_id == "w3":  
            param.data[expert_id][self.intermediate_size :, :] = loaded_weight  
        else:  
            raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")

def _load_fp8_scale(  
        self,  
        param: torch.nn.Parameter,  
        loaded_weight: torch.Tensor,  
        weight_name: str,  
        shard_id: str,  
        expert_id: int,  
    ) -> None:  
        """加载 FP8 量化的缩放因子

  
        Args:  
            param:  目标参数  
            loaded_weight:  加载的权重张量  
            weight_name:  权重名称  
            shard_id:  分片 ID(w1/w2/w3)  
            expert_id:  专家 ID

  
        Raises:  
            ValueError:  当输入缩放因子不相等时抛出异常  
        """  
        param_data = param.data

#  输入缩放因子可以直接加载,且必须相等  
        if "input_scale" in weight_name:  
            if (  
                param_data[expert_id] != 1  
                and (param_data[expert_id] - loaded_weight).abs() > 1e-5  
            ):  
                raise ValueError(  
                    "input_scales of w1 and w3 of a layer "  
                    f"must be equal. But got {param_data[expert_id]} "  
                    f"vs. {loaded_weight}"  
                )  
            param_data[expert_id] = loaded_weight  
        #  权重缩放因子  
        elif "weight_scale" in weight_name:  
            #  合并列的情况(gate_up_proj)  
            if shard_id in ("w1", "w3"):  
                #  需要保留 w1 和 w3 的权重缩放因子,因为加载权重后需要重新量化  
                idx = 0 if shard_id == "w1" else 1  
                param_data[expert_id][idx] = loaded_weight  
            #  行并行的情况(down_proj)  
            else:  
                param_data[expert_id] = loaded_weight  

这几个权重加载相关的工具函数会在模型实现中的load_weights方法中被调用,本文就不继续关注这部分了,感兴趣的读者可以查看一下 VLLM 和 SGLang 是如何优雅的做模型权重加载工作的。

解析到这里就可以了,我们把握住 EPMoE 类的 forward 的整体逻辑就行。

0x3. SGLang EP MoE Kernel 实现

代码位置:https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/ep_moe/kernels.py

再复述一次 EPMoE Layer 实现中的 forward 的主要流程,和本节要解析的 kernel 可以对应起来。EPMoE Layer 的 forward 主要流程为:

  • 首先根据 router_logits 选择每个 token 要使用的 top-k 个专家及其权重
  • 对输入数据进行预处理和重排序,将相同专家的数据分组在一起以便后续批量计算
  • 执行第一次分组矩阵乘法(grouped gemm),将输入与 gate 和 up 投影权重(w13_weight)相乘
  • 对第一次矩阵乘法的结果应用 SiLU 激活函数并进行处理
  • 执行第二次分组矩阵乘法,将激活后的结果与 down 投影权重(w2_weight)相乘
  • 最后进行后重排序,将各个专家的输出按原始 token 顺序重组,并根据专家权重进行加权组合得到最终输出

Token 按照 Expert 重排 index 信息预处理

在 forward 函数中获得了 topk_ids 之后首先进行了预处理 topk ID,获取重排序信息:

reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, self.num_experts)  

对应的 Triton Kernel 添加注释:

@triton.jit  
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):  
    """计算每个专家对应 token 段的起始位置

  
    Args:  
        reorder_topk_ids:  排序后的专家 ID  
        seg_indptr:  分段指针数组  
        num_toks: token 总数  
    """  
    #  获取当前专家 ID  
    expert = tl.program_id(0)

  
    #  二分查找当前专家对应的 token 段位置  
    low = 0  
    high = num_toks - 1  
    target_location = -1  
    while low <= high:  
        mid = (low + high) // 2

#  如果中间位置的专家 ID 大于当前专家 ID,在左半部分继续查找  
        if tl.load(reorder_topk_ids + mid) > expert:  
            high = mid - 1  
        #  否则在右半部分继续查找,并更新目标位置  
        else:  
            low = mid + 1  
            target_location = mid

  
    #  存储当前专家对应 token 段的结束位置  
    tl.store(seg_indptr + expert + 1, target_location + 1)

@triton.jit  
def compute_src2dst_triton_kernel(  
    reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr  
):  
    """计算源索引到目标索引的映射

  
    Args:  
        reorder_ids:  重排序后的索引  
        src2dst:  源索引到目标索引的映射数组  
        num_toks: token 总数  
        BLOCK_SIZE:  每个线程块处理的 token 数量  
    """  
    #  获取当前程序块 ID  
    pid = tl.program_id(axis=0)

  
    #  计算当前块内的目标索引  
    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

  
    #  生成有效 token 的掩码  
    mask = dst_id < num_toks

  
    #  加载源索引  
    src_id = tl.load(reorder_ids + dst_id, mask=mask)

  
    #  存储源索引到目标索引的映射  
    tl.store(src2dst + src_id, dst_id, mask=mask)

def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):  
    """预处理 MoE 专家并行的 topk ID,生成重排序信息

  
    Args:  
        topk_ids:  每个 token 选择的专家 ID 张量  
        num_experts:  专家总数

  
    Returns:  
        reorder_topk_ids:  排序后的专家 ID  
        src2dst:  源索引到目标索引的映射  
        seg_indptr:  每个专家对应的 token 段的起始位置  
    """  
    #  对专家 ID 进行稳定排序  
    reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)

  
    #  初始化分段指针和源目标映射数组  
    seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)  
    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)

#  计算每个专家对应 token 段的起始位置  
    compute_seg_indptr_triton_kernel[(num_experts,)](  
        reorder_topk_ids, seg_indptr, topk_ids.numel()  
    )

#  计算源索引到目标索引的映射  
    BLOCK_SIZE = 512  
    grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)  
    compute_src2dst_triton_kernel[grid](  
        reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE  
    )  
    return reorder_topk_ids, src2dst, seg_indptr  

这段代码实际上还是比较好理解的,我这里举个例子来说明一下。

假设有 10 个 token,4 个专家(expert_id: 0,1,2,3),每个 token 选择的专家分配如下:

# 原始的token到专家的分配 (topk_ids)  
token_idx:     [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  
expert_ids:    [1, 3, 2, 1, 0, 2, 3, 1, 2, 0]  

上面代码处理后会得到:

  1. 排序后的专家 ID (reorder_topk_ids):
[0, 0, 1, 1, 1, 2, 2, 2, 3, 3]  
  1. 每个专家负责的 token 段位置 (seg_indptr):
expert_id:     [0,    1,    2,    3,    4]  
seg_indptr:    [0,    2,    5,    8,    10]  
# 含义:  
# - expert 0 处理索引 0-1 的token  
# - expert 1 处理索引 2-4 的token  
# - expert 2 处理索引 5-7 的token  
# - expert 3 处理索引 8-9 的token  
  1. 原始位置到重排序后位置的映射 (src2dst):
原始位置:      [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  
重排序后位置:   [4, 9, 2, 3, 7, 5, 8, 6, 0, 1]  

这样重排序后,相同专家要处理的 token 就被组织在了一起

重排序后位置:   [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  
专家ID:        [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]  

执行真正的 Token 按照 Expert 重排(等价于第一次 All2All)

对应 EPMoE Layer forward 的下面一行代码:

pre_reorder_triton_kernel[(hidden_states.shape[0],)](  
            hidden_states,  
            gateup_input,  
            src2dst,  
            topk_ids,  
            self.w13_input_scale,  
            self.start_expert_id,  
            self.end_expert_id,  
            self.top_k,  
            hidden_states.shape[1],  
            BLOCK_SIZE=512,  
        )  

我们看一下 Triton 的实现:

@triton.jit  
def pre_reorder_triton_kernel(  
    input_ptr,          #  输入张量指针  
    gateup_input_ptr,   #  门控输入张量指针  
    src2dst_ptr,        #  源到目标索引映射指针  
    topk_ids_ptr,       # topk 专家 ID 指针  
    a1_scales_ptr,      #  输入缩放因子指针  
    start_expert_id,    #  当前 rank 起始专家 ID  
    end_expert_id,      #  当前 rank 结束专家 ID  
    topk,               #  每个 token 选择的专家数量  
    hidden_size,        #  隐藏层大小  
    BLOCK_SIZE: tl.constexpr,  #  计算块大小  
):  
    """预重排序 kernel,将输入数据重新排列并应用缩放

  
     该 kernel 将输入数据按照专家分配重新排列,并对分配到当前 rank 的专家数据进行缩放处理。  
     对于每个输入 token,遍历其选择的 topk 个专家,如果专家属于当前 rank,则将该 token 的数据  
     拷贝到对应位置并应用缩放因子。  
    """  
    #  获取输出数据类型  
    OutDtype = gateup_input_ptr.dtype.element_ty

#  获取当前处理的输入 token 索引  
    src_idx = tl.program_id(0)  
    #  计算当前 token 的 src2dst 和 topk_ids 指针位置  
    src2dst_ptr = src2dst_ptr + src_idx * topk  
    topk_ids_ptr = topk_ids_ptr + src_idx * topk

#  计算输入数据指针位置  
    src_ptr = input_ptr + src_idx * hidden_size

  
    #  遍历当前 token 选择的 topk 个专家  
    for idx in range(topk):  
        #  加载专家 ID  
        expert_id = tl.load(topk_ids_ptr + idx)  
        #  检查专家是否属于当前 rank  
        if expert_id >= start_expert_id and expert_id <= end_expert_id:  
            #  计算缩放因子  
            if a1_scales_ptr is not None:  
                scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)  
            else:  
                scale = 1.0

#  获取目标位置索引和指针  
            dst_idx = tl.load(src2dst_ptr + idx)  
            dst_ptr = gateup_input_ptr + dst_idx * hidden_size

  
            #  按块处理 hidden_size 维度的数据  
            for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):  
                offset = start_offset + tl.arange(0, BLOCK_SIZE)  
                mask = offset < hidden_size  
                #  加载输入数据并转换为 float32  
                in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)  
                #  应用缩放并转换为输出类型  
                out_data = (in_data * scale).to(OutDtype)  
                #  存储到目标位置  
                tl.store(dst_ptr + offset, out_data, mask=mask)  

这个 kernel 就是根据我们上一步获得的重排信息来执行真正的重排。

Group GEMM 和激活函数

接下来就是执行 gateup 和 down 的 Group GEMM 以及夹在它们中间的 silu_and_mul 激活操作。在 EPMoE Forward 对应的函数为:

#  获取当前 rank 的分段指针和权重索引  
        seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]  
        weight_indices_cur_rank = torch.arange(  
            0,  
            self.num_experts_per_partition,  
            device=hidden_states.device,  
            dtype=torch.int64,  
        )

  
        #  第一次分组矩阵乘法  
        gateup_output = torch.empty(  
            gateup_input.shape[0],  
            self.w13_weight.shape[1],  
            device=hidden_states.device,  
            dtype=hidden_states.dtype,  
        )  
        gateup_output = self.grouped_gemm_runner(  
            a=gateup_input,  
            b=self.w13_weight,  
            c=gateup_output,  
            batch_size=self.num_experts_per_partition,  
            weight_column_major=True,  
            seg_indptr=seg_indptr_cur_rank,  
            weight_indices=weight_indices_cur_rank,  
            use_fp8_w8a8=self.use_fp8_w8a8,  
            scale_a=self.w13_input_scale,  
            scale_b=self.w13_weight_scale,  
        )

#  激活函数处理  
        down_input = torch.empty(  
            gateup_output.shape[0],  
            gateup_output.shape[1] // 2,  
            device=gateup_output.device,  
            dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,  
        )  
        if self.w2_input_scale is None:  
            self.w2_input_scale = torch.ones(  
                self.num_experts_per_partition,  
                dtype=torch.float32,  
                device=hidden_states.device,  
            )  
        silu_and_mul_triton_kernel[(gateup_output.shape[0],)](  
            gateup_output,  
            down_input,  
            gateup_output.shape[1],  
            reorder_topk_ids,  
            self.w2_input_scale,  
            self.start_expert_id,  
            self.end_expert_id,  
            BLOCK_SIZE=512,  
        )

#  第二次分组矩阵乘法  
        down_output = torch.empty(  
            down_input.shape[0],  
            self.w2_weight.shape[1],  
            device=hidden_states.device,  
            dtype=hidden_states.dtype,  
        )  
        down_output = self.grouped_gemm_runner(  
            a=down_input,  
            b=self.w2_weight,  
            c=down_output,  
            batch_size=self.num_experts_per_partition,  
            weight_column_major=True,  
            seg_indptr=seg_indptr_cur_rank,  
            weight_indices=weight_indices_cur_rank,  
            use_fp8_w8a8=self.use_fp8_w8a8,  
            scale_a=self.w2_input_scale,  
            scale_b=self.w2_weight_scale,  
        )

Group GEMM 和激活函数都是比较常规的,这里就不解析这两个比较长的 Triton 实现了,使用 Triton 来实现这两个操作也是相当不高效的。

后重排序(等价于第二次 All2All),生成最终输出

对应了 EPMoE 的最后 2 行代码:

output = torch.empty_like(hidden_states)  
post_reorder_triton_kernel[(hidden_states.size(0),)](  
    down_output,  
    output,  
    src2dst,  
    topk_ids,  
    topk_weights,  
    self.start_expert_id,  
    self.end_expert_id,  
    self.top_k,  
    hidden_states.size(1),  
    BLOCK_SIZE=512,  
)  

Triton Kernel 代码如下:

@triton.jit  
def post_reorder_triton_kernel(  
    down_output_ptr  #  存储专家处理后的输出  
    output_ptr       #  最终输出结果的存储位置  
    src2dst_ptr      #  重排序映射关系  
    topk_ids_ptr     #  每个 token 对应的专家 ID  
    topk_weights_ptr #  每个 token 对应的专家权重  
    start_expert_id,    #  起始专家 ID  
    end_expert_id,      #  结束专家 ID  
    topk,               # topk 值  
    hidden_size,        #  隐藏层大小  
    BLOCK_SIZE: tl.constexpr,  #  块大小常量  
):  
    """后重排序 triton 核函数

  
     该函数将专家输出重新排序并加权求和,生成最终输出。  
     主要步骤:  
    1.  获取输入数据类型和程序 ID  
    2.  计算各指针偏移量  
    3.  对每个 block:  
       -  创建零向量用于累加  
       -  对每个 topk 专家:  
         *  如果专家 ID 在范围内,加载并累加其输出  
    4.  如果没有计算过的专家,输出全零向量  
    """  
    #  获取输入数据类型  
    InDtype = down_output_ptr.dtype.element_ty

#  获取当前程序 ID 作为源索引  
    src_idx = tl.program_id(0)  
    #  计算各指针的实际位置  
    src2dst_ptr = src2dst_ptr + src_idx * topk  
    topk_ids_ptr = topk_ids_ptr + src_idx * topk  
    topk_weights_ptr = topk_weights_ptr + src_idx * topk

#  标记是否有专家参与计算  
    computed = False  
    #  计算存储位置  
    store_ptr = output_ptr + src_idx * hidden_size

  
    #  按 block 大小遍历 hidden_size  
    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):  
        offset = start_offset + tl.arange(0, BLOCK_SIZE)  
        mask = offset < hidden_size

#  创建零向量用于累加  
        sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)  
        #  遍历 topk 个专家  
        for idx in range(topk):  
            expert_id = tl.load(topk_ids_ptr + idx)  
            #  检查专家 ID 是否在有效范围内  
            if expert_id >= start_expert_id and expert_id <= end_expert_id:  
                computed = True  
                #  加载目标索引和权重  
                dst_idx = tl.load(src2dst_ptr + idx)  
                weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)  
                #  计算加载位置并加载数据  
                load_ptr = down_output_ptr + dst_idx * hidden_size  
                in_data = tl.load(load_ptr + offset, mask=mask)  
                #  加权累加  
                sum_vec += in_data * weigh_scale  
        #  存储累加结果  
        tl.store(store_ptr + offset, sum_vec, mask=mask)

#  如果没有专家参与计算,输出全零  
    if computed == False:  
        for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):  
            offset = start_offset + tl.arange(0, BLOCK_SIZE)  
            mask = offset < hidden_size  
            tl.store(  
                store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask  
            )

只看这个代码可能还是有点抽象,我继续用上面的 Token 按照 Expert 重排 index 的例子来说明下,我新增一组 expert_weights:

注意,每个 token 有 topk 个权重,也就是 select_experts 输出的 topk_weights。
#  原始的 token 分配和权重  
token_idx:        [0,  1,  2,  3,  4,  5,  6,  7,  8,  9]  
expert_ids:       [1,  3,  2,  1,  0,  2,  3,  1,  2,  0]  
expert_weights:   [0.6,0.8,0.7,0.5,0.9,0.6,0.7,0.4,0.8,0.5]

#  重排序后的顺序(之前例子的结果)  
重排序位置:        [0,  1,  2,  3,  4,  5,  6,  7,  8,  9]  
专家 ID:           [0,  0,  1,  1,  1,  2,  2,  2,  3,  3]  

现在,post_reorder_triton_kernel kernel 的工作流程是:

  1. 对每个原始 token 位置(通过 src_idx = tl.program_id(0)获取):
# 比如处理原始token_idx=0的数据时:  
 expert_id = 1  
 weight = 0.6  
 # 需要从重排序后的位置2,3,4中找到对应的输出结果  
 也就是下面这行代码:  
src2dst_ptr = src2dst_ptr + src_idx * topk  
  1. 处理 hidden_size 维度的数据:

假设 hidden_size=1024,BLOCK_SIZE=256,代码会将 1024 维的数据分成 4 个块来处理,每个块创建一个零向量用于累加结果

  1. 对每个 token 的专家输出进行加权组合:
# 以token_idx=0为例:  
   sum_vec = 0  # 初始化累加向量  
   expert_output = load_expert_output(expert_id=1)  # 加载专家1的输出  
   sum_vec += expert_output * 0.6  # 应用权重0.6  
  1. 如果当前 token 有专家处理(computed=True),存储加权后的结果,否则存储全零向量。

通过这个后重排序,我们就可以支持一个 token 被多个专家并行处理以及使用 topk weights 来控制不同专家的贡献程度。

0x4. SGLang EPMoE 和 MoE EP 训练流程的区别

回收开头,EPMoE Layer forward 的最后为什么要对结果使用tensor_model_parallel_all_reduce

实际上从上面的 EPMoE Forward 的流程来看,我么们发现它是直接实现了几个 Triton Kernel 来等价原始的 Expert Parallel 中的 2 次 All2All,并没有像训练那样调用通信源语来做 All2All。然后从上面的post_reorder_triton_kernel中对每个 token 的累加过程来看,如果某个 Rank 上的当前 token 没有被这个 Rank 持有的 Expert 处理的话,它的输出会设置为 0,但是如果在另外一个 EP Rank 上对当前这个 token 是会被它持有的 Expert 处理的话,我们最终就需要做一次 allreduce 把所有 rank 上的结果加起来。在推理的时候 All2All 几乎没有重叠机会,而 All2All 的速度是比较慢的,通过这里的对 All2All 流程的优化其实也可以降低通信的成本。

0x5. 总结

SGLang EPMoE 目前这个实现整体上比较清晰,但笔者目前没有详细实测过这个 Feature,所以不确定它和普通的 TP 的性能谁更好,此外这个 EPMoE 计算流程中最耗时的 Group GEMM 也暂时没有使用 FalshInfer 的优化版本,Triton 的实现应该会比较慢。

END

作者:BBuf
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式 AI 专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18849
内容数
1389
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息