HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

01-PyTorch分布式源码解析

概述

PyTorch是最流行的深度学习框架之一,其分布式训练模块经过多年发展已经非常成熟。本章深入解析PyTorch分布式训练的核心源码,包括torch.distributed、DistributedDataParallel(DDP)、以及FSDP的实现原理。

PyTorch 分布式架构

整体架构

┌─────────────────────────────────────────────────────────────────────────┐
│                    PyTorch 分布式架构                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                         用户 API 层                              │    │
│  │                                                                  │    │
│  │  torch.distributed        DistributedDataParallel    FSDP       │    │
│  │  ├─ init_process_group   ├─ DDP(model)               ├─ FSDP()  │    │
│  │  ├─ all_reduce           ├─ forward/backward         ├─ shard   │    │
│  │  ├─ broadcast            └─ gradient sync            └─ gather  │    │
│  │  └─ all_gather                                                   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                      ProcessGroup 层                             │    │
│  │                                                                  │    │
│  │  ProcessGroup              ProcessGroupNCCL    ProcessGroupGloo  │    │
│  │  ├─ size()                 ├─ NCCL backend     ├─ Gloo backend  │    │
│  │  ├─ rank()                 ├─ GPU集合通信       ├─ CPU集合通信   │    │
│  │  └─ work()                 └─ 高性能           └─ 通用          │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                      通信后端层                                  │    │
│  │                                                                  │    │
│  │     NCCL             Gloo              MPI             UCC       │    │
│  │     ├─ Ring          ├─ Allreduce      ├─ OpenMPI      ├─ UCX   │    │
│  │     ├─ Tree          ├─ Broadcast      └─ MPICH        └─ 异构  │    │
│  │     └─ P2P           └─ P2P                                      │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                      网络/硬件层                                 │    │
│  │                                                                  │    │
│  │     NVLink/NVSwitch        InfiniBand           Ethernet         │    │
│  │     (节点内GPU)            (节点间高速)          (通用网络)       │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

源码目录结构

torch/
├── distributed/
│   ├── __init__.py              # 分布式API入口
│   ├── distributed_c10d.py      # C10d Python绑定
│   ├── rendezvous.py            # 进程会合
│   ├── launch.py                # torchrun启动器
│   │
│   ├── nn/
│   │   ├── parallel/
│   │   │   ├── distributed.py   # DDP实现
│   │   │   └── _functions.py    # 分布式函数
│   │   │
│   │   └── functional/
│   │       └── _fsdp/           # FSDP实现
│   │
│   └── algorithms/
│       └── ddp_comm_hooks/      # DDP通信钩子
│
├── csrc/
│   └── distributed/
│       ├── c10d/
│       │   ├── ProcessGroup.hpp      # ProcessGroup基类
│       │   ├── ProcessGroupNCCL.hpp  # NCCL后端
│       │   ├── ProcessGroupGloo.hpp  # Gloo后端
│       │   └── Store.hpp             # 分布式存储
│       │
│       └── rpc/                       # RPC实现

torch.distributed 核心实现

init_process_group 源码

# torch/distributed/distributed_c10d.py

def init_process_group(
    backend: str,
    init_method: Optional[str] = None,
    timeout: timedelta = default_pg_timeout,
    world_size: int = -1,
    rank: int = -1,
    store: Optional[Store] = None,
    group_name: str = "",
    pg_options: Optional[ProcessGroupOptions] = None,
):
    """
    初始化分布式进程组

    Args:
        backend: 通信后端 ("nccl", "gloo", "mpi")
        init_method: 初始化方法 ("env://", "tcp://", "file://")
        world_size: 进程总数
        rank: 当前进程rank
        store: 分布式KV存储
    """
    global _pg_group_ranks
    global _backend
    global _default_pg_init_method

    # 1. 参数验证
    if backend not in ["nccl", "gloo", "mpi", "ucc"]:
        raise ValueError(f"Invalid backend: {backend}")

    # 2. 创建或获取Store (用于进程间协调)
    if store is None:
        if init_method is None:
            init_method = "env://"  # 默认使用环境变量

        # 根据init_method创建Store
        if init_method == "env://":
            # 从环境变量获取: MASTER_ADDR, MASTER_PORT
            store = _create_store_from_env()
        elif init_method.startswith("tcp://"):
            # TCP方式: tcp://master_ip:port
            store = TCPStore(...)
        elif init_method.startswith("file://"):
            # 文件方式: file:///path/to/file
            store = FileStore(...)

    # 3. 如果world_size和rank未指定,从环境变量获取
    if world_size == -1:
        world_size = int(os.environ.get("WORLD_SIZE", -1))
    if rank == -1:
        rank = int(os.environ.get("RANK", -1))

    # 4. 创建ProcessGroup
    if backend == "nccl":
        pg = ProcessGroupNCCL(store, rank, world_size, pg_options)
    elif backend == "gloo":
        pg = ProcessGroupGloo(store, rank, world_size, pg_options)
    elif backend == "mpi":
        pg = ProcessGroupMPI(rank, world_size)

    # 5. 设置为默认ProcessGroup
    _set_default_pg(pg)
    _backend = backend
    _default_pg_init_method = init_method

    # 6. 初始化同步屏障
    barrier()

    return pg


def _create_store_from_env():
    """从环境变量创建Store"""
    master_addr = os.environ.get("MASTER_ADDR", "localhost")
    master_port = int(os.environ.get("MASTER_PORT", "29500"))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    rank = int(os.environ.get("RANK", "0"))

    # 创建TCPStore
    # rank 0 是server, 其他rank是client
    is_master = (rank == 0)

    store = TCPStore(
        host_name=master_addr,
        port=master_port,
        world_size=world_size,
        is_master=is_master,
        timeout=timedelta(seconds=300),
    )

    return store

ProcessGroup 实现

// torch/csrc/distributed/c10d/ProcessGroup.hpp

class TORCH_API ProcessGroup : public torch::CustomClassHolder {
 public:
    // 进程组的基本属性
    int rank_;           // 当前进程rank
    int size_;           // 进程组大小
    c10::intrusive_ptr<Store> store_;  // KV存储

    // 集合通信接口
    virtual c10::intrusive_ptr<Work> allreduce(
        std::vector<at::Tensor>& tensors,
        const AllreduceOptions& opts = AllreduceOptions()
    ) = 0;

    virtual c10::intrusive_ptr<Work> broadcast(
        std::vector<at::Tensor>& tensors,
        const BroadcastOptions& opts = BroadcastOptions()
    ) = 0;

    virtual c10::intrusive_ptr<Work> allgather(
        std::vector<std::vector<at::Tensor>>& outputTensors,
        std::vector<at::Tensor>& inputTensors,
        const AllgatherOptions& opts = AllgatherOptions()
    ) = 0;

    virtual c10::intrusive_ptr<Work> reduce_scatter(
        std::vector<at::Tensor>& outputTensors,
        std::vector<std::vector<at::Tensor>>& inputTensors,
        const ReduceScatterOptions& opts = ReduceScatterOptions()
    ) = 0;

    // P2P通信
    virtual c10::intrusive_ptr<Work> send(
        std::vector<at::Tensor>& tensors,
        int dstRank,
        int tag
    ) = 0;

    virtual c10::intrusive_ptr<Work> recv(
        std::vector<at::Tensor>& tensors,
        int srcRank,
        int tag
    ) = 0;

    // 屏障同步
    virtual c10::intrusive_ptr<Work> barrier(
        const BarrierOptions& opts = BarrierOptions()
    ) = 0;
};


// ProcessGroupNCCL 实现
// torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

class ProcessGroupNCCL : public ProcessGroup {
 private:
    // NCCL通信器 (每个设备一个)
    std::unordered_map<std::string, std::vector<ncclComm_t>> ncclComms_;

    // CUDA流 (用于异步通信)
    std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>> ncclStreams_;

    // 事件 (用于同步)
    std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;

 public:
    c10::intrusive_ptr<Work> allreduce(
        std::vector<at::Tensor>& tensors,
        const AllreduceOptions& opts
    ) override {
        // 1. 检查tensor设备
        check_gpu_tensors(tensors);

        // 2. 获取或创建NCCL通信器
        auto devices = getDevices(tensors);
        auto key = getKeyFromDevices(devices);
        auto& ncclComms = getNCCLComms(key, devices);

        // 3. 创建Work对象
        auto work = c10::make_intrusive<ProcessGroupNCCL::WorkNCCL>(
            devices, rank_, opts.opType
        );

        // 4. 在NCCL流中执行AllReduce
        for (size_t i = 0; i < tensors.size(); i++) {
            auto& tensor = tensors[i];
            auto& ncclComm = ncclComms[i];
            auto& ncclStream = ncclStreams_[key][i];

            // 切换到NCCL流
            c10::cuda::CUDAStreamGuard guard(ncclStream);

            // 记录开始事件
            work->ncclStartEvents_[i].record(ncclStream);

            // 调用NCCL AllReduce
            NCCL_CHECK(ncclAllReduce(
                tensor.data_ptr(),
                tensor.data_ptr(),
                tensor.numel(),
                getNcclDataType(tensor.scalar_type()),
                getNcclReduceOp(opts.reduceOp),
                ncclComm,
                ncclStream.stream()
            ));

            // 记录结束事件
            work->ncclEndEvents_[i].record(ncclStream);
        }

        return work;
    }
};

all_reduce 函数

# torch/distributed/distributed_c10d.py

def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
    """
    对tensor执行AllReduce操作

    Args:
        tensor: 输入/输出tensor (in-place操作)
        op: 归约操作 (SUM, PRODUCT, MIN, MAX, etc.)
        group: ProcessGroup (默认使用全局默认组)
        async_op: 是否异步执行

    Returns:
        如果async_op=True, 返回Work对象用于同步
        否则返回None
    """
    # 获取ProcessGroup
    if group is None:
        group = _get_default_group()

    # 参数验证
    if not tensor.is_cuda and group._get_backend_name() == "nccl":
        raise ValueError("NCCL backend requires CUDA tensors")

    # 执行AllReduce
    opts = AllreduceOptions()
    opts.reduceOp = op

    work = group.allreduce([tensor], opts)

    if async_op:
        return work
    else:
        work.wait()
        return None


# 异步AllReduce使用示例
async_work = dist.all_reduce(tensor, async_op=True)
# ... 做其他计算 ...
async_work.wait()  # 等待完成

DistributedDataParallel (DDP) 源码

DDP 初始化

# torch/nn/parallel/distributed.py

class DistributedDataParallel(Module):
    """
    分布式数据并行

    将模型复制到多个GPU,每个GPU处理不同的数据batch
    梯度通过AllReduce同步
    """

    def __init__(
        self,
        module,
        device_ids=None,
        output_device=None,
        dim=0,
        broadcast_buffers=True,
        process_group=None,
        bucket_cap_mb=25,
        find_unused_parameters=False,
        check_reduction=False,
        gradient_as_bucket_view=False,
        static_graph=False,
    ):
        super(DistributedDataParallel, self).__init__()

        # 1. 保存模块引用
        self.module = module

        # 2. 获取ProcessGroup
        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        # 3. 设置设备
        if device_ids is None:
            device_ids = [torch.cuda.current_device()]
        self.device_ids = device_ids
        self.output_device = output_device or device_ids[0]

        # 4. 同步模型参数 (从rank 0广播)
        self._sync_params_and_buffers(authoritative_rank=0)

        # 5. 构建Reducer
        self._build_reducer(
            bucket_cap_mb=bucket_cap_mb,
            find_unused_parameters=find_unused_parameters,
            gradient_as_bucket_view=gradient_as_bucket_view,
        )

        # 6. 注册梯度钩子
        self._register_grad_hooks()

    def _sync_params_and_buffers(self, authoritative_rank=0):
        """同步所有rank的模型参数"""
        # 收集所有参数和buffer
        module_states = []
        for param in self.module.parameters():
            module_states.append(param.detach())
        for buffer in self.module.buffers():
            module_states.append(buffer.detach())

        # 从authoritative_rank广播到所有rank
        for state in module_states:
            dist.broadcast(state, src=authoritative_rank,
                          group=self.process_group)

    def _build_reducer(self, bucket_cap_mb, find_unused_parameters,
                       gradient_as_bucket_view):
        """构建Reducer用于梯度同步"""

        # 收集需要同步的参数
        parameters = list(self.module.parameters())
        params_for_reduction = [p for p in parameters if p.requires_grad]

        # 构建参数桶
        # 将参数按大小分组,减少通信次数
        self.reducer = dist.Reducer(
            params_for_reduction,
            list(range(len(params_for_reduction))),  # 参数索引
            self.process_group,
            [[]],  # 预期稀疏梯度
            bucket_cap_mb * 1024 * 1024,  # 桶大小 (bytes)
            find_unused_parameters,
            gradient_as_bucket_view,
            self.module.parameters,  # 参数名称
        )

    def _register_grad_hooks(self):
        """为每个参数注册梯度钩子"""
        for param in self.module.parameters():
            if param.requires_grad:
                # 注册backward钩子
                param.register_hook(self._make_param_hook(param))

    def _make_param_hook(self, param):
        """创建参数钩子"""
        def hook(grad):
            # 通知Reducer该参数的梯度已ready
            self.reducer.mark_variable_ready(param)
            return grad
        return hook

DDP Forward 实现

# torch/nn/parallel/distributed.py (续)

class DistributedDataParallel(Module):

    def forward(self, *inputs, **kwargs):
        """
        DDP forward流程:
        1. 同步模型参数 (如果需要)
        2. 准备Reducer
        3. 执行模型forward
        """
        # 1. 如果配置了broadcast_buffers,同步buffer
        if self.broadcast_buffers:
            self._sync_buffers()

        # 2. 准备Reducer (重置状态)
        if self.require_backward_grad_sync:
            self.reducer.prepare_for_backward([])

        # 3. 执行forward
        if self.device_ids:
            # 移动输入到GPU
            inputs = self._move_inputs_to_device(inputs, self.device_ids[0])
            kwargs = self._move_kwargs_to_device(kwargs, self.device_ids[0])

        # 4. 调用原始模块的forward
        output = self.module(*inputs, **kwargs)

        # 5. 如果是第一次迭代,记录计算图
        if self._is_first_iteration():
            self.reducer.prepare_for_backward(
                list(_find_tensors(output))
            )

        return output

    def _sync_buffers(self):
        """同步模型buffer (如BatchNorm的running_mean)"""
        buffers = [b.data for b in self.module.buffers()]
        if buffers:
            # 从rank 0广播
            dist.broadcast_coalesced(buffers, group=self.process_group)

Reducer 实现 (C++)

// torch/csrc/distributed/c10d/reducer.cpp

class Reducer {
 public:
    // 参数桶
    struct Bucket {
        std::vector<at::Tensor> gradients;
        at::Tensor flat_buffer;  // 展平后的梯度缓冲区
        size_t size;             // 桶大小 (bytes)
        bool ready;              // 是否可以开始通信
    };

    std::vector<Bucket> buckets_;

    // 异步AllReduce的Work对象
    std::vector<c10::intrusive_ptr<c10d::Work>> pending_works_;

    void mark_variable_ready(size_t variable_index) {
        /*
        当某个参数的梯度计算完成时调用

        DDP的核心优化: 边计算边通信
        - 不等所有梯度都ready才开始AllReduce
        - 一个bucket满了就开始该bucket的AllReduce
        - 与后续梯度计算重叠
        */

        // 1. 找到该参数所属的bucket
        auto& bucket = buckets_[variable_to_bucket_[variable_index]];

        // 2. 增加ready计数
        bucket.pending--;

        // 3. 如果bucket的所有梯度都ready, 启动AllReduce
        if (bucket.pending == 0) {
            start_bucket_allreduce(bucket);
        }
    }

    void start_bucket_allreduce(Bucket& bucket) {
        // 1. 将梯度拷贝到连续buffer (或使用gradient_as_bucket_view优化)
        if (!gradient_as_bucket_view_) {
            copy_grads_to_bucket(bucket);
        }

        // 2. 异步AllReduce
        auto work = process_group_->allreduce(
            {bucket.flat_buffer},
            AllreduceOptions()
        );

        pending_works_.push_back(work);
    }

    void finalize_backward() {
        /*
        backward完成后调用
        等待所有AllReduce完成
        */

        // 等待所有pending的AllReduce
        for (auto& work : pending_works_) {
            work->wait();
        }
        pending_works_.clear();

        // 如果不是gradient_as_bucket_view, 需要拷贝回参数
        if (!gradient_as_bucket_view_) {
            copy_bucket_to_grads();
        }

        // 除以world_size得到平均梯度
        // (NCCL的AllReduce默认是SUM, 需要手动除)
        for (auto& param : params_) {
            param.grad().div_(world_size_);
        }
    }
};

梯度桶 (Bucket) 机制

┌─────────────────────────────────────────────────────────────────────────┐
│                      DDP Bucket 机制                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  问题: 参数多 + 单独AllReduce = 通信次数太多                             │
│                                                                          │
│  解决: 将参数分组成桶, 以桶为单位通信                                    │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────┐        │
│  │                      参数排列                                │        │
│  │                                                              │        │
│  │  Layer 1         Layer 2         Layer 3        Layer 4     │        │
│  │  [weight, bias]  [weight, bias]  [weight, bias] [weight]    │        │
│  │     ↓               ↓               ↓             ↓         │        │
│  │  [1MB, 0.1MB]   [10MB, 1MB]     [5MB, 0.5MB]   [8MB]       │        │
│  │                                                              │        │
│  └─────────────────────────────────────────────────────────────┘        │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────┐        │
│  │                      桶划分 (25MB/bucket)                    │        │
│  │                                                              │        │
│  │  Bucket 0 (25MB)        Bucket 1 (12.6MB)                   │        │
│  │  ┌──────────────────┐   ┌──────────────────┐                │        │
│  │  │ Layer 4 weight   │   │ Layer 2 weight   │                │        │
│  │  │ Layer 3 weight   │   │ Layer 2 bias     │                │        │
│  │  │ Layer 3 bias     │   │ Layer 1 weight   │                │        │
│  │  │                  │   │ Layer 1 bias     │                │        │
│  │  └──────────────────┘   └──────────────────┘                │        │
│  │                                                              │        │
│  │  注意: 按backward顺序排列, Layer 4最先ready                  │        │
│  └─────────────────────────────────────────────────────────────┘        │
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────┐        │
│  │                      通信与计算重叠                          │        │
│  │                                                              │        │
│  │  Backward:  [Layer4][Layer3][  Layer2  ][Layer1]            │        │
│  │                ↓      ↓                                      │        │
│  │  AllReduce:   [Bucket0 AllReduce] [Bucket1 AllReduce]       │        │
│  │                                                              │        │
│  │  Bucket0 包含 Layer4, Layer3                                 │        │
│  │  当 Layer3 backward完成, Bucket0就可以开始AllReduce          │        │
│  │  与 Layer2, Layer1 的backward并行执行                        │        │
│  │                                                              │        │
│  └─────────────────────────────────────────────────────────────┘        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

DDP 通信钩子

# torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py

def allreduce_hook(
    process_group: dist.ProcessGroup,
    bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    """
    默认AllReduce钩子

    Args:
        process_group: 通信组
        bucket: 梯度桶 (包含展平后的梯度tensor)

    Returns:
        Future对象, 完成后包含归约后的梯度
    """
    return (
        dist.all_reduce(bucket.buffer(), group=process_group, async_op=True)
        .get_future()
        .then(lambda fut: fut.value()[0])
    )


def fp16_compress_hook(
    process_group: dist.ProcessGroup,
    bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    """
    FP16压缩钩子

    将梯度转换为FP16传输, 减少通信量50%
    """
    # 转换为FP16
    compressed = bucket.buffer().half()

    # AllReduce
    future = dist.all_reduce(compressed, group=process_group, async_op=True).get_future()

    # 转回FP32
    def decompress(fut):
        return fut.value()[0].float()

    return future.then(decompress)


def powerSGD_hook(
    process_group: dist.ProcessGroup,
    bucket: dist.GradBucket,
    state: PowerSGDState
) -> torch.futures.Future[torch.Tensor]:
    """
    PowerSGD压缩钩子

    使用低秩分解压缩梯度
    """
    input_tensor = bucket.buffer()
    device = input_tensor.device

    # 1. 重塑为矩阵
    matrix = input_tensor.view(state.matrix_approximation_rank, -1)

    # 2. 正交化
    if state.use_error_feedback:
        matrix.add_(state.error_dict[bucket.index()])

    # 3. 低秩分解
    p, q = torch.qr(matrix)
    # 只传输p和q, 大大减少通信量

    # 4. AllReduce p和q
    dist.all_reduce(p, group=process_group)
    dist.all_reduce(q, group=process_group)

    # 5. 重建梯度
    reconstructed = torch.mm(p, q.t())

    # 6. 误差反馈
    if state.use_error_feedback:
        state.error_dict[bucket.index()] = matrix - reconstructed

    return torch.futures.Future().set_result(reconstructed.view_as(input_tensor))


# 使用通信钩子
model = DDP(model)
model.register_comm_hook(process_group, fp16_compress_hook)

FSDP 源码解析

FSDP 基本原理

┌─────────────────────────────────────────────────────────────────────────┐
│                        FSDP vs DDP                                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  DDP (DistributedDataParallel):                                         │
│  ══════════════════════════════                                          │
│                                                                          │
│  Rank 0            Rank 1            Rank 2            Rank 3           │
│  ┌──────────┐     ┌──────────┐     ┌──────────┐     ┌──────────┐       │
│  │ 完整模型  │     │ 完整模型  │     │ 完整模型  │     │ 完整模型  │       │
│  │ 完整梯度  │     │ 完整梯度  │     │ 完整梯度  │     │ 完整梯度  │       │
│  │ 完整优化器│     │ 完整优化器│     │ 完整优化器│     │ 完整优化器│       │
│  └──────────┘     └──────────┘     └──────────┘     └──────────┘       │
│                                                                          │
│  每个rank显存占用 = 模型 + 梯度 + 优化器状态 = 16x参数量 (Adam FP32)     │
│                                                                          │
│  FSDP (FullyShardedDataParallel):                                       │
│  ═════════════════════════════════                                       │
│                                                                          │
│  Rank 0            Rank 1            Rank 2            Rank 3           │
│  ┌──────────┐     ┌──────────┐     ┌──────────┐     ┌──────────┐       │
│  │ 模型分片0 │     │ 模型分片1 │     │ 模型分片2 │     │ 模型分片3 │       │
│  │ 梯度分片0 │     │ 梯度分片1 │     │ 梯度分片2 │     │ 梯度分片3 │       │
│  │ 优化器0   │     │ 优化器1   │     │ 优化器2   │     │ 优化器3   │       │
│  └──────────┘     └──────────┘     └──────────┘     └──────────┘       │
│                                                                          │
│  每个rank显存占用 = (模型 + 梯度 + 优化器) / N = 16x参数量 / N          │
│                                                                          │
│  Forward时:                                                              │
│  需要使用某层参数时, AllGather收集完整参数, 使用后丢弃                   │
│                                                                          │
│  Backward时:                                                             │
│  AllGather收集完整参数和梯度, 计算后ReduceScatter到各自分片              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

FSDP 核心实现

# torch/distributed/fsdp/fully_sharded_data_parallel.py

class FullyShardedDataParallel(nn.Module):
    """
    全分片数据并行

    将模型参数、梯度、优化器状态分片到多个GPU
    """

    def __init__(
        self,
        module: nn.Module,
        process_group: ProcessGroup = None,
        sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
        cpu_offload: CPUOffload = None,
        auto_wrap_policy: Optional[Callable] = None,
        backward_prefetch: BackwardPrefetch = BackwardPrefetch.BACKWARD_PRE,
        mixed_precision: Optional[MixedPrecision] = None,
        ignored_modules: Optional[Iterable[nn.Module]] = None,
        param_init_fn: Optional[Callable] = None,
        sync_module_states: bool = False,
    ):
        super().__init__()

        self.process_group = process_group or dist.distributed_c10d._get_default_group()
        self.rank = dist.get_rank(self.process_group)
        self.world_size = dist.get_world_size(self.process_group)

        # 分片策略
        self.sharding_strategy = sharding_strategy

        # CPU卸载
        self.cpu_offload = cpu_offload

        # 混合精度
        self.mixed_precision = mixed_precision

        # 1. 自动wrap子模块 (可选)
        if auto_wrap_policy is not None:
            module = self._auto_wrap(module, auto_wrap_policy)

        self._fsdp_wrapped_module = module

        # 2. 收集并分片参数
        self._shard_parameters()

        # 3. 注册forward/backward钩子
        self._register_hooks()

    def _shard_parameters(self):
        """分片模型参数"""
        self._handles = []

        for param_name, param in self._fsdp_wrapped_module.named_parameters():
            # 1. 创建FlatParameter (展平参数)
            flat_param = FlatParameter(param)

            # 2. 计算每个rank的分片大小
            param_numel = flat_param.numel()
            chunk_size = (param_numel + self.world_size - 1) // self.world_size

            # 3. 只保留本rank的分片
            start_idx = self.rank * chunk_size
            end_idx = min(start_idx + chunk_size, param_numel)

            # 4. 创建分片视图
            sharded_param = flat_param.data[start_idx:end_idx].clone()

            # 5. 替换原参数
            setattr(self._fsdp_wrapped_module, param_name,
                   nn.Parameter(sharded_param))

            # 6. 创建Handle管理分片
            handle = FlatParamHandle(
                flat_param=flat_param,
                sharded_param=sharded_param,
                param_name=param_name,
            )
            self._handles.append(handle)

    def _unshard_params(self, handles=None):
        """取消分片, 收集完整参数"""
        if handles is None:
            handles = self._handles

        for handle in handles:
            # AllGather收集所有分片
            full_param = torch.empty(
                handle.flat_param.shape,
                dtype=handle.sharded_param.dtype,
                device=handle.sharded_param.device,
            )

            # 收集所有rank的分片
            dist.all_gather_into_tensor(
                full_param,
                handle.sharded_param,
                group=self.process_group,
            )

            # 更新模型参数为完整参数
            handle.flat_param.data = full_param

    def _reshard_params(self, handles=None):
        """重新分片, 释放非本地参数"""
        if handles is None:
            handles = self._handles

        for handle in handles:
            # 只保留本地分片
            handle.flat_param.data = handle.sharded_param

    def forward(self, *args, **kwargs):
        """
        FSDP forward流程:
        1. AllGather收集完整参数
        2. 执行forward
        3. 释放参数 (重新分片)
        """
        # 1. Unshard参数
        self._unshard_params()

        # 2. Forward
        output = self._fsdp_wrapped_module(*args, **kwargs)

        # 3. Reshard参数 (如果不需要保留用于backward)
        if self.sharding_strategy == ShardingStrategy.FULL_SHARD:
            self._reshard_params()

        return output

FSDP Backward

# torch/distributed/fsdp/fully_sharded_data_parallel.py (续)

class FullyShardedDataParallel(nn.Module):

    def _register_hooks(self):
        """注册forward/backward钩子"""

        def _pre_forward_hook(module, args):
            """Forward前: Unshard参数"""
            if module._is_root:
                self._unshard_params()

        def _post_forward_hook(module, args, output):
            """Forward后: Reshard参数"""
            if self.sharding_strategy == ShardingStrategy.FULL_SHARD:
                self._reshard_params()

        def _pre_backward_hook(module, grad_output):
            """Backward前: Unshard参数 (如果已reshard)"""
            self._unshard_params()

            # 预取下一层的参数 (优化)
            if self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
                self._prefetch_next_layer()

        def _post_backward_hook(param, grad):
            """
            Backward后: ReduceScatter梯度

            不是AllReduce! 每个rank只需要自己分片的梯度
            """
            # 1. ReduceScatter: 每个rank得到梯度的不同分片
            grad_sharded = torch.empty(
                self._chunk_size,
                dtype=grad.dtype,
                device=grad.device,
            )

            dist.reduce_scatter_tensor(
                grad_sharded,
                grad,
                group=self.process_group,
            )

            # 2. 更新本地梯度
            param.grad = grad_sharded

            # 3. Reshard参数
            self._reshard_params()

        # 注册钩子
        self.register_forward_pre_hook(_pre_forward_hook)
        self.register_forward_hook(_post_forward_hook)
        # backward钩子通过autograd hooks注册


class FlatParamHandle:
    """
    管理分片参数的Handle
    """

    def __init__(self, flat_param, sharded_param, param_name):
        self.flat_param = flat_param      # 原始完整参数
        self.sharded_param = sharded_param  # 本rank的分片
        self.param_name = param_name

        # 梯度累积
        self.grad_accumulator = None

        # 通信handle (用于异步)
        self._all_gather_work = None
        self._reduce_scatter_work = None

    def all_gather_params(self, process_group, async_op=True):
        """异步AllGather参数"""
        full_param = torch.empty(
            self.flat_param.shape,
            dtype=self.sharded_param.dtype,
            device=self.sharded_param.device,
        )

        work = dist.all_gather_into_tensor(
            full_param,
            self.sharded_param,
            group=process_group,
            async_op=async_op,
        )

        if async_op:
            self._all_gather_work = work
            self._full_param_buffer = full_param
        else:
            self.flat_param.data = full_param

    def wait_all_gather(self):
        """等待AllGather完成"""
        if self._all_gather_work is not None:
            self._all_gather_work.wait()
            self.flat_param.data = self._full_param_buffer
            self._all_gather_work = None

    def reduce_scatter_grad(self, process_group, async_op=True):
        """ReduceScatter梯度"""
        grad_sharded = torch.empty(
            self.sharded_param.shape,
            dtype=self.flat_param.grad.dtype,
            device=self.flat_param.grad.device,
        )

        work = dist.reduce_scatter_tensor(
            grad_sharded,
            self.flat_param.grad,
            group=process_group,
            async_op=async_op,
        )

        if async_op:
            self._reduce_scatter_work = work
            self._grad_sharded_buffer = grad_sharded
        else:
            self.sharded_param.grad = grad_sharded

    def wait_reduce_scatter(self):
        """等待ReduceScatter完成"""
        if self._reduce_scatter_work is not None:
            self._reduce_scatter_work.wait()
            self.sharded_param.grad = self._grad_sharded_buffer
            self._reduce_scatter_work = None

源码调试技巧

启用调试日志

import os
import logging

# 启用PyTorch分布式日志
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"

# 或通过代码设置
import torch.distributed as dist
dist.set_debug_level(dist.DebugLevel.DETAIL)

# NCCL调试
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"

# 查看DDP内部状态
model = DDP(model)
print(model._module_copies)
print(model.reducer.buckets)

常用调试工具

# 检查通信状态
def check_dist_state():
    print(f"Rank: {dist.get_rank()}")
    print(f"World Size: {dist.get_world_size()}")
    print(f"Backend: {dist.get_backend()}")
    print(f"Is Initialized: {dist.is_initialized()}")

# 检查DDP状态
def check_ddp_state(model):
    print(f"Device IDs: {model.device_ids}")
    print(f"Output Device: {model.output_device}")
    print(f"Bucket Cap: {model.bucket_bytes_cap}")
    print(f"Find Unused Params: {model.find_unused_parameters}")

# 检查FSDP状态
def check_fsdp_state(model):
    print(f"Sharding Strategy: {model.sharding_strategy}")
    print(f"CPU Offload: {model.cpu_offload}")
    print(f"Mixed Precision: {model.mixed_precision}")

# 同步检查点
def sync_check(msg):
    rank = dist.get_rank()
    print(f"[Rank {rank}] {msg}")
    dist.barrier()

总结

PyTorch分布式核心组件

┌─────────────────────────────────────────────────────────────────────────┐
│                    PyTorch分布式核心组件                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  torch.distributed                                                       │
│  ─────────────────                                                       │
│  • init_process_group: 初始化分布式环境                                  │
│  • ProcessGroup: 抽象通信组                                              │
│  • 集合通信原语: all_reduce, broadcast, all_gather等                     │
│                                                                          │
│  DistributedDataParallel (DDP)                                          │
│  ─────────────────────────────                                           │
│  • 每rank持有完整模型副本                                                │
│  • Reducer管理梯度同步                                                   │
│  • Bucket机制减少通信次数                                                │
│  • 通信钩子支持梯度压缩                                                  │
│                                                                          │
│  FullyShardedDataParallel (FSDP)                                        │
│  ────────────────────────────────                                        │
│  • 参数、梯度、优化器状态全分片                                          │
│  • Forward时AllGather, Backward时ReduceScatter                          │
│  • 支持CPU Offload                                                       │
│  • 显存效率比DDP高N倍                                                    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

面试高频问题

  1. DDP和FSDP的核心区别是什么?
  2. DDP的Bucket机制如何实现通信与计算重叠?
  3. FSDP为什么使用ReduceScatter而不是AllReduce?
  4. ProcessGroupNCCL是如何调用NCCL的?
  5. 如何自定义DDP的通信钩子实现梯度压缩?
Prev
12-框架源码解析
Next
02-DeepSpeed源码深度解析