HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

02-DeepSpeed源码深度解析

概述

DeepSpeed是Microsoft开发的深度学习优化库,提供了ZeRO、3D并行、推理优化等核心技术。本章深入解析DeepSpeed的核心源码实现,帮助理解大规模模型训练的底层机制。

DeepSpeed 架构

整体架构

┌─────────────────────────────────────────────────────────────────────────┐
│                       DeepSpeed 架构                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                         用户 API                                 │    │
│  │                                                                  │    │
│  │  deepspeed.initialize()        model_engine.step()              │    │
│  │  deepspeed.config             model_engine.backward()            │    │
│  │                                                                  │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                      DeepSpeed Engine                            │    │
│  │                                                                  │    │
│  │  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │    │
│  │  │   ZeRO      │ │   混合精度   │ │   激活检查点 │ │   通信    │ │    │
│  │  │  Optimizer  │ │   FP16/BF16 │ │  Checkpointing│ │  后端    │ │    │
│  │  │  Stage 1-3  │ │   Loss Scale│ │              │ │          │ │    │
│  │  └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                      并行策略层                                  │    │
│  │                                                                  │    │
│  │  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │    │
│  │  │ Data Parallel│ │Tensor Parallel│ │Pipeline     │ │ Expert   │ │    │
│  │  │   (DP)      │ │    (TP)      │ │ Parallel(PP)│ │Parallel  │ │    │
│  │  └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                │                                         │
│  ┌─────────────────────────────┴───────────────────────────────────┐    │
│  │                      通信基础设施                                │    │
│  │                                                                  │    │
│  │         NCCL              torch.distributed              MPI     │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

源码目录结构

deepspeed/
├── __init__.py                 # 入口
├── runtime/
│   ├── engine.py               # DeepSpeedEngine
│   ├── zero/
│   │   ├── stage_1_and_2.py    # ZeRO Stage 1 & 2
│   │   ├── stage3.py           # ZeRO Stage 3
│   │   └── partition_parameters.py
│   ├── pipe/
│   │   └── engine.py           # Pipeline Engine
│   ├── activation_checkpointing/
│   │   └── checkpointing.py
│   └── fp16/
│       └── fused_optimizer.py
├── ops/
│   ├── adam/                   # Fused Adam
│   ├── transformer/            # Fused Transformer
│   └── sparse_attention/       # 稀疏注意力
├── moe/                        # Mixture of Experts
├── inference/                  # 推理优化
└── comm/                       # 通信工具

ZeRO 优化器源码

ZeRO 原理回顾

┌─────────────────────────────────────────────────────────────────────────┐
│                         ZeRO 优化级别                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  显存占用 (以Adam为例, 假设模型参数Ψ):                                   │
│  ═══════════════════════════════════════                                 │
│                                                                          │
│  • 参数 (FP16): 2Ψ bytes                                                │
│  • 梯度 (FP16): 2Ψ bytes                                                │
│  • 优化器状态 (FP32):                                                    │
│    - Adam: 参数副本(4Ψ) + momentum(4Ψ) + variance(4Ψ) = 12Ψ bytes      │
│                                                                          │
│  总计: 2Ψ + 2Ψ + 12Ψ = 16Ψ bytes (每GPU)                               │
│                                                                          │
│  ┌────────────────────────────────────────────────────────────────┐     │
│  │                                                                 │     │
│  │  Stage 0 (无ZeRO):                                              │     │
│  │  每GPU: 参数(2Ψ) + 梯度(2Ψ) + 优化器(12Ψ) = 16Ψ                │     │
│  │                                                                 │     │
│  │  Stage 1 (优化器状态分片):                                      │     │
│  │  每GPU: 参数(2Ψ) + 梯度(2Ψ) + 优化器(12Ψ/N) ≈ 4Ψ + 12Ψ/N      │     │
│  │                                                                 │     │
│  │  Stage 2 (+ 梯度分片):                                          │     │
│  │  每GPU: 参数(2Ψ) + 梯度(2Ψ/N) + 优化器(12Ψ/N) ≈ 2Ψ + 14Ψ/N    │     │
│  │                                                                 │     │
│  │  Stage 3 (+ 参数分片):                                          │     │
│  │  每GPU: 参数(2Ψ/N) + 梯度(2Ψ/N) + 优化器(12Ψ/N) = 16Ψ/N        │     │
│  │                                                                 │     │
│  └────────────────────────────────────────────────────────────────┘     │
│                                                                          │
│  通信开销:                                                               │
│  ────────                                                                │
│  • Stage 1: 与DDP相同 (AllReduce梯度)                                   │
│  • Stage 2: ReduceScatter梯度 + AllGather参数                           │
│  • Stage 3: 每层 AllGather参数 (forward/backward)                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

ZeRO Stage 1 & 2 实现

# deepspeed/runtime/zero/stage_1_and_2.py

class DeepSpeedZeroOptimizer(ZeROOptimizer):
    """
    ZeRO Stage 1 和 Stage 2 优化器
    """

    def __init__(
        self,
        init_optimizer,           # 基础优化器 (如Adam)
        param_names,              # 参数名称
        timers,
        static_loss_scale=1.0,
        dynamic_loss_scale=False,
        dynamic_loss_args=None,
        verbose=True,
        contiguous_gradients=True,
        reduce_bucket_size=500000000,  # 500MB
        allgather_bucket_size=500000000,
        dp_process_group=None,
        expert_parallel_group=None,
        expert_data_parallel_group=None,
        reduce_scatter=True,      # True = Stage 2, False = Stage 1
        overlap_comm=False,
        cpu_offload=False,
        mpu=None,
        clip_grad=0.0,
        communication_data_type=torch.float16,
        postscale_gradients=True,
        gradient_predivide_factor=1.0,
        gradient_accumulation_steps=1,
    ):
        super().__init__()

        self.optimizer = init_optimizer
        self.reduce_scatter = reduce_scatter
        self.overlap_comm = overlap_comm

        # 分布式设置
        self.dp_process_group = dp_process_group
        self.dp_world_size = dist.get_world_size(group=dp_process_group)
        self.dp_rank = dist.get_rank(group=dp_process_group)

        # 参数分组
        self._partition_parameters()

        # 通信bucket
        self.reduce_bucket_size = reduce_bucket_size
        self.allgather_bucket_size = allgather_bucket_size

    def _partition_parameters(self):
        """将优化器状态分片到各rank"""
        all_params = []
        for param_group in self.optimizer.param_groups:
            all_params.extend(param_group['params'])

        # 按rank分配参数
        self.param_partitions = [[] for _ in range(self.dp_world_size)]
        for i, param in enumerate(all_params):
            partition_id = i % self.dp_world_size
            self.param_partitions[partition_id].append(param)

        # 本rank负责的参数
        self.local_params = self.param_partitions[self.dp_rank]

        # 为本地参数创建优化器状态
        self._create_local_optimizer_states()

    def _create_local_optimizer_states(self):
        """只为本rank负责的参数创建优化器状态"""
        # 重置优化器的param_groups
        new_param_groups = []
        for param_group in self.optimizer.param_groups:
            new_group = {k: v for k, v in param_group.items() if k != 'params'}
            new_group['params'] = [p for p in param_group['params']
                                   if p in self.local_params]
            new_param_groups.append(new_group)

        self.optimizer.param_groups = new_param_groups

    def backward(self, loss, retain_graph=False):
        """
        执行backward并同步梯度

        Stage 1: AllReduce梯度, 只在本地参数上更新
        Stage 2: ReduceScatter梯度, 每个rank只保留负责的参数梯度
        """
        # 1. 标准backward
        loss.backward(retain_graph=retain_graph)

        # 2. 梯度同步
        if self.reduce_scatter:
            self._reduce_scatter_gradients()  # Stage 2
        else:
            self._allreduce_gradients()       # Stage 1

    def _allreduce_gradients(self):
        """Stage 1: AllReduce所有梯度"""
        # 收集梯度到bucket
        buckets = self._build_grad_buckets()

        for bucket in buckets:
            # AllReduce
            dist.all_reduce(
                bucket.buffer,
                group=self.dp_process_group
            )

            # 平均
            bucket.buffer.div_(self.dp_world_size)

            # 拷贝回参数梯度
            bucket.copy_back_to_grads()

    def _reduce_scatter_gradients(self):
        """Stage 2: ReduceScatter梯度"""
        # 收集所有参数的梯度
        flat_grads = self._flatten_gradients()

        # ReduceScatter
        # 每个rank只得到自己负责的那部分梯度
        chunk_size = flat_grads.numel() // self.dp_world_size
        output = torch.empty(chunk_size, dtype=flat_grads.dtype,
                           device=flat_grads.device)

        dist.reduce_scatter_tensor(
            output,
            flat_grads,
            group=self.dp_process_group
        )

        # 将分片梯度拷贝到本地参数
        self._copy_grad_partitions(output)

    def step(self):
        """
        执行优化器step

        Stage 1/2: 只在本地参数上执行step, 然后AllGather更新后的参数
        """
        # 1. 梯度裁剪 (在本地分片上)
        if self.clip_grad > 0:
            self._clip_grad_norm()

        # 2. 优化器step (只更新本地参数)
        self.optimizer.step()

        # 3. AllGather更新后的参数
        self._allgather_parameters()

    def _allgather_parameters(self):
        """AllGather所有分片的参数"""
        for param_group in self.original_param_groups:
            for param in param_group['params']:
                # 确定该参数属于哪个rank
                owner_rank = self._get_param_owner(param)

                if owner_rank == self.dp_rank:
                    # 本rank拥有该参数,广播给其他rank
                    dist.broadcast(
                        param.data,
                        src=owner_rank,
                        group=self.dp_process_group
                    )
                else:
                    # 接收其他rank的参数
                    dist.broadcast(
                        param.data,
                        src=owner_rank,
                        group=self.dp_process_group
                    )

ZeRO Stage 3 实现

# deepspeed/runtime/zero/stage3.py

class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
    """
    ZeRO Stage 3: 参数、梯度、优化器状态全分片
    """

    def __init__(
        self,
        module,
        init_optimizer,
        timers,
        ds_config,
        static_loss_scale=1,
        dynamic_loss_scale=False,
        dynamic_loss_args=None,
        verbose=True,
        contiguous_gradients=True,
        reduce_bucket_size=500000000,
        prefetch_bucket_size=50000000,
        max_reuse_distance=1000000000,
        max_live_parameters=1000000000,
        param_persistence_threshold=100000,
        dp_process_group=None,
        reduce_scatter=True,
        overlap_comm=False,
        offload_optimizer_config=None,
        offload_param_config=None,
        sub_group_size=1000000000000,
        mpu=None,
        clip_grad=0.0,
        communication_data_type=torch.float16,
        postscale_gradients=True,
        gradient_predivide_factor=1.0,
        gradient_accumulation_steps=1,
    ):
        super().__init__()

        self.module = module
        self.optimizer = init_optimizer

        # 分布式设置
        self.dp_process_group = dp_process_group
        self.dp_world_size = dist.get_world_size(group=dp_process_group)
        self.dp_rank = dist.get_rank(group=dp_process_group)

        # 参数分片
        self._partition_all_parameters()

        # 预取配置
        self.prefetch_bucket_size = prefetch_bucket_size
        self.max_reuse_distance = max_reuse_distance

        # CPU Offload配置
        self.offload_optimizer = offload_optimizer_config is not None
        self.offload_param = offload_param_config is not None

        # 注册前向/后向钩子
        self._register_hooks()

    def _partition_all_parameters(self):
        """将所有参数分片"""
        self.param_handles = {}

        for name, param in self.module.named_parameters():
            # 1. 展平参数
            flat_param = param.data.flatten()

            # 2. 计算每个rank的分片大小
            padded_size = self._get_padded_size(flat_param.numel())
            chunk_size = padded_size // self.dp_world_size

            # 3. 只保留本rank的分片
            start = self.dp_rank * chunk_size
            end = start + chunk_size

            if start < flat_param.numel():
                # 有效分片
                local_end = min(end, flat_param.numel())
                local_param = flat_param[start:local_end].clone()

                # 填充到chunk_size
                if local_param.numel() < chunk_size:
                    local_param = F.pad(local_param,
                                       (0, chunk_size - local_param.numel()))
            else:
                # 纯填充分片
                local_param = torch.zeros(chunk_size, dtype=param.dtype,
                                        device=param.device)

            # 4. 创建分片参数Handle
            handle = PartitionedParameterHandle(
                param=param,
                param_name=name,
                flat_param=flat_param,
                local_param=local_param,
                dp_rank=self.dp_rank,
                dp_world_size=self.dp_world_size,
            )
            self.param_handles[name] = handle

            # 5. 释放原始参数
            param.data = torch.empty(0)

    def _register_hooks(self):
        """注册前向/后向钩子"""

        def _pre_forward_hook(module, inputs):
            """Forward前: AllGather该层的参数"""
            for name, param in module.named_parameters(recurse=False):
                if name in self.param_handles:
                    self._allgather_param(name)

        def _post_forward_hook(module, inputs, outputs):
            """Forward后: 释放非持久化参数"""
            for name, param in module.named_parameters(recurse=False):
                if name in self.param_handles:
                    handle = self.param_handles[name]
                    if not handle.persistent:
                        self._release_param(name)

        def _pre_backward_hook(module, grad_output):
            """Backward前: AllGather该层的参数"""
            for name, param in module.named_parameters(recurse=False):
                if name in self.param_handles:
                    self._allgather_param(name)

        def _post_backward_hook(param):
            """Backward后: ReduceScatter梯度"""
            handle = self.param_handles.get(param._name)
            if handle:
                self._reduce_scatter_grad(param._name)
                self._release_param(param._name)

        # 为每个子模块注册钩子
        for module in self.module.modules():
            module.register_forward_pre_hook(_pre_forward_hook)
            module.register_forward_hook(_post_forward_hook)
            module.register_full_backward_pre_hook(_pre_backward_hook)

        # 为每个参数注册梯度钩子
        for name, param in self.module.named_parameters():
            param._name = name
            param.register_hook(_post_backward_hook)

    def _allgather_param(self, param_name):
        """AllGather收集完整参数"""
        handle = self.param_handles[param_name]

        # 1. 分配完整参数buffer
        full_param = torch.empty(
            handle.padded_size,
            dtype=handle.local_param.dtype,
            device=handle.local_param.device
        )

        # 2. AllGather
        dist.all_gather_into_tensor(
            full_param,
            handle.local_param,
            group=self.dp_process_group
        )

        # 3. 去除padding, reshape回原始形状
        full_param = full_param[:handle.original_numel]
        full_param = full_param.view(handle.original_shape)

        # 4. 更新参数
        handle.param.data = full_param

    def _release_param(self, param_name):
        """释放完整参数,只保留本地分片"""
        handle = self.param_handles[param_name]
        handle.param.data = torch.empty(0)

    def _reduce_scatter_grad(self, param_name):
        """ReduceScatter梯度到分片"""
        handle = self.param_handles[param_name]

        # 1. 展平梯度
        flat_grad = handle.param.grad.flatten()

        # 2. 填充到可整除大小
        if flat_grad.numel() < handle.padded_size:
            flat_grad = F.pad(flat_grad,
                             (0, handle.padded_size - flat_grad.numel()))

        # 3. ReduceScatter
        local_grad = torch.empty(
            handle.chunk_size,
            dtype=flat_grad.dtype,
            device=flat_grad.device
        )

        dist.reduce_scatter_tensor(
            local_grad,
            flat_grad,
            group=self.dp_process_group
        )

        # 4. 保存本地梯度
        handle.local_grad = local_grad

    def step(self):
        """执行优化器step"""
        # 1. 在CPU上执行优化器step (如果CPU offload)
        if self.offload_optimizer:
            self._step_on_cpu()
        else:
            self._step_on_gpu()

    def _step_on_gpu(self):
        """在GPU上执行step"""
        for name, handle in self.param_handles.items():
            # 使用本地分片的参数和梯度
            local_param_fp32 = handle.local_param.float()
            local_grad_fp32 = handle.local_grad.float()

            # 更新优化器状态
            state = self.optimizer.state[handle.param]

            # Adam更新
            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(local_param_fp32)
                state['exp_avg_sq'] = torch.zeros_like(local_param_fp32)
                state['step'] = 0

            state['step'] += 1
            beta1, beta2 = self.optimizer.param_groups[0]['betas']
            lr = self.optimizer.param_groups[0]['lr']
            eps = self.optimizer.param_groups[0]['eps']

            # 更新momentum和variance
            state['exp_avg'].mul_(beta1).add_(local_grad_fp32, alpha=1 - beta1)
            state['exp_avg_sq'].mul_(beta2).addcmul_(
                local_grad_fp32, local_grad_fp32, value=1 - beta2
            )

            # 偏置修正
            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']

            # 参数更新
            denom = (state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
            step_size = lr / bias_correction1
            local_param_fp32.addcdiv_(state['exp_avg'], denom, value=-step_size)

            # 转回FP16
            handle.local_param.copy_(local_param_fp32.half())


class PartitionedParameterHandle:
    """管理分片参数"""

    def __init__(self, param, param_name, flat_param, local_param,
                 dp_rank, dp_world_size):
        self.param = param
        self.param_name = param_name
        self.flat_param = flat_param
        self.local_param = local_param
        self.dp_rank = dp_rank
        self.dp_world_size = dp_world_size

        # 原始形状和大小
        self.original_shape = param.shape
        self.original_numel = param.numel()

        # 填充后的大小
        self.padded_size = self._get_padded_size()
        self.chunk_size = self.padded_size // dp_world_size

        # 梯度分片
        self.local_grad = None

        # 是否持久化 (小参数可以不释放)
        self.persistent = param.numel() < 100000  # 100K以下持久化

    def _get_padded_size(self):
        """计算填充后的大小 (可被world_size整除)"""
        size = self.original_numel
        if size % self.dp_world_size != 0:
            size = ((size // self.dp_world_size) + 1) * self.dp_world_size
        return size

CPU Offload 实现

# deepspeed/runtime/zero/stage3.py (CPU Offload部分)

class DeepSpeedZeroOptimizer_Stage3:

    def _setup_cpu_offload(self):
        """设置CPU Offload"""
        if self.offload_optimizer:
            # 将优化器状态移到CPU
            for handle in self.param_handles.values():
                state = self.optimizer.state.get(handle.param, {})
                for key, val in state.items():
                    if torch.is_tensor(val):
                        state[key] = val.cpu().pin_memory()

        if self.offload_param:
            # 将参数分片移到CPU
            for handle in self.param_handles.values():
                handle.local_param = handle.local_param.cpu().pin_memory()

    def _step_on_cpu(self):
        """在CPU上执行优化器step"""
        # 1. 将梯度移到CPU
        for handle in self.param_handles.values():
            if handle.local_grad is not None:
                handle.local_grad_cpu = handle.local_grad.cpu()

        # 2. 在CPU上执行Adam更新
        # 使用numpy或直接PyTorch CPU计算
        for name, handle in self.param_handles.items():
            local_param_fp32 = handle.local_param.float()
            local_grad_fp32 = handle.local_grad_cpu.float()

            state = self.optimizer.state[handle.param]

            # Adam更新 (CPU)
            # ... (与GPU版本相同的逻辑)

            # 更新本地参数
            handle.local_param.copy_(local_param_fp32.half())

        # 3. 异步拷贝回GPU
        for handle in self.param_handles.values():
            handle.local_param_gpu = handle.local_param.cuda(non_blocking=True)

    def _prefetch_params_to_gpu(self, layer_id):
        """预取参数到GPU"""
        # 在执行当前层时,提前将下一层的参数从CPU移到GPU
        next_layer_params = self._get_layer_params(layer_id + 1)

        for param_name in next_layer_params:
            handle = self.param_handles.get(param_name)
            if handle and handle.local_param.device.type == 'cpu':
                # 异步拷贝到GPU
                handle.local_param_gpu = handle.local_param.cuda(non_blocking=True)

Pipeline Parallel 源码

Pipeline Engine 实现

# deepspeed/runtime/pipe/engine.py

class PipelineEngine(DeepSpeedEngine):
    """
    Pipeline Parallel Engine

    实现 GPipe 和 1F1B 调度
    """

    def __init__(
        self,
        model,
        config,
        mpu,
        **kwargs
    ):
        super().__init__(model, config, **kwargs)

        # Pipeline并行设置
        self.pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
        self.pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
        self.num_stages = self.pipeline_parallel_size

        # micro-batch设置
        self.micro_batches = config.gradient_accumulation_steps

        # 通信buffer
        self._allocate_buffers()

        # 调度器
        self.schedule = self._build_schedule()

    def _build_schedule(self):
        """构建Pipeline调度表"""
        if self.config.pipeline_schedule == "1f1b":
            return self._build_1f1b_schedule()
        else:
            return self._build_gpipe_schedule()

    def _build_1f1b_schedule(self):
        """
        构建 1F1B 调度

        1F1B (One Forward One Backward):
        - 先执行足够多的forward填满pipeline
        - 然后交替执行forward和backward
        - 最后执行剩余的backward
        """
        schedule = []
        num_warmup_microbatches = min(
            self.num_stages - self.pipeline_parallel_rank - 1,
            self.micro_batches
        )
        num_microbatches_remaining = self.micro_batches - num_warmup_microbatches

        # Warmup phase: 只有forward
        for i in range(num_warmup_microbatches):
            schedule.append(('forward', i))

        # Steady state: 1F1B
        for i in range(num_microbatches_remaining):
            schedule.append(('forward', num_warmup_microbatches + i))
            schedule.append(('backward', i))

        # Cooldown phase: 只有backward
        for i in range(num_warmup_microbatches):
            schedule.append(('backward', num_microbatches_remaining + i))

        return schedule

    def _build_gpipe_schedule(self):
        """
        构建 GPipe 调度

        GPipe: 先执行所有forward, 再执行所有backward
        """
        schedule = []

        # 所有forward
        for i in range(self.micro_batches):
            schedule.append(('forward', i))

        # 所有backward
        for i in range(self.micro_batches):
            schedule.append(('backward', i))

        return schedule

    def train_batch(self, data_iter):
        """执行一个batch的训练"""
        # 1. 准备micro-batches
        micro_batches = self._prepare_micro_batches(data_iter)

        # 2. 执行调度
        losses = []
        for action, micro_batch_id in self.schedule:
            if action == 'forward':
                loss = self._exec_forward_pass(micro_batches[micro_batch_id])
                losses.append(loss)
            else:  # backward
                self._exec_backward_pass(micro_batch_id)

        # 3. 梯度同步
        if self.is_data_parallel:
            self._sync_gradients()

        # 4. 优化器step
        self.optimizer.step()
        self.optimizer.zero_grad()

        return sum(losses) / len(losses)

    def _exec_forward_pass(self, micro_batch):
        """执行forward pass"""
        # 1. 接收来自上一个stage的输入
        if self.pipeline_parallel_rank > 0:
            input_tensor = self._recv_forward()
        else:
            input_tensor = micro_batch

        # 2. 执行本stage的forward
        with torch.cuda.amp.autocast(enabled=self.fp16_enabled):
            output_tensor = self.module(input_tensor)

        # 3. 发送输出到下一个stage
        if self.pipeline_parallel_rank < self.num_stages - 1:
            self._send_forward(output_tensor)

        # 4. 保存用于backward
        self._save_activation(output_tensor)

        # 5. 计算loss (只有最后一个stage)
        if self.is_last_stage:
            loss = self.loss_fn(output_tensor, micro_batch['labels'])
            return loss
        return None

    def _exec_backward_pass(self, micro_batch_id):
        """执行backward pass"""
        # 1. 接收来自下一个stage的梯度
        if self.pipeline_parallel_rank < self.num_stages - 1:
            output_grad = self._recv_backward()
        else:
            # 最后一个stage: 从loss计算梯度
            output_grad = None

        # 2. 加载保存的activation
        output_tensor = self._load_activation(micro_batch_id)

        # 3. 执行backward
        if output_grad is not None:
            output_tensor.backward(output_grad)
        else:
            output_tensor.backward()

        # 4. 发送梯度到上一个stage
        if self.pipeline_parallel_rank > 0:
            input_grad = self._get_input_grad()
            self._send_backward(input_grad)

    def _recv_forward(self):
        """接收来自上一个stage的activation"""
        src_rank = self.pipeline_parallel_rank - 1
        recv_buffer = self._get_recv_buffer()

        dist.recv(recv_buffer, src=src_rank, group=self.pipeline_group)
        return recv_buffer

    def _send_forward(self, tensor):
        """发送activation到下一个stage"""
        dst_rank = self.pipeline_parallel_rank + 1
        dist.send(tensor, dst=dst_rank, group=self.pipeline_group)

    def _recv_backward(self):
        """接收来自下一个stage的梯度"""
        src_rank = self.pipeline_parallel_rank + 1
        recv_buffer = self._get_recv_buffer()

        dist.recv(recv_buffer, src=src_rank, group=self.pipeline_group)
        return recv_buffer

    def _send_backward(self, tensor):
        """发送梯度到上一个stage"""
        dst_rank = self.pipeline_parallel_rank - 1
        dist.send(tensor, dst=dst_rank, group=self.pipeline_group)

1F1B 调度图解

┌─────────────────────────────────────────────────────────────────────────┐
│                        1F1B Pipeline 调度                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  假设: 4个Stage, 8个micro-batch                                         │
│                                                                          │
│  Time →                                                                  │
│                                                                          │
│  Stage 0: F0 F1 F2 F3 F4 B0 F5 B1 F6 B2 F7 B3    B4 B5 B6 B7           │
│  Stage 1:    F0 F1 F2 F3 B0 F4 B1 F5 B2 F6 B3 F7 B4 B5 B6 B7           │
│  Stage 2:       F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 F6 B4 F7 B5 B6 B7        │
│  Stage 3:          F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5 F6 B6 F7 B7     │
│                                                                          │
│            ╔═══════╗╔═══════════════════════╗╔═══════╗                  │
│            ║Warmup ║║   Steady State 1F1B   ║║Cooldown║                  │
│            ╚═══════╝╚═══════════════════════╝╚═══════╝                  │
│                                                                          │
│  F = Forward, B = Backward                                              │
│                                                                          │
│  特点:                                                                   │
│  • Warmup: 填满pipeline                                                  │
│  • Steady State: 每个stage同时执行forward和backward                      │
│  • Cooldown: 清空pipeline                                                │
│  • Bubble最小化: 只有warmup和cooldown有bubble                            │
│                                                                          │
│  Bubble比例 ≈ (p-1) / m                                                  │
│  其中 p = pipeline stages, m = micro-batches                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Activation Checkpointing 源码

# deepspeed/runtime/activation_checkpointing/checkpointing.py

class CheckpointFunction(torch.autograd.Function):
    """
    激活检查点实现

    Forward时不保存中间激活,只保存输入
    Backward时重新计算forward得到中间激活
    """

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        # 保存计算函数
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state

        # 保存输入 (用于backward时重计算)
        ctx.save_for_backward(*args)

        # 保存RNG状态 (确保重计算结果一致)
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.fwd_gpu_devices = [torch.cuda.current_device()]
            ctx.fwd_gpu_states = []
            for device in ctx.fwd_gpu_devices:
                ctx.fwd_gpu_states.append(torch.cuda.get_rng_state(device))

        # 在no_grad下执行forward (不保存计算图)
        with torch.no_grad():
            outputs = run_function(*args)

        return outputs

    @staticmethod
    def backward(ctx, *output_grads):
        # 恢复输入
        inputs = ctx.saved_tensors

        # 恢复RNG状态
        if ctx.preserve_rng_state:
            torch.set_rng_state(ctx.fwd_cpu_state)
            for device, state in zip(ctx.fwd_gpu_devices, ctx.fwd_gpu_states):
                torch.cuda.set_rng_state(state, device)

        # 重新计算forward (这次保存计算图)
        with torch.enable_grad():
            # detach输入并启用梯度
            inputs_with_grad = [inp.detach().requires_grad_(inp.requires_grad)
                               for inp in inputs]
            outputs = ctx.run_function(*inputs_with_grad)

        # 如果outputs不是tuple,转换为tuple
        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        # 计算梯度
        input_grads = torch.autograd.grad(
            outputs,
            inputs_with_grad,
            output_grads,
            allow_unused=True
        )

        return (None, None) + input_grads


def checkpoint(function, *args, **kwargs):
    """
    使用激活检查点包装函数

    Usage:
        output = checkpoint(transformer_layer, hidden_states, attention_mask)
    """
    preserve_rng_state = kwargs.get('preserve_rng_state', True)
    return CheckpointFunction.apply(function, preserve_rng_state, *args)


# 高级API: 分段检查点
def checkpoint_sequential(functions, segments, input, **kwargs):
    """
    将多个函数分段检查点

    Args:
        functions: 函数列表 (如transformer layers)
        segments: 分段数量
        input: 输入tensor
    """

    def run_function(start, end, functions):
        def forward(input):
            for j in range(start, end):
                input = functions[j](input)
            return input
        return forward

    # 计算每段的函数数量
    segment_size = len(functions) // segments

    # 对每段应用checkpoint
    for start in range(0, len(functions), segment_size):
        end = min(start + segment_size, len(functions))
        input = checkpoint(run_function(start, end, functions), input, **kwargs)

    return input

DeepSpeed 初始化流程

# deepspeed/__init__.py

def initialize(
    args=None,
    model=None,
    optimizer=None,
    model_parameters=None,
    training_data=None,
    lr_scheduler=None,
    mpu=None,
    dist_init_required=True,
    collate_fn=None,
    config=None,
    config_params=None,
):
    """
    DeepSpeed初始化入口

    Returns:
        tuple: (engine, optimizer, dataloader, lr_scheduler)
    """
    # 1. 解析配置
    ds_config = DeepSpeedConfig(config, config_params)

    # 2. 初始化分布式环境
    if dist_init_required:
        init_distributed()

    # 3. 创建模型引擎
    if ds_config.pipeline_enabled:
        # Pipeline Parallel
        engine = PipelineEngine(
            model=model,
            config=ds_config,
            optimizer=optimizer,
            model_parameters=model_parameters,
            mpu=mpu,
            training_data=training_data,
            lr_scheduler=lr_scheduler,
        )
    elif ds_config.zero_enabled:
        # ZeRO
        if ds_config.zero_stage == 3:
            engine = DeepSpeedEngine(
                model=model,
                config=ds_config,
                optimizer=DeepSpeedZeroOptimizer_Stage3(...),
                ...
            )
        else:
            engine = DeepSpeedEngine(
                model=model,
                config=ds_config,
                optimizer=DeepSpeedZeroOptimizer(...),
                ...
            )
    else:
        # 基础DeepSpeed
        engine = DeepSpeedEngine(
            model=model,
            config=ds_config,
            optimizer=optimizer,
            model_parameters=model_parameters,
            training_data=training_data,
            lr_scheduler=lr_scheduler,
        )

    # 4. 创建数据加载器
    dataloader = engine.create_data_loader(training_data, collate_fn)

    return engine, engine.optimizer, dataloader, engine.lr_scheduler


def init_distributed():
    """初始化分布式环境"""
    if not dist.is_initialized():
        # 从环境变量获取配置
        backend = os.environ.get('DISTRIBUTED_BACKEND', 'nccl')
        init_method = os.environ.get('INIT_METHOD', 'env://')

        dist.init_process_group(
            backend=backend,
            init_method=init_method,
        )

        # 设置当前设备
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
        torch.cuda.set_device(local_rank)

总结

DeepSpeed 核心组件

┌─────────────────────────────────────────────────────────────────────────┐
│                     DeepSpeed 核心组件                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  ZeRO Optimizer                                                          │
│  ──────────────                                                          │
│  • Stage 1: 优化器状态分片                                               │
│  • Stage 2: + 梯度分片 (ReduceScatter)                                  │
│  • Stage 3: + 参数分片 (AllGather forward/backward)                     │
│  • CPU Offload: 将状态卸载到CPU                                          │
│                                                                          │
│  Pipeline Engine                                                         │
│  ───────────────                                                         │
│  • GPipe: 先所有F, 后所有B                                               │
│  • 1F1B: 最小化bubble                                                    │
│  • P2P通信: send/recv activation                                        │
│                                                                          │
│  Activation Checkpointing                                                │
│  ────────────────────────                                                │
│  • 不保存中间激活                                                        │
│  • Backward时重计算                                                      │
│  • 显存换计算时间                                                        │
│                                                                          │
│  混合精度训练                                                            │
│  ────────────                                                            │
│  • FP16/BF16计算                                                         │
│  • 动态Loss Scaling                                                      │
│  • FP32优化器状态                                                        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

面试高频问题

  1. ZeRO Stage 1/2/3 的区别和通信开销?
  2. 为什么Stage 3用ReduceScatter而不是AllReduce?
  3. 1F1B调度如何最小化pipeline bubble?
  4. Activation Checkpointing如何在backward时重计算?
  5. CPU Offload是如何实现的?
Prev
01-PyTorch分布式源码解析
Next
03-Megatron-LM源码解析