02-DeepSpeed源码深度解析
概述
DeepSpeed是Microsoft开发的深度学习优化库,提供了ZeRO、3D并行、推理优化等核心技术。本章深入解析DeepSpeed的核心源码实现,帮助理解大规模模型训练的底层机制。
DeepSpeed 架构
整体架构
┌─────────────────────────────────────────────────────────────────────────┐
│ DeepSpeed 架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 用户 API │ │
│ │ │ │
│ │ deepspeed.initialize() model_engine.step() │ │
│ │ deepspeed.config model_engine.backward() │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ DeepSpeed Engine │ │
│ │ │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │ │
│ │ │ ZeRO │ │ 混合精度 │ │ 激活检查点 │ │ 通信 │ │ │
│ │ │ Optimizer │ │ FP16/BF16 │ │ Checkpointing│ │ 后端 │ │ │
│ │ │ Stage 1-3 │ │ Loss Scale│ │ │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 并行策略层 │ │
│ │ │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │ │
│ │ │ Data Parallel│ │Tensor Parallel│ │Pipeline │ │ Expert │ │ │
│ │ │ (DP) │ │ (TP) │ │ Parallel(PP)│ │Parallel │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 通信基础设施 │ │
│ │ │ │
│ │ NCCL torch.distributed MPI │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
源码目录结构
deepspeed/
├── __init__.py # 入口
├── runtime/
│ ├── engine.py # DeepSpeedEngine
│ ├── zero/
│ │ ├── stage_1_and_2.py # ZeRO Stage 1 & 2
│ │ ├── stage3.py # ZeRO Stage 3
│ │ └── partition_parameters.py
│ ├── pipe/
│ │ └── engine.py # Pipeline Engine
│ ├── activation_checkpointing/
│ │ └── checkpointing.py
│ └── fp16/
│ └── fused_optimizer.py
├── ops/
│ ├── adam/ # Fused Adam
│ ├── transformer/ # Fused Transformer
│ └── sparse_attention/ # 稀疏注意力
├── moe/ # Mixture of Experts
├── inference/ # 推理优化
└── comm/ # 通信工具
ZeRO 优化器源码
ZeRO 原理回顾
┌─────────────────────────────────────────────────────────────────────────┐
│ ZeRO 优化级别 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 显存占用 (以Adam为例, 假设模型参数Ψ): │
│ ═══════════════════════════════════════ │
│ │
│ • 参数 (FP16): 2Ψ bytes │
│ • 梯度 (FP16): 2Ψ bytes │
│ • 优化器状态 (FP32): │
│ - Adam: 参数副本(4Ψ) + momentum(4Ψ) + variance(4Ψ) = 12Ψ bytes │
│ │
│ 总计: 2Ψ + 2Ψ + 12Ψ = 16Ψ bytes (每GPU) │
│ │
│ ┌────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Stage 0 (无ZeRO): │ │
│ │ 每GPU: 参数(2Ψ) + 梯度(2Ψ) + 优化器(12Ψ) = 16Ψ │ │
│ │ │ │
│ │ Stage 1 (优化器状态分片): │ │
│ │ 每GPU: 参数(2Ψ) + 梯度(2Ψ) + 优化器(12Ψ/N) ≈ 4Ψ + 12Ψ/N │ │
│ │ │ │
│ │ Stage 2 (+ 梯度分片): │ │
│ │ 每GPU: 参数(2Ψ) + 梯度(2Ψ/N) + 优化器(12Ψ/N) ≈ 2Ψ + 14Ψ/N │ │
│ │ │ │
│ │ Stage 3 (+ 参数分片): │ │
│ │ 每GPU: 参数(2Ψ/N) + 梯度(2Ψ/N) + 优化器(12Ψ/N) = 16Ψ/N │ │
│ │ │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ 通信开销: │
│ ──────── │
│ • Stage 1: 与DDP相同 (AllReduce梯度) │
│ • Stage 2: ReduceScatter梯度 + AllGather参数 │
│ • Stage 3: 每层 AllGather参数 (forward/backward) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
ZeRO Stage 1 & 2 实现
# deepspeed/runtime/zero/stage_1_and_2.py
class DeepSpeedZeroOptimizer(ZeROOptimizer):
"""
ZeRO Stage 1 和 Stage 2 优化器
"""
def __init__(
self,
init_optimizer, # 基础优化器 (如Adam)
param_names, # 参数名称
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000, # 500MB
allgather_bucket_size=500000000,
dp_process_group=None,
expert_parallel_group=None,
expert_data_parallel_group=None,
reduce_scatter=True, # True = Stage 2, False = Stage 1
overlap_comm=False,
cpu_offload=False,
mpu=None,
clip_grad=0.0,
communication_data_type=torch.float16,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
):
super().__init__()
self.optimizer = init_optimizer
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
# 分布式设置
self.dp_process_group = dp_process_group
self.dp_world_size = dist.get_world_size(group=dp_process_group)
self.dp_rank = dist.get_rank(group=dp_process_group)
# 参数分组
self._partition_parameters()
# 通信bucket
self.reduce_bucket_size = reduce_bucket_size
self.allgather_bucket_size = allgather_bucket_size
def _partition_parameters(self):
"""将优化器状态分片到各rank"""
all_params = []
for param_group in self.optimizer.param_groups:
all_params.extend(param_group['params'])
# 按rank分配参数
self.param_partitions = [[] for _ in range(self.dp_world_size)]
for i, param in enumerate(all_params):
partition_id = i % self.dp_world_size
self.param_partitions[partition_id].append(param)
# 本rank负责的参数
self.local_params = self.param_partitions[self.dp_rank]
# 为本地参数创建优化器状态
self._create_local_optimizer_states()
def _create_local_optimizer_states(self):
"""只为本rank负责的参数创建优化器状态"""
# 重置优化器的param_groups
new_param_groups = []
for param_group in self.optimizer.param_groups:
new_group = {k: v for k, v in param_group.items() if k != 'params'}
new_group['params'] = [p for p in param_group['params']
if p in self.local_params]
new_param_groups.append(new_group)
self.optimizer.param_groups = new_param_groups
def backward(self, loss, retain_graph=False):
"""
执行backward并同步梯度
Stage 1: AllReduce梯度, 只在本地参数上更新
Stage 2: ReduceScatter梯度, 每个rank只保留负责的参数梯度
"""
# 1. 标准backward
loss.backward(retain_graph=retain_graph)
# 2. 梯度同步
if self.reduce_scatter:
self._reduce_scatter_gradients() # Stage 2
else:
self._allreduce_gradients() # Stage 1
def _allreduce_gradients(self):
"""Stage 1: AllReduce所有梯度"""
# 收集梯度到bucket
buckets = self._build_grad_buckets()
for bucket in buckets:
# AllReduce
dist.all_reduce(
bucket.buffer,
group=self.dp_process_group
)
# 平均
bucket.buffer.div_(self.dp_world_size)
# 拷贝回参数梯度
bucket.copy_back_to_grads()
def _reduce_scatter_gradients(self):
"""Stage 2: ReduceScatter梯度"""
# 收集所有参数的梯度
flat_grads = self._flatten_gradients()
# ReduceScatter
# 每个rank只得到自己负责的那部分梯度
chunk_size = flat_grads.numel() // self.dp_world_size
output = torch.empty(chunk_size, dtype=flat_grads.dtype,
device=flat_grads.device)
dist.reduce_scatter_tensor(
output,
flat_grads,
group=self.dp_process_group
)
# 将分片梯度拷贝到本地参数
self._copy_grad_partitions(output)
def step(self):
"""
执行优化器step
Stage 1/2: 只在本地参数上执行step, 然后AllGather更新后的参数
"""
# 1. 梯度裁剪 (在本地分片上)
if self.clip_grad > 0:
self._clip_grad_norm()
# 2. 优化器step (只更新本地参数)
self.optimizer.step()
# 3. AllGather更新后的参数
self._allgather_parameters()
def _allgather_parameters(self):
"""AllGather所有分片的参数"""
for param_group in self.original_param_groups:
for param in param_group['params']:
# 确定该参数属于哪个rank
owner_rank = self._get_param_owner(param)
if owner_rank == self.dp_rank:
# 本rank拥有该参数,广播给其他rank
dist.broadcast(
param.data,
src=owner_rank,
group=self.dp_process_group
)
else:
# 接收其他rank的参数
dist.broadcast(
param.data,
src=owner_rank,
group=self.dp_process_group
)
ZeRO Stage 3 实现
# deepspeed/runtime/zero/stage3.py
class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
"""
ZeRO Stage 3: 参数、梯度、优化器状态全分片
"""
def __init__(
self,
module,
init_optimizer,
timers,
ds_config,
static_loss_scale=1,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
offload_optimizer_config=None,
offload_param_config=None,
sub_group_size=1000000000000,
mpu=None,
clip_grad=0.0,
communication_data_type=torch.float16,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
):
super().__init__()
self.module = module
self.optimizer = init_optimizer
# 分布式设置
self.dp_process_group = dp_process_group
self.dp_world_size = dist.get_world_size(group=dp_process_group)
self.dp_rank = dist.get_rank(group=dp_process_group)
# 参数分片
self._partition_all_parameters()
# 预取配置
self.prefetch_bucket_size = prefetch_bucket_size
self.max_reuse_distance = max_reuse_distance
# CPU Offload配置
self.offload_optimizer = offload_optimizer_config is not None
self.offload_param = offload_param_config is not None
# 注册前向/后向钩子
self._register_hooks()
def _partition_all_parameters(self):
"""将所有参数分片"""
self.param_handles = {}
for name, param in self.module.named_parameters():
# 1. 展平参数
flat_param = param.data.flatten()
# 2. 计算每个rank的分片大小
padded_size = self._get_padded_size(flat_param.numel())
chunk_size = padded_size // self.dp_world_size
# 3. 只保留本rank的分片
start = self.dp_rank * chunk_size
end = start + chunk_size
if start < flat_param.numel():
# 有效分片
local_end = min(end, flat_param.numel())
local_param = flat_param[start:local_end].clone()
# 填充到chunk_size
if local_param.numel() < chunk_size:
local_param = F.pad(local_param,
(0, chunk_size - local_param.numel()))
else:
# 纯填充分片
local_param = torch.zeros(chunk_size, dtype=param.dtype,
device=param.device)
# 4. 创建分片参数Handle
handle = PartitionedParameterHandle(
param=param,
param_name=name,
flat_param=flat_param,
local_param=local_param,
dp_rank=self.dp_rank,
dp_world_size=self.dp_world_size,
)
self.param_handles[name] = handle
# 5. 释放原始参数
param.data = torch.empty(0)
def _register_hooks(self):
"""注册前向/后向钩子"""
def _pre_forward_hook(module, inputs):
"""Forward前: AllGather该层的参数"""
for name, param in module.named_parameters(recurse=False):
if name in self.param_handles:
self._allgather_param(name)
def _post_forward_hook(module, inputs, outputs):
"""Forward后: 释放非持久化参数"""
for name, param in module.named_parameters(recurse=False):
if name in self.param_handles:
handle = self.param_handles[name]
if not handle.persistent:
self._release_param(name)
def _pre_backward_hook(module, grad_output):
"""Backward前: AllGather该层的参数"""
for name, param in module.named_parameters(recurse=False):
if name in self.param_handles:
self._allgather_param(name)
def _post_backward_hook(param):
"""Backward后: ReduceScatter梯度"""
handle = self.param_handles.get(param._name)
if handle:
self._reduce_scatter_grad(param._name)
self._release_param(param._name)
# 为每个子模块注册钩子
for module in self.module.modules():
module.register_forward_pre_hook(_pre_forward_hook)
module.register_forward_hook(_post_forward_hook)
module.register_full_backward_pre_hook(_pre_backward_hook)
# 为每个参数注册梯度钩子
for name, param in self.module.named_parameters():
param._name = name
param.register_hook(_post_backward_hook)
def _allgather_param(self, param_name):
"""AllGather收集完整参数"""
handle = self.param_handles[param_name]
# 1. 分配完整参数buffer
full_param = torch.empty(
handle.padded_size,
dtype=handle.local_param.dtype,
device=handle.local_param.device
)
# 2. AllGather
dist.all_gather_into_tensor(
full_param,
handle.local_param,
group=self.dp_process_group
)
# 3. 去除padding, reshape回原始形状
full_param = full_param[:handle.original_numel]
full_param = full_param.view(handle.original_shape)
# 4. 更新参数
handle.param.data = full_param
def _release_param(self, param_name):
"""释放完整参数,只保留本地分片"""
handle = self.param_handles[param_name]
handle.param.data = torch.empty(0)
def _reduce_scatter_grad(self, param_name):
"""ReduceScatter梯度到分片"""
handle = self.param_handles[param_name]
# 1. 展平梯度
flat_grad = handle.param.grad.flatten()
# 2. 填充到可整除大小
if flat_grad.numel() < handle.padded_size:
flat_grad = F.pad(flat_grad,
(0, handle.padded_size - flat_grad.numel()))
# 3. ReduceScatter
local_grad = torch.empty(
handle.chunk_size,
dtype=flat_grad.dtype,
device=flat_grad.device
)
dist.reduce_scatter_tensor(
local_grad,
flat_grad,
group=self.dp_process_group
)
# 4. 保存本地梯度
handle.local_grad = local_grad
def step(self):
"""执行优化器step"""
# 1. 在CPU上执行优化器step (如果CPU offload)
if self.offload_optimizer:
self._step_on_cpu()
else:
self._step_on_gpu()
def _step_on_gpu(self):
"""在GPU上执行step"""
for name, handle in self.param_handles.items():
# 使用本地分片的参数和梯度
local_param_fp32 = handle.local_param.float()
local_grad_fp32 = handle.local_grad.float()
# 更新优化器状态
state = self.optimizer.state[handle.param]
# Adam更新
if 'exp_avg' not in state:
state['exp_avg'] = torch.zeros_like(local_param_fp32)
state['exp_avg_sq'] = torch.zeros_like(local_param_fp32)
state['step'] = 0
state['step'] += 1
beta1, beta2 = self.optimizer.param_groups[0]['betas']
lr = self.optimizer.param_groups[0]['lr']
eps = self.optimizer.param_groups[0]['eps']
# 更新momentum和variance
state['exp_avg'].mul_(beta1).add_(local_grad_fp32, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
local_grad_fp32, local_grad_fp32, value=1 - beta2
)
# 偏置修正
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# 参数更新
denom = (state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
local_param_fp32.addcdiv_(state['exp_avg'], denom, value=-step_size)
# 转回FP16
handle.local_param.copy_(local_param_fp32.half())
class PartitionedParameterHandle:
"""管理分片参数"""
def __init__(self, param, param_name, flat_param, local_param,
dp_rank, dp_world_size):
self.param = param
self.param_name = param_name
self.flat_param = flat_param
self.local_param = local_param
self.dp_rank = dp_rank
self.dp_world_size = dp_world_size
# 原始形状和大小
self.original_shape = param.shape
self.original_numel = param.numel()
# 填充后的大小
self.padded_size = self._get_padded_size()
self.chunk_size = self.padded_size // dp_world_size
# 梯度分片
self.local_grad = None
# 是否持久化 (小参数可以不释放)
self.persistent = param.numel() < 100000 # 100K以下持久化
def _get_padded_size(self):
"""计算填充后的大小 (可被world_size整除)"""
size = self.original_numel
if size % self.dp_world_size != 0:
size = ((size // self.dp_world_size) + 1) * self.dp_world_size
return size
CPU Offload 实现
# deepspeed/runtime/zero/stage3.py (CPU Offload部分)
class DeepSpeedZeroOptimizer_Stage3:
def _setup_cpu_offload(self):
"""设置CPU Offload"""
if self.offload_optimizer:
# 将优化器状态移到CPU
for handle in self.param_handles.values():
state = self.optimizer.state.get(handle.param, {})
for key, val in state.items():
if torch.is_tensor(val):
state[key] = val.cpu().pin_memory()
if self.offload_param:
# 将参数分片移到CPU
for handle in self.param_handles.values():
handle.local_param = handle.local_param.cpu().pin_memory()
def _step_on_cpu(self):
"""在CPU上执行优化器step"""
# 1. 将梯度移到CPU
for handle in self.param_handles.values():
if handle.local_grad is not None:
handle.local_grad_cpu = handle.local_grad.cpu()
# 2. 在CPU上执行Adam更新
# 使用numpy或直接PyTorch CPU计算
for name, handle in self.param_handles.items():
local_param_fp32 = handle.local_param.float()
local_grad_fp32 = handle.local_grad_cpu.float()
state = self.optimizer.state[handle.param]
# Adam更新 (CPU)
# ... (与GPU版本相同的逻辑)
# 更新本地参数
handle.local_param.copy_(local_param_fp32.half())
# 3. 异步拷贝回GPU
for handle in self.param_handles.values():
handle.local_param_gpu = handle.local_param.cuda(non_blocking=True)
def _prefetch_params_to_gpu(self, layer_id):
"""预取参数到GPU"""
# 在执行当前层时,提前将下一层的参数从CPU移到GPU
next_layer_params = self._get_layer_params(layer_id + 1)
for param_name in next_layer_params:
handle = self.param_handles.get(param_name)
if handle and handle.local_param.device.type == 'cpu':
# 异步拷贝到GPU
handle.local_param_gpu = handle.local_param.cuda(non_blocking=True)
Pipeline Parallel 源码
Pipeline Engine 实现
# deepspeed/runtime/pipe/engine.py
class PipelineEngine(DeepSpeedEngine):
"""
Pipeline Parallel Engine
实现 GPipe 和 1F1B 调度
"""
def __init__(
self,
model,
config,
mpu,
**kwargs
):
super().__init__(model, config, **kwargs)
# Pipeline并行设置
self.pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
self.pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
self.num_stages = self.pipeline_parallel_size
# micro-batch设置
self.micro_batches = config.gradient_accumulation_steps
# 通信buffer
self._allocate_buffers()
# 调度器
self.schedule = self._build_schedule()
def _build_schedule(self):
"""构建Pipeline调度表"""
if self.config.pipeline_schedule == "1f1b":
return self._build_1f1b_schedule()
else:
return self._build_gpipe_schedule()
def _build_1f1b_schedule(self):
"""
构建 1F1B 调度
1F1B (One Forward One Backward):
- 先执行足够多的forward填满pipeline
- 然后交替执行forward和backward
- 最后执行剩余的backward
"""
schedule = []
num_warmup_microbatches = min(
self.num_stages - self.pipeline_parallel_rank - 1,
self.micro_batches
)
num_microbatches_remaining = self.micro_batches - num_warmup_microbatches
# Warmup phase: 只有forward
for i in range(num_warmup_microbatches):
schedule.append(('forward', i))
# Steady state: 1F1B
for i in range(num_microbatches_remaining):
schedule.append(('forward', num_warmup_microbatches + i))
schedule.append(('backward', i))
# Cooldown phase: 只有backward
for i in range(num_warmup_microbatches):
schedule.append(('backward', num_microbatches_remaining + i))
return schedule
def _build_gpipe_schedule(self):
"""
构建 GPipe 调度
GPipe: 先执行所有forward, 再执行所有backward
"""
schedule = []
# 所有forward
for i in range(self.micro_batches):
schedule.append(('forward', i))
# 所有backward
for i in range(self.micro_batches):
schedule.append(('backward', i))
return schedule
def train_batch(self, data_iter):
"""执行一个batch的训练"""
# 1. 准备micro-batches
micro_batches = self._prepare_micro_batches(data_iter)
# 2. 执行调度
losses = []
for action, micro_batch_id in self.schedule:
if action == 'forward':
loss = self._exec_forward_pass(micro_batches[micro_batch_id])
losses.append(loss)
else: # backward
self._exec_backward_pass(micro_batch_id)
# 3. 梯度同步
if self.is_data_parallel:
self._sync_gradients()
# 4. 优化器step
self.optimizer.step()
self.optimizer.zero_grad()
return sum(losses) / len(losses)
def _exec_forward_pass(self, micro_batch):
"""执行forward pass"""
# 1. 接收来自上一个stage的输入
if self.pipeline_parallel_rank > 0:
input_tensor = self._recv_forward()
else:
input_tensor = micro_batch
# 2. 执行本stage的forward
with torch.cuda.amp.autocast(enabled=self.fp16_enabled):
output_tensor = self.module(input_tensor)
# 3. 发送输出到下一个stage
if self.pipeline_parallel_rank < self.num_stages - 1:
self._send_forward(output_tensor)
# 4. 保存用于backward
self._save_activation(output_tensor)
# 5. 计算loss (只有最后一个stage)
if self.is_last_stage:
loss = self.loss_fn(output_tensor, micro_batch['labels'])
return loss
return None
def _exec_backward_pass(self, micro_batch_id):
"""执行backward pass"""
# 1. 接收来自下一个stage的梯度
if self.pipeline_parallel_rank < self.num_stages - 1:
output_grad = self._recv_backward()
else:
# 最后一个stage: 从loss计算梯度
output_grad = None
# 2. 加载保存的activation
output_tensor = self._load_activation(micro_batch_id)
# 3. 执行backward
if output_grad is not None:
output_tensor.backward(output_grad)
else:
output_tensor.backward()
# 4. 发送梯度到上一个stage
if self.pipeline_parallel_rank > 0:
input_grad = self._get_input_grad()
self._send_backward(input_grad)
def _recv_forward(self):
"""接收来自上一个stage的activation"""
src_rank = self.pipeline_parallel_rank - 1
recv_buffer = self._get_recv_buffer()
dist.recv(recv_buffer, src=src_rank, group=self.pipeline_group)
return recv_buffer
def _send_forward(self, tensor):
"""发送activation到下一个stage"""
dst_rank = self.pipeline_parallel_rank + 1
dist.send(tensor, dst=dst_rank, group=self.pipeline_group)
def _recv_backward(self):
"""接收来自下一个stage的梯度"""
src_rank = self.pipeline_parallel_rank + 1
recv_buffer = self._get_recv_buffer()
dist.recv(recv_buffer, src=src_rank, group=self.pipeline_group)
return recv_buffer
def _send_backward(self, tensor):
"""发送梯度到上一个stage"""
dst_rank = self.pipeline_parallel_rank - 1
dist.send(tensor, dst=dst_rank, group=self.pipeline_group)
1F1B 调度图解
┌─────────────────────────────────────────────────────────────────────────┐
│ 1F1B Pipeline 调度 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 假设: 4个Stage, 8个micro-batch │
│ │
│ Time → │
│ │
│ Stage 0: F0 F1 F2 F3 F4 B0 F5 B1 F6 B2 F7 B3 B4 B5 B6 B7 │
│ Stage 1: F0 F1 F2 F3 B0 F4 B1 F5 B2 F6 B3 F7 B4 B5 B6 B7 │
│ Stage 2: F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 F6 B4 F7 B5 B6 B7 │
│ Stage 3: F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5 F6 B6 F7 B7 │
│ │
│ ╔═══════╗╔═══════════════════════╗╔═══════╗ │
│ ║Warmup ║║ Steady State 1F1B ║║Cooldown║ │
│ ╚═══════╝╚═══════════════════════╝╚═══════╝ │
│ │
│ F = Forward, B = Backward │
│ │
│ 特点: │
│ • Warmup: 填满pipeline │
│ • Steady State: 每个stage同时执行forward和backward │
│ • Cooldown: 清空pipeline │
│ • Bubble最小化: 只有warmup和cooldown有bubble │
│ │
│ Bubble比例 ≈ (p-1) / m │
│ 其中 p = pipeline stages, m = micro-batches │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Activation Checkpointing 源码
# deepspeed/runtime/activation_checkpointing/checkpointing.py
class CheckpointFunction(torch.autograd.Function):
"""
激活检查点实现
Forward时不保存中间激活,只保存输入
Backward时重新计算forward得到中间激活
"""
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# 保存计算函数
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# 保存输入 (用于backward时重计算)
ctx.save_for_backward(*args)
# 保存RNG状态 (确保重计算结果一致)
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
ctx.fwd_gpu_devices = [torch.cuda.current_device()]
ctx.fwd_gpu_states = []
for device in ctx.fwd_gpu_devices:
ctx.fwd_gpu_states.append(torch.cuda.get_rng_state(device))
# 在no_grad下执行forward (不保存计算图)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *output_grads):
# 恢复输入
inputs = ctx.saved_tensors
# 恢复RNG状态
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
for device, state in zip(ctx.fwd_gpu_devices, ctx.fwd_gpu_states):
torch.cuda.set_rng_state(state, device)
# 重新计算forward (这次保存计算图)
with torch.enable_grad():
# detach输入并启用梯度
inputs_with_grad = [inp.detach().requires_grad_(inp.requires_grad)
for inp in inputs]
outputs = ctx.run_function(*inputs_with_grad)
# 如果outputs不是tuple,转换为tuple
if not isinstance(outputs, tuple):
outputs = (outputs,)
# 计算梯度
input_grads = torch.autograd.grad(
outputs,
inputs_with_grad,
output_grads,
allow_unused=True
)
return (None, None) + input_grads
def checkpoint(function, *args, **kwargs):
"""
使用激活检查点包装函数
Usage:
output = checkpoint(transformer_layer, hidden_states, attention_mask)
"""
preserve_rng_state = kwargs.get('preserve_rng_state', True)
return CheckpointFunction.apply(function, preserve_rng_state, *args)
# 高级API: 分段检查点
def checkpoint_sequential(functions, segments, input, **kwargs):
"""
将多个函数分段检查点
Args:
functions: 函数列表 (如transformer layers)
segments: 分段数量
input: 输入tensor
"""
def run_function(start, end, functions):
def forward(input):
for j in range(start, end):
input = functions[j](input)
return input
return forward
# 计算每段的函数数量
segment_size = len(functions) // segments
# 对每段应用checkpoint
for start in range(0, len(functions), segment_size):
end = min(start + segment_size, len(functions))
input = checkpoint(run_function(start, end, functions), input, **kwargs)
return input
DeepSpeed 初始化流程
# deepspeed/__init__.py
def initialize(
args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=True,
collate_fn=None,
config=None,
config_params=None,
):
"""
DeepSpeed初始化入口
Returns:
tuple: (engine, optimizer, dataloader, lr_scheduler)
"""
# 1. 解析配置
ds_config = DeepSpeedConfig(config, config_params)
# 2. 初始化分布式环境
if dist_init_required:
init_distributed()
# 3. 创建模型引擎
if ds_config.pipeline_enabled:
# Pipeline Parallel
engine = PipelineEngine(
model=model,
config=ds_config,
optimizer=optimizer,
model_parameters=model_parameters,
mpu=mpu,
training_data=training_data,
lr_scheduler=lr_scheduler,
)
elif ds_config.zero_enabled:
# ZeRO
if ds_config.zero_stage == 3:
engine = DeepSpeedEngine(
model=model,
config=ds_config,
optimizer=DeepSpeedZeroOptimizer_Stage3(...),
...
)
else:
engine = DeepSpeedEngine(
model=model,
config=ds_config,
optimizer=DeepSpeedZeroOptimizer(...),
...
)
else:
# 基础DeepSpeed
engine = DeepSpeedEngine(
model=model,
config=ds_config,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
)
# 4. 创建数据加载器
dataloader = engine.create_data_loader(training_data, collate_fn)
return engine, engine.optimizer, dataloader, engine.lr_scheduler
def init_distributed():
"""初始化分布式环境"""
if not dist.is_initialized():
# 从环境变量获取配置
backend = os.environ.get('DISTRIBUTED_BACKEND', 'nccl')
init_method = os.environ.get('INIT_METHOD', 'env://')
dist.init_process_group(
backend=backend,
init_method=init_method,
)
# 设置当前设备
local_rank = int(os.environ.get('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank)
总结
DeepSpeed 核心组件
┌─────────────────────────────────────────────────────────────────────────┐
│ DeepSpeed 核心组件 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ZeRO Optimizer │
│ ────────────── │
│ • Stage 1: 优化器状态分片 │
│ • Stage 2: + 梯度分片 (ReduceScatter) │
│ • Stage 3: + 参数分片 (AllGather forward/backward) │
│ • CPU Offload: 将状态卸载到CPU │
│ │
│ Pipeline Engine │
│ ─────────────── │
│ • GPipe: 先所有F, 后所有B │
│ • 1F1B: 最小化bubble │
│ • P2P通信: send/recv activation │
│ │
│ Activation Checkpointing │
│ ──────────────────────── │
│ • 不保存中间激活 │
│ • Backward时重计算 │
│ • 显存换计算时间 │
│ │
│ 混合精度训练 │
│ ──────────── │
│ • FP16/BF16计算 │
│ • 动态Loss Scaling │
│ • FP32优化器状态 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
面试高频问题
- ZeRO Stage 1/2/3 的区别和通信开销?
- 为什么Stage 3用ReduceScatter而不是AllReduce?
- 1F1B调度如何最小化pipeline bubble?
- Activation Checkpointing如何在backward时重计算?
- CPU Offload是如何实现的?