HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

03-Megatron-LM源码解析

概述

Megatron-LM是NVIDIA开发的大规模语言模型训练框架,首创了Tensor Parallel和高效的3D并行策略。本章深入解析Megatron-LM的核心实现,包括张量并行、序列并行、以及与DeepSpeed的集成。

Megatron-LM 架构

整体架构

┌─────────────────────────────────────────────────────────────────────────┐
│                      Megatron-LM 架构                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                        用户接口层                                │    │
│  │                                                                  │    │
│  │   pretrain_gpt.py        arguments.py        training.py        │    │
│  │   (GPT预训练入口)        (参数解析)          (训练循环)         │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                        模型定义层                                │    │
│  │                                                                  │    │
│  │  ┌────────────────────────────────────────────────────────────┐ │    │
│  │  │  GPTModel / BertModel / T5Model                            │ │    │
│  │  │    ├─ ParallelTransformer                                  │ │    │
│  │  │    │    ├─ ParallelTransformerLayer                       │ │    │
│  │  │    │    │    ├─ ParallelAttention                         │ │    │
│  │  │    │    │    └─ ParallelMLP                               │ │    │
│  │  │    │    └─ sequence parallelism hooks                     │ │    │
│  │  │    ├─ Embedding                                            │ │    │
│  │  │    └─ Output layer                                         │ │    │
│  │  └────────────────────────────────────────────────────────────┘ │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                        并行原语层                                │    │
│  │                                                                  │    │
│  │  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │    │
│  │  │ Tensor      │ │ Pipeline    │ │ Sequence    │ │ Data      │ │    │
│  │  │ Parallel    │ │ Parallel    │ │ Parallel    │ │ Parallel  │ │    │
│  │  │ (TP)        │ │ (PP)        │ │ (SP)        │ │ (DP)      │ │    │
│  │  └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                        MPU (Model Parallel Unit)                 │    │
│  │                                                                  │    │
│  │  parallel_state.py:                                              │    │
│  │  • tensor_model_parallel_group                                   │    │
│  │  • pipeline_model_parallel_group                                 │    │
│  │  • data_parallel_group                                           │    │
│  │  • get_tensor_model_parallel_rank/world_size                     │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

源码目录结构

megatron/
├── arguments.py                 # 参数解析
├── training.py                  # 训练循环
├── initialize.py               # 初始化
├── global_vars.py              # 全局变量
│
├── core/
│   ├── parallel_state.py       # 并行状态管理 (MPU)
│   ├── tensor_parallel/
│   │   ├── layers.py           # TP层实现
│   │   ├── mappings.py         # 通信原语
│   │   └── utils.py
│   ├── pipeline_parallel/
│   │   ├── schedules.py        # PP调度
│   │   └── p2p_communication.py
│   └── sequence_parallel/
│       └── layers.py           # SP层实现
│
├── model/
│   ├── gpt_model.py            # GPT模型
│   ├── language_model.py       # 语言模型基类
│   ├── transformer.py          # Transformer实现
│   └── module.py               # 模块基类
│
├── data/
│   ├── gpt_dataset.py          # 数据集
│   └── data_samplers.py        # 采样器
│
└── optimizer/
    ├── optimizer.py            # 优化器
    └── grad_scaler.py          # 梯度缩放

Model Parallel Unit (MPU)

并行组初始化

# megatron/core/parallel_state.py

# 全局并行组变量
_TENSOR_MODEL_PARALLEL_GROUP = None
_PIPELINE_MODEL_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP = None
_EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None

def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    pipeline_model_parallel_split_rank: Optional[int] = None,
):
    """
    初始化模型并行组

    假设world_size = 16, TP=2, PP=4, 则DP=2

    GPU布局:
    ┌────────────────────────────────────────────────────────────┐
    │                                                             │
    │  DP组0:                        DP组1:                       │
    │  ┌──────────────────────┐     ┌──────────────────────────┐ │
    │  │ PP Stage 0           │     │ PP Stage 0               │ │
    │  │ ┌─────┐ ┌─────┐     │     │ ┌─────┐ ┌─────┐         │ │
    │  │ │GPU 0│ │GPU 1│ TP组│     │ │GPU 8│ │GPU 9│ TP组    │ │
    │  │ └─────┘ └─────┘     │     │ └─────┘ └─────┘         │ │
    │  │                      │     │                          │ │
    │  │ PP Stage 1           │     │ PP Stage 1               │ │
    │  │ ┌─────┐ ┌─────┐     │     │ ┌─────┐ ┌─────┐         │ │
    │  │ │GPU 2│ │GPU 3│     │     │ │GPU10│ │GPU11│         │ │
    │  │ └─────┘ └─────┘     │     │ └─────┘ └─────┘         │ │
    │  │                      │     │                          │ │
    │  │ PP Stage 2           │     │ PP Stage 2               │ │
    │  │ ┌─────┐ ┌─────┐     │     │ ┌─────┐ ┌─────┐         │ │
    │  │ │GPU 4│ │GPU 5│     │     │ │GPU12│ │GPU13│         │ │
    │  │ └─────┘ └─────┘     │     │ └─────┘ └─────┘         │ │
    │  │                      │     │                          │ │
    │  │ PP Stage 3           │     │ PP Stage 3               │ │
    │  │ ┌─────┐ ┌─────┐     │     │ ┌─────┐ ┌─────┐         │ │
    │  │ │GPU 6│ │GPU 7│     │     │ │GPU14│ │GPU15│         │ │
    │  │ └─────┘ └─────┘     │     │ └─────┘ └─────┘         │ │
    │  └──────────────────────┘     └──────────────────────────┘ │
    │                                                             │
    └────────────────────────────────────────────────────────────┘
    """
    global _TENSOR_MODEL_PARALLEL_GROUP
    global _PIPELINE_MODEL_PARALLEL_GROUP
    global _DATA_PARALLEL_GROUP

    # 计算数据并行大小
    world_size = torch.distributed.get_world_size()
    data_parallel_size = world_size // (tensor_model_parallel_size *
                                        pipeline_model_parallel_size)

    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
    num_data_parallel_groups = world_size // data_parallel_size

    rank = torch.distributed.get_rank()

    # ==== 创建Tensor Parallel组 ====
    # 同一个TP组内的GPU需要AllReduce
    # 例如: [0,1], [2,3], [4,5], ... 每组2个GPU
    for i in range(num_tensor_model_parallel_groups):
        start_rank = i * tensor_model_parallel_size
        end_rank = start_rank + tensor_model_parallel_size
        ranks = list(range(start_rank, end_rank))

        group = torch.distributed.new_group(ranks)

        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group

    # ==== 创建Pipeline Parallel组 ====
    # 同一个PP组内的GPU需要P2P通信
    # 例如: [0,2,4,6], [1,3,5,7], [8,10,12,14], [9,11,13,15]
    for i in range(num_pipeline_model_parallel_groups):
        ranks = []
        for j in range(pipeline_model_parallel_size):
            # 计算每个stage的rank
            rank_in_pipeline = (i % data_parallel_size) * tensor_model_parallel_size + \
                              (i // data_parallel_size) + \
                              j * (data_parallel_size * tensor_model_parallel_size)
            ranks.append(rank_in_pipeline)

        group = torch.distributed.new_group(ranks)

        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group

    # ==== 创建Data Parallel组 ====
    # 同一个DP组内的GPU需要AllReduce梯度
    # 例如: [0,8], [1,9], [2,10], ...
    for i in range(num_data_parallel_groups):
        ranks = []
        for j in range(data_parallel_size):
            ranks.append(i + j * num_data_parallel_groups)

        group = torch.distributed.new_group(ranks)

        if rank in ranks:
            _DATA_PARALLEL_GROUP = group


# 辅助函数
def get_tensor_model_parallel_group():
    return _TENSOR_MODEL_PARALLEL_GROUP

def get_tensor_model_parallel_rank():
    return torch.distributed.get_rank(group=_TENSOR_MODEL_PARALLEL_GROUP)

def get_tensor_model_parallel_world_size():
    return torch.distributed.get_world_size(group=_TENSOR_MODEL_PARALLEL_GROUP)

def get_pipeline_model_parallel_group():
    return _PIPELINE_MODEL_PARALLEL_GROUP

def get_pipeline_model_parallel_rank():
    return torch.distributed.get_rank(group=_PIPELINE_MODEL_PARALLEL_GROUP)

def get_data_parallel_group():
    return _DATA_PARALLEL_GROUP

Tensor Parallel 实现

ColumnParallelLinear

# megatron/core/tensor_parallel/layers.py

class ColumnParallelLinear(torch.nn.Module):
    """
    列并行线性层

    将权重按列切分到不同GPU:
    Y = XA, 其中A按列切分: A = [A_1, A_2, ..., A_n]
    Y_i = X @ A_i

    ┌─────────────────────────────────────────────────────────┐
    │                                                          │
    │       X (input)                                          │
    │   ┌───────────────┐                                     │
    │   │               │                                     │
    │   │   [B, S, H]   │     广播到所有TP rank               │
    │   │               │                                     │
    │   └───────────────┘                                     │
    │           │                                              │
    │           ▼                                              │
    │   ┌───────────────────────────────────────────────────┐ │
    │   │              Weight A                              │ │
    │   │                                                    │ │
    │   │  ┌─────────┬─────────┬─────────┬─────────┐       │ │
    │   │  │   A_0   │   A_1   │   A_2   │   A_3   │       │ │
    │   │  │ [H,H/4] │ [H,H/4] │ [H,H/4] │ [H,H/4] │       │ │
    │   │  │  GPU 0  │  GPU 1  │  GPU 2  │  GPU 3  │       │ │
    │   │  └─────────┴─────────┴─────────┴─────────┘       │ │
    │   └───────────────────────────────────────────────────┘ │
    │           │                                              │
    │           ▼                                              │
    │   ┌───────────────────────────────────────────────────┐ │
    │   │              Output Y                              │ │
    │   │                                                    │ │
    │   │  ┌─────────┬─────────┬─────────┬─────────┐       │ │
    │   │  │   Y_0   │   Y_1   │   Y_2   │   Y_3   │       │ │
    │   │  │[B,S,H/4]│[B,S,H/4]│[B,S,H/4]│[B,S,H/4]│       │ │
    │   │  │  GPU 0  │  GPU 1  │  GPU 2  │  GPU 3  │       │ │
    │   │  └─────────┴─────────┴─────────┴─────────┘       │ │
    │   │                                                    │ │
    │   │  后续通常是AllGather (如果需要完整输出)             │ │
    │   │  或直接输入到RowParallelLinear                     │ │
    │   └───────────────────────────────────────────────────┘ │
    │                                                          │
    └─────────────────────────────────────────────────────────┘
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = True,
        init_method=init.xavier_normal_,
        stride: int = 1,
        keep_master_weight_for_test: bool = False,
        skip_bias_add: bool = False,
        async_tensor_model_parallel_allreduce: bool = True,
        sequence_parallel_enabled: bool = False,
    ):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        self.skip_bias_add = skip_bias_add
        self.sequence_parallel = sequence_parallel_enabled

        # 获取TP配置
        world_size = get_tensor_model_parallel_world_size()
        self.output_size_per_partition = divide(output_size, world_size)

        # 初始化权重 (只初始化本partition)
        self.weight = Parameter(torch.empty(
            self.output_size_per_partition,
            self.input_size,
            device=torch.cuda.current_device(),
            dtype=torch.float32
        ))
        init_method(self.weight)

        # Bias (只有rank 0初始化完整bias, 其他rank初始化本partition)
        if bias:
            self.bias = Parameter(torch.empty(
                self.output_size_per_partition,
                device=torch.cuda.current_device(),
                dtype=torch.float32
            ))
            # 初始化为0
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)

        self.async_tensor_model_parallel_allreduce = async_tensor_model_parallel_allreduce

    def forward(self, input_):
        """
        Forward pass

        如果sequence_parallel: 输入是[S/TP, B, H], 需要先AllGather
        否则: 输入是[S, B, H], 直接计算
        """
        if self.sequence_parallel:
            # Sequence Parallel: AllGather输入
            input_parallel = gather_from_sequence_parallel_region(input_)
        else:
            input_parallel = input_

        # 异步AllReduce优化 (用于减少TP AllReduce延迟)
        if self.async_tensor_model_parallel_allreduce:
            input_parallel = copy_to_tensor_model_parallel_region(input_parallel)

        # 线性计算: Y = X @ W^T + b
        output_parallel = F.linear(input_parallel, self.weight, self.bias)

        if self.gather_output:
            # AllGather输出
            output = gather_from_tensor_model_parallel_region(output_parallel)
        else:
            output = output_parallel

        if self.skip_bias_add:
            return output, self.bias
        return output


class RowParallelLinear(torch.nn.Module):
    """
    行并行线性层

    将权重按行切分到不同GPU:
    Y = XA, 其中X按列切分: X = [X_1, X_2, ..., X_n]
    A按行切分: A = [A_1; A_2; ...; A_n]
    Y = sum(X_i @ A_i)

    通常接在ColumnParallelLinear之后
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = False,
        init_method=init.xavier_normal_,
        stride: int = 1,
        keep_master_weight_for_test: bool = False,
        skip_bias_add: bool = False,
        sequence_parallel_enabled: bool = False,
    ):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        self.skip_bias_add = skip_bias_add
        self.sequence_parallel = sequence_parallel_enabled

        # 获取TP配置
        world_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, world_size)

        # 初始化权重 (只初始化本partition)
        self.weight = Parameter(torch.empty(
            self.output_size,
            self.input_size_per_partition,
            device=torch.cuda.current_device(),
            dtype=torch.float32
        ))
        init_method(self.weight)

        # Bias (完整bias, 但只在AllReduce后加一次)
        if bias:
            self.bias = Parameter(torch.empty(
                self.output_size,
                device=torch.cuda.current_device(),
                dtype=torch.float32
            ))
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)

    def forward(self, input_):
        """
        Forward pass

        输入: [B, S, H/TP] (来自ColumnParallelLinear, 已经是分片的)
        输出: [B, S, H] (需要AllReduce求和)
        """
        if self.input_is_parallel:
            input_parallel = input_
        else:
            # 切分输入
            input_parallel = scatter_to_tensor_model_parallel_region(input_)

        # 线性计算
        output_parallel = F.linear(input_parallel, self.weight)

        # AllReduce求和
        if self.sequence_parallel:
            # Sequence Parallel: ReduceScatter
            output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
        else:
            # 标准TP: AllReduce
            output_ = reduce_from_tensor_model_parallel_region(output_parallel)

        # 加bias (只加一次, 不是每个rank都加)
        if self.bias is not None and not self.skip_bias_add:
            output = output_ + self.bias
        else:
            output = output_

        if self.skip_bias_add:
            return output, self.bias
        return output

Tensor Parallel 通信原语

# megatron/core/tensor_parallel/mappings.py

class _CopyToModelParallelRegion(torch.autograd.Function):
    """
    Forward: 复制 (无操作)
    Backward: AllReduce梯度
    """

    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        # AllReduce梯度
        return _reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
    """
    Forward: AllReduce
    Backward: 复制 (无操作)
    """

    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class _GatherFromModelParallelRegion(torch.autograd.Function):
    """
    Forward: AllGather
    Backward: 切分梯度 (取对应partition)
    """

    @staticmethod
    def forward(ctx, input_):
        return _gather(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _split(grad_output)


class _ScatterToModelParallelRegion(torch.autograd.Function):
    """
    Forward: 切分 (取对应partition)
    Backward: AllGather梯度
    """

    @staticmethod
    def forward(ctx, input_):
        return _split(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather(grad_output)


def _reduce(input_):
    """AllReduce"""
    if get_tensor_model_parallel_world_size() == 1:
        return input_

    torch.distributed.all_reduce(
        input_,
        group=get_tensor_model_parallel_group()
    )
    return input_


def _gather(input_):
    """AllGather along last dimension"""
    world_size = get_tensor_model_parallel_world_size()
    if world_size == 1:
        return input_

    # Gather
    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    torch.distributed.all_gather(
        tensor_list,
        input_,
        group=get_tensor_model_parallel_group()
    )

    # Concatenate along last dimension
    output = torch.cat(tensor_list, dim=-1)
    return output


def _split(input_):
    """Split along last dimension"""
    world_size = get_tensor_model_parallel_world_size()
    if world_size == 1:
        return input_

    rank = get_tensor_model_parallel_rank()

    # Split
    input_list = torch.chunk(input_, world_size, dim=-1)
    output = input_list[rank].contiguous()
    return output


# 包装函数
def copy_to_tensor_model_parallel_region(input_):
    return _CopyToModelParallelRegion.apply(input_)

def reduce_from_tensor_model_parallel_region(input_):
    return _ReduceFromModelParallelRegion.apply(input_)

def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)

def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)

Sequence Parallel 实现

Sequence Parallel 原理

┌─────────────────────────────────────────────────────────────────────────┐
│                    Sequence Parallel 原理                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  问题: Tensor Parallel中, LayerNorm和Dropout需要完整序列                 │
│       这部分不能并行, 成为显存瓶颈                                        │
│                                                                          │
│  解决: 将序列维度也分片, 在需要时通过通信重建                             │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                                                                  │    │
│  │  普通 Tensor Parallel:                                           │    │
│  │                                                                  │    │
│  │  Input [S, B, H]                                                │    │
│  │       │                                                          │    │
│  │       ▼                                                          │    │
│  │  LayerNorm [S, B, H]  <── 需要完整H, 每个rank持有完整activation   │    │
│  │       │                                                          │    │
│  │       ▼                                                          │    │
│  │  ColumnLinear          <── 每个rank计算 [S, B, H] @ [H, H/TP]    │    │
│  │  [S, B, H/TP]                                                   │    │
│  │       │                                                          │    │
│  │       ▼                                                          │    │
│  │  RowLinear + AllReduce  <── AllReduce恢复完整输出                │    │
│  │  [S, B, H]                                                      │    │
│  │                                                                  │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                                                                  │    │
│  │  Sequence Parallel:                                              │    │
│  │                                                                  │    │
│  │  Input [S/TP, B, H]    <── 序列也分片!                          │    │
│  │       │                                                          │    │
│  │       ▼                                                          │    │
│  │  LayerNorm [S/TP, B, H]  <── 每个rank只处理部分序列              │    │
│  │       │                                                          │    │
│  │       ▼  AllGather along S                                       │    │
│  │  [S, B, H]                                                      │    │
│  │       │                                                          │    │
│  │       ▼                                                          │    │
│  │  ColumnLinear                                                    │    │
│  │  [S, B, H/TP]                                                   │    │
│  │       │                                                          │    │
│  │       ▼  ReduceScatter along S                                  │    │
│  │  RowLinear                                                       │    │
│  │  [S/TP, B, H]           <── 输出也是序列分片                     │    │
│  │                                                                  │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
│  显存节省:                                                               │
│  • LayerNorm激活: H → H/TP                                              │
│  • Dropout mask: S → S/TP                                               │
│  • 总节省约 2x (取决于配置)                                              │
│                                                                          │
│  通信开销:                                                               │
│  • AllGather: 每个Transformer层forward 1次                              │
│  • ReduceScatter: 每个Transformer层forward 1次                          │
│  • 与AllReduce通信量相同, 只是模式不同                                   │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Sequence Parallel 实现

# megatron/core/tensor_parallel/layers.py (Sequence Parallel部分)

class _GatherFromSequenceParallelRegion(torch.autograd.Function):
    """
    Forward: AllGather along sequence dimension
    Backward: ReduceScatter gradient
    """

    @staticmethod
    def forward(ctx, input_, tensor_parallel_output_grad=True):
        ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
        return _gather_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
        if tensor_parallel_output_grad:
            return _reduce_scatter_along_first_dim(grad_output), None
        else:
            return _split_along_first_dim(grad_output), None


class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
    """
    Forward: ReduceScatter
    Backward: AllGather gradient
    """

    @staticmethod
    def forward(ctx, input_):
        return _reduce_scatter_along_first_dim(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather_along_first_dim(grad_output)


def _gather_along_first_dim(input_):
    """AllGather along first dimension (sequence)"""
    world_size = get_tensor_model_parallel_world_size()
    if world_size == 1:
        return input_

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

    output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)

    torch.distributed.all_gather_into_tensor(
        output,
        input_.contiguous(),
        group=get_tensor_model_parallel_group()
    )

    return output


def _reduce_scatter_along_first_dim(input_):
    """ReduceScatter along first dimension (sequence)"""
    world_size = get_tensor_model_parallel_world_size()
    if world_size == 1:
        return input_

    dim_size = list(input_.size())
    assert dim_size[0] % world_size == 0
    dim_size[0] = dim_size[0] // world_size

    output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)

    torch.distributed.reduce_scatter_tensor(
        output,
        input_.contiguous(),
        group=get_tensor_model_parallel_group()
    )

    return output


# Sequence Parallel LayerNorm
class LayerNorm(torch.nn.Module):
    """
    Sequence Parallel aware LayerNorm
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
    ):
        super().__init__()
        self.sequence_parallel = sequence_parallel
        self.hidden_size = hidden_size
        self.eps = eps

        self.weight = Parameter(torch.ones(hidden_size))
        self.bias = Parameter(torch.zeros(hidden_size))

    def forward(self, input_):
        # 输入: [S/TP, B, H] 如果sequence_parallel
        #      [S, B, H] 否则

        # LayerNorm对每个token独立计算, 不需要跨序列聚合
        # 所以可以直接在分片序列上计算
        output = F.layer_norm(input_, (self.hidden_size,),
                             self.weight, self.bias, self.eps)

        return output

Parallel Transformer 实现

# megatron/model/transformer.py

class ParallelTransformerLayer(MegatronModule):
    """
    并行Transformer层

    结合Tensor Parallel和Sequence Parallel
    """

    def __init__(
        self,
        config,
        layer_number,
        self_attn_mask_type=AttnMaskType.padding,
    ):
        super().__init__()

        self.layer_number = layer_number
        self.sequence_parallel = config.sequence_parallel

        # LayerNorm 1
        self.input_layernorm = LayerNorm(
            config.hidden_size,
            eps=config.layernorm_epsilon,
            sequence_parallel=config.sequence_parallel,
        )

        # Self-Attention
        self.self_attention = ParallelAttention(
            config,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type,
        )

        # LayerNorm 2
        self.post_attention_layernorm = LayerNorm(
            config.hidden_size,
            eps=config.layernorm_epsilon,
            sequence_parallel=config.sequence_parallel,
        )

        # MLP
        self.mlp = ParallelMLP(config)

    def forward(
        self,
        hidden_states,
        attention_mask,
        encoder_output=None,
        enc_dec_attn_mask=None,
        inference_params=None,
    ):
        # 输入: [S/TP, B, H] (如果sequence_parallel) 或 [S, B, H]

        # ==== Self-Attention Block ====
        # 1. LayerNorm
        layernorm_output = self.input_layernorm(hidden_states)

        # 2. Attention
        attention_output, attention_bias = self.self_attention(
            layernorm_output,
            attention_mask,
            inference_params=inference_params,
        )

        # 3. Residual connection
        # 如果sequence_parallel, attention_output已经是[S/TP, B, H]
        if self.sequence_parallel:
            # Bias需要特殊处理 (先AllReduce再加到分片输出上)
            hidden_states = hidden_states + attention_output + attention_bias
        else:
            hidden_states = hidden_states + attention_output + attention_bias

        # ==== MLP Block ====
        # 1. LayerNorm
        layernorm_output = self.post_attention_layernorm(hidden_states)

        # 2. MLP
        mlp_output, mlp_bias = self.mlp(layernorm_output)

        # 3. Residual connection
        hidden_states = hidden_states + mlp_output + mlp_bias

        return hidden_states


class ParallelAttention(MegatronModule):
    """
    并行Attention

    Q, K, V的投影使用ColumnParallelLinear
    Output投影使用RowParallelLinear
    """

    def __init__(self, config, layer_number, attention_type, attn_mask_type):
        super().__init__()

        self.config = config
        self.layer_number = layer_number

        # 每个head的维度
        self.hidden_size_per_attention_head = (
            config.hidden_size // config.num_attention_heads
        )

        # TP分片后每个rank的head数
        self.num_attention_heads_per_partition = divide(
            config.num_attention_heads,
            get_tensor_model_parallel_world_size()
        )

        # QKV投影 (ColumnParallel)
        self.query_key_value = ColumnParallelLinear(
            config.hidden_size,
            3 * config.hidden_size,
            gather_output=False,  # 不gather, 直接传给下一层
            init_method=config.init_method,
            sequence_parallel_enabled=config.sequence_parallel,
        )

        # Attention计算
        self.core_attention = CoreAttention(
            config, layer_number, attn_mask_type
        )

        # Output投影 (RowParallel)
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            input_is_parallel=True,  # 输入来自ColumnParallel, 已经分片
            init_method=config.output_layer_init_method,
            skip_bias_add=True,
            sequence_parallel_enabled=config.sequence_parallel,
        )

    def forward(
        self,
        hidden_states,
        attention_mask,
        encoder_output=None,
        inference_params=None,
    ):
        # 1. QKV投影
        # 输入: [S/TP, B, H] (sequence_parallel) 或 [S, B, H]
        # 输出: [S, B, 3*H/TP] (ColumnParallel分片后)
        mixed_x_layer, _ = self.query_key_value(hidden_states)

        # 2. 拆分Q, K, V
        # [S, B, 3*H/TP] -> [S, B, num_heads/TP, 3*head_dim]
        new_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads_per_partition,
            3 * self.hidden_size_per_attention_head,
        )
        mixed_x_layer = mixed_x_layer.view(*new_shape)

        # 分离Q, K, V
        (query_layer, key_layer, value_layer) = torch.split(
            mixed_x_layer,
            self.hidden_size_per_attention_head,
            dim=-1
        )

        # 3. Attention计算
        # 每个rank只计算自己负责的head
        context_layer = self.core_attention(
            query_layer, key_layer, value_layer, attention_mask
        )

        # 4. Output投影
        # 输入: [S, B, H/TP]
        # 输出: [S/TP, B, H] (sequence_parallel) 或 [S, B, H]
        output, output_bias = self.dense(context_layer)

        return output, output_bias


class ParallelMLP(MegatronModule):
    """
    并行MLP

    First Linear: ColumnParallel
    Second Linear: RowParallel
    """

    def __init__(self, config):
        super().__init__()

        # 第一层: hidden_size -> 4*hidden_size (按列分片)
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            gather_output=False,
            init_method=config.init_method,
            skip_bias_add=True,
            sequence_parallel_enabled=config.sequence_parallel,
        )

        # 激活函数
        if config.activation_func == 'gelu':
            self.activation_func = F.gelu
        elif config.activation_func == 'swiglu':
            # SwiGLU需要特殊处理
            self.activation_func = self._swiglu
        else:
            self.activation_func = F.relu

        # 第二层: 4*hidden_size -> hidden_size (按行分片)
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            input_is_parallel=True,
            init_method=config.output_layer_init_method,
            skip_bias_add=True,
            sequence_parallel_enabled=config.sequence_parallel,
        )

    def forward(self, hidden_states):
        # 1. 第一层 + 激活
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

        if bias_parallel is not None:
            intermediate_parallel = intermediate_parallel + bias_parallel

        intermediate_parallel = self.activation_func(intermediate_parallel)

        # 2. 第二层
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)

        return output, output_bias

总结

Megatron-LM 核心技术

┌─────────────────────────────────────────────────────────────────────────┐
│                     Megatron-LM 核心技术                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Tensor Parallel (TP)                                                   │
│  ────────────────────                                                    │
│  • ColumnParallelLinear: 权重按列分片, 输出也分片                        │
│  • RowParallelLinear: 权重按行分片, 需要AllReduce输出                    │
│  • 成对使用减少通信 (Column后接Row, 中间不AllReduce)                     │
│                                                                          │
│  Sequence Parallel (SP)                                                  │
│  ──────────────────────                                                  │
│  • 序列维度也分片, 减少激活显存                                          │
│  • AllGather替代复制, ReduceScatter替代AllReduce                        │
│  • 通信量相同, 但显存更省                                                │
│                                                                          │
│  MPU (Model Parallel Unit)                                              │
│  ─────────────────────────                                               │
│  • 管理TP/PP/DP三种并行组                                                │
│  • GPU逻辑布局与物理布局映射                                             │
│  • 提供并行rank和world_size查询                                          │
│                                                                          │
│  Pipeline Parallel (PP)                                                  │
│  ──────────────────────                                                  │
│  • 模型按层切分到不同stage                                               │
│  • 1F1B调度最小化bubble                                                  │
│  • P2P通信传递activation                                                 │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

面试高频问题

  1. ColumnParallelLinear和RowParallelLinear如何成对使用?
  2. Sequence Parallel如何减少显存占用?
  3. Megatron-LM的3D并行如何组织GPU?
  4. Tensor Parallel的通信原语有哪些?
  5. 为什么Sequence Parallel的通信量与Tensor Parallel相同?
Prev
02-DeepSpeed源码深度解析
Next
04-vLLM推理引擎源码解析