MoE 架构与训练
概述
Mixture of Experts (MoE) 是一种稀疏激活的模型架构,通过路由机制选择性激活部分专家网络,在保持计算效率的同时大幅提升模型容量。本章深入讲解 MoE 的架构设计、训练策略和工程实现。
MoE 架构原理
基本结构
┌─────────────────────────────────────────────────────────────────────────────┐
│ Mixture of Experts (MoE) Layer │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Input Token │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ Router Network │ │
│ │ │ │
│ │ Input ──▶ Linear ──▶ Softmax ──▶ Top-K Selection ──▶ Weights │ │
│ │ │ │
│ │ Output: (expert_indices, routing_weights) │ │
│ └───────────────────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ┌───────────────────────────┼───────────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐│
│ │ Expert 1 │ │ Expert 2 │ ... │ Expert N ││
│ │ │ │ │ │ ││
│ │ FFN Layer │ │ FFN Layer │ │ FFN Layer ││
│ │ (Hidden×4) │ │ (Hidden×4) │ │ (Hidden×4) ││
│ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘│
│ │ │ │ │
│ └───────────────────────────┼───────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────┐ │
│ │ Weighted Sum │ │
│ │ Σ(weight_i × │ │
│ │ expert_i) │ │
│ └─────────┬─────────┘ │
│ │ │
│ ▼ │
│ Output Token │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
MoE vs Dense Model Comparison:
┌─────────────────────────────────────────────────────────────────────────────┐
│ │
│ Dense Model (7B): │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ FFN: 4096 → 16384 → 4096 │ │
│ │ 每个 token 激活全部参数 │ │
│ │ 参数量: 7B, 计算量: 7B × tokens │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
│ MoE Model (8×7B with top-2): │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 8 个专家,每个 7B 参数 │ │
│ │ 每个 token 只激活 2 个专家 │ │
│ │ 总参数量: 56B, 计算量: 14B × tokens (相当于 2 个专家) │ │
│ │ → 参数增加 8x,计算只增加 2x │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
路由机制
"""
MoE 路由机制实现
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from dataclasses import dataclass
@dataclass
class MoEConfig:
"""MoE 配置"""
hidden_size: int = 4096
intermediate_size: int = 14336
num_experts: int = 8
num_experts_per_tok: int = 2
router_jitter_noise: float = 0.0
router_aux_loss_coef: float = 0.01
router_z_loss_coef: float = 0.001
class TopKRouter(nn.Module):
"""Top-K 路由器"""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
# 路由网络
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
def forward(
self,
hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
hidden_states: [batch_size, seq_len, hidden_size]
Returns:
routing_weights: [batch_size, seq_len, num_experts_per_tok]
selected_experts: [batch_size, seq_len, num_experts_per_tok]
router_logits: [batch_size, seq_len, num_experts]
"""
batch_size, seq_len, hidden_size = hidden_states.shape
# 计算路由 logits
router_logits = self.gate(hidden_states) # [B, S, E]
# 添加训练时的噪声
if self.training and self.config.router_jitter_noise > 0:
router_logits = router_logits + torch.randn_like(router_logits) * self.config.router_jitter_noise
# Top-K 选择
routing_weights, selected_experts = torch.topk(
router_logits,
self.top_k,
dim=-1
)
# Softmax 归一化(只在选中的专家上)
routing_weights = F.softmax(routing_weights, dim=-1)
return routing_weights, selected_experts, router_logits
class ExpertChoiceRouter(nn.Module):
"""Expert Choice 路由器 - 每个专家选择 token"""
def __init__(self, config: MoEConfig, capacity_factor: float = 1.25):
super().__init__()
self.config = config
self.capacity_factor = capacity_factor
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
def forward(
self,
hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Expert Choice: 每个专家选择固定数量的 token
Returns:
expert_weights: [num_experts, capacity, 1]
expert_indices: [num_experts, capacity]
router_logits: [batch_size * seq_len, num_experts]
"""
batch_size, seq_len, hidden_size = hidden_states.shape
num_tokens = batch_size * seq_len
# 计算容量
capacity = int(num_tokens * self.capacity_factor / self.config.num_experts)
# 计算路由分数
hidden_flat = hidden_states.view(-1, hidden_size)
router_logits = self.gate(hidden_flat) # [N, E]
# 转置后让每个专家选择 top-k tokens
router_logits_t = router_logits.t() # [E, N]
expert_weights, expert_indices = torch.topk(
router_logits_t,
capacity,
dim=-1
)
# Softmax
expert_weights = F.softmax(expert_weights, dim=-1, dtype=torch.float32)
expert_weights = expert_weights.unsqueeze(-1)
return expert_weights, expert_indices, router_logits
class LoadBalancingLoss(nn.Module):
"""负载均衡损失"""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
def forward(
self,
router_logits: torch.Tensor,
selected_experts: torch.Tensor
) -> torch.Tensor:
"""
计算辅助损失以促进负载均衡
Args:
router_logits: [batch_size, seq_len, num_experts]
selected_experts: [batch_size, seq_len, top_k]
"""
num_experts = self.config.num_experts
# 计算每个专家被选中的频率
expert_mask = F.one_hot(selected_experts, num_experts).float() # [B, S, K, E]
expert_mask = expert_mask.sum(dim=2) # [B, S, E] - 每个位置每个专家被选中次数
# 计算路由概率
router_probs = F.softmax(router_logits, dim=-1) # [B, S, E]
# 辅助损失: 鼓励均匀分布
# f_i: 专家 i 被选中的比例
# P_i: 路由到专家 i 的平均概率
tokens_per_expert = expert_mask.float().mean(dim=[0, 1]) # [E]
router_prob_per_expert = router_probs.mean(dim=[0, 1]) # [E]
aux_loss = torch.sum(tokens_per_expert * router_prob_per_expert) * num_experts
# Z-loss: 防止 router logits 过大
z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()
total_loss = (
self.config.router_aux_loss_coef * aux_loss +
self.config.router_z_loss_coef * z_loss
)
return total_loss
MoE 层实现
"""
MoE 层完整实现
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
class Expert(nn.Module):
"""单个专家网络 (FFN)"""
def __init__(self, config: MoEConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU activation
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class SparseMoELayer(nn.Module):
"""稀疏 MoE 层"""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
# 路由器
self.router = TopKRouter(config)
# 专家网络
self.experts = nn.ModuleList([
Expert(config) for _ in range(config.num_experts)
])
# 负载均衡损失
self.load_balancing_loss = LoadBalancingLoss(config)
def forward(
self,
hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
hidden_states: [batch_size, seq_len, hidden_size]
Returns:
output: [batch_size, seq_len, hidden_size]
aux_loss: 辅助损失
"""
batch_size, seq_len, hidden_size = hidden_states.shape
# 路由
routing_weights, selected_experts, router_logits = self.router(hidden_states)
# 计算辅助损失
aux_loss = None
if self.training:
aux_loss = self.load_balancing_loss(router_logits, selected_experts)
# 计算专家输出
final_output = self._compute_expert_outputs(
hidden_states,
routing_weights,
selected_experts
)
return final_output, aux_loss
def _compute_expert_outputs(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
selected_experts: torch.Tensor
) -> torch.Tensor:
"""计算专家输出 - 基础实现"""
batch_size, seq_len, hidden_size = hidden_states.shape
final_output = torch.zeros_like(hidden_states)
# 为每个专家收集 token
for expert_idx in range(self.num_experts):
# 找到选择了这个专家的位置
expert_mask = (selected_experts == expert_idx) # [B, S, K]
# 遍历 top-k 位置
for k in range(self.top_k):
mask_k = expert_mask[:, :, k] # [B, S]
if mask_k.any():
# 获取对应的 token
expert_input = hidden_states[mask_k] # [N, H]
# 计算专家输出
expert_output = self.experts[expert_idx](expert_input)
# 加权累加
weights = routing_weights[:, :, k][mask_k].unsqueeze(-1) # [N, 1]
final_output[mask_k] += weights * expert_output
return final_output
class OptimizedSparseMoE(nn.Module):
"""优化的稀疏 MoE 实现 - 使用分组计算"""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.router = TopKRouter(config)
self.experts = nn.ModuleList([
Expert(config) for _ in range(config.num_experts)
])
self.load_balancing_loss = LoadBalancingLoss(config)
def forward(
self,
hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_len, hidden_size = hidden_states.shape
num_tokens = batch_size * seq_len
# Flatten
hidden_flat = hidden_states.view(-1, hidden_size) # [N, H]
# 路由
routing_weights, selected_experts, router_logits = self.router(hidden_states)
routing_weights = routing_weights.view(num_tokens, self.top_k)
selected_experts = selected_experts.view(num_tokens, self.top_k)
# 辅助损失
aux_loss = None
if self.training:
aux_loss = self.load_balancing_loss(
router_logits.view(num_tokens, -1),
selected_experts
)
# 按专家分组并批量计算
final_output = self._batched_expert_forward(
hidden_flat,
routing_weights,
selected_experts
)
return final_output.view(batch_size, seq_len, hidden_size), aux_loss
def _batched_expert_forward(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
selected_experts: torch.Tensor
) -> torch.Tensor:
"""批量计算专家输出"""
num_tokens, hidden_size = hidden_states.shape
device = hidden_states.device
dtype = hidden_states.dtype
final_output = torch.zeros(num_tokens, hidden_size, device=device, dtype=dtype)
# 创建专家索引
flat_selected = selected_experts.view(-1) # [N * K]
flat_weights = routing_weights.view(-1, 1) # [N * K, 1]
# 创建 token 索引(每个 token 重复 K 次)
token_indices = torch.arange(num_tokens, device=device)
token_indices = token_indices.unsqueeze(1).expand(-1, self.top_k).reshape(-1)
# 对每个专家进行批量计算
for expert_idx in range(self.num_experts):
expert_mask = (flat_selected == expert_idx)
if expert_mask.any():
# 获取分配给这个专家的 token
expert_token_indices = token_indices[expert_mask]
expert_weights = flat_weights[expert_mask]
# 批量计算
expert_input = hidden_states[expert_token_indices]
expert_output = self.experts[expert_idx](expert_input)
# 加权累加到输出
weighted_output = expert_weights * expert_output
final_output.index_add_(0, expert_token_indices, weighted_output)
return final_output
训练策略
分布式训练
"""
MoE 分布式训练
"""
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
class ExpertParallelMoE(nn.Module):
"""专家并行 MoE"""
def __init__(
self,
config: MoEConfig,
expert_parallel_size: int = 1
):
super().__init__()
self.config = config
self.ep_size = expert_parallel_size
# 确定本地专家
self.rank = dist.get_rank()
self.num_local_experts = config.num_experts // expert_parallel_size
expert_start = self.rank * self.num_local_experts
expert_end = expert_start + self.num_local_experts
# 只创建本地专家
self.local_expert_indices = list(range(expert_start, expert_end))
self.experts = nn.ModuleList([
Expert(config) for _ in range(self.num_local_experts)
])
# 全局路由器
self.router = TopKRouter(config)
# 创建专家并行通信组
self._setup_expert_parallel_group()
def _setup_expert_parallel_group(self):
"""设置专家并行通信组"""
world_size = dist.get_world_size()
# 假设所有 GPU 在同一个专家并行组
ranks = list(range(world_size))
self.ep_group = dist.new_group(ranks)
def forward(
self,
hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_len, hidden_size = hidden_states.shape
# 路由
routing_weights, selected_experts, router_logits = self.router(hidden_states)
# All-to-All: 将 token 发送到对应专家所在的 rank
dispatched_input, dispatch_info = self._all_to_all_dispatch(
hidden_states,
selected_experts,
routing_weights
)
# 本地专家计算
local_outputs = self._local_expert_compute(dispatched_input)
# All-to-All: 将结果返回
combined_output = self._all_to_all_combine(
local_outputs,
dispatch_info
)
# 辅助损失
aux_loss = None
if self.training:
aux_loss = self.load_balancing_loss(router_logits, selected_experts)
return combined_output, aux_loss
def _all_to_all_dispatch(
self,
hidden_states: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor
):
"""分发 token 到各个专家所在的 rank"""
batch_size, seq_len, hidden_size = hidden_states.shape
num_tokens = batch_size * seq_len
hidden_flat = hidden_states.view(-1, hidden_size)
selected_flat = selected_experts.view(-1, self.config.num_experts_per_tok)
weights_flat = routing_weights.view(-1, self.config.num_experts_per_tok)
# 计算每个 rank 接收的 token 数
send_counts = torch.zeros(self.ep_size, dtype=torch.long, device=hidden_states.device)
for k in range(self.config.num_experts_per_tok):
for rank in range(self.ep_size):
local_start = rank * self.num_local_experts
local_end = local_start + self.num_local_experts
mask = (selected_flat[:, k] >= local_start) & (selected_flat[:, k] < local_end)
send_counts[rank] += mask.sum()
# All-to-All 交换 token
recv_counts = torch.zeros_like(send_counts)
dist.all_to_all_single(recv_counts, send_counts, group=self.ep_group)
# 准备发送数据
# 这里简化处理,实际需要更复杂的索引管理
dispatched_input = hidden_flat # 简化
dispatch_info = {
"send_counts": send_counts,
"recv_counts": recv_counts,
"weights": weights_flat,
"experts": selected_flat,
"original_shape": (batch_size, seq_len, hidden_size)
}
return dispatched_input, dispatch_info
def _local_expert_compute(self, inputs: torch.Tensor) -> torch.Tensor:
"""在本地专家上计算"""
outputs = torch.zeros_like(inputs)
# 简化实现:所有输入通过所有本地专家
for i, expert in enumerate(self.experts):
outputs += expert(inputs)
return outputs / len(self.experts)
def _all_to_all_combine(
self,
local_outputs: torch.Tensor,
dispatch_info: dict
) -> torch.Tensor:
"""合并结果"""
# 简化实现
batch_size, seq_len, hidden_size = dispatch_info["original_shape"]
return local_outputs.view(batch_size, seq_len, hidden_size)
class MoETrainer:
"""MoE 训练器"""
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
config: dict
):
self.model = model
self.optimizer = optimizer
self.config = config
# 辅助损失权重
self.aux_loss_weight = config.get("aux_loss_weight", 0.01)
# 梯度裁剪
self.max_grad_norm = config.get("max_grad_norm", 1.0)
# 专家容量监控
self.expert_usage_history = []
def train_step(
self,
input_ids: torch.Tensor,
labels: torch.Tensor
) -> dict:
"""训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
outputs = self.model(input_ids, labels=labels)
main_loss = outputs.loss
aux_loss = outputs.aux_loss if hasattr(outputs, 'aux_loss') else 0
# 总损失
total_loss = main_loss + self.aux_loss_weight * aux_loss
# 反向传播
total_loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
# 更新参数
self.optimizer.step()
# 监控专家使用情况
if hasattr(outputs, 'router_logits'):
self._monitor_expert_usage(outputs.router_logits)
return {
"loss": total_loss.item(),
"main_loss": main_loss.item(),
"aux_loss": aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
}
def _monitor_expert_usage(self, router_logits: torch.Tensor):
"""监控专家使用情况"""
with torch.no_grad():
probs = F.softmax(router_logits, dim=-1)
usage = probs.mean(dim=[0, 1]) # 平均使用率
self.expert_usage_history.append(usage.cpu().numpy())
# 检查是否有专家过载或空闲
if len(self.expert_usage_history) % 100 == 0:
avg_usage = np.mean(self.expert_usage_history[-100:], axis=0)
print(f"Expert usage: min={avg_usage.min():.3f}, max={avg_usage.max():.3f}, "
f"std={avg_usage.std():.3f}")
容量因子与负载均衡
"""
容量因子与高级负载均衡策略
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CapacityAwareMoE(nn.Module):
"""容量感知的 MoE"""
def __init__(
self,
config: MoEConfig,
capacity_factor: float = 1.25,
drop_tokens: bool = True
):
super().__init__()
self.config = config
self.capacity_factor = capacity_factor
self.drop_tokens = drop_tokens
self.router = TopKRouter(config)
self.experts = nn.ModuleList([
Expert(config) for _ in range(config.num_experts)
])
def forward(self, hidden_states: torch.Tensor):
batch_size, seq_len, hidden_size = hidden_states.shape
num_tokens = batch_size * seq_len
# 计算容量
capacity = int(
self.capacity_factor * num_tokens /
self.config.num_experts *
self.config.num_experts_per_tok
)
# 路由
routing_weights, selected_experts, router_logits = self.router(hidden_states)
# 应用容量限制
routing_weights, selected_experts, dropped_mask = self._apply_capacity(
routing_weights.view(num_tokens, -1),
selected_experts.view(num_tokens, -1),
capacity
)
# 计算输出
hidden_flat = hidden_states.view(num_tokens, hidden_size)
output = self._compute_output(hidden_flat, routing_weights, selected_experts)
# 处理被丢弃的 token
if self.drop_tokens and dropped_mask.any():
# 被丢弃的 token 使用原始输入(残差)
output[dropped_mask] = hidden_flat[dropped_mask]
return output.view(batch_size, seq_len, hidden_size), router_logits
def _apply_capacity(
self,
routing_weights: torch.Tensor,
selected_experts: torch.Tensor,
capacity: int
):
"""应用容量限制"""
num_tokens, top_k = selected_experts.shape
device = selected_experts.device
# 计算每个专家的 token 计数
expert_counts = torch.zeros(
self.config.num_experts,
dtype=torch.long,
device=device
)
# 创建掩码标记被丢弃的 token
dropped_mask = torch.zeros(num_tokens, dtype=torch.bool, device=device)
# 重新分配 routing weights
new_routing_weights = routing_weights.clone()
for k in range(top_k):
for token_idx in range(num_tokens):
expert_idx = selected_experts[token_idx, k].item()
if expert_counts[expert_idx] >= capacity:
# 超过容量,置零权重
new_routing_weights[token_idx, k] = 0
if k == 0: # 第一选择被丢弃
dropped_mask[token_idx] = True
else:
expert_counts[expert_idx] += 1
# 重新归一化权重
weight_sum = new_routing_weights.sum(dim=-1, keepdim=True)
weight_sum = torch.clamp(weight_sum, min=1e-6)
new_routing_weights = new_routing_weights / weight_sum
return new_routing_weights, selected_experts, dropped_mask
class SwitchTransformerRouter(nn.Module):
"""Switch Transformer 路由器 (Top-1)"""
def __init__(self, config: MoEConfig, capacity_factor: float = 1.0):
super().__init__()
self.config = config
self.capacity_factor = capacity_factor
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
def forward(self, hidden_states: torch.Tensor):
batch_size, seq_len, hidden_size = hidden_states.shape
num_tokens = batch_size * seq_len
# 计算 router logits
hidden_flat = hidden_states.view(-1, hidden_size)
router_logits = self.gate(hidden_flat) # [N, E]
# Top-1 选择
routing_weights, selected_experts = router_logits.max(dim=-1)
routing_weights = F.softmax(router_logits, dim=-1).gather(
1, selected_experts.unsqueeze(-1)
).squeeze(-1)
# 容量计算
capacity = int(self.capacity_factor * num_tokens / self.config.num_experts)
# One-hot 编码
expert_mask = F.one_hot(selected_experts, self.config.num_experts).float()
# 应用容量限制(使用 cumsum 技巧)
position_in_expert = torch.zeros_like(selected_experts, dtype=torch.long)
for e in range(self.config.num_experts):
mask = (selected_experts == e)
positions = torch.cumsum(mask.int(), dim=0)
position_in_expert[mask] = positions[mask] - 1
# 超过容量的 token 被丢弃
within_capacity = (position_in_expert < capacity)
routing_weights = routing_weights * within_capacity.float()
return routing_weights, selected_experts, router_logits
class GShard_Router(nn.Module):
"""GShard 路由器 - 随机容量"""
def __init__(self, config: MoEConfig, capacity_factor: float = 2.0):
super().__init__()
self.config = config
self.capacity_factor = capacity_factor
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
def forward(self, hidden_states: torch.Tensor):
batch_size, seq_len, hidden_size = hidden_states.shape
num_tokens = batch_size * seq_len
hidden_flat = hidden_states.view(-1, hidden_size)
router_logits = self.gate(hidden_flat)
# Top-2 选择
routing_weights, selected_experts = torch.topk(router_logits, 2, dim=-1)
routing_weights = F.softmax(routing_weights, dim=-1)
# 第二专家使用随机阈值
if self.training:
random_threshold = torch.rand(num_tokens, 1, device=hidden_states.device)
second_expert_mask = (routing_weights[:, 1:] > random_threshold * 0.5)
routing_weights[:, 1:] = routing_weights[:, 1:] * second_expert_mask.float()
# 重新归一化
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-6)
return routing_weights, selected_experts, router_logits
推理优化
"""
MoE 推理优化
"""
import torch
from typing import List, Dict, Optional
class MoEInferenceOptimizer:
"""MoE 推理优化器"""
def __init__(self, model: nn.Module, config: Dict):
self.model = model
self.config = config
# 专家缓存
self.expert_cache: Dict[int, torch.Tensor] = {}
# 专家使用统计
self.expert_usage_stats: Dict[int, int] = {}
def optimize_for_inference(self):
"""优化推理"""
self.model.eval()
# 1. 量化专家(可选)
if self.config.get("quantize_experts", False):
self._quantize_experts()
# 2. 预热专家缓存
if self.config.get("cache_experts", False):
self._warmup_expert_cache()
# 3. 设置推理模式
for module in self.model.modules():
if isinstance(module, SparseMoELayer):
module.training = False
def _quantize_experts(self):
"""量化专家网络"""
for module in self.model.modules():
if isinstance(module, Expert):
# 动态量化
torch.quantization.quantize_dynamic(
module,
{nn.Linear},
dtype=torch.qint8
)
def _warmup_expert_cache(self):
"""预热专家缓存"""
# 分析历史数据确定热门专家
hot_experts = self._identify_hot_experts()
for expert_idx in hot_experts:
# 预加载到 GPU
for module in self.model.modules():
if isinstance(module, SparseMoELayer):
expert = module.experts[expert_idx]
self.expert_cache[expert_idx] = {
name: param.clone()
for name, param in expert.named_parameters()
}
def _identify_hot_experts(self, threshold: float = 0.1) -> List[int]:
"""识别热门专家"""
if not self.expert_usage_stats:
return []
total_usage = sum(self.expert_usage_stats.values())
hot_experts = [
idx for idx, count in self.expert_usage_stats.items()
if count / total_usage > threshold
]
return hot_experts
class SpeculativeExpertExecution:
"""推测性专家执行"""
def __init__(
self,
model: nn.Module,
speculation_threshold: float = 0.8
):
self.model = model
self.speculation_threshold = speculation_threshold
# 专家共现矩阵
self.expert_cooccurrence = torch.zeros(8, 8)
def speculative_forward(
self,
hidden_states: torch.Tensor,
prev_expert: Optional[int] = None
) -> torch.Tensor:
"""推测性前向传播"""
# 获取路由分数
router_logits = self._get_router_logits(hidden_states)
probs = F.softmax(router_logits, dim=-1)
# 选择主专家
main_expert = probs.argmax(dim=-1)
# 推测下一层可能的专家
if prev_expert is not None:
speculated_experts = self._speculate_next_experts(prev_expert)
else:
speculated_experts = []
# 并行执行主专家和推测专家
outputs = self._parallel_expert_execute(
hidden_states,
main_expert,
speculated_experts
)
return outputs, main_expert
def _speculate_next_experts(self, prev_expert: int) -> List[int]:
"""推测下一个可能的专家"""
cooccur = self.expert_cooccurrence[prev_expert]
cooccur_prob = cooccur / cooccur.sum()
# 选择高概率的专家
speculated = (cooccur_prob > self.speculation_threshold).nonzero().squeeze(-1)
return speculated.tolist()
def _parallel_expert_execute(
self,
hidden_states: torch.Tensor,
main_expert: torch.Tensor,
speculated_experts: List[int]
) -> torch.Tensor:
"""并行执行专家"""
# 使用 CUDA Streams 并行执行
streams = [torch.cuda.Stream() for _ in range(len(speculated_experts) + 1)]
results = {}
# 主专家在主 stream
with torch.cuda.stream(streams[0]):
results["main"] = self._execute_expert(hidden_states, main_expert)
# 推测专家在其他 streams
for i, expert_idx in enumerate(speculated_experts):
with torch.cuda.stream(streams[i + 1]):
results[expert_idx] = self._execute_expert(
hidden_states,
torch.tensor([expert_idx])
)
# 同步
torch.cuda.synchronize()
return results["main"]
def _execute_expert(
self,
hidden_states: torch.Tensor,
expert_idx: torch.Tensor
) -> torch.Tensor:
"""执行单个专家"""
# 简化实现
for module in self.model.modules():
if isinstance(module, SparseMoELayer):
expert = module.experts[expert_idx.item()]
return expert(hidden_states)
return hidden_states
def _get_router_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""获取路由 logits"""
for module in self.model.modules():
if isinstance(module, TopKRouter):
return module.gate(hidden_states)
raise ValueError("No router found")
小结
本章深入讲解了 MoE 架构:
- 架构原理:稀疏激活、路由机制、容量因子
- 路由策略:Top-K、Expert Choice、Switch Transformer
- 训练技巧:负载均衡损失、容量限制、梯度处理
- 分布式训练:专家并行、All-to-All 通信
- 推理优化:专家缓存、量化、推测执行
下一章我们将探讨 Speculative Decoding,讲解如何通过推测解码加速 LLM 推理。