HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

MoE 架构与训练

概述

Mixture of Experts (MoE) 是一种稀疏激活的模型架构,通过路由机制选择性激活部分专家网络,在保持计算效率的同时大幅提升模型容量。本章深入讲解 MoE 的架构设计、训练策略和工程实现。

MoE 架构原理

基本结构

┌─────────────────────────────────────────────────────────────────────────────┐
│                        Mixture of Experts (MoE) Layer                       │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│   Input Token                                                               │
│       │                                                                     │
│       ▼                                                                     │
│   ┌──────────────────────────────────────────────────────────────────────┐ │
│   │                           Router Network                              │ │
│   │                                                                       │ │
│   │   Input ──▶ Linear ──▶ Softmax ──▶ Top-K Selection ──▶ Weights       │ │
│   │                                                                       │ │
│   │   Output: (expert_indices, routing_weights)                           │ │
│   └───────────────────────────────────┬──────────────────────────────────┘ │
│                                       │                                     │
│           ┌───────────────────────────┼───────────────────────────┐        │
│           │                           │                           │        │
│           ▼                           ▼                           ▼        │
│   ┌───────────────┐           ┌───────────────┐           ┌───────────────┐│
│   │   Expert 1    │           │   Expert 2    │    ...    │   Expert N    ││
│   │               │           │               │           │               ││
│   │  FFN Layer    │           │  FFN Layer    │           │  FFN Layer    ││
│   │  (Hidden×4)   │           │  (Hidden×4)   │           │  (Hidden×4)   ││
│   └───────┬───────┘           └───────┬───────┘           └───────┬───────┘│
│           │                           │                           │        │
│           └───────────────────────────┼───────────────────────────┘        │
│                                       │                                     │
│                                       ▼                                     │
│                           ┌───────────────────┐                             │
│                           │   Weighted Sum    │                             │
│                           │   Σ(weight_i ×    │                             │
│                           │    expert_i)      │                             │
│                           └─────────┬─────────┘                             │
│                                     │                                       │
│                                     ▼                                       │
│                               Output Token                                  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

MoE vs Dense Model Comparison:
┌─────────────────────────────────────────────────────────────────────────────┐
│                                                                             │
│  Dense Model (7B):                                                          │
│  ┌─────────────────────────────────────────────────────────────┐           │
│  │  FFN: 4096 → 16384 → 4096                                   │           │
│  │  每个 token 激活全部参数                                      │           │
│  │  参数量: 7B, 计算量: 7B × tokens                             │           │
│  └─────────────────────────────────────────────────────────────┘           │
│                                                                             │
│  MoE Model (8×7B with top-2):                                               │
│  ┌─────────────────────────────────────────────────────────────┐           │
│  │  8 个专家,每个 7B 参数                                       │           │
│  │  每个 token 只激活 2 个专家                                   │           │
│  │  总参数量: 56B, 计算量: 14B × tokens (相当于 2 个专家)         │           │
│  │  → 参数增加 8x,计算只增加 2x                                 │           │
│  └─────────────────────────────────────────────────────────────┘           │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

路由机制

"""
MoE 路由机制实现
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from dataclasses import dataclass

@dataclass
class MoEConfig:
    """MoE 配置"""
    hidden_size: int = 4096
    intermediate_size: int = 14336
    num_experts: int = 8
    num_experts_per_tok: int = 2
    router_jitter_noise: float = 0.0
    router_aux_loss_coef: float = 0.01
    router_z_loss_coef: float = 0.001


class TopKRouter(nn.Module):
    """Top-K 路由器"""

    def __init__(self, config: MoEConfig):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok

        # 路由网络
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_states: [batch_size, seq_len, hidden_size]

        Returns:
            routing_weights: [batch_size, seq_len, num_experts_per_tok]
            selected_experts: [batch_size, seq_len, num_experts_per_tok]
            router_logits: [batch_size, seq_len, num_experts]
        """
        batch_size, seq_len, hidden_size = hidden_states.shape

        # 计算路由 logits
        router_logits = self.gate(hidden_states)  # [B, S, E]

        # 添加训练时的噪声
        if self.training and self.config.router_jitter_noise > 0:
            router_logits = router_logits + torch.randn_like(router_logits) * self.config.router_jitter_noise

        # Top-K 选择
        routing_weights, selected_experts = torch.topk(
            router_logits,
            self.top_k,
            dim=-1
        )

        # Softmax 归一化(只在选中的专家上)
        routing_weights = F.softmax(routing_weights, dim=-1)

        return routing_weights, selected_experts, router_logits


class ExpertChoiceRouter(nn.Module):
    """Expert Choice 路由器 - 每个专家选择 token"""

    def __init__(self, config: MoEConfig, capacity_factor: float = 1.25):
        super().__init__()
        self.config = config
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Expert Choice: 每个专家选择固定数量的 token

        Returns:
            expert_weights: [num_experts, capacity, 1]
            expert_indices: [num_experts, capacity]
            router_logits: [batch_size * seq_len, num_experts]
        """
        batch_size, seq_len, hidden_size = hidden_states.shape
        num_tokens = batch_size * seq_len

        # 计算容量
        capacity = int(num_tokens * self.capacity_factor / self.config.num_experts)

        # 计算路由分数
        hidden_flat = hidden_states.view(-1, hidden_size)
        router_logits = self.gate(hidden_flat)  # [N, E]

        # 转置后让每个专家选择 top-k tokens
        router_logits_t = router_logits.t()  # [E, N]

        expert_weights, expert_indices = torch.topk(
            router_logits_t,
            capacity,
            dim=-1
        )

        # Softmax
        expert_weights = F.softmax(expert_weights, dim=-1, dtype=torch.float32)
        expert_weights = expert_weights.unsqueeze(-1)

        return expert_weights, expert_indices, router_logits


class LoadBalancingLoss(nn.Module):
    """负载均衡损失"""

    def __init__(self, config: MoEConfig):
        super().__init__()
        self.config = config

    def forward(
        self,
        router_logits: torch.Tensor,
        selected_experts: torch.Tensor
    ) -> torch.Tensor:
        """
        计算辅助损失以促进负载均衡

        Args:
            router_logits: [batch_size, seq_len, num_experts]
            selected_experts: [batch_size, seq_len, top_k]
        """
        num_experts = self.config.num_experts

        # 计算每个专家被选中的频率
        expert_mask = F.one_hot(selected_experts, num_experts).float()  # [B, S, K, E]
        expert_mask = expert_mask.sum(dim=2)  # [B, S, E] - 每个位置每个专家被选中次数

        # 计算路由概率
        router_probs = F.softmax(router_logits, dim=-1)  # [B, S, E]

        # 辅助损失: 鼓励均匀分布
        # f_i: 专家 i 被选中的比例
        # P_i: 路由到专家 i 的平均概率
        tokens_per_expert = expert_mask.float().mean(dim=[0, 1])  # [E]
        router_prob_per_expert = router_probs.mean(dim=[0, 1])  # [E]

        aux_loss = torch.sum(tokens_per_expert * router_prob_per_expert) * num_experts

        # Z-loss: 防止 router logits 过大
        z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()

        total_loss = (
            self.config.router_aux_loss_coef * aux_loss +
            self.config.router_z_loss_coef * z_loss
        )

        return total_loss

MoE 层实现

"""
MoE 层完整实现
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple

class Expert(nn.Module):
    """单个专家网络 (FFN)"""

    def __init__(self, config: MoEConfig):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU activation
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class SparseMoELayer(nn.Module):
    """稀疏 MoE 层"""

    def __init__(self, config: MoEConfig):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok

        # 路由器
        self.router = TopKRouter(config)

        # 专家网络
        self.experts = nn.ModuleList([
            Expert(config) for _ in range(config.num_experts)
        ])

        # 负载均衡损失
        self.load_balancing_loss = LoadBalancingLoss(config)

    def forward(
        self,
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            hidden_states: [batch_size, seq_len, hidden_size]

        Returns:
            output: [batch_size, seq_len, hidden_size]
            aux_loss: 辅助损失
        """
        batch_size, seq_len, hidden_size = hidden_states.shape

        # 路由
        routing_weights, selected_experts, router_logits = self.router(hidden_states)

        # 计算辅助损失
        aux_loss = None
        if self.training:
            aux_loss = self.load_balancing_loss(router_logits, selected_experts)

        # 计算专家输出
        final_output = self._compute_expert_outputs(
            hidden_states,
            routing_weights,
            selected_experts
        )

        return final_output, aux_loss

    def _compute_expert_outputs(
        self,
        hidden_states: torch.Tensor,
        routing_weights: torch.Tensor,
        selected_experts: torch.Tensor
    ) -> torch.Tensor:
        """计算专家输出 - 基础实现"""
        batch_size, seq_len, hidden_size = hidden_states.shape
        final_output = torch.zeros_like(hidden_states)

        # 为每个专家收集 token
        for expert_idx in range(self.num_experts):
            # 找到选择了这个专家的位置
            expert_mask = (selected_experts == expert_idx)  # [B, S, K]

            # 遍历 top-k 位置
            for k in range(self.top_k):
                mask_k = expert_mask[:, :, k]  # [B, S]

                if mask_k.any():
                    # 获取对应的 token
                    expert_input = hidden_states[mask_k]  # [N, H]

                    # 计算专家输出
                    expert_output = self.experts[expert_idx](expert_input)

                    # 加权累加
                    weights = routing_weights[:, :, k][mask_k].unsqueeze(-1)  # [N, 1]
                    final_output[mask_k] += weights * expert_output

        return final_output


class OptimizedSparseMoE(nn.Module):
    """优化的稀疏 MoE 实现 - 使用分组计算"""

    def __init__(self, config: MoEConfig):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok

        self.router = TopKRouter(config)
        self.experts = nn.ModuleList([
            Expert(config) for _ in range(config.num_experts)
        ])
        self.load_balancing_loss = LoadBalancingLoss(config)

    def forward(
        self,
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size, seq_len, hidden_size = hidden_states.shape
        num_tokens = batch_size * seq_len

        # Flatten
        hidden_flat = hidden_states.view(-1, hidden_size)  # [N, H]

        # 路由
        routing_weights, selected_experts, router_logits = self.router(hidden_states)
        routing_weights = routing_weights.view(num_tokens, self.top_k)
        selected_experts = selected_experts.view(num_tokens, self.top_k)

        # 辅助损失
        aux_loss = None
        if self.training:
            aux_loss = self.load_balancing_loss(
                router_logits.view(num_tokens, -1),
                selected_experts
            )

        # 按专家分组并批量计算
        final_output = self._batched_expert_forward(
            hidden_flat,
            routing_weights,
            selected_experts
        )

        return final_output.view(batch_size, seq_len, hidden_size), aux_loss

    def _batched_expert_forward(
        self,
        hidden_states: torch.Tensor,
        routing_weights: torch.Tensor,
        selected_experts: torch.Tensor
    ) -> torch.Tensor:
        """批量计算专家输出"""
        num_tokens, hidden_size = hidden_states.shape
        device = hidden_states.device
        dtype = hidden_states.dtype

        final_output = torch.zeros(num_tokens, hidden_size, device=device, dtype=dtype)

        # 创建专家索引
        flat_selected = selected_experts.view(-1)  # [N * K]
        flat_weights = routing_weights.view(-1, 1)  # [N * K, 1]

        # 创建 token 索引(每个 token 重复 K 次)
        token_indices = torch.arange(num_tokens, device=device)
        token_indices = token_indices.unsqueeze(1).expand(-1, self.top_k).reshape(-1)

        # 对每个专家进行批量计算
        for expert_idx in range(self.num_experts):
            expert_mask = (flat_selected == expert_idx)

            if expert_mask.any():
                # 获取分配给这个专家的 token
                expert_token_indices = token_indices[expert_mask]
                expert_weights = flat_weights[expert_mask]

                # 批量计算
                expert_input = hidden_states[expert_token_indices]
                expert_output = self.experts[expert_idx](expert_input)

                # 加权累加到输出
                weighted_output = expert_weights * expert_output
                final_output.index_add_(0, expert_token_indices, weighted_output)

        return final_output

训练策略

分布式训练

"""
MoE 分布式训练
"""
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

class ExpertParallelMoE(nn.Module):
    """专家并行 MoE"""

    def __init__(
        self,
        config: MoEConfig,
        expert_parallel_size: int = 1
    ):
        super().__init__()
        self.config = config
        self.ep_size = expert_parallel_size

        # 确定本地专家
        self.rank = dist.get_rank()
        self.num_local_experts = config.num_experts // expert_parallel_size

        expert_start = self.rank * self.num_local_experts
        expert_end = expert_start + self.num_local_experts

        # 只创建本地专家
        self.local_expert_indices = list(range(expert_start, expert_end))
        self.experts = nn.ModuleList([
            Expert(config) for _ in range(self.num_local_experts)
        ])

        # 全局路由器
        self.router = TopKRouter(config)

        # 创建专家并行通信组
        self._setup_expert_parallel_group()

    def _setup_expert_parallel_group(self):
        """设置专家并行通信组"""
        world_size = dist.get_world_size()

        # 假设所有 GPU 在同一个专家并行组
        ranks = list(range(world_size))
        self.ep_group = dist.new_group(ranks)

    def forward(
        self,
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size, seq_len, hidden_size = hidden_states.shape

        # 路由
        routing_weights, selected_experts, router_logits = self.router(hidden_states)

        # All-to-All: 将 token 发送到对应专家所在的 rank
        dispatched_input, dispatch_info = self._all_to_all_dispatch(
            hidden_states,
            selected_experts,
            routing_weights
        )

        # 本地专家计算
        local_outputs = self._local_expert_compute(dispatched_input)

        # All-to-All: 将结果返回
        combined_output = self._all_to_all_combine(
            local_outputs,
            dispatch_info
        )

        # 辅助损失
        aux_loss = None
        if self.training:
            aux_loss = self.load_balancing_loss(router_logits, selected_experts)

        return combined_output, aux_loss

    def _all_to_all_dispatch(
        self,
        hidden_states: torch.Tensor,
        selected_experts: torch.Tensor,
        routing_weights: torch.Tensor
    ):
        """分发 token 到各个专家所在的 rank"""
        batch_size, seq_len, hidden_size = hidden_states.shape
        num_tokens = batch_size * seq_len

        hidden_flat = hidden_states.view(-1, hidden_size)
        selected_flat = selected_experts.view(-1, self.config.num_experts_per_tok)
        weights_flat = routing_weights.view(-1, self.config.num_experts_per_tok)

        # 计算每个 rank 接收的 token 数
        send_counts = torch.zeros(self.ep_size, dtype=torch.long, device=hidden_states.device)

        for k in range(self.config.num_experts_per_tok):
            for rank in range(self.ep_size):
                local_start = rank * self.num_local_experts
                local_end = local_start + self.num_local_experts
                mask = (selected_flat[:, k] >= local_start) & (selected_flat[:, k] < local_end)
                send_counts[rank] += mask.sum()

        # All-to-All 交换 token
        recv_counts = torch.zeros_like(send_counts)
        dist.all_to_all_single(recv_counts, send_counts, group=self.ep_group)

        # 准备发送数据
        # 这里简化处理,实际需要更复杂的索引管理
        dispatched_input = hidden_flat  # 简化

        dispatch_info = {
            "send_counts": send_counts,
            "recv_counts": recv_counts,
            "weights": weights_flat,
            "experts": selected_flat,
            "original_shape": (batch_size, seq_len, hidden_size)
        }

        return dispatched_input, dispatch_info

    def _local_expert_compute(self, inputs: torch.Tensor) -> torch.Tensor:
        """在本地专家上计算"""
        outputs = torch.zeros_like(inputs)

        # 简化实现:所有输入通过所有本地专家
        for i, expert in enumerate(self.experts):
            outputs += expert(inputs)

        return outputs / len(self.experts)

    def _all_to_all_combine(
        self,
        local_outputs: torch.Tensor,
        dispatch_info: dict
    ) -> torch.Tensor:
        """合并结果"""
        # 简化实现
        batch_size, seq_len, hidden_size = dispatch_info["original_shape"]
        return local_outputs.view(batch_size, seq_len, hidden_size)


class MoETrainer:
    """MoE 训练器"""

    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        config: dict
    ):
        self.model = model
        self.optimizer = optimizer
        self.config = config

        # 辅助损失权重
        self.aux_loss_weight = config.get("aux_loss_weight", 0.01)

        # 梯度裁剪
        self.max_grad_norm = config.get("max_grad_norm", 1.0)

        # 专家容量监控
        self.expert_usage_history = []

    def train_step(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor
    ) -> dict:
        """训练步骤"""
        self.model.train()
        self.optimizer.zero_grad()

        # 前向传播
        outputs = self.model(input_ids, labels=labels)
        main_loss = outputs.loss
        aux_loss = outputs.aux_loss if hasattr(outputs, 'aux_loss') else 0

        # 总损失
        total_loss = main_loss + self.aux_loss_weight * aux_loss

        # 反向传播
        total_loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.max_grad_norm
        )

        # 更新参数
        self.optimizer.step()

        # 监控专家使用情况
        if hasattr(outputs, 'router_logits'):
            self._monitor_expert_usage(outputs.router_logits)

        return {
            "loss": total_loss.item(),
            "main_loss": main_loss.item(),
            "aux_loss": aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
        }

    def _monitor_expert_usage(self, router_logits: torch.Tensor):
        """监控专家使用情况"""
        with torch.no_grad():
            probs = F.softmax(router_logits, dim=-1)
            usage = probs.mean(dim=[0, 1])  # 平均使用率

            self.expert_usage_history.append(usage.cpu().numpy())

            # 检查是否有专家过载或空闲
            if len(self.expert_usage_history) % 100 == 0:
                avg_usage = np.mean(self.expert_usage_history[-100:], axis=0)
                print(f"Expert usage: min={avg_usage.min():.3f}, max={avg_usage.max():.3f}, "
                      f"std={avg_usage.std():.3f}")

容量因子与负载均衡

"""
容量因子与高级负载均衡策略
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class CapacityAwareMoE(nn.Module):
    """容量感知的 MoE"""

    def __init__(
        self,
        config: MoEConfig,
        capacity_factor: float = 1.25,
        drop_tokens: bool = True
    ):
        super().__init__()
        self.config = config
        self.capacity_factor = capacity_factor
        self.drop_tokens = drop_tokens

        self.router = TopKRouter(config)
        self.experts = nn.ModuleList([
            Expert(config) for _ in range(config.num_experts)
        ])

    def forward(self, hidden_states: torch.Tensor):
        batch_size, seq_len, hidden_size = hidden_states.shape
        num_tokens = batch_size * seq_len

        # 计算容量
        capacity = int(
            self.capacity_factor * num_tokens /
            self.config.num_experts *
            self.config.num_experts_per_tok
        )

        # 路由
        routing_weights, selected_experts, router_logits = self.router(hidden_states)

        # 应用容量限制
        routing_weights, selected_experts, dropped_mask = self._apply_capacity(
            routing_weights.view(num_tokens, -1),
            selected_experts.view(num_tokens, -1),
            capacity
        )

        # 计算输出
        hidden_flat = hidden_states.view(num_tokens, hidden_size)
        output = self._compute_output(hidden_flat, routing_weights, selected_experts)

        # 处理被丢弃的 token
        if self.drop_tokens and dropped_mask.any():
            # 被丢弃的 token 使用原始输入(残差)
            output[dropped_mask] = hidden_flat[dropped_mask]

        return output.view(batch_size, seq_len, hidden_size), router_logits

    def _apply_capacity(
        self,
        routing_weights: torch.Tensor,
        selected_experts: torch.Tensor,
        capacity: int
    ):
        """应用容量限制"""
        num_tokens, top_k = selected_experts.shape
        device = selected_experts.device

        # 计算每个专家的 token 计数
        expert_counts = torch.zeros(
            self.config.num_experts,
            dtype=torch.long,
            device=device
        )

        # 创建掩码标记被丢弃的 token
        dropped_mask = torch.zeros(num_tokens, dtype=torch.bool, device=device)

        # 重新分配 routing weights
        new_routing_weights = routing_weights.clone()

        for k in range(top_k):
            for token_idx in range(num_tokens):
                expert_idx = selected_experts[token_idx, k].item()

                if expert_counts[expert_idx] >= capacity:
                    # 超过容量,置零权重
                    new_routing_weights[token_idx, k] = 0

                    if k == 0:  # 第一选择被丢弃
                        dropped_mask[token_idx] = True
                else:
                    expert_counts[expert_idx] += 1

        # 重新归一化权重
        weight_sum = new_routing_weights.sum(dim=-1, keepdim=True)
        weight_sum = torch.clamp(weight_sum, min=1e-6)
        new_routing_weights = new_routing_weights / weight_sum

        return new_routing_weights, selected_experts, dropped_mask


class SwitchTransformerRouter(nn.Module):
    """Switch Transformer 路由器 (Top-1)"""

    def __init__(self, config: MoEConfig, capacity_factor: float = 1.0):
        super().__init__()
        self.config = config
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

    def forward(self, hidden_states: torch.Tensor):
        batch_size, seq_len, hidden_size = hidden_states.shape
        num_tokens = batch_size * seq_len

        # 计算 router logits
        hidden_flat = hidden_states.view(-1, hidden_size)
        router_logits = self.gate(hidden_flat)  # [N, E]

        # Top-1 选择
        routing_weights, selected_experts = router_logits.max(dim=-1)
        routing_weights = F.softmax(router_logits, dim=-1).gather(
            1, selected_experts.unsqueeze(-1)
        ).squeeze(-1)

        # 容量计算
        capacity = int(self.capacity_factor * num_tokens / self.config.num_experts)

        # One-hot 编码
        expert_mask = F.one_hot(selected_experts, self.config.num_experts).float()

        # 应用容量限制(使用 cumsum 技巧)
        position_in_expert = torch.zeros_like(selected_experts, dtype=torch.long)

        for e in range(self.config.num_experts):
            mask = (selected_experts == e)
            positions = torch.cumsum(mask.int(), dim=0)
            position_in_expert[mask] = positions[mask] - 1

        # 超过容量的 token 被丢弃
        within_capacity = (position_in_expert < capacity)
        routing_weights = routing_weights * within_capacity.float()

        return routing_weights, selected_experts, router_logits


class GShard_Router(nn.Module):
    """GShard 路由器 - 随机容量"""

    def __init__(self, config: MoEConfig, capacity_factor: float = 2.0):
        super().__init__()
        self.config = config
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

    def forward(self, hidden_states: torch.Tensor):
        batch_size, seq_len, hidden_size = hidden_states.shape
        num_tokens = batch_size * seq_len

        hidden_flat = hidden_states.view(-1, hidden_size)
        router_logits = self.gate(hidden_flat)

        # Top-2 选择
        routing_weights, selected_experts = torch.topk(router_logits, 2, dim=-1)
        routing_weights = F.softmax(routing_weights, dim=-1)

        # 第二专家使用随机阈值
        if self.training:
            random_threshold = torch.rand(num_tokens, 1, device=hidden_states.device)
            second_expert_mask = (routing_weights[:, 1:] > random_threshold * 0.5)
            routing_weights[:, 1:] = routing_weights[:, 1:] * second_expert_mask.float()

        # 重新归一化
        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-6)

        return routing_weights, selected_experts, router_logits

推理优化

"""
MoE 推理优化
"""
import torch
from typing import List, Dict, Optional

class MoEInferenceOptimizer:
    """MoE 推理优化器"""

    def __init__(self, model: nn.Module, config: Dict):
        self.model = model
        self.config = config

        # 专家缓存
        self.expert_cache: Dict[int, torch.Tensor] = {}

        # 专家使用统计
        self.expert_usage_stats: Dict[int, int] = {}

    def optimize_for_inference(self):
        """优化推理"""
        self.model.eval()

        # 1. 量化专家(可选)
        if self.config.get("quantize_experts", False):
            self._quantize_experts()

        # 2. 预热专家缓存
        if self.config.get("cache_experts", False):
            self._warmup_expert_cache()

        # 3. 设置推理模式
        for module in self.model.modules():
            if isinstance(module, SparseMoELayer):
                module.training = False

    def _quantize_experts(self):
        """量化专家网络"""
        for module in self.model.modules():
            if isinstance(module, Expert):
                # 动态量化
                torch.quantization.quantize_dynamic(
                    module,
                    {nn.Linear},
                    dtype=torch.qint8
                )

    def _warmup_expert_cache(self):
        """预热专家缓存"""
        # 分析历史数据确定热门专家
        hot_experts = self._identify_hot_experts()

        for expert_idx in hot_experts:
            # 预加载到 GPU
            for module in self.model.modules():
                if isinstance(module, SparseMoELayer):
                    expert = module.experts[expert_idx]
                    self.expert_cache[expert_idx] = {
                        name: param.clone()
                        for name, param in expert.named_parameters()
                    }

    def _identify_hot_experts(self, threshold: float = 0.1) -> List[int]:
        """识别热门专家"""
        if not self.expert_usage_stats:
            return []

        total_usage = sum(self.expert_usage_stats.values())
        hot_experts = [
            idx for idx, count in self.expert_usage_stats.items()
            if count / total_usage > threshold
        ]

        return hot_experts


class SpeculativeExpertExecution:
    """推测性专家执行"""

    def __init__(
        self,
        model: nn.Module,
        speculation_threshold: float = 0.8
    ):
        self.model = model
        self.speculation_threshold = speculation_threshold

        # 专家共现矩阵
        self.expert_cooccurrence = torch.zeros(8, 8)

    def speculative_forward(
        self,
        hidden_states: torch.Tensor,
        prev_expert: Optional[int] = None
    ) -> torch.Tensor:
        """推测性前向传播"""
        # 获取路由分数
        router_logits = self._get_router_logits(hidden_states)
        probs = F.softmax(router_logits, dim=-1)

        # 选择主专家
        main_expert = probs.argmax(dim=-1)

        # 推测下一层可能的专家
        if prev_expert is not None:
            speculated_experts = self._speculate_next_experts(prev_expert)
        else:
            speculated_experts = []

        # 并行执行主专家和推测专家
        outputs = self._parallel_expert_execute(
            hidden_states,
            main_expert,
            speculated_experts
        )

        return outputs, main_expert

    def _speculate_next_experts(self, prev_expert: int) -> List[int]:
        """推测下一个可能的专家"""
        cooccur = self.expert_cooccurrence[prev_expert]
        cooccur_prob = cooccur / cooccur.sum()

        # 选择高概率的专家
        speculated = (cooccur_prob > self.speculation_threshold).nonzero().squeeze(-1)

        return speculated.tolist()

    def _parallel_expert_execute(
        self,
        hidden_states: torch.Tensor,
        main_expert: torch.Tensor,
        speculated_experts: List[int]
    ) -> torch.Tensor:
        """并行执行专家"""
        # 使用 CUDA Streams 并行执行
        streams = [torch.cuda.Stream() for _ in range(len(speculated_experts) + 1)]

        results = {}

        # 主专家在主 stream
        with torch.cuda.stream(streams[0]):
            results["main"] = self._execute_expert(hidden_states, main_expert)

        # 推测专家在其他 streams
        for i, expert_idx in enumerate(speculated_experts):
            with torch.cuda.stream(streams[i + 1]):
                results[expert_idx] = self._execute_expert(
                    hidden_states,
                    torch.tensor([expert_idx])
                )

        # 同步
        torch.cuda.synchronize()

        return results["main"]

    def _execute_expert(
        self,
        hidden_states: torch.Tensor,
        expert_idx: torch.Tensor
    ) -> torch.Tensor:
        """执行单个专家"""
        # 简化实现
        for module in self.model.modules():
            if isinstance(module, SparseMoELayer):
                expert = module.experts[expert_idx.item()]
                return expert(hidden_states)

        return hidden_states

    def _get_router_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """获取路由 logits"""
        for module in self.model.modules():
            if isinstance(module, TopKRouter):
                return module.gate(hidden_states)

        raise ValueError("No router found")

小结

本章深入讲解了 MoE 架构:

  1. 架构原理:稀疏激活、路由机制、容量因子
  2. 路由策略:Top-K、Expert Choice、Switch Transformer
  3. 训练技巧:负载均衡损失、容量限制、梯度处理
  4. 分布式训练:专家并行、All-to-All 通信
  5. 推理优化:专家缓存、量化、推测执行

下一章我们将探讨 Speculative Decoding,讲解如何通过推测解码加速 LLM 推理。