03-Megatron-LM源码解析
概述
Megatron-LM是NVIDIA开发的大规模语言模型训练框架,首创了Tensor Parallel和高效的3D并行策略。本章深入解析Megatron-LM的核心实现,包括张量并行、序列并行、以及与DeepSpeed的集成。
Megatron-LM 架构
整体架构
┌─────────────────────────────────────────────────────────────────────────┐
│ Megatron-LM 架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 用户接口层 │ │
│ │ │ │
│ │ pretrain_gpt.py arguments.py training.py │ │
│ │ (GPT预训练入口) (参数解析) (训练循环) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 模型定义层 │ │
│ │ │ │
│ │ ┌────────────────────────────────────────────────────────────┐ │ │
│ │ │ GPTModel / BertModel / T5Model │ │ │
│ │ │ ├─ ParallelTransformer │ │ │
│ │ │ │ ├─ ParallelTransformerLayer │ │ │
│ │ │ │ │ ├─ ParallelAttention │ │ │
│ │ │ │ │ └─ ParallelMLP │ │ │
│ │ │ │ └─ sequence parallelism hooks │ │ │
│ │ │ ├─ Embedding │ │ │
│ │ │ └─ Output layer │ │ │
│ │ └────────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 并行原语层 │ │
│ │ │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │ │
│ │ │ Tensor │ │ Pipeline │ │ Sequence │ │ Data │ │ │
│ │ │ Parallel │ │ Parallel │ │ Parallel │ │ Parallel │ │ │
│ │ │ (TP) │ │ (PP) │ │ (SP) │ │ (DP) │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ MPU (Model Parallel Unit) │ │
│ │ │ │
│ │ parallel_state.py: │ │
│ │ • tensor_model_parallel_group │ │
│ │ • pipeline_model_parallel_group │ │
│ │ • data_parallel_group │ │
│ │ • get_tensor_model_parallel_rank/world_size │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
源码目录结构
megatron/
├── arguments.py # 参数解析
├── training.py # 训练循环
├── initialize.py # 初始化
├── global_vars.py # 全局变量
│
├── core/
│ ├── parallel_state.py # 并行状态管理 (MPU)
│ ├── tensor_parallel/
│ │ ├── layers.py # TP层实现
│ │ ├── mappings.py # 通信原语
│ │ └── utils.py
│ ├── pipeline_parallel/
│ │ ├── schedules.py # PP调度
│ │ └── p2p_communication.py
│ └── sequence_parallel/
│ └── layers.py # SP层实现
│
├── model/
│ ├── gpt_model.py # GPT模型
│ ├── language_model.py # 语言模型基类
│ ├── transformer.py # Transformer实现
│ └── module.py # 模块基类
│
├── data/
│ ├── gpt_dataset.py # 数据集
│ └── data_samplers.py # 采样器
│
└── optimizer/
├── optimizer.py # 优化器
└── grad_scaler.py # 梯度缩放
Model Parallel Unit (MPU)
并行组初始化
# megatron/core/parallel_state.py
# 全局并行组变量
_TENSOR_MODEL_PARALLEL_GROUP = None
_PIPELINE_MODEL_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP = None
_EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
):
"""
初始化模型并行组
假设world_size = 16, TP=2, PP=4, 则DP=2
GPU布局:
┌────────────────────────────────────────────────────────────┐
│ │
│ DP组0: DP组1: │
│ ┌──────────────────────┐ ┌──────────────────────────┐ │
│ │ PP Stage 0 │ │ PP Stage 0 │ │
│ │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ │
│ │ │GPU 0│ │GPU 1│ TP组│ │ │GPU 8│ │GPU 9│ TP组 │ │
│ │ └─────┘ └─────┘ │ │ └─────┘ └─────┘ │ │
│ │ │ │ │ │
│ │ PP Stage 1 │ │ PP Stage 1 │ │
│ │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ │
│ │ │GPU 2│ │GPU 3│ │ │ │GPU10│ │GPU11│ │ │
│ │ └─────┘ └─────┘ │ │ └─────┘ └─────┘ │ │
│ │ │ │ │ │
│ │ PP Stage 2 │ │ PP Stage 2 │ │
│ │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ │
│ │ │GPU 4│ │GPU 5│ │ │ │GPU12│ │GPU13│ │ │
│ │ └─────┘ └─────┘ │ │ └─────┘ └─────┘ │ │
│ │ │ │ │ │
│ │ PP Stage 3 │ │ PP Stage 3 │ │
│ │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ │
│ │ │GPU 6│ │GPU 7│ │ │ │GPU14│ │GPU15│ │ │
│ │ └─────┘ └─────┘ │ │ └─────┘ └─────┘ │ │
│ └──────────────────────┘ └──────────────────────────┘ │
│ │
└────────────────────────────────────────────────────────────┘
"""
global _TENSOR_MODEL_PARALLEL_GROUP
global _PIPELINE_MODEL_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP
# 计算数据并行大小
world_size = torch.distributed.get_world_size()
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank()
# ==== 创建Tensor Parallel组 ====
# 同一个TP组内的GPU需要AllReduce
# 例如: [0,1], [2,3], [4,5], ... 每组2个GPU
for i in range(num_tensor_model_parallel_groups):
start_rank = i * tensor_model_parallel_size
end_rank = start_rank + tensor_model_parallel_size
ranks = list(range(start_rank, end_rank))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# ==== 创建Pipeline Parallel组 ====
# 同一个PP组内的GPU需要P2P通信
# 例如: [0,2,4,6], [1,3,5,7], [8,10,12,14], [9,11,13,15]
for i in range(num_pipeline_model_parallel_groups):
ranks = []
for j in range(pipeline_model_parallel_size):
# 计算每个stage的rank
rank_in_pipeline = (i % data_parallel_size) * tensor_model_parallel_size + \
(i // data_parallel_size) + \
j * (data_parallel_size * tensor_model_parallel_size)
ranks.append(rank_in_pipeline)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
# ==== 创建Data Parallel组 ====
# 同一个DP组内的GPU需要AllReduce梯度
# 例如: [0,8], [1,9], [2,10], ...
for i in range(num_data_parallel_groups):
ranks = []
for j in range(data_parallel_size):
ranks.append(i + j * num_data_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# 辅助函数
def get_tensor_model_parallel_group():
return _TENSOR_MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_rank():
return torch.distributed.get_rank(group=_TENSOR_MODEL_PARALLEL_GROUP)
def get_tensor_model_parallel_world_size():
return torch.distributed.get_world_size(group=_TENSOR_MODEL_PARALLEL_GROUP)
def get_pipeline_model_parallel_group():
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=_PIPELINE_MODEL_PARALLEL_GROUP)
def get_data_parallel_group():
return _DATA_PARALLEL_GROUP
Tensor Parallel 实现
ColumnParallelLinear
# megatron/core/tensor_parallel/layers.py
class ColumnParallelLinear(torch.nn.Module):
"""
列并行线性层
将权重按列切分到不同GPU:
Y = XA, 其中A按列切分: A = [A_1, A_2, ..., A_n]
Y_i = X @ A_i
┌─────────────────────────────────────────────────────────┐
│ │
│ X (input) │
│ ┌───────────────┐ │
│ │ │ │
│ │ [B, S, H] │ 广播到所有TP rank │
│ │ │ │
│ └───────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Weight A │ │
│ │ │ │
│ │ ┌─────────┬─────────┬─────────┬─────────┐ │ │
│ │ │ A_0 │ A_1 │ A_2 │ A_3 │ │ │
│ │ │ [H,H/4] │ [H,H/4] │ [H,H/4] │ [H,H/4] │ │ │
│ │ │ GPU 0 │ GPU 1 │ GPU 2 │ GPU 3 │ │ │
│ │ └─────────┴─────────┴─────────┴─────────┘ │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Output Y │ │
│ │ │ │
│ │ ┌─────────┬─────────┬─────────┬─────────┐ │ │
│ │ │ Y_0 │ Y_1 │ Y_2 │ Y_3 │ │ │
│ │ │[B,S,H/4]│[B,S,H/4]│[B,S,H/4]│[B,S,H/4]│ │ │
│ │ │ GPU 0 │ GPU 1 │ GPU 2 │ GPU 3 │ │ │
│ │ └─────────┴─────────┴─────────┴─────────┘ │ │
│ │ │ │
│ │ 后续通常是AllGather (如果需要完整输出) │ │
│ │ 或直接输入到RowParallelLinear │ │
│ └───────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = True,
init_method=init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
async_tensor_model_parallel_allreduce: bool = True,
sequence_parallel_enabled: bool = False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.sequence_parallel = sequence_parallel_enabled
# 获取TP配置
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
# 初始化权重 (只初始化本partition)
self.weight = Parameter(torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=torch.float32
))
init_method(self.weight)
# Bias (只有rank 0初始化完整bias, 其他rank初始化本partition)
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=torch.float32
))
# 初始化为0
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = async_tensor_model_parallel_allreduce
def forward(self, input_):
"""
Forward pass
如果sequence_parallel: 输入是[S/TP, B, H], 需要先AllGather
否则: 输入是[S, B, H], 直接计算
"""
if self.sequence_parallel:
# Sequence Parallel: AllGather输入
input_parallel = gather_from_sequence_parallel_region(input_)
else:
input_parallel = input_
# 异步AllReduce优化 (用于减少TP AllReduce延迟)
if self.async_tensor_model_parallel_allreduce:
input_parallel = copy_to_tensor_model_parallel_region(input_parallel)
# 线性计算: Y = X @ W^T + b
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# AllGather输出
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
return output
class RowParallelLinear(torch.nn.Module):
"""
行并行线性层
将权重按行切分到不同GPU:
Y = XA, 其中X按列切分: X = [X_1, X_2, ..., X_n]
A按行切分: A = [A_1; A_2; ...; A_n]
Y = sum(X_i @ A_i)
通常接在ColumnParallelLinear之后
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = False,
init_method=init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
sequence_parallel_enabled: bool = False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.skip_bias_add = skip_bias_add
self.sequence_parallel = sequence_parallel_enabled
# 获取TP配置
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
# 初始化权重 (只初始化本partition)
self.weight = Parameter(torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=torch.float32
))
init_method(self.weight)
# Bias (完整bias, 但只在AllReduce后加一次)
if bias:
self.bias = Parameter(torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=torch.float32
))
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
"""
Forward pass
输入: [B, S, H/TP] (来自ColumnParallelLinear, 已经是分片的)
输出: [B, S, H] (需要AllReduce求和)
"""
if self.input_is_parallel:
input_parallel = input_
else:
# 切分输入
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# 线性计算
output_parallel = F.linear(input_parallel, self.weight)
# AllReduce求和
if self.sequence_parallel:
# Sequence Parallel: ReduceScatter
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# 标准TP: AllReduce
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
# 加bias (只加一次, 不是每个rank都加)
if self.bias is not None and not self.skip_bias_add:
output = output_ + self.bias
else:
output = output_
if self.skip_bias_add:
return output, self.bias
return output
Tensor Parallel 通信原语
# megatron/core/tensor_parallel/mappings.py
class _CopyToModelParallelRegion(torch.autograd.Function):
"""
Forward: 复制 (无操作)
Backward: AllReduce梯度
"""
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
# AllReduce梯度
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""
Forward: AllReduce
Backward: 复制 (无操作)
"""
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""
Forward: AllGather
Backward: 切分梯度 (取对应partition)
"""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""
Forward: 切分 (取对应partition)
Backward: AllGather梯度
"""
@staticmethod
def forward(ctx, input_):
return _split(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
def _reduce(input_):
"""AllReduce"""
if get_tensor_model_parallel_world_size() == 1:
return input_
torch.distributed.all_reduce(
input_,
group=get_tensor_model_parallel_group()
)
return input_
def _gather(input_):
"""AllGather along last dimension"""
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
return input_
# Gather
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(
tensor_list,
input_,
group=get_tensor_model_parallel_group()
)
# Concatenate along last dimension
output = torch.cat(tensor_list, dim=-1)
return output
def _split(input_):
"""Split along last dimension"""
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
return input_
rank = get_tensor_model_parallel_rank()
# Split
input_list = torch.chunk(input_, world_size, dim=-1)
output = input_list[rank].contiguous()
return output
# 包装函数
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
Sequence Parallel 实现
Sequence Parallel 原理
┌─────────────────────────────────────────────────────────────────────────┐
│ Sequence Parallel 原理 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 问题: Tensor Parallel中, LayerNorm和Dropout需要完整序列 │
│ 这部分不能并行, 成为显存瓶颈 │
│ │
│ 解决: 将序列维度也分片, 在需要时通过通信重建 │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 普通 Tensor Parallel: │ │
│ │ │ │
│ │ Input [S, B, H] │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ LayerNorm [S, B, H] <── 需要完整H, 每个rank持有完整activation │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ColumnLinear <── 每个rank计算 [S, B, H] @ [H, H/TP] │ │
│ │ [S, B, H/TP] │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ RowLinear + AllReduce <── AllReduce恢复完整输出 │ │
│ │ [S, B, H] │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Sequence Parallel: │ │
│ │ │ │
│ │ Input [S/TP, B, H] <── 序列也分片! │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ LayerNorm [S/TP, B, H] <── 每个rank只处理部分序列 │ │
│ │ │ │ │
│ │ ▼ AllGather along S │ │
│ │ [S, B, H] │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ColumnLinear │ │
│ │ [S, B, H/TP] │ │
│ │ │ │ │
│ │ ▼ ReduceScatter along S │ │
│ │ RowLinear │ │
│ │ [S/TP, B, H] <── 输出也是序列分片 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ 显存节省: │
│ • LayerNorm激活: H → H/TP │
│ • Dropout mask: S → S/TP │
│ • 总节省约 2x (取决于配置) │
│ │
│ 通信开销: │
│ • AllGather: 每个Transformer层forward 1次 │
│ • ReduceScatter: 每个Transformer层forward 1次 │
│ • 与AllReduce通信量相同, 只是模式不同 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Sequence Parallel 实现
# megatron/core/tensor_parallel/layers.py (Sequence Parallel部分)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""
Forward: AllGather along sequence dimension
Backward: ReduceScatter gradient
"""
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""
Forward: ReduceScatter
Backward: AllGather gradient
"""
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
def _gather_along_first_dim(input_):
"""AllGather along first dimension (sequence)"""
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
torch.distributed.all_gather_into_tensor(
output,
input_.contiguous(),
group=get_tensor_model_parallel_group()
)
return output
def _reduce_scatter_along_first_dim(input_):
"""ReduceScatter along first dimension (sequence)"""
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
torch.distributed.reduce_scatter_tensor(
output,
input_.contiguous(),
group=get_tensor_model_parallel_group()
)
return output
# Sequence Parallel LayerNorm
class LayerNorm(torch.nn.Module):
"""
Sequence Parallel aware LayerNorm
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
):
super().__init__()
self.sequence_parallel = sequence_parallel
self.hidden_size = hidden_size
self.eps = eps
self.weight = Parameter(torch.ones(hidden_size))
self.bias = Parameter(torch.zeros(hidden_size))
def forward(self, input_):
# 输入: [S/TP, B, H] 如果sequence_parallel
# [S, B, H] 否则
# LayerNorm对每个token独立计算, 不需要跨序列聚合
# 所以可以直接在分片序列上计算
output = F.layer_norm(input_, (self.hidden_size,),
self.weight, self.bias, self.eps)
return output
Parallel Transformer 实现
# megatron/model/transformer.py
class ParallelTransformerLayer(MegatronModule):
"""
并行Transformer层
结合Tensor Parallel和Sequence Parallel
"""
def __init__(
self,
config,
layer_number,
self_attn_mask_type=AttnMaskType.padding,
):
super().__init__()
self.layer_number = layer_number
self.sequence_parallel = config.sequence_parallel
# LayerNorm 1
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel,
)
# Self-Attention
self.self_attention = ParallelAttention(
config,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
)
# LayerNorm 2
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel,
)
# MLP
self.mlp = ParallelMLP(config)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
# 输入: [S/TP, B, H] (如果sequence_parallel) 或 [S, B, H]
# ==== Self-Attention Block ====
# 1. LayerNorm
layernorm_output = self.input_layernorm(hidden_states)
# 2. Attention
attention_output, attention_bias = self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params,
)
# 3. Residual connection
# 如果sequence_parallel, attention_output已经是[S/TP, B, H]
if self.sequence_parallel:
# Bias需要特殊处理 (先AllReduce再加到分片输出上)
hidden_states = hidden_states + attention_output + attention_bias
else:
hidden_states = hidden_states + attention_output + attention_bias
# ==== MLP Block ====
# 1. LayerNorm
layernorm_output = self.post_attention_layernorm(hidden_states)
# 2. MLP
mlp_output, mlp_bias = self.mlp(layernorm_output)
# 3. Residual connection
hidden_states = hidden_states + mlp_output + mlp_bias
return hidden_states
class ParallelAttention(MegatronModule):
"""
并行Attention
Q, K, V的投影使用ColumnParallelLinear
Output投影使用RowParallelLinear
"""
def __init__(self, config, layer_number, attention_type, attn_mask_type):
super().__init__()
self.config = config
self.layer_number = layer_number
# 每个head的维度
self.hidden_size_per_attention_head = (
config.hidden_size // config.num_attention_heads
)
# TP分片后每个rank的head数
self.num_attention_heads_per_partition = divide(
config.num_attention_heads,
get_tensor_model_parallel_world_size()
)
# QKV投影 (ColumnParallel)
self.query_key_value = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size,
gather_output=False, # 不gather, 直接传给下一层
init_method=config.init_method,
sequence_parallel_enabled=config.sequence_parallel,
)
# Attention计算
self.core_attention = CoreAttention(
config, layer_number, attn_mask_type
)
# Output投影 (RowParallel)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
input_is_parallel=True, # 输入来自ColumnParallel, 已经分片
init_method=config.output_layer_init_method,
skip_bias_add=True,
sequence_parallel_enabled=config.sequence_parallel,
)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
inference_params=None,
):
# 1. QKV投影
# 输入: [S/TP, B, H] (sequence_parallel) 或 [S, B, H]
# 输出: [S, B, 3*H/TP] (ColumnParallel分片后)
mixed_x_layer, _ = self.query_key_value(hidden_states)
# 2. 拆分Q, K, V
# [S, B, 3*H/TP] -> [S, B, num_heads/TP, 3*head_dim]
new_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_shape)
# 分离Q, K, V
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer,
self.hidden_size_per_attention_head,
dim=-1
)
# 3. Attention计算
# 每个rank只计算自己负责的head
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
# 4. Output投影
# 输入: [S, B, H/TP]
# 输出: [S/TP, B, H] (sequence_parallel) 或 [S, B, H]
output, output_bias = self.dense(context_layer)
return output, output_bias
class ParallelMLP(MegatronModule):
"""
并行MLP
First Linear: ColumnParallel
Second Linear: RowParallel
"""
def __init__(self, config):
super().__init__()
# 第一层: hidden_size -> 4*hidden_size (按列分片)
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
gather_output=False,
init_method=config.init_method,
skip_bias_add=True,
sequence_parallel_enabled=config.sequence_parallel,
)
# 激活函数
if config.activation_func == 'gelu':
self.activation_func = F.gelu
elif config.activation_func == 'swiglu':
# SwiGLU需要特殊处理
self.activation_func = self._swiglu
else:
self.activation_func = F.relu
# 第二层: 4*hidden_size -> hidden_size (按行分片)
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
input_is_parallel=True,
init_method=config.output_layer_init_method,
skip_bias_add=True,
sequence_parallel_enabled=config.sequence_parallel,
)
def forward(self, hidden_states):
# 1. 第一层 + 激活
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# 2. 第二层
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
总结
Megatron-LM 核心技术
┌─────────────────────────────────────────────────────────────────────────┐
│ Megatron-LM 核心技术 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Tensor Parallel (TP) │
│ ──────────────────── │
│ • ColumnParallelLinear: 权重按列分片, 输出也分片 │
│ • RowParallelLinear: 权重按行分片, 需要AllReduce输出 │
│ • 成对使用减少通信 (Column后接Row, 中间不AllReduce) │
│ │
│ Sequence Parallel (SP) │
│ ────────────────────── │
│ • 序列维度也分片, 减少激活显存 │
│ • AllGather替代复制, ReduceScatter替代AllReduce │
│ • 通信量相同, 但显存更省 │
│ │
│ MPU (Model Parallel Unit) │
│ ───────────────────────── │
│ • 管理TP/PP/DP三种并行组 │
│ • GPU逻辑布局与物理布局映射 │
│ • 提供并行rank和world_size查询 │
│ │
│ Pipeline Parallel (PP) │
│ ────────────────────── │
│ • 模型按层切分到不同stage │
│ • 1F1B调度最小化bubble │
│ • P2P通信传递activation │
│ │
└─────────────────────────────────────────────────────────────────────────┘
面试高频问题
- ColumnParallelLinear和RowParallelLinear如何成对使用?
- Sequence Parallel如何减少显存占用?
- Megatron-LM的3D并行如何组织GPU?
- Tensor Parallel的通信原语有哪些?
- 为什么Sequence Parallel的通信量与Tensor Parallel相同?