分布式训练框架
概述
随着模型规模的爆发式增长(GPT-4 估计超过 1.7 万亿参数),单机训练已无法满足需求。分布式训练通过将模型和数据分布到多个 GPU/节点上并行处理,是训练大规模 AI 模型的核心技术。本章深入讲解分布式训练的原理、主流框架的实现机制以及在 Kubernetes 上的部署实践。
1. 分布式训练基础
1.1 并行策略
┌──────────────────────────────────────────────────────────────────────────┐
│ Distributed Training Strategies │
├──────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Data Parallelism │ │
│ │ │ │
│ │ Model Copy Model Copy Model Copy Model Copy │ │
│ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │
│ │ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ GPU 3 │ │ │
│ │ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │ │
│ │ │ │ │ │ │ │
│ │ Data Batch 0 Data Batch 1 Data Batch 2 Data Batch 3 │ │
│ │ │ │ │
│ │ AllReduce Gradients │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Model Parallelism │ │
│ │ │ │
│ │ ┌──────────────────────────────────────────────────────────────┐ │ │
│ │ │ Pipeline Parallelism │ │ │
│ │ │ │ │ │
│ │ │ GPU 0 GPU 1 GPU 2 GPU 3 │ │ │
│ │ │ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │ │ │
│ │ │ │Layer│ ───▶ │Layer│ ───▶ │Layer│ ───▶ │Layer│ │ │ │
│ │ │ │ 1-6 │ │7-12 │ │13-18│ │19-24│ │ │ │
│ │ │ └─────┘ └─────┘ └─────┘ └─────┘ │ │ │
│ │ │ │ │ │
│ │ │ Stage 0 Stage 1 Stage 2 Stage 3 │ │ │
│ │ └──────────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ ┌──────────────────────────────────────────────────────────────┐ │ │
│ │ │ Tensor Parallelism │ │ │
│ │ │ │ │ │
│ │ │ Matrix A (4096 x 4096) │ │ │
│ │ │ ┌──────────┬──────────┬──────────┬──────────┐ │ │ │
│ │ │ │ GPU 0 │ GPU 1 │ GPU 2 │ GPU 3 │ │ │ │
│ │ │ │ Col 0-1K │ Col 1-2K │ Col 2-3K │ Col 3-4K │ │ │ │
│ │ │ └──────────┴──────────┴──────────┴──────────┘ │ │ │
│ │ │ │ │ │
│ │ │ AllGather / ReduceScatter │ │ │
│ │ └──────────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ 3D/4D Hybrid Parallelism │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────────────────────┐ │ │
│ │ │ Data Parallel Group 0 Data Parallel Group 1 │ │ │
│ │ │ ┌───────────────────┐ ┌───────────────────┐ │ │ │
│ │ │ │ Pipeline Stage 0 │ │ Pipeline Stage 0 │ │ │ │
│ │ │ │ TP: GPU 0,1 │ │ TP: GPU 4,5 │ │ │ │
│ │ │ ├───────────────────┤ ├───────────────────┤ │ │ │
│ │ │ │ Pipeline Stage 1 │ │ Pipeline Stage 1 │ │ │ │
│ │ │ │ TP: GPU 2,3 │ │ TP: GPU 6,7 │ │ │ │
│ │ │ └───────────────────┘ └───────────────────┘ │ │ │
│ │ └─────────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────┘
1.2 通信原语
// pkg/distributed/collective.go
package distributed
import (
"fmt"
)
// CollectiveOp 集合通信操作
type CollectiveOp string
const (
// OpAllReduce 全规约:所有进程得到相同的规约结果
OpAllReduce CollectiveOp = "AllReduce"
// OpAllGather 全收集:所有进程得到所有数据的拼接
OpAllGather CollectiveOp = "AllGather"
// OpReduceScatter 规约分散:规约后分散到各进程
OpReduceScatter CollectiveOp = "ReduceScatter"
// OpBroadcast 广播:一个进程发送,所有进程接收
OpBroadcast CollectiveOp = "Broadcast"
// OpReduce 规约:所有进程发送,一个进程接收规约结果
OpReduce CollectiveOp = "Reduce"
// OpScatter 分散:一个进程发送不同数据到各进程
OpScatter CollectiveOp = "Scatter"
// OpGather 收集:所有进程发送,一个进程接收拼接
OpGather CollectiveOp = "Gather"
// OpAllToAll 全交换:所有进程交换数据
OpAllToAll CollectiveOp = "AllToAll"
)
// CollectiveAlgorithm 集合通信算法
type CollectiveAlgorithm string
const (
// AlgoRing 环形算法
AlgoRing CollectiveAlgorithm = "Ring"
// AlgoTree 树形算法
AlgoTree CollectiveAlgorithm = "Tree"
// AlgoRecursiveHalving 递归减半
AlgoRecursiveHalving CollectiveAlgorithm = "RecursiveHalving"
// AlgoRecursiveDoubling 递归加倍
AlgoRecursiveDoubling CollectiveAlgorithm = "RecursiveDoubling"
// AlgoBinaryBlock 二进制分块
AlgoBinaryBlock CollectiveAlgorithm = "BinaryBlock"
)
// CommunicationCost 通信开销模型
type CommunicationCost struct {
// 启动延迟 (秒)
Alpha float64
// 传输时间每字节 (秒/字节)
Beta float64
// 计算时间每字节 (秒/字节)
Gamma float64
}
// CalculateRingAllReduceTime 计算 Ring AllReduce 时间
// 参数:n=进程数, m=数据大小(字节)
func (c *CommunicationCost) CalculateRingAllReduceTime(n int, m int64) float64 {
// Ring AllReduce 分为两个阶段:
// 1. Reduce-Scatter: (n-1) 步,每步传输 m/n 数据
// 2. All-Gather: (n-1) 步,每步传输 m/n 数据
// 总时间 = 2(n-1) * (α + m/n * β) + (n-1) * m/n * γ
steps := float64(n - 1)
chunkSize := float64(m) / float64(n)
// 通信时间
commTime := 2 * steps * (c.Alpha + chunkSize*c.Beta)
// 计算时间(规约操作)
computeTime := steps * chunkSize * c.Gamma
return commTime + computeTime
}
// CalculateTreeAllReduceTime 计算 Tree AllReduce 时间
func (c *CommunicationCost) CalculateTreeAllReduceTime(n int, m int64) float64 {
// Tree AllReduce:
// 1. Reduce: log(n) 步,每步传输 m 数据
// 2. Broadcast: log(n) 步,每步传输 m 数据
// 总时间 = 2 * log(n) * (α + m * β) + log(n) * m * γ
import "math"
steps := math.Log2(float64(n))
dataSize := float64(m)
commTime := 2 * steps * (c.Alpha + dataSize*c.Beta)
computeTime := steps * dataSize * c.Gamma
return commTime + computeTime
}
// SelectBestAlgorithm 选择最优算法
func (c *CommunicationCost) SelectBestAlgorithm(op CollectiveOp, n int, m int64) CollectiveAlgorithm {
switch op {
case OpAllReduce:
ringTime := c.CalculateRingAllReduceTime(n, m)
treeTime := c.CalculateTreeAllReduceTime(n, m)
if ringTime < treeTime {
return AlgoRing
}
return AlgoTree
case OpBroadcast, OpReduce:
// 树形算法对广播/规约更高效
return AlgoTree
case OpAllGather, OpReduceScatter:
// Ring 算法带宽利用率高
return AlgoRing
default:
return AlgoRing
}
}
1.3 数据并行实现
# distributed_data_parallel.py
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
class DataParallelTrainer:
"""数据并行训练器"""
def __init__(self, model, optimizer_class, lr, backend='nccl'):
self.backend = backend
self.lr = lr
self.optimizer_class = optimizer_class
# 初始化分布式环境
self._init_distributed()
# 包装模型
self.model = self._setup_model(model)
# 创建优化器
self.optimizer = optimizer_class(self.model.parameters(), lr=lr)
def _init_distributed(self):
"""初始化分布式环境"""
# 从环境变量获取配置
self.world_size = int(os.environ.get('WORLD_SIZE', 1))
self.rank = int(os.environ.get('RANK', 0))
self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
# 初始化进程组
if not dist.is_initialized():
dist.init_process_group(
backend=self.backend,
init_method='env://'
)
# 设置当前设备
torch.cuda.set_device(self.local_rank)
print(f"Initialized rank {self.rank}/{self.world_size}, "
f"local_rank {self.local_rank}")
def _setup_model(self, model):
"""设置分布式模型"""
# 移动模型到 GPU
model = model.cuda(self.local_rank)
# 使用 DDP 包装
model = DDP(
model,
device_ids=[self.local_rank],
output_device=self.local_rank,
# 梯度桶大小(MB)
bucket_cap_mb=25,
# 是否查找未使用的参数
find_unused_parameters=False,
# 梯度作为桶视图(内存优化)
gradient_as_bucket_view=True,
)
return model
def train_step(self, batch):
"""单步训练"""
self.model.train()
inputs, targets = batch
inputs = inputs.cuda(self.local_rank, non_blocking=True)
targets = targets.cuda(self.local_rank, non_blocking=True)
# 前向传播
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = nn.functional.cross_entropy(outputs, targets)
# 反向传播(DDP 自动同步梯度)
loss.backward()
# 更新参数
self.optimizer.step()
return loss.item()
def save_checkpoint(self, path, epoch):
"""保存检查点(仅 rank 0)"""
if self.rank != 0:
return
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.module.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}
torch.save(checkpoint, path)
def load_checkpoint(self, path):
"""加载检查点"""
# 使用 map_location 加载到正确的设备
checkpoint = torch.load(
path,
map_location=f'cuda:{self.local_rank}'
)
self.model.module.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
class ZeRODataParallel:
"""ZeRO 优化的数据并行"""
def __init__(self, model, optimizer_class, lr, zero_stage=2):
"""
ZeRO 阶段:
- Stage 1: 优化器状态分片
- Stage 2: 优化器状态 + 梯度分片
- Stage 3: 优化器状态 + 梯度 + 参数分片
"""
self.zero_stage = zero_stage
self._init_distributed()
# 使用 DeepSpeed 或 FSDP 实现 ZeRO
if zero_stage == 3:
self.model = self._setup_fsdp(model)
else:
self.model = self._setup_deepspeed(model, zero_stage)
def _setup_fsdp(self, model):
"""使用 PyTorch FSDP (ZeRO-3)"""
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
# 混合精度配置
mixed_precision = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
# 自动包装策略
auto_wrap_policy = size_based_auto_wrap_policy(
min_num_params=1e6 # 100万参数以上的模块自动分片
)
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=self.local_rank,
)
return model
def _setup_deepspeed(self, model, stage):
"""使用 DeepSpeed (ZeRO-1/2)"""
import deepspeed
ds_config = {
"train_batch_size": 32 * self.world_size,
"gradient_accumulation_steps": 1,
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16,
},
"zero_optimization": {
"stage": stage,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
},
}
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config=ds_config,
)
return model_engine
2. PyTorch 分布式训练
2.1 torchrun 启动方式
#!/bin/bash
# launch_distributed.sh
# 单机多卡
torchrun \
--standalone \
--nproc_per_node=8 \
train.py --config config.yaml
# 多机多卡
torchrun \
--nnodes=4 \
--nproc_per_node=8 \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=29500 \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:29400 \
train.py --config config.yaml
2.2 弹性训练实现
# elastic_training.py
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.data import ElasticDistributedSampler
import os
class ElasticTrainingManager:
"""弹性训练管理器"""
def __init__(self):
self.world_size = int(os.environ.get('WORLD_SIZE', 1))
self.rank = int(os.environ.get('RANK', 0))
self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
# 初始状态
self.initial_world_size = self.world_size
self.checkpoint_path = os.environ.get('CHECKPOINT_PATH', '/checkpoints')
@record
def train(self, model, train_dataset, epochs, batch_size):
"""弹性训练主循环"""
# 初始化分布式
dist.init_process_group(backend='nccl')
torch.cuda.set_device(self.local_rank)
# 创建模型
model = model.cuda()
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self.local_rank]
)
# 创建弹性采样器
sampler = ElasticDistributedSampler(
train_dataset,
num_replicas=self.world_size,
rank=self.rank,
)
dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 加载检查点
start_epoch = self._load_checkpoint(model, optimizer)
for epoch in range(start_epoch, epochs):
# 更新采样器 epoch(保证数据随机性)
sampler.set_epoch(epoch)
# 检查 world_size 变化
current_world_size = dist.get_world_size()
if current_world_size != self.world_size:
self._handle_world_size_change(
model, optimizer, sampler, current_world_size
)
# 训练一个 epoch
self._train_epoch(model, dataloader, optimizer, epoch)
# 保存检查点
self._save_checkpoint(model, optimizer, epoch)
dist.destroy_process_group()
def _handle_world_size_change(self, model, optimizer, sampler, new_world_size):
"""处理 world_size 变化"""
print(f"World size changed: {self.world_size} -> {new_world_size}")
old_world_size = self.world_size
self.world_size = new_world_size
self.rank = dist.get_rank()
# 更新采样器
sampler.set_num_replicas(new_world_size)
sampler.set_rank(self.rank)
# 调整学习率(可选)
scale_factor = new_world_size / old_world_size
for param_group in optimizer.param_groups:
param_group['lr'] *= scale_factor
def _train_epoch(self, model, dataloader, optimizer, epoch):
"""训练一个 epoch"""
model.train()
total_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.cuda(self.local_rank, non_blocking=True)
targets = targets.cuda(self.local_rank, non_blocking=True)
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0 and self.rank == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
return total_loss / len(dataloader)
def _save_checkpoint(self, model, optimizer, epoch):
"""保存检查点"""
if self.rank != 0:
dist.barrier()
return
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'world_size': self.world_size,
}
path = os.path.join(self.checkpoint_path, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, path)
# 更新 latest 链接
latest_path = os.path.join(self.checkpoint_path, 'latest.pt')
if os.path.exists(latest_path):
os.remove(latest_path)
os.symlink(path, latest_path)
dist.barrier()
def _load_checkpoint(self, model, optimizer):
"""加载检查点"""
latest_path = os.path.join(self.checkpoint_path, 'latest.pt')
if not os.path.exists(latest_path):
return 0
checkpoint = torch.load(
latest_path,
map_location=f'cuda:{self.local_rank}'
)
model.module.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
return checkpoint['epoch']
2.3 流水线并行
# pipeline_parallel.py
import torch
import torch.nn as nn
from torch.distributed.pipeline.sync import Pipe
from torch.distributed import rpc
class PipelineParallelModel(nn.Module):
"""流水线并行模型"""
def __init__(self, num_stages, hidden_size, num_layers):
super().__init__()
self.num_stages = num_stages
layers_per_stage = num_layers // num_stages
# 创建每个阶段的层
self.stages = nn.ModuleList()
for stage_id in range(num_stages):
layers = []
for _ in range(layers_per_stage):
layers.append(nn.Linear(hidden_size, hidden_size))
layers.append(nn.ReLU())
self.stages.append(nn.Sequential(*layers))
def forward(self, x):
for stage in self.stages:
x = stage(x)
return x
class GPipeTrainer:
"""GPipe 风格的流水线并行训练"""
def __init__(self, model, num_stages, num_microbatches):
self.num_stages = num_stages
self.num_microbatches = num_microbatches
# 分割模型
self.model = self._partition_model(model)
# 使用 Pipe 包装
self.pipe_model = Pipe(
self.model,
chunks=num_microbatches,
checkpoint='never', # 或 'always' 用于激活检查点
)
def _partition_model(self, model):
"""将模型分割到多个设备"""
# 假设模型有 stages 属性
partitions = []
for stage_id, stage in enumerate(model.stages):
# 将每个阶段放到对应的 GPU
device = torch.device(f'cuda:{stage_id}')
stage = stage.to(device)
partitions.append(stage)
return nn.Sequential(*partitions)
def train_step(self, batch):
"""流水线训练步骤"""
inputs, targets = batch
# 输入放到第一个设备
inputs = inputs.to(torch.device('cuda:0'))
# 目标放到最后一个设备
targets = targets.to(torch.device(f'cuda:{self.num_stages-1}'))
# 前向传播(自动流水线)
outputs = self.pipe_model(inputs)
# 计算损失
loss = nn.functional.cross_entropy(outputs, targets)
# 反向传播
loss.backward()
return loss.item()
class Interleaved1F1B:
"""交错式 1F1B 流水线调度"""
def __init__(self, num_stages, num_microbatches):
self.num_stages = num_stages
self.num_microbatches = num_microbatches
def generate_schedule(self):
"""生成流水线调度"""
schedule = []
num_warmup = self.num_stages - 1
num_1f1b = self.num_microbatches - num_warmup
# Warmup 阶段:只有前向传播
for mb in range(num_warmup):
for stage in range(self.num_stages):
schedule.append(('forward', stage, mb))
# 1F1B 稳态阶段
for mb in range(num_1f1b):
forward_mb = num_warmup + mb
backward_mb = mb
for stage in range(self.num_stages):
schedule.append(('forward', stage, forward_mb))
schedule.append(('backward', stage, backward_mb))
# Cooldown 阶段:只有反向传播
for mb in range(num_1f1b, self.num_microbatches):
for stage in range(self.num_stages):
schedule.append(('backward', stage, mb))
return schedule
def visualize_schedule(self):
"""可视化调度"""
schedule = self.generate_schedule()
# 创建时间线
timeline = {}
for stage in range(self.num_stages):
timeline[stage] = []
time = 0
for op, stage, mb in schedule:
symbol = f"F{mb}" if op == 'forward' else f"B{mb}"
timeline[stage].append(symbol)
# 打印
print("Pipeline Schedule:")
for stage in range(self.num_stages):
print(f"Stage {stage}: {' '.join(timeline[stage])}")
3. Megatron-LM 深度解析
3.1 张量并行实现
# megatron_tensor_parallel.py
import torch
import torch.nn as nn
import torch.distributed as dist
class ColumnParallelLinear(nn.Module):
"""列并行线性层
将权重矩阵按列分割:
Y = X * W,其中 W 被分割为 [W1, W2, ..., Wn]
每个 GPU 计算 Yi = X * Wi
"""
def __init__(self, input_size, output_size, bias=True,
gather_output=True, init_method=None):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# 获取张量并行大小
self.tp_world_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# 每个 GPU 的输出大小
self.output_size_per_partition = output_size // self.tp_world_size
# 创建分片权重
self.weight = nn.Parameter(torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=torch.float32,
))
if bias:
self.bias = nn.Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=torch.float32,
))
else:
self.register_parameter('bias', None)
# 初始化权重
if init_method is not None:
init_method(self.weight)
else:
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, input_):
# 前向传播
# input_: [batch, seq_len, hidden]
# weight: [output_per_gpu, hidden]
# output: [batch, seq_len, output_per_gpu]
# 复制输入到所有 TP rank
input_parallel = copy_to_tensor_model_parallel_region(input_)
# 线性变换
output_parallel = torch.nn.functional.linear(
input_parallel, self.weight, self.bias
)
if self.gather_output:
# AllGather 输出
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
return output
class RowParallelLinear(nn.Module):
"""行并行线性层
将权重矩阵按行分割:
Y = X * W,其中 X 被分割为 [X1, X2, ..., Xn]
每个 GPU 计算 Yi = Xi * Wi
最后 AllReduce 得到 Y = sum(Yi)
"""
def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False, init_method=None):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.tp_world_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# 每个 GPU 的输入大小
self.input_size_per_partition = input_size // self.tp_world_size
# 创建分片权重
self.weight = nn.Parameter(torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=torch.float32,
))
if bias:
self.bias = nn.Parameter(torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=torch.float32,
))
else:
self.register_parameter('bias', None)
# 初始化
if init_method is not None:
init_method(self.weight)
else:
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, input_):
if self.input_is_parallel:
input_parallel = input_
else:
# 分割输入
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# 线性变换
output_parallel = torch.nn.functional.linear(input_parallel, self.weight)
# AllReduce
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
# 添加 bias
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
class ParallelTransformerLayer(nn.Module):
"""并行 Transformer 层"""
def __init__(self, hidden_size, num_attention_heads, ffn_hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
# 获取 TP 大小
tp_size = get_tensor_model_parallel_world_size()
# 确保 attention heads 可以被 TP 整除
assert num_attention_heads % tp_size == 0
# Self-Attention
# QKV 使用列并行(每个 GPU 计算部分 heads)
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=True,
gather_output=False, # 保持并行
)
# Output projection 使用行并行
self.output_proj = RowParallelLinear(
hidden_size,
hidden_size,
bias=True,
input_is_parallel=True, # 输入已经是并行的
)
# FFN
# 第一个线性层:列并行
self.ffn_up = ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
bias=True,
gather_output=False,
)
# 第二个线性层:行并行
self.ffn_down = RowParallelLinear(
ffn_hidden_size,
hidden_size,
bias=True,
input_is_parallel=True,
)
# LayerNorm
self.input_layernorm = nn.LayerNorm(hidden_size)
self.post_attention_layernorm = nn.LayerNorm(hidden_size)
def forward(self, hidden_states, attention_mask=None):
# Self-Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# QKV projection
qkv = self.qkv_proj(hidden_states)
# 分割 Q, K, V
# [batch, seq, 3 * hidden_per_tp] -> 3 * [batch, seq, hidden_per_tp]
q, k, v = qkv.chunk(3, dim=-1)
# 计算注意力(每个 GPU 计算部分 heads)
attn_output = self._compute_attention(q, k, v, attention_mask)
# Output projection(包含 AllReduce)
hidden_states = self.output_proj(attn_output)
hidden_states = residual + hidden_states
# FFN
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# FFN up(列并行)
hidden_states = self.ffn_up(hidden_states)
hidden_states = torch.nn.functional.gelu(hidden_states)
# FFN down(行并行,包含 AllReduce)
hidden_states = self.ffn_down(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def _compute_attention(self, q, k, v, mask):
"""计算注意力(简化版)"""
batch_size, seq_len, hidden_per_tp = q.shape
tp_size = get_tensor_model_parallel_world_size()
# 每个 GPU 的 heads 数
num_heads_per_tp = self.num_attention_heads // tp_size
head_dim = hidden_per_tp // num_heads_per_tp
# Reshape: [batch, seq, hidden] -> [batch, num_heads, seq, head_dim]
q = q.view(batch_size, seq_len, num_heads_per_tp, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, num_heads_per_tp, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, num_heads_per_tp, head_dim).transpose(1, 2)
# 缩放点积注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5)
if mask is not None:
scores = scores + mask
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# Reshape back: [batch, num_heads, seq, head_dim] -> [batch, seq, hidden]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, hidden_per_tp)
return attn_output
# 辅助函数
def get_tensor_model_parallel_world_size():
"""获取张量并行大小"""
return dist.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""获取张量并行 rank"""
return dist.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_group():
"""获取张量并行进程组"""
# 实际实现中需要初始化进程组
return None # placeholder
def copy_to_tensor_model_parallel_region(input_):
"""复制到张量并行区域(前向无操作,反向 AllReduce)"""
return input_ # 简化
def scatter_to_tensor_model_parallel_region(input_):
"""分散到张量并行区域"""
return input_ # 简化
def gather_from_tensor_model_parallel_region(input_):
"""从张量并行区域收集"""
return input_ # 简化
def reduce_from_tensor_model_parallel_region(input_):
"""从张量并行区域规约"""
return input_ # 简化
3.2 3D 并行配置
# megatron_3d_parallel.py
import torch
import torch.distributed as dist
class ParallelConfig:
"""3D 并行配置"""
def __init__(self, tensor_parallel_size, pipeline_parallel_size,
data_parallel_size, world_size):
"""
Args:
tensor_parallel_size: 张量并行大小
pipeline_parallel_size: 流水线并行大小
data_parallel_size: 数据并行大小
world_size: 总进程数
"""
assert (tensor_parallel_size * pipeline_parallel_size *
data_parallel_size == world_size), \
"TP * PP * DP must equal world_size"
self.tp_size = tensor_parallel_size
self.pp_size = pipeline_parallel_size
self.dp_size = data_parallel_size
self.world_size = world_size
# 进程组
self.tp_group = None
self.pp_group = None
self.dp_group = None
def initialize_parallel_groups(self, rank):
"""初始化并行进程组"""
# 计算各维度的 rank
# 假设 GPU 排列方式:[DP, PP, TP]
# 即相邻 TP_SIZE 个 GPU 组成一个 TP 组
tp_rank = rank % self.tp_size
pp_rank = (rank // self.tp_size) % self.pp_size
dp_rank = rank // (self.tp_size * self.pp_size)
# 创建张量并行组
# 同一个 PP stage 和 DP replica 内的 GPU
for dp in range(self.dp_size):
for pp in range(self.pp_size):
ranks = []
for tp in range(self.tp_size):
r = dp * self.pp_size * self.tp_size + pp * self.tp_size + tp
ranks.append(r)
group = dist.new_group(ranks)
if dp_rank == dp and pp_rank == pp:
self.tp_group = group
# 创建流水线并行组
# 同一个 DP replica 内,相同 TP rank 的 GPU
for dp in range(self.dp_size):
for tp in range(self.tp_size):
ranks = []
for pp in range(self.pp_size):
r = dp * self.pp_size * self.tp_size + pp * self.tp_size + tp
ranks.append(r)
group = dist.new_group(ranks)
if dp_rank == dp and tp_rank == tp:
self.pp_group = group
# 创建数据并行组
# 不同 DP replica 内,相同 PP stage 和 TP rank 的 GPU
for pp in range(self.pp_size):
for tp in range(self.tp_size):
ranks = []
for dp in range(self.dp_size):
r = dp * self.pp_size * self.tp_size + pp * self.tp_size + tp
ranks.append(r)
group = dist.new_group(ranks)
if pp_rank == pp and tp_rank == tp:
self.dp_group = group
# 存储 rank 信息
self.tp_rank = tp_rank
self.pp_rank = pp_rank
self.dp_rank = dp_rank
return {
'tp_rank': tp_rank,
'pp_rank': pp_rank,
'dp_rank': dp_rank,
'tp_group': self.tp_group,
'pp_group': self.pp_group,
'dp_group': self.dp_group,
}
class Megatron3DParallel:
"""Megatron 风格的 3D 并行训练"""
def __init__(self, model_config, parallel_config: ParallelConfig):
self.model_config = model_config
self.parallel_config = parallel_config
# 初始化并行组
self.rank = dist.get_rank()
self.parallel_info = parallel_config.initialize_parallel_groups(self.rank)
# 创建模型
self.model = self._create_model()
def _create_model(self):
"""创建分布式模型"""
# 根据 PP rank 创建对应的层
num_layers = self.model_config['num_layers']
layers_per_stage = num_layers // self.parallel_config.pp_size
start_layer = self.parallel_info['pp_rank'] * layers_per_stage
end_layer = start_layer + layers_per_stage
layers = []
for layer_id in range(start_layer, end_layer):
layer = ParallelTransformerLayer(
hidden_size=self.model_config['hidden_size'],
num_attention_heads=self.model_config['num_attention_heads'],
ffn_hidden_size=self.model_config['ffn_hidden_size'],
)
layers.append(layer)
# 如果是第一个 stage,添加 embedding
if self.parallel_info['pp_rank'] == 0:
embedding = self._create_embedding()
layers.insert(0, embedding)
# 如果是最后一个 stage,添加 output layer
if self.parallel_info['pp_rank'] == self.parallel_config.pp_size - 1:
output_layer = self._create_output_layer()
layers.append(output_layer)
return nn.Sequential(*layers)
def train_step(self, batch):
"""一次训练迭代"""
# 根据 PP rank 执行对应的流水线阶段
if self.parallel_info['pp_rank'] == 0:
# 第一个 stage:从数据加载
hidden_states = self.model[0](batch)
# 发送到下一个 stage
self._send_forward(hidden_states)
elif self.parallel_info['pp_rank'] == self.parallel_config.pp_size - 1:
# 最后一个 stage:计算 loss
hidden_states = self._recv_forward()
output = self.model[-1](hidden_states)
loss = self._compute_loss(output, batch)
# 发送梯度到上一个 stage
self._send_backward(hidden_states.grad)
return loss
else:
# 中间 stage
hidden_states = self._recv_forward()
output = self.model(hidden_states)
self._send_forward(output)
grad = self._recv_backward()
output.backward(grad)
self._send_backward(hidden_states.grad)
def _send_forward(self, tensor):
"""发送前向激活"""
next_rank = self.rank + self.parallel_config.tp_size
dist.send(tensor, next_rank, group=self.parallel_config.pp_group)
def _recv_forward(self):
"""接收前向激活"""
prev_rank = self.rank - self.parallel_config.tp_size
tensor = torch.empty(...) # 需要知道形状
dist.recv(tensor, prev_rank, group=self.parallel_config.pp_group)
return tensor
def _send_backward(self, grad):
"""发送反向梯度"""
prev_rank = self.rank - self.parallel_config.tp_size
dist.send(grad, prev_rank, group=self.parallel_config.pp_group)
def _recv_backward(self):
"""接收反向梯度"""
next_rank = self.rank + self.parallel_config.tp_size
grad = torch.empty(...) # 需要知道形状
dist.recv(grad, next_rank, group=self.parallel_config.pp_group)
return grad
4. Kubernetes 上的分布式训练
4.1 PyTorchJob CRD
# pytorchjob.yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
name: llm-training
namespace: ml-workloads
spec:
pytorchReplicaSpecs:
Master:
replicas: 1
restartPolicy: OnFailure
template:
metadata:
annotations:
gpu.topology/policy: nvlink
gpu.topology/min-bandwidth: "200"
spec:
schedulerName: gpu-scheduler
containers:
- name: pytorch
image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
imagePullPolicy: Always
command:
- python
- -m
- torch.distributed.run
- --nproc_per_node=8
- --nnodes=$(WORLD_SIZE)
- --node_rank=$(RANK)
- --master_addr=$(MASTER_ADDR)
- --master_port=$(MASTER_PORT)
- train.py
- --config=config/llm.yaml
resources:
limits:
nvidia.com/gpu: 8
memory: 512Gi
requests:
nvidia.com/gpu: 8
cpu: "64"
memory: 512Gi
env:
- name: NCCL_DEBUG
value: INFO
- name: NCCL_IB_DISABLE
value: "0"
- name: NCCL_NET_GDR_LEVEL
value: "5"
volumeMounts:
- name: dataset
mountPath: /data
- name: checkpoints
mountPath: /checkpoints
- name: shm
mountPath: /dev/shm
volumes:
- name: dataset
persistentVolumeClaim:
claimName: training-dataset
- name: checkpoints
persistentVolumeClaim:
claimName: checkpoints
- name: shm
emptyDir:
medium: Memory
sizeLimit: 64Gi
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchLabels:
training.kubeflow.org/job-name: llm-training
topologyKey: kubernetes.io/hostname
Worker:
replicas: 3
restartPolicy: OnFailure
template:
metadata:
annotations:
gpu.topology/policy: nvlink
spec:
schedulerName: gpu-scheduler
containers:
- name: pytorch
image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
command:
- python
- -m
- torch.distributed.run
- --nproc_per_node=8
- --nnodes=$(WORLD_SIZE)
- --node_rank=$(RANK)
- --master_addr=$(MASTER_ADDR)
- --master_port=$(MASTER_PORT)
- train.py
- --config=config/llm.yaml
resources:
limits:
nvidia.com/gpu: 8
memory: 512Gi
requests:
nvidia.com/gpu: 8
cpu: "64"
memory: 512Gi
env:
- name: NCCL_DEBUG
value: INFO
volumeMounts:
- name: dataset
mountPath: /data
- name: checkpoints
mountPath: /checkpoints
- name: shm
mountPath: /dev/shm
volumes:
- name: dataset
persistentVolumeClaim:
claimName: training-dataset
readOnly: true
- name: checkpoints
persistentVolumeClaim:
claimName: checkpoints
- name: shm
emptyDir:
medium: Memory
sizeLimit: 64Gi
4.2 训练 Operator 实现
// pkg/controller/pytorchjob_controller.go
package controller
import (
"context"
"fmt"
"strings"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
"sigs.k8s.io/controller-runtime/pkg/source"
)
// PyTorchJobReconciler reconciles a PyTorchJob object
type PyTorchJobReconciler struct {
client.Client
Scheme *runtime.Scheme
}
// Reconcile 协调 PyTorchJob
func (r *PyTorchJobReconciler) Reconcile(ctx context.Context,
req reconcile.Request) (reconcile.Result, error) {
// 获取 PyTorchJob
job := &kubeflowv1.PyTorchJob{}
if err := r.Get(ctx, req.NamespacedName, job); err != nil {
return reconcile.Result{}, client.IgnoreNotFound(err)
}
// 计算副本数
replicas := r.calculateReplicas(job)
// 创建或更新 Pods
for replicaType, spec := range job.Spec.PyTorchReplicaSpecs {
for i := int32(0); i < *spec.Replicas; i++ {
podName := fmt.Sprintf("%s-%s-%d", job.Name, strings.ToLower(string(replicaType)), i)
// 检查 Pod 是否存在
existingPod := &corev1.Pod{}
err := r.Get(ctx, client.ObjectKey{
Namespace: job.Namespace,
Name: podName,
}, existingPod)
if err != nil {
// 创建 Pod
pod := r.createPod(job, replicaType, i, replicas)
if err := r.Create(ctx, pod); err != nil {
return reconcile.Result{}, err
}
}
}
}
// 更新状态
if err := r.updateStatus(ctx, job); err != nil {
return reconcile.Result{}, err
}
return reconcile.Result{}, nil
}
// calculateReplicas 计算总副本数
func (r *PyTorchJobReconciler) calculateReplicas(job *kubeflowv1.PyTorchJob) int32 {
total := int32(0)
for _, spec := range job.Spec.PyTorchReplicaSpecs {
if spec.Replicas != nil {
total += *spec.Replicas
}
}
return total
}
// createPod 创建 Pod
func (r *PyTorchJobReconciler) createPod(job *kubeflowv1.PyTorchJob,
replicaType kubeflowv1.ReplicaType, index int32, totalReplicas int32) *corev1.Pod {
spec := job.Spec.PyTorchReplicaSpecs[replicaType]
podName := fmt.Sprintf("%s-%s-%d", job.Name, strings.ToLower(string(replicaType)), index)
// 计算 rank
rank := r.calculateRank(job, replicaType, index)
// 获取 master 地址
masterAddr := fmt.Sprintf("%s-master-0", job.Name)
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
Namespace: job.Namespace,
Labels: map[string]string{
"training.kubeflow.org/job-name": job.Name,
"training.kubeflow.org/replica-type": string(replicaType),
"training.kubeflow.org/replica-index": fmt.Sprintf("%d", index),
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(job, kubeflowv1.SchemeGroupVersion.WithKind("PyTorchJob")),
},
},
Spec: *spec.Template.Spec.DeepCopy(),
}
// 注入环境变量
for i := range pod.Spec.Containers {
pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env,
corev1.EnvVar{Name: "WORLD_SIZE", Value: fmt.Sprintf("%d", totalReplicas)},
corev1.EnvVar{Name: "RANK", Value: fmt.Sprintf("%d", rank)},
corev1.EnvVar{Name: "MASTER_ADDR", Value: masterAddr},
corev1.EnvVar{Name: "MASTER_PORT", Value: "29500"},
)
}
return pod
}
// calculateRank 计算 rank
func (r *PyTorchJobReconciler) calculateRank(job *kubeflowv1.PyTorchJob,
replicaType kubeflowv1.ReplicaType, index int32) int32 {
// Master rank 为 0
if replicaType == kubeflowv1.PyTorchJobReplicaTypeMaster {
return 0
}
// Worker rank 从 1 开始
masterReplicas := int32(0)
if spec, ok := job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok {
if spec.Replicas != nil {
masterReplicas = *spec.Replicas
}
}
return masterReplicas + index
}
// updateStatus 更新状态
func (r *PyTorchJobReconciler) updateStatus(ctx context.Context,
job *kubeflowv1.PyTorchJob) error {
// 收集所有 Pod 状态
podList := &corev1.PodList{}
if err := r.List(ctx, podList, client.InNamespace(job.Namespace),
client.MatchingLabels{"training.kubeflow.org/job-name": job.Name}); err != nil {
return err
}
// 统计状态
running := 0
succeeded := 0
failed := 0
for _, pod := range podList.Items {
switch pod.Status.Phase {
case corev1.PodRunning:
running++
case corev1.PodSucceeded:
succeeded++
case corev1.PodFailed:
failed++
}
}
// 更新 job 状态
totalReplicas := r.calculateReplicas(job)
if failed > 0 {
job.Status.Conditions = append(job.Status.Conditions, kubeflowv1.JobCondition{
Type: kubeflowv1.JobFailed,
Status: corev1.ConditionTrue,
Reason: "PodFailed",
Message: fmt.Sprintf("%d pods failed", failed),
})
} else if int32(succeeded) == totalReplicas {
job.Status.Conditions = append(job.Status.Conditions, kubeflowv1.JobCondition{
Type: kubeflowv1.JobSucceeded,
Status: corev1.ConditionTrue,
Reason: "AllPodsSucceeded",
Message: "All pods succeeded",
})
} else if int32(running) == totalReplicas {
job.Status.Conditions = append(job.Status.Conditions, kubeflowv1.JobCondition{
Type: kubeflowv1.JobRunning,
Status: corev1.ConditionTrue,
Reason: "AllPodsRunning",
Message: "All pods are running",
})
}
return r.Status().Update(ctx, job)
}
5. 性能优化
5.1 通信优化
# communication_optimization.py
import torch
import torch.distributed as dist
class CommunicationOptimizer:
"""通信优化器"""
def __init__(self, world_size, bucket_size_mb=25):
self.world_size = world_size
self.bucket_size = bucket_size_mb * 1024 * 1024 # 转换为字节
def optimize_allreduce(self, gradients):
"""优化 AllReduce"""
# 1. 梯度分桶
buckets = self._bucket_gradients(gradients)
# 2. 异步 AllReduce
handles = []
for bucket in buckets:
handle = dist.all_reduce(bucket, async_op=True)
handles.append(handle)
# 3. 等待完成
for handle in handles:
handle.wait()
# 4. 平均梯度
for grad in gradients:
grad.div_(self.world_size)
def _bucket_gradients(self, gradients):
"""将梯度分桶"""
buckets = []
current_bucket = []
current_size = 0
for grad in gradients:
grad_size = grad.numel() * grad.element_size()
if current_size + grad_size > self.bucket_size:
# 当前桶满了,创建新桶
if current_bucket:
buckets.append(torch.cat([g.flatten() for g in current_bucket]))
current_bucket = [grad]
current_size = grad_size
else:
current_bucket.append(grad)
current_size += grad_size
# 处理最后一个桶
if current_bucket:
buckets.append(torch.cat([g.flatten() for g in current_bucket]))
return buckets
def gradient_compression(self, gradient, ratio=0.01):
"""梯度压缩(TopK)"""
flat_grad = gradient.flatten()
k = max(1, int(flat_grad.numel() * ratio))
# 选择 TopK
values, indices = torch.topk(flat_grad.abs(), k)
signs = torch.sign(flat_grad[indices])
# 返回压缩后的梯度
return {
'values': values * signs,
'indices': indices,
'shape': gradient.shape,
}
def decompress_gradient(self, compressed, device):
"""解压梯度"""
gradient = torch.zeros(compressed['shape'], device=device)
gradient.flatten()[compressed['indices']] = compressed['values']
return gradient
class OverlapCommunicationComputation:
"""通信计算重叠"""
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
self.handles = []
def backward_with_overlap(self, loss):
"""带重叠的反向传播"""
# 注册 hook 在反向传播时启动通信
hooks = []
for param in self.model.parameters():
if param.requires_grad:
hook = param.register_hook(self._create_allreduce_hook(param))
hooks.append(hook)
# 执行反向传播
loss.backward()
# 等待所有通信完成
for handle in self.handles:
handle.wait()
self.handles.clear()
# 移除 hooks
for hook in hooks:
hook.remove()
def _create_allreduce_hook(self, param):
def hook(grad):
# 异步启动 AllReduce
handle = dist.all_reduce(grad, async_op=True)
self.handles.append(handle)
return grad
return hook
5.2 内存优化
# memory_optimization.py
import torch
from torch.utils.checkpoint import checkpoint
class MemoryOptimizer:
"""内存优化器"""
@staticmethod
def activation_checkpointing(model, segments=4):
"""激活检查点"""
class CheckpointedModule(torch.nn.Module):
def __init__(self, module, num_segments):
super().__init__()
self.module = module
self.num_segments = num_segments
def forward(self, x):
# 将模块分成多个段
layers = list(self.module.children())
segment_size = len(layers) // self.num_segments
for i in range(0, len(layers), segment_size):
segment = torch.nn.Sequential(*layers[i:i+segment_size])
# 使用检查点
x = checkpoint(segment, x, use_reentrant=False)
return x
return CheckpointedModule(model, segments)
@staticmethod
def gradient_accumulation(model, dataloader, accumulation_steps,
optimizer, loss_fn):
"""梯度累积"""
model.train()
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(dataloader):
inputs = inputs.cuda()
targets = targets.cuda()
# 前向传播
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# 缩放 loss(模拟更大的 batch)
loss = loss / accumulation_steps
loss.backward()
# 每 accumulation_steps 步更新一次
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
@staticmethod
def offload_optimizer_states(optimizer):
"""将优化器状态卸载到 CPU"""
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cpu()
@staticmethod
def load_optimizer_states_to_gpu(optimizer, device):
"""将优化器状态加载到 GPU"""
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
class MixedPrecisionTraining:
"""混合精度训练"""
def __init__(self, model, optimizer, loss_scale='dynamic'):
self.model = model
self.optimizer = optimizer
# 创建 GradScaler
self.scaler = torch.cuda.amp.GradScaler(
init_scale=65536.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True,
)
def train_step(self, inputs, targets, loss_fn):
"""混合精度训练步骤"""
self.optimizer.zero_grad()
# 使用 autocast 进行前向传播
with torch.cuda.amp.autocast(dtype=torch.float16):
outputs = self.model(inputs)
loss = loss_fn(outputs, targets)
# 缩放 loss 并反向传播
self.scaler.scale(loss).backward()
# 反缩放梯度并更新
self.scaler.step(self.optimizer)
# 更新缩放因子
self.scaler.update()
return loss.item()
总结
本章深入讲解了分布式训练框架的核心技术:
- 并行策略:数据并行、张量并行、流水线并行及其组合
- 通信原语:AllReduce、AllGather 等集合通信操作及算法选择
- PyTorch 分布式:DDP、FSDP、弹性训练的实现
- Megatron-LM:张量并行和 3D 并行的深度实现
- Kubernetes 集成:PyTorchJob CRD 和训练 Operator
- 性能优化:通信优化、内存优化、混合精度训练
分布式训练是大模型时代的核心基础设施,深入理解其原理对于优化训练效率至关重要。
下一章我们将探讨 训练任务调度,讲解如何在 Kubernetes 上高效管理大规模训练任务。