HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

分布式训练框架

概述

随着模型规模的爆发式增长(GPT-4 估计超过 1.7 万亿参数),单机训练已无法满足需求。分布式训练通过将模型和数据分布到多个 GPU/节点上并行处理,是训练大规模 AI 模型的核心技术。本章深入讲解分布式训练的原理、主流框架的实现机制以及在 Kubernetes 上的部署实践。

1. 分布式训练基础

1.1 并行策略

┌──────────────────────────────────────────────────────────────────────────┐
│                     Distributed Training Strategies                       │
├──────────────────────────────────────────────────────────────────────────┤
│                                                                           │
│  ┌─────────────────────────────────────────────────────────────────────┐ │
│  │                        Data Parallelism                             │ │
│  │                                                                     │ │
│  │    Model Copy       Model Copy       Model Copy       Model Copy    │ │
│  │   ┌─────────┐      ┌─────────┐      ┌─────────┐      ┌─────────┐   │ │
│  │   │  GPU 0  │      │  GPU 1  │      │  GPU 2  │      │  GPU 3  │   │ │
│  │   └────┬────┘      └────┬────┘      └────┬────┘      └────┬────┘   │ │
│  │        │                │                │                │        │ │
│  │   Data Batch 0     Data Batch 1     Data Batch 2     Data Batch 3  │ │
│  │                         │                                          │ │
│  │                    AllReduce Gradients                             │ │
│  └─────────────────────────────────────────────────────────────────────┘ │
│                                                                           │
│  ┌─────────────────────────────────────────────────────────────────────┐ │
│  │                       Model Parallelism                             │ │
│  │                                                                     │ │
│  │  ┌──────────────────────────────────────────────────────────────┐  │ │
│  │  │                    Pipeline Parallelism                       │  │ │
│  │  │                                                               │  │ │
│  │  │   GPU 0        GPU 1        GPU 2        GPU 3               │  │ │
│  │  │  ┌─────┐      ┌─────┐      ┌─────┐      ┌─────┐              │  │ │
│  │  │  │Layer│ ───▶ │Layer│ ───▶ │Layer│ ───▶ │Layer│              │  │ │
│  │  │  │ 1-6 │      │7-12 │      │13-18│      │19-24│              │  │ │
│  │  │  └─────┘      └─────┘      └─────┘      └─────┘              │  │ │
│  │  │                                                               │  │ │
│  │  │  Stage 0      Stage 1      Stage 2      Stage 3              │  │ │
│  │  └──────────────────────────────────────────────────────────────┘  │ │
│  │                                                                     │ │
│  │  ┌──────────────────────────────────────────────────────────────┐  │ │
│  │  │                    Tensor Parallelism                         │  │ │
│  │  │                                                               │  │ │
│  │  │        Matrix A (4096 x 4096)                                 │  │ │
│  │  │   ┌──────────┬──────────┬──────────┬──────────┐              │  │ │
│  │  │   │   GPU 0  │   GPU 1  │   GPU 2  │   GPU 3  │              │  │ │
│  │  │   │ Col 0-1K │ Col 1-2K │ Col 2-3K │ Col 3-4K │              │  │ │
│  │  │   └──────────┴──────────┴──────────┴──────────┘              │  │ │
│  │  │                                                               │  │ │
│  │  │        AllGather / ReduceScatter                              │  │ │
│  │  └──────────────────────────────────────────────────────────────┘  │ │
│  └─────────────────────────────────────────────────────────────────────┘ │
│                                                                           │
│  ┌─────────────────────────────────────────────────────────────────────┐ │
│  │                    3D/4D Hybrid Parallelism                         │ │
│  │                                                                     │ │
│  │   ┌─────────────────────────────────────────────────────────────┐  │ │
│  │   │  Data Parallel Group 0           Data Parallel Group 1      │  │ │
│  │   │  ┌───────────────────┐          ┌───────────────────┐       │  │ │
│  │   │  │ Pipeline Stage 0  │          │ Pipeline Stage 0  │       │  │ │
│  │   │  │ TP: GPU 0,1       │          │ TP: GPU 4,5       │       │  │ │
│  │   │  ├───────────────────┤          ├───────────────────┤       │  │ │
│  │   │  │ Pipeline Stage 1  │          │ Pipeline Stage 1  │       │  │ │
│  │   │  │ TP: GPU 2,3       │          │ TP: GPU 6,7       │       │  │ │
│  │   │  └───────────────────┘          └───────────────────┘       │  │ │
│  │   └─────────────────────────────────────────────────────────────┘  │ │
│  └─────────────────────────────────────────────────────────────────────┘ │
│                                                                           │
└──────────────────────────────────────────────────────────────────────────┘

1.2 通信原语

// pkg/distributed/collective.go
package distributed

import (
    "fmt"
)

// CollectiveOp 集合通信操作
type CollectiveOp string

const (
    // OpAllReduce 全规约:所有进程得到相同的规约结果
    OpAllReduce CollectiveOp = "AllReduce"
    // OpAllGather 全收集:所有进程得到所有数据的拼接
    OpAllGather CollectiveOp = "AllGather"
    // OpReduceScatter 规约分散:规约后分散到各进程
    OpReduceScatter CollectiveOp = "ReduceScatter"
    // OpBroadcast 广播:一个进程发送,所有进程接收
    OpBroadcast CollectiveOp = "Broadcast"
    // OpReduce 规约:所有进程发送,一个进程接收规约结果
    OpReduce CollectiveOp = "Reduce"
    // OpScatter 分散:一个进程发送不同数据到各进程
    OpScatter CollectiveOp = "Scatter"
    // OpGather 收集:所有进程发送,一个进程接收拼接
    OpGather CollectiveOp = "Gather"
    // OpAllToAll 全交换:所有进程交换数据
    OpAllToAll CollectiveOp = "AllToAll"
)

// CollectiveAlgorithm 集合通信算法
type CollectiveAlgorithm string

const (
    // AlgoRing 环形算法
    AlgoRing CollectiveAlgorithm = "Ring"
    // AlgoTree 树形算法
    AlgoTree CollectiveAlgorithm = "Tree"
    // AlgoRecursiveHalving 递归减半
    AlgoRecursiveHalving CollectiveAlgorithm = "RecursiveHalving"
    // AlgoRecursiveDoubling 递归加倍
    AlgoRecursiveDoubling CollectiveAlgorithm = "RecursiveDoubling"
    // AlgoBinaryBlock 二进制分块
    AlgoBinaryBlock CollectiveAlgorithm = "BinaryBlock"
)

// CommunicationCost 通信开销模型
type CommunicationCost struct {
    // 启动延迟 (秒)
    Alpha float64
    // 传输时间每字节 (秒/字节)
    Beta float64
    // 计算时间每字节 (秒/字节)
    Gamma float64
}

// CalculateRingAllReduceTime 计算 Ring AllReduce 时间
// 参数:n=进程数, m=数据大小(字节)
func (c *CommunicationCost) CalculateRingAllReduceTime(n int, m int64) float64 {
    // Ring AllReduce 分为两个阶段:
    // 1. Reduce-Scatter: (n-1) 步,每步传输 m/n 数据
    // 2. All-Gather: (n-1) 步,每步传输 m/n 数据
    // 总时间 = 2(n-1) * (α + m/n * β) + (n-1) * m/n * γ

    steps := float64(n - 1)
    chunkSize := float64(m) / float64(n)

    // 通信时间
    commTime := 2 * steps * (c.Alpha + chunkSize*c.Beta)
    // 计算时间(规约操作)
    computeTime := steps * chunkSize * c.Gamma

    return commTime + computeTime
}

// CalculateTreeAllReduceTime 计算 Tree AllReduce 时间
func (c *CommunicationCost) CalculateTreeAllReduceTime(n int, m int64) float64 {
    // Tree AllReduce:
    // 1. Reduce: log(n) 步,每步传输 m 数据
    // 2. Broadcast: log(n) 步,每步传输 m 数据
    // 总时间 = 2 * log(n) * (α + m * β) + log(n) * m * γ

    import "math"
    steps := math.Log2(float64(n))
    dataSize := float64(m)

    commTime := 2 * steps * (c.Alpha + dataSize*c.Beta)
    computeTime := steps * dataSize * c.Gamma

    return commTime + computeTime
}

// SelectBestAlgorithm 选择最优算法
func (c *CommunicationCost) SelectBestAlgorithm(op CollectiveOp, n int, m int64) CollectiveAlgorithm {
    switch op {
    case OpAllReduce:
        ringTime := c.CalculateRingAllReduceTime(n, m)
        treeTime := c.CalculateTreeAllReduceTime(n, m)

        if ringTime < treeTime {
            return AlgoRing
        }
        return AlgoTree

    case OpBroadcast, OpReduce:
        // 树形算法对广播/规约更高效
        return AlgoTree

    case OpAllGather, OpReduceScatter:
        // Ring 算法带宽利用率高
        return AlgoRing

    default:
        return AlgoRing
    }
}

1.3 数据并行实现

# distributed_data_parallel.py
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

class DataParallelTrainer:
    """数据并行训练器"""

    def __init__(self, model, optimizer_class, lr, backend='nccl'):
        self.backend = backend
        self.lr = lr
        self.optimizer_class = optimizer_class

        # 初始化分布式环境
        self._init_distributed()

        # 包装模型
        self.model = self._setup_model(model)

        # 创建优化器
        self.optimizer = optimizer_class(self.model.parameters(), lr=lr)

    def _init_distributed(self):
        """初始化分布式环境"""
        # 从环境变量获取配置
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        self.rank = int(os.environ.get('RANK', 0))
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))

        # 初始化进程组
        if not dist.is_initialized():
            dist.init_process_group(
                backend=self.backend,
                init_method='env://'
            )

        # 设置当前设备
        torch.cuda.set_device(self.local_rank)

        print(f"Initialized rank {self.rank}/{self.world_size}, "
              f"local_rank {self.local_rank}")

    def _setup_model(self, model):
        """设置分布式模型"""
        # 移动模型到 GPU
        model = model.cuda(self.local_rank)

        # 使用 DDP 包装
        model = DDP(
            model,
            device_ids=[self.local_rank],
            output_device=self.local_rank,
            # 梯度桶大小(MB)
            bucket_cap_mb=25,
            # 是否查找未使用的参数
            find_unused_parameters=False,
            # 梯度作为桶视图(内存优化)
            gradient_as_bucket_view=True,
        )

        return model

    def train_step(self, batch):
        """单步训练"""
        self.model.train()

        inputs, targets = batch
        inputs = inputs.cuda(self.local_rank, non_blocking=True)
        targets = targets.cuda(self.local_rank, non_blocking=True)

        # 前向传播
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = nn.functional.cross_entropy(outputs, targets)

        # 反向传播(DDP 自动同步梯度)
        loss.backward()

        # 更新参数
        self.optimizer.step()

        return loss.item()

    def save_checkpoint(self, path, epoch):
        """保存检查点(仅 rank 0)"""
        if self.rank != 0:
            return

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.module.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }
        torch.save(checkpoint, path)

    def load_checkpoint(self, path):
        """加载检查点"""
        # 使用 map_location 加载到正确的设备
        checkpoint = torch.load(
            path,
            map_location=f'cuda:{self.local_rank}'
        )

        self.model.module.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        return checkpoint['epoch']


class ZeRODataParallel:
    """ZeRO 优化的数据并行"""

    def __init__(self, model, optimizer_class, lr, zero_stage=2):
        """
        ZeRO 阶段:
        - Stage 1: 优化器状态分片
        - Stage 2: 优化器状态 + 梯度分片
        - Stage 3: 优化器状态 + 梯度 + 参数分片
        """
        self.zero_stage = zero_stage
        self._init_distributed()

        # 使用 DeepSpeed 或 FSDP 实现 ZeRO
        if zero_stage == 3:
            self.model = self._setup_fsdp(model)
        else:
            self.model = self._setup_deepspeed(model, zero_stage)

    def _setup_fsdp(self, model):
        """使用 PyTorch FSDP (ZeRO-3)"""
        from torch.distributed.fsdp import (
            FullyShardedDataParallel as FSDP,
            MixedPrecision,
            BackwardPrefetch,
            ShardingStrategy,
        )
        from torch.distributed.fsdp.wrap import (
            size_based_auto_wrap_policy,
            transformer_auto_wrap_policy,
        )

        # 混合精度配置
        mixed_precision = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        )

        # 自动包装策略
        auto_wrap_policy = size_based_auto_wrap_policy(
            min_num_params=1e6  # 100万参数以上的模块自动分片
        )

        model = FSDP(
            model,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            mixed_precision=mixed_precision,
            auto_wrap_policy=auto_wrap_policy,
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            device_id=self.local_rank,
        )

        return model

    def _setup_deepspeed(self, model, stage):
        """使用 DeepSpeed (ZeRO-1/2)"""
        import deepspeed

        ds_config = {
            "train_batch_size": 32 * self.world_size,
            "gradient_accumulation_steps": 1,
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "initial_scale_power": 16,
            },
            "zero_optimization": {
                "stage": stage,
                "contiguous_gradients": True,
                "overlap_comm": True,
                "reduce_scatter": True,
                "reduce_bucket_size": 5e8,
                "allgather_bucket_size": 5e8,
            },
        }

        model_engine, optimizer, _, _ = deepspeed.initialize(
            model=model,
            config=ds_config,
        )

        return model_engine

2. PyTorch 分布式训练

2.1 torchrun 启动方式

#!/bin/bash
# launch_distributed.sh

# 单机多卡
torchrun \
    --standalone \
    --nproc_per_node=8 \
    train.py --config config.yaml

# 多机多卡
torchrun \
    --nnodes=4 \
    --nproc_per_node=8 \
    --node_rank=$NODE_RANK \
    --master_addr=$MASTER_ADDR \
    --master_port=29500 \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:29400 \
    train.py --config config.yaml

2.2 弹性训练实现

# elastic_training.py
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.data import ElasticDistributedSampler
import os

class ElasticTrainingManager:
    """弹性训练管理器"""

    def __init__(self):
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        self.rank = int(os.environ.get('RANK', 0))
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))

        # 初始状态
        self.initial_world_size = self.world_size
        self.checkpoint_path = os.environ.get('CHECKPOINT_PATH', '/checkpoints')

    @record
    def train(self, model, train_dataset, epochs, batch_size):
        """弹性训练主循环"""
        # 初始化分布式
        dist.init_process_group(backend='nccl')
        torch.cuda.set_device(self.local_rank)

        # 创建模型
        model = model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[self.local_rank]
        )

        # 创建弹性采样器
        sampler = ElasticDistributedSampler(
            train_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
        )

        dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=4,
            pin_memory=True,
        )

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        # 加载检查点
        start_epoch = self._load_checkpoint(model, optimizer)

        for epoch in range(start_epoch, epochs):
            # 更新采样器 epoch(保证数据随机性)
            sampler.set_epoch(epoch)

            # 检查 world_size 变化
            current_world_size = dist.get_world_size()
            if current_world_size != self.world_size:
                self._handle_world_size_change(
                    model, optimizer, sampler, current_world_size
                )

            # 训练一个 epoch
            self._train_epoch(model, dataloader, optimizer, epoch)

            # 保存检查点
            self._save_checkpoint(model, optimizer, epoch)

        dist.destroy_process_group()

    def _handle_world_size_change(self, model, optimizer, sampler, new_world_size):
        """处理 world_size 变化"""
        print(f"World size changed: {self.world_size} -> {new_world_size}")

        old_world_size = self.world_size
        self.world_size = new_world_size
        self.rank = dist.get_rank()

        # 更新采样器
        sampler.set_num_replicas(new_world_size)
        sampler.set_rank(self.rank)

        # 调整学习率(可选)
        scale_factor = new_world_size / old_world_size
        for param_group in optimizer.param_groups:
            param_group['lr'] *= scale_factor

    def _train_epoch(self, model, dataloader, optimizer, epoch):
        """训练一个 epoch"""
        model.train()
        total_loss = 0.0

        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs = inputs.cuda(self.local_rank, non_blocking=True)
            targets = targets.cuda(self.local_rank, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = torch.nn.functional.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0 and self.rank == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

        return total_loss / len(dataloader)

    def _save_checkpoint(self, model, optimizer, epoch):
        """保存检查点"""
        if self.rank != 0:
            dist.barrier()
            return

        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'world_size': self.world_size,
        }

        path = os.path.join(self.checkpoint_path, f'checkpoint_epoch_{epoch}.pt')
        torch.save(checkpoint, path)

        # 更新 latest 链接
        latest_path = os.path.join(self.checkpoint_path, 'latest.pt')
        if os.path.exists(latest_path):
            os.remove(latest_path)
        os.symlink(path, latest_path)

        dist.barrier()

    def _load_checkpoint(self, model, optimizer):
        """加载检查点"""
        latest_path = os.path.join(self.checkpoint_path, 'latest.pt')

        if not os.path.exists(latest_path):
            return 0

        checkpoint = torch.load(
            latest_path,
            map_location=f'cuda:{self.local_rank}'
        )

        model.module.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
        return checkpoint['epoch']

2.3 流水线并行

# pipeline_parallel.py
import torch
import torch.nn as nn
from torch.distributed.pipeline.sync import Pipe
from torch.distributed import rpc

class PipelineParallelModel(nn.Module):
    """流水线并行模型"""

    def __init__(self, num_stages, hidden_size, num_layers):
        super().__init__()

        self.num_stages = num_stages
        layers_per_stage = num_layers // num_stages

        # 创建每个阶段的层
        self.stages = nn.ModuleList()
        for stage_id in range(num_stages):
            layers = []
            for _ in range(layers_per_stage):
                layers.append(nn.Linear(hidden_size, hidden_size))
                layers.append(nn.ReLU())
            self.stages.append(nn.Sequential(*layers))

    def forward(self, x):
        for stage in self.stages:
            x = stage(x)
        return x


class GPipeTrainer:
    """GPipe 风格的流水线并行训练"""

    def __init__(self, model, num_stages, num_microbatches):
        self.num_stages = num_stages
        self.num_microbatches = num_microbatches

        # 分割模型
        self.model = self._partition_model(model)

        # 使用 Pipe 包装
        self.pipe_model = Pipe(
            self.model,
            chunks=num_microbatches,
            checkpoint='never',  # 或 'always' 用于激活检查点
        )

    def _partition_model(self, model):
        """将模型分割到多个设备"""
        # 假设模型有 stages 属性
        partitions = []

        for stage_id, stage in enumerate(model.stages):
            # 将每个阶段放到对应的 GPU
            device = torch.device(f'cuda:{stage_id}')
            stage = stage.to(device)
            partitions.append(stage)

        return nn.Sequential(*partitions)

    def train_step(self, batch):
        """流水线训练步骤"""
        inputs, targets = batch

        # 输入放到第一个设备
        inputs = inputs.to(torch.device('cuda:0'))
        # 目标放到最后一个设备
        targets = targets.to(torch.device(f'cuda:{self.num_stages-1}'))

        # 前向传播(自动流水线)
        outputs = self.pipe_model(inputs)

        # 计算损失
        loss = nn.functional.cross_entropy(outputs, targets)

        # 反向传播
        loss.backward()

        return loss.item()


class Interleaved1F1B:
    """交错式 1F1B 流水线调度"""

    def __init__(self, num_stages, num_microbatches):
        self.num_stages = num_stages
        self.num_microbatches = num_microbatches

    def generate_schedule(self):
        """生成流水线调度"""
        schedule = []
        num_warmup = self.num_stages - 1
        num_1f1b = self.num_microbatches - num_warmup

        # Warmup 阶段:只有前向传播
        for mb in range(num_warmup):
            for stage in range(self.num_stages):
                schedule.append(('forward', stage, mb))

        # 1F1B 稳态阶段
        for mb in range(num_1f1b):
            forward_mb = num_warmup + mb
            backward_mb = mb

            for stage in range(self.num_stages):
                schedule.append(('forward', stage, forward_mb))
                schedule.append(('backward', stage, backward_mb))

        # Cooldown 阶段:只有反向传播
        for mb in range(num_1f1b, self.num_microbatches):
            for stage in range(self.num_stages):
                schedule.append(('backward', stage, mb))

        return schedule

    def visualize_schedule(self):
        """可视化调度"""
        schedule = self.generate_schedule()

        # 创建时间线
        timeline = {}
        for stage in range(self.num_stages):
            timeline[stage] = []

        time = 0
        for op, stage, mb in schedule:
            symbol = f"F{mb}" if op == 'forward' else f"B{mb}"
            timeline[stage].append(symbol)

        # 打印
        print("Pipeline Schedule:")
        for stage in range(self.num_stages):
            print(f"Stage {stage}: {' '.join(timeline[stage])}")

3. Megatron-LM 深度解析

3.1 张量并行实现

# megatron_tensor_parallel.py
import torch
import torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):
    """列并行线性层

    将权重矩阵按列分割:
    Y = X * W,其中 W 被分割为 [W1, W2, ..., Wn]
    每个 GPU 计算 Yi = X * Wi
    """

    def __init__(self, input_size, output_size, bias=True,
                 gather_output=True, init_method=None):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output

        # 获取张量并行大小
        self.tp_world_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        # 每个 GPU 的输出大小
        self.output_size_per_partition = output_size // self.tp_world_size

        # 创建分片权重
        self.weight = nn.Parameter(torch.empty(
            self.output_size_per_partition,
            self.input_size,
            device=torch.cuda.current_device(),
            dtype=torch.float32,
        ))

        if bias:
            self.bias = nn.Parameter(torch.empty(
                self.output_size_per_partition,
                device=torch.cuda.current_device(),
                dtype=torch.float32,
            ))
        else:
            self.register_parameter('bias', None)

        # 初始化权重
        if init_method is not None:
            init_method(self.weight)
        else:
            nn.init.xavier_uniform_(self.weight)

        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, input_):
        # 前向传播
        # input_: [batch, seq_len, hidden]
        # weight: [output_per_gpu, hidden]
        # output: [batch, seq_len, output_per_gpu]

        # 复制输入到所有 TP rank
        input_parallel = copy_to_tensor_model_parallel_region(input_)

        # 线性变换
        output_parallel = torch.nn.functional.linear(
            input_parallel, self.weight, self.bias
        )

        if self.gather_output:
            # AllGather 输出
            output = gather_from_tensor_model_parallel_region(output_parallel)
        else:
            output = output_parallel

        return output


class RowParallelLinear(nn.Module):
    """行并行线性层

    将权重矩阵按行分割:
    Y = X * W,其中 X 被分割为 [X1, X2, ..., Xn]
    每个 GPU 计算 Yi = Xi * Wi
    最后 AllReduce 得到 Y = sum(Yi)
    """

    def __init__(self, input_size, output_size, bias=True,
                 input_is_parallel=False, init_method=None):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel

        self.tp_world_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        # 每个 GPU 的输入大小
        self.input_size_per_partition = input_size // self.tp_world_size

        # 创建分片权重
        self.weight = nn.Parameter(torch.empty(
            self.output_size,
            self.input_size_per_partition,
            device=torch.cuda.current_device(),
            dtype=torch.float32,
        ))

        if bias:
            self.bias = nn.Parameter(torch.empty(
                self.output_size,
                device=torch.cuda.current_device(),
                dtype=torch.float32,
            ))
        else:
            self.register_parameter('bias', None)

        # 初始化
        if init_method is not None:
            init_method(self.weight)
        else:
            nn.init.xavier_uniform_(self.weight)

        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, input_):
        if self.input_is_parallel:
            input_parallel = input_
        else:
            # 分割输入
            input_parallel = scatter_to_tensor_model_parallel_region(input_)

        # 线性变换
        output_parallel = torch.nn.functional.linear(input_parallel, self.weight)

        # AllReduce
        output_ = reduce_from_tensor_model_parallel_region(output_parallel)

        # 添加 bias
        if self.bias is not None:
            output = output_ + self.bias
        else:
            output = output_

        return output


class ParallelTransformerLayer(nn.Module):
    """并行 Transformer 层"""

    def __init__(self, hidden_size, num_attention_heads, ffn_hidden_size):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads

        # 获取 TP 大小
        tp_size = get_tensor_model_parallel_world_size()

        # 确保 attention heads 可以被 TP 整除
        assert num_attention_heads % tp_size == 0

        # Self-Attention
        # QKV 使用列并行(每个 GPU 计算部分 heads)
        self.qkv_proj = ColumnParallelLinear(
            hidden_size,
            3 * hidden_size,
            bias=True,
            gather_output=False,  # 保持并行
        )

        # Output projection 使用行并行
        self.output_proj = RowParallelLinear(
            hidden_size,
            hidden_size,
            bias=True,
            input_is_parallel=True,  # 输入已经是并行的
        )

        # FFN
        # 第一个线性层:列并行
        self.ffn_up = ColumnParallelLinear(
            hidden_size,
            ffn_hidden_size,
            bias=True,
            gather_output=False,
        )

        # 第二个线性层:行并行
        self.ffn_down = RowParallelLinear(
            ffn_hidden_size,
            hidden_size,
            bias=True,
            input_is_parallel=True,
        )

        # LayerNorm
        self.input_layernorm = nn.LayerNorm(hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, attention_mask=None):
        # Self-Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        # QKV projection
        qkv = self.qkv_proj(hidden_states)

        # 分割 Q, K, V
        # [batch, seq, 3 * hidden_per_tp] -> 3 * [batch, seq, hidden_per_tp]
        q, k, v = qkv.chunk(3, dim=-1)

        # 计算注意力(每个 GPU 计算部分 heads)
        attn_output = self._compute_attention(q, k, v, attention_mask)

        # Output projection(包含 AllReduce)
        hidden_states = self.output_proj(attn_output)
        hidden_states = residual + hidden_states

        # FFN
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        # FFN up(列并行)
        hidden_states = self.ffn_up(hidden_states)
        hidden_states = torch.nn.functional.gelu(hidden_states)

        # FFN down(行并行,包含 AllReduce)
        hidden_states = self.ffn_down(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

    def _compute_attention(self, q, k, v, mask):
        """计算注意力(简化版)"""
        batch_size, seq_len, hidden_per_tp = q.shape
        tp_size = get_tensor_model_parallel_world_size()

        # 每个 GPU 的 heads 数
        num_heads_per_tp = self.num_attention_heads // tp_size
        head_dim = hidden_per_tp // num_heads_per_tp

        # Reshape: [batch, seq, hidden] -> [batch, num_heads, seq, head_dim]
        q = q.view(batch_size, seq_len, num_heads_per_tp, head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, num_heads_per_tp, head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, num_heads_per_tp, head_dim).transpose(1, 2)

        # 缩放点积注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5)

        if mask is not None:
            scores = scores + mask

        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        # Reshape back: [batch, num_heads, seq, head_dim] -> [batch, seq, hidden]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, hidden_per_tp)

        return attn_output


# 辅助函数
def get_tensor_model_parallel_world_size():
    """获取张量并行大小"""
    return dist.get_world_size(group=get_tensor_model_parallel_group())

def get_tensor_model_parallel_rank():
    """获取张量并行 rank"""
    return dist.get_rank(group=get_tensor_model_parallel_group())

def get_tensor_model_parallel_group():
    """获取张量并行进程组"""
    # 实际实现中需要初始化进程组
    return None  # placeholder

def copy_to_tensor_model_parallel_region(input_):
    """复制到张量并行区域(前向无操作,反向 AllReduce)"""
    return input_  # 简化

def scatter_to_tensor_model_parallel_region(input_):
    """分散到张量并行区域"""
    return input_  # 简化

def gather_from_tensor_model_parallel_region(input_):
    """从张量并行区域收集"""
    return input_  # 简化

def reduce_from_tensor_model_parallel_region(input_):
    """从张量并行区域规约"""
    return input_  # 简化

3.2 3D 并行配置

# megatron_3d_parallel.py
import torch
import torch.distributed as dist

class ParallelConfig:
    """3D 并行配置"""

    def __init__(self, tensor_parallel_size, pipeline_parallel_size,
                 data_parallel_size, world_size):
        """
        Args:
            tensor_parallel_size: 张量并行大小
            pipeline_parallel_size: 流水线并行大小
            data_parallel_size: 数据并行大小
            world_size: 总进程数
        """
        assert (tensor_parallel_size * pipeline_parallel_size *
                data_parallel_size == world_size), \
            "TP * PP * DP must equal world_size"

        self.tp_size = tensor_parallel_size
        self.pp_size = pipeline_parallel_size
        self.dp_size = data_parallel_size
        self.world_size = world_size

        # 进程组
        self.tp_group = None
        self.pp_group = None
        self.dp_group = None

    def initialize_parallel_groups(self, rank):
        """初始化并行进程组"""
        # 计算各维度的 rank
        # 假设 GPU 排列方式:[DP, PP, TP]
        # 即相邻 TP_SIZE 个 GPU 组成一个 TP 组

        tp_rank = rank % self.tp_size
        pp_rank = (rank // self.tp_size) % self.pp_size
        dp_rank = rank // (self.tp_size * self.pp_size)

        # 创建张量并行组
        # 同一个 PP stage 和 DP replica 内的 GPU
        for dp in range(self.dp_size):
            for pp in range(self.pp_size):
                ranks = []
                for tp in range(self.tp_size):
                    r = dp * self.pp_size * self.tp_size + pp * self.tp_size + tp
                    ranks.append(r)
                group = dist.new_group(ranks)
                if dp_rank == dp and pp_rank == pp:
                    self.tp_group = group

        # 创建流水线并行组
        # 同一个 DP replica 内,相同 TP rank 的 GPU
        for dp in range(self.dp_size):
            for tp in range(self.tp_size):
                ranks = []
                for pp in range(self.pp_size):
                    r = dp * self.pp_size * self.tp_size + pp * self.tp_size + tp
                    ranks.append(r)
                group = dist.new_group(ranks)
                if dp_rank == dp and tp_rank == tp:
                    self.pp_group = group

        # 创建数据并行组
        # 不同 DP replica 内,相同 PP stage 和 TP rank 的 GPU
        for pp in range(self.pp_size):
            for tp in range(self.tp_size):
                ranks = []
                for dp in range(self.dp_size):
                    r = dp * self.pp_size * self.tp_size + pp * self.tp_size + tp
                    ranks.append(r)
                group = dist.new_group(ranks)
                if pp_rank == pp and tp_rank == tp:
                    self.dp_group = group

        # 存储 rank 信息
        self.tp_rank = tp_rank
        self.pp_rank = pp_rank
        self.dp_rank = dp_rank

        return {
            'tp_rank': tp_rank,
            'pp_rank': pp_rank,
            'dp_rank': dp_rank,
            'tp_group': self.tp_group,
            'pp_group': self.pp_group,
            'dp_group': self.dp_group,
        }


class Megatron3DParallel:
    """Megatron 风格的 3D 并行训练"""

    def __init__(self, model_config, parallel_config: ParallelConfig):
        self.model_config = model_config
        self.parallel_config = parallel_config

        # 初始化并行组
        self.rank = dist.get_rank()
        self.parallel_info = parallel_config.initialize_parallel_groups(self.rank)

        # 创建模型
        self.model = self._create_model()

    def _create_model(self):
        """创建分布式模型"""
        # 根据 PP rank 创建对应的层
        num_layers = self.model_config['num_layers']
        layers_per_stage = num_layers // self.parallel_config.pp_size

        start_layer = self.parallel_info['pp_rank'] * layers_per_stage
        end_layer = start_layer + layers_per_stage

        layers = []
        for layer_id in range(start_layer, end_layer):
            layer = ParallelTransformerLayer(
                hidden_size=self.model_config['hidden_size'],
                num_attention_heads=self.model_config['num_attention_heads'],
                ffn_hidden_size=self.model_config['ffn_hidden_size'],
            )
            layers.append(layer)

        # 如果是第一个 stage,添加 embedding
        if self.parallel_info['pp_rank'] == 0:
            embedding = self._create_embedding()
            layers.insert(0, embedding)

        # 如果是最后一个 stage,添加 output layer
        if self.parallel_info['pp_rank'] == self.parallel_config.pp_size - 1:
            output_layer = self._create_output_layer()
            layers.append(output_layer)

        return nn.Sequential(*layers)

    def train_step(self, batch):
        """一次训练迭代"""
        # 根据 PP rank 执行对应的流水线阶段

        if self.parallel_info['pp_rank'] == 0:
            # 第一个 stage:从数据加载
            hidden_states = self.model[0](batch)
            # 发送到下一个 stage
            self._send_forward(hidden_states)
        elif self.parallel_info['pp_rank'] == self.parallel_config.pp_size - 1:
            # 最后一个 stage:计算 loss
            hidden_states = self._recv_forward()
            output = self.model[-1](hidden_states)
            loss = self._compute_loss(output, batch)
            # 发送梯度到上一个 stage
            self._send_backward(hidden_states.grad)
            return loss
        else:
            # 中间 stage
            hidden_states = self._recv_forward()
            output = self.model(hidden_states)
            self._send_forward(output)
            grad = self._recv_backward()
            output.backward(grad)
            self._send_backward(hidden_states.grad)

    def _send_forward(self, tensor):
        """发送前向激活"""
        next_rank = self.rank + self.parallel_config.tp_size
        dist.send(tensor, next_rank, group=self.parallel_config.pp_group)

    def _recv_forward(self):
        """接收前向激活"""
        prev_rank = self.rank - self.parallel_config.tp_size
        tensor = torch.empty(...)  # 需要知道形状
        dist.recv(tensor, prev_rank, group=self.parallel_config.pp_group)
        return tensor

    def _send_backward(self, grad):
        """发送反向梯度"""
        prev_rank = self.rank - self.parallel_config.tp_size
        dist.send(grad, prev_rank, group=self.parallel_config.pp_group)

    def _recv_backward(self):
        """接收反向梯度"""
        next_rank = self.rank + self.parallel_config.tp_size
        grad = torch.empty(...)  # 需要知道形状
        dist.recv(grad, next_rank, group=self.parallel_config.pp_group)
        return grad

4. Kubernetes 上的分布式训练

4.1 PyTorchJob CRD

# pytorchjob.yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: llm-training
  namespace: ml-workloads
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      restartPolicy: OnFailure
      template:
        metadata:
          annotations:
            gpu.topology/policy: nvlink
            gpu.topology/min-bandwidth: "200"
        spec:
          schedulerName: gpu-scheduler
          containers:
            - name: pytorch
              image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
              imagePullPolicy: Always
              command:
                - python
                - -m
                - torch.distributed.run
                - --nproc_per_node=8
                - --nnodes=$(WORLD_SIZE)
                - --node_rank=$(RANK)
                - --master_addr=$(MASTER_ADDR)
                - --master_port=$(MASTER_PORT)
                - train.py
                - --config=config/llm.yaml
              resources:
                limits:
                  nvidia.com/gpu: 8
                  memory: 512Gi
                requests:
                  nvidia.com/gpu: 8
                  cpu: "64"
                  memory: 512Gi
              env:
                - name: NCCL_DEBUG
                  value: INFO
                - name: NCCL_IB_DISABLE
                  value: "0"
                - name: NCCL_NET_GDR_LEVEL
                  value: "5"
              volumeMounts:
                - name: dataset
                  mountPath: /data
                - name: checkpoints
                  mountPath: /checkpoints
                - name: shm
                  mountPath: /dev/shm
          volumes:
            - name: dataset
              persistentVolumeClaim:
                claimName: training-dataset
            - name: checkpoints
              persistentVolumeClaim:
                claimName: checkpoints
            - name: shm
              emptyDir:
                medium: Memory
                sizeLimit: 64Gi
          affinity:
            podAntiAffinity:
              requiredDuringSchedulingIgnoredDuringExecution:
                - labelSelector:
                    matchLabels:
                      training.kubeflow.org/job-name: llm-training
                  topologyKey: kubernetes.io/hostname

    Worker:
      replicas: 3
      restartPolicy: OnFailure
      template:
        metadata:
          annotations:
            gpu.topology/policy: nvlink
        spec:
          schedulerName: gpu-scheduler
          containers:
            - name: pytorch
              image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
              command:
                - python
                - -m
                - torch.distributed.run
                - --nproc_per_node=8
                - --nnodes=$(WORLD_SIZE)
                - --node_rank=$(RANK)
                - --master_addr=$(MASTER_ADDR)
                - --master_port=$(MASTER_PORT)
                - train.py
                - --config=config/llm.yaml
              resources:
                limits:
                  nvidia.com/gpu: 8
                  memory: 512Gi
                requests:
                  nvidia.com/gpu: 8
                  cpu: "64"
                  memory: 512Gi
              env:
                - name: NCCL_DEBUG
                  value: INFO
              volumeMounts:
                - name: dataset
                  mountPath: /data
                - name: checkpoints
                  mountPath: /checkpoints
                - name: shm
                  mountPath: /dev/shm
          volumes:
            - name: dataset
              persistentVolumeClaim:
                claimName: training-dataset
                readOnly: true
            - name: checkpoints
              persistentVolumeClaim:
                claimName: checkpoints
            - name: shm
              emptyDir:
                medium: Memory
                sizeLimit: 64Gi

4.2 训练 Operator 实现

// pkg/controller/pytorchjob_controller.go
package controller

import (
    "context"
    "fmt"
    "strings"

    kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    corev1 "k8s.io/api/core/v1"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    "k8s.io/apimachinery/pkg/runtime"
    "sigs.k8s.io/controller-runtime/pkg/client"
    "sigs.k8s.io/controller-runtime/pkg/controller"
    "sigs.k8s.io/controller-runtime/pkg/handler"
    "sigs.k8s.io/controller-runtime/pkg/manager"
    "sigs.k8s.io/controller-runtime/pkg/reconcile"
    "sigs.k8s.io/controller-runtime/pkg/source"
)

// PyTorchJobReconciler reconciles a PyTorchJob object
type PyTorchJobReconciler struct {
    client.Client
    Scheme *runtime.Scheme
}

// Reconcile 协调 PyTorchJob
func (r *PyTorchJobReconciler) Reconcile(ctx context.Context,
    req reconcile.Request) (reconcile.Result, error) {

    // 获取 PyTorchJob
    job := &kubeflowv1.PyTorchJob{}
    if err := r.Get(ctx, req.NamespacedName, job); err != nil {
        return reconcile.Result{}, client.IgnoreNotFound(err)
    }

    // 计算副本数
    replicas := r.calculateReplicas(job)

    // 创建或更新 Pods
    for replicaType, spec := range job.Spec.PyTorchReplicaSpecs {
        for i := int32(0); i < *spec.Replicas; i++ {
            podName := fmt.Sprintf("%s-%s-%d", job.Name, strings.ToLower(string(replicaType)), i)

            // 检查 Pod 是否存在
            existingPod := &corev1.Pod{}
            err := r.Get(ctx, client.ObjectKey{
                Namespace: job.Namespace,
                Name:      podName,
            }, existingPod)

            if err != nil {
                // 创建 Pod
                pod := r.createPod(job, replicaType, i, replicas)
                if err := r.Create(ctx, pod); err != nil {
                    return reconcile.Result{}, err
                }
            }
        }
    }

    // 更新状态
    if err := r.updateStatus(ctx, job); err != nil {
        return reconcile.Result{}, err
    }

    return reconcile.Result{}, nil
}

// calculateReplicas 计算总副本数
func (r *PyTorchJobReconciler) calculateReplicas(job *kubeflowv1.PyTorchJob) int32 {
    total := int32(0)
    for _, spec := range job.Spec.PyTorchReplicaSpecs {
        if spec.Replicas != nil {
            total += *spec.Replicas
        }
    }
    return total
}

// createPod 创建 Pod
func (r *PyTorchJobReconciler) createPod(job *kubeflowv1.PyTorchJob,
    replicaType kubeflowv1.ReplicaType, index int32, totalReplicas int32) *corev1.Pod {

    spec := job.Spec.PyTorchReplicaSpecs[replicaType]
    podName := fmt.Sprintf("%s-%s-%d", job.Name, strings.ToLower(string(replicaType)), index)

    // 计算 rank
    rank := r.calculateRank(job, replicaType, index)

    // 获取 master 地址
    masterAddr := fmt.Sprintf("%s-master-0", job.Name)

    pod := &corev1.Pod{
        ObjectMeta: metav1.ObjectMeta{
            Name:      podName,
            Namespace: job.Namespace,
            Labels: map[string]string{
                "training.kubeflow.org/job-name":      job.Name,
                "training.kubeflow.org/replica-type":  string(replicaType),
                "training.kubeflow.org/replica-index": fmt.Sprintf("%d", index),
            },
            OwnerReferences: []metav1.OwnerReference{
                *metav1.NewControllerRef(job, kubeflowv1.SchemeGroupVersion.WithKind("PyTorchJob")),
            },
        },
        Spec: *spec.Template.Spec.DeepCopy(),
    }

    // 注入环境变量
    for i := range pod.Spec.Containers {
        pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env,
            corev1.EnvVar{Name: "WORLD_SIZE", Value: fmt.Sprintf("%d", totalReplicas)},
            corev1.EnvVar{Name: "RANK", Value: fmt.Sprintf("%d", rank)},
            corev1.EnvVar{Name: "MASTER_ADDR", Value: masterAddr},
            corev1.EnvVar{Name: "MASTER_PORT", Value: "29500"},
        )
    }

    return pod
}

// calculateRank 计算 rank
func (r *PyTorchJobReconciler) calculateRank(job *kubeflowv1.PyTorchJob,
    replicaType kubeflowv1.ReplicaType, index int32) int32 {

    // Master rank 为 0
    if replicaType == kubeflowv1.PyTorchJobReplicaTypeMaster {
        return 0
    }

    // Worker rank 从 1 开始
    masterReplicas := int32(0)
    if spec, ok := job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok {
        if spec.Replicas != nil {
            masterReplicas = *spec.Replicas
        }
    }

    return masterReplicas + index
}

// updateStatus 更新状态
func (r *PyTorchJobReconciler) updateStatus(ctx context.Context,
    job *kubeflowv1.PyTorchJob) error {

    // 收集所有 Pod 状态
    podList := &corev1.PodList{}
    if err := r.List(ctx, podList, client.InNamespace(job.Namespace),
        client.MatchingLabels{"training.kubeflow.org/job-name": job.Name}); err != nil {
        return err
    }

    // 统计状态
    running := 0
    succeeded := 0
    failed := 0

    for _, pod := range podList.Items {
        switch pod.Status.Phase {
        case corev1.PodRunning:
            running++
        case corev1.PodSucceeded:
            succeeded++
        case corev1.PodFailed:
            failed++
        }
    }

    // 更新 job 状态
    totalReplicas := r.calculateReplicas(job)

    if failed > 0 {
        job.Status.Conditions = append(job.Status.Conditions, kubeflowv1.JobCondition{
            Type:    kubeflowv1.JobFailed,
            Status:  corev1.ConditionTrue,
            Reason:  "PodFailed",
            Message: fmt.Sprintf("%d pods failed", failed),
        })
    } else if int32(succeeded) == totalReplicas {
        job.Status.Conditions = append(job.Status.Conditions, kubeflowv1.JobCondition{
            Type:    kubeflowv1.JobSucceeded,
            Status:  corev1.ConditionTrue,
            Reason:  "AllPodsSucceeded",
            Message: "All pods succeeded",
        })
    } else if int32(running) == totalReplicas {
        job.Status.Conditions = append(job.Status.Conditions, kubeflowv1.JobCondition{
            Type:    kubeflowv1.JobRunning,
            Status:  corev1.ConditionTrue,
            Reason:  "AllPodsRunning",
            Message: "All pods are running",
        })
    }

    return r.Status().Update(ctx, job)
}

5. 性能优化

5.1 通信优化

# communication_optimization.py
import torch
import torch.distributed as dist

class CommunicationOptimizer:
    """通信优化器"""

    def __init__(self, world_size, bucket_size_mb=25):
        self.world_size = world_size
        self.bucket_size = bucket_size_mb * 1024 * 1024  # 转换为字节

    def optimize_allreduce(self, gradients):
        """优化 AllReduce"""
        # 1. 梯度分桶
        buckets = self._bucket_gradients(gradients)

        # 2. 异步 AllReduce
        handles = []
        for bucket in buckets:
            handle = dist.all_reduce(bucket, async_op=True)
            handles.append(handle)

        # 3. 等待完成
        for handle in handles:
            handle.wait()

        # 4. 平均梯度
        for grad in gradients:
            grad.div_(self.world_size)

    def _bucket_gradients(self, gradients):
        """将梯度分桶"""
        buckets = []
        current_bucket = []
        current_size = 0

        for grad in gradients:
            grad_size = grad.numel() * grad.element_size()

            if current_size + grad_size > self.bucket_size:
                # 当前桶满了,创建新桶
                if current_bucket:
                    buckets.append(torch.cat([g.flatten() for g in current_bucket]))
                current_bucket = [grad]
                current_size = grad_size
            else:
                current_bucket.append(grad)
                current_size += grad_size

        # 处理最后一个桶
        if current_bucket:
            buckets.append(torch.cat([g.flatten() for g in current_bucket]))

        return buckets

    def gradient_compression(self, gradient, ratio=0.01):
        """梯度压缩(TopK)"""
        flat_grad = gradient.flatten()
        k = max(1, int(flat_grad.numel() * ratio))

        # 选择 TopK
        values, indices = torch.topk(flat_grad.abs(), k)
        signs = torch.sign(flat_grad[indices])

        # 返回压缩后的梯度
        return {
            'values': values * signs,
            'indices': indices,
            'shape': gradient.shape,
        }

    def decompress_gradient(self, compressed, device):
        """解压梯度"""
        gradient = torch.zeros(compressed['shape'], device=device)
        gradient.flatten()[compressed['indices']] = compressed['values']
        return gradient


class OverlapCommunicationComputation:
    """通信计算重叠"""

    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer
        self.handles = []

    def backward_with_overlap(self, loss):
        """带重叠的反向传播"""
        # 注册 hook 在反向传播时启动通信
        hooks = []
        for param in self.model.parameters():
            if param.requires_grad:
                hook = param.register_hook(self._create_allreduce_hook(param))
                hooks.append(hook)

        # 执行反向传播
        loss.backward()

        # 等待所有通信完成
        for handle in self.handles:
            handle.wait()
        self.handles.clear()

        # 移除 hooks
        for hook in hooks:
            hook.remove()

    def _create_allreduce_hook(self, param):
        def hook(grad):
            # 异步启动 AllReduce
            handle = dist.all_reduce(grad, async_op=True)
            self.handles.append(handle)
            return grad
        return hook

5.2 内存优化

# memory_optimization.py
import torch
from torch.utils.checkpoint import checkpoint

class MemoryOptimizer:
    """内存优化器"""

    @staticmethod
    def activation_checkpointing(model, segments=4):
        """激活检查点"""
        class CheckpointedModule(torch.nn.Module):
            def __init__(self, module, num_segments):
                super().__init__()
                self.module = module
                self.num_segments = num_segments

            def forward(self, x):
                # 将模块分成多个段
                layers = list(self.module.children())
                segment_size = len(layers) // self.num_segments

                for i in range(0, len(layers), segment_size):
                    segment = torch.nn.Sequential(*layers[i:i+segment_size])
                    # 使用检查点
                    x = checkpoint(segment, x, use_reentrant=False)

                return x

        return CheckpointedModule(model, segments)

    @staticmethod
    def gradient_accumulation(model, dataloader, accumulation_steps,
                            optimizer, loss_fn):
        """梯度累积"""
        model.train()
        optimizer.zero_grad()

        for i, (inputs, targets) in enumerate(dataloader):
            inputs = inputs.cuda()
            targets = targets.cuda()

            # 前向传播
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            # 缩放 loss(模拟更大的 batch)
            loss = loss / accumulation_steps
            loss.backward()

            # 每 accumulation_steps 步更新一次
            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

    @staticmethod
    def offload_optimizer_states(optimizer):
        """将优化器状态卸载到 CPU"""
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cpu()

    @staticmethod
    def load_optimizer_states_to_gpu(optimizer, device):
        """将优化器状态加载到 GPU"""
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)


class MixedPrecisionTraining:
    """混合精度训练"""

    def __init__(self, model, optimizer, loss_scale='dynamic'):
        self.model = model
        self.optimizer = optimizer

        # 创建 GradScaler
        self.scaler = torch.cuda.amp.GradScaler(
            init_scale=65536.0,
            growth_factor=2.0,
            backoff_factor=0.5,
            growth_interval=2000,
            enabled=True,
        )

    def train_step(self, inputs, targets, loss_fn):
        """混合精度训练步骤"""
        self.optimizer.zero_grad()

        # 使用 autocast 进行前向传播
        with torch.cuda.amp.autocast(dtype=torch.float16):
            outputs = self.model(inputs)
            loss = loss_fn(outputs, targets)

        # 缩放 loss 并反向传播
        self.scaler.scale(loss).backward()

        # 反缩放梯度并更新
        self.scaler.step(self.optimizer)

        # 更新缩放因子
        self.scaler.update()

        return loss.item()

总结

本章深入讲解了分布式训练框架的核心技术:

  1. 并行策略:数据并行、张量并行、流水线并行及其组合
  2. 通信原语:AllReduce、AllGather 等集合通信操作及算法选择
  3. PyTorch 分布式:DDP、FSDP、弹性训练的实现
  4. Megatron-LM:张量并行和 3D 并行的深度实现
  5. Kubernetes 集成:PyTorchJob CRD 和训练 Operator
  6. 性能优化:通信优化、内存优化、混合精度训练

分布式训练是大模型时代的核心基础设施,深入理解其原理对于优化训练效率至关重要。

下一章我们将探讨 训练任务调度,讲解如何在 Kubernetes 上高效管理大规模训练任务。

Next
训练任务调度