01-PyTorch分布式源码解析
概述
PyTorch是最流行的深度学习框架之一,其分布式训练模块经过多年发展已经非常成熟。本章深入解析PyTorch分布式训练的核心源码,包括torch.distributed、DistributedDataParallel(DDP)、以及FSDP的实现原理。
PyTorch 分布式架构
整体架构
┌─────────────────────────────────────────────────────────────────────────┐
│ PyTorch 分布式架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 用户 API 层 │ │
│ │ │ │
│ │ torch.distributed DistributedDataParallel FSDP │ │
│ │ ├─ init_process_group ├─ DDP(model) ├─ FSDP() │ │
│ │ ├─ all_reduce ├─ forward/backward ├─ shard │ │
│ │ ├─ broadcast └─ gradient sync └─ gather │ │
│ │ └─ all_gather │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ ProcessGroup 层 │ │
│ │ │ │
│ │ ProcessGroup ProcessGroupNCCL ProcessGroupGloo │ │
│ │ ├─ size() ├─ NCCL backend ├─ Gloo backend │ │
│ │ ├─ rank() ├─ GPU集合通信 ├─ CPU集合通信 │ │
│ │ └─ work() └─ 高性能 └─ 通用 │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 通信后端层 │ │
│ │ │ │
│ │ NCCL Gloo MPI UCC │ │
│ │ ├─ Ring ├─ Allreduce ├─ OpenMPI ├─ UCX │ │
│ │ ├─ Tree ├─ Broadcast └─ MPICH └─ 异构 │ │
│ │ └─ P2P └─ P2P │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 网络/硬件层 │ │
│ │ │ │
│ │ NVLink/NVSwitch InfiniBand Ethernet │ │
│ │ (节点内GPU) (节点间高速) (通用网络) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
源码目录结构
torch/
├── distributed/
│ ├── __init__.py # 分布式API入口
│ ├── distributed_c10d.py # C10d Python绑定
│ ├── rendezvous.py # 进程会合
│ ├── launch.py # torchrun启动器
│ │
│ ├── nn/
│ │ ├── parallel/
│ │ │ ├── distributed.py # DDP实现
│ │ │ └── _functions.py # 分布式函数
│ │ │
│ │ └── functional/
│ │ └── _fsdp/ # FSDP实现
│ │
│ └── algorithms/
│ └── ddp_comm_hooks/ # DDP通信钩子
│
├── csrc/
│ └── distributed/
│ ├── c10d/
│ │ ├── ProcessGroup.hpp # ProcessGroup基类
│ │ ├── ProcessGroupNCCL.hpp # NCCL后端
│ │ ├── ProcessGroupGloo.hpp # Gloo后端
│ │ └── Store.hpp # 分布式存储
│ │
│ └── rpc/ # RPC实现
torch.distributed 核心实现
init_process_group 源码
# torch/distributed/distributed_c10d.py
def init_process_group(
backend: str,
init_method: Optional[str] = None,
timeout: timedelta = default_pg_timeout,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: str = "",
pg_options: Optional[ProcessGroupOptions] = None,
):
"""
初始化分布式进程组
Args:
backend: 通信后端 ("nccl", "gloo", "mpi")
init_method: 初始化方法 ("env://", "tcp://", "file://")
world_size: 进程总数
rank: 当前进程rank
store: 分布式KV存储
"""
global _pg_group_ranks
global _backend
global _default_pg_init_method
# 1. 参数验证
if backend not in ["nccl", "gloo", "mpi", "ucc"]:
raise ValueError(f"Invalid backend: {backend}")
# 2. 创建或获取Store (用于进程间协调)
if store is None:
if init_method is None:
init_method = "env://" # 默认使用环境变量
# 根据init_method创建Store
if init_method == "env://":
# 从环境变量获取: MASTER_ADDR, MASTER_PORT
store = _create_store_from_env()
elif init_method.startswith("tcp://"):
# TCP方式: tcp://master_ip:port
store = TCPStore(...)
elif init_method.startswith("file://"):
# 文件方式: file:///path/to/file
store = FileStore(...)
# 3. 如果world_size和rank未指定,从环境变量获取
if world_size == -1:
world_size = int(os.environ.get("WORLD_SIZE", -1))
if rank == -1:
rank = int(os.environ.get("RANK", -1))
# 4. 创建ProcessGroup
if backend == "nccl":
pg = ProcessGroupNCCL(store, rank, world_size, pg_options)
elif backend == "gloo":
pg = ProcessGroupGloo(store, rank, world_size, pg_options)
elif backend == "mpi":
pg = ProcessGroupMPI(rank, world_size)
# 5. 设置为默认ProcessGroup
_set_default_pg(pg)
_backend = backend
_default_pg_init_method = init_method
# 6. 初始化同步屏障
barrier()
return pg
def _create_store_from_env():
"""从环境变量创建Store"""
master_addr = os.environ.get("MASTER_ADDR", "localhost")
master_port = int(os.environ.get("MASTER_PORT", "29500"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
# 创建TCPStore
# rank 0 是server, 其他rank是client
is_master = (rank == 0)
store = TCPStore(
host_name=master_addr,
port=master_port,
world_size=world_size,
is_master=is_master,
timeout=timedelta(seconds=300),
)
return store
ProcessGroup 实现
// torch/csrc/distributed/c10d/ProcessGroup.hpp
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
public:
// 进程组的基本属性
int rank_; // 当前进程rank
int size_; // 进程组大小
c10::intrusive_ptr<Store> store_; // KV存储
// 集合通信接口
virtual c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()
) = 0;
virtual c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()
) = 0;
virtual c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()
) = 0;
virtual c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()
) = 0;
// P2P通信
virtual c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag
) = 0;
virtual c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag
) = 0;
// 屏障同步
virtual c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()
) = 0;
};
// ProcessGroupNCCL 实现
// torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
class ProcessGroupNCCL : public ProcessGroup {
private:
// NCCL通信器 (每个设备一个)
std::unordered_map<std::string, std::vector<ncclComm_t>> ncclComms_;
// CUDA流 (用于异步通信)
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>> ncclStreams_;
// 事件 (用于同步)
std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
public:
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts
) override {
// 1. 检查tensor设备
check_gpu_tensors(tensors);
// 2. 获取或创建NCCL通信器
auto devices = getDevices(tensors);
auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComms(key, devices);
// 3. 创建Work对象
auto work = c10::make_intrusive<ProcessGroupNCCL::WorkNCCL>(
devices, rank_, opts.opType
);
// 4. 在NCCL流中执行AllReduce
for (size_t i = 0; i < tensors.size(); i++) {
auto& tensor = tensors[i];
auto& ncclComm = ncclComms[i];
auto& ncclStream = ncclStreams_[key][i];
// 切换到NCCL流
c10::cuda::CUDAStreamGuard guard(ncclStream);
// 记录开始事件
work->ncclStartEvents_[i].record(ncclStream);
// 调用NCCL AllReduce
NCCL_CHECK(ncclAllReduce(
tensor.data_ptr(),
tensor.data_ptr(),
tensor.numel(),
getNcclDataType(tensor.scalar_type()),
getNcclReduceOp(opts.reduceOp),
ncclComm,
ncclStream.stream()
));
// 记录结束事件
work->ncclEndEvents_[i].record(ncclStream);
}
return work;
}
};
all_reduce 函数
# torch/distributed/distributed_c10d.py
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
"""
对tensor执行AllReduce操作
Args:
tensor: 输入/输出tensor (in-place操作)
op: 归约操作 (SUM, PRODUCT, MIN, MAX, etc.)
group: ProcessGroup (默认使用全局默认组)
async_op: 是否异步执行
Returns:
如果async_op=True, 返回Work对象用于同步
否则返回None
"""
# 获取ProcessGroup
if group is None:
group = _get_default_group()
# 参数验证
if not tensor.is_cuda and group._get_backend_name() == "nccl":
raise ValueError("NCCL backend requires CUDA tensors")
# 执行AllReduce
opts = AllreduceOptions()
opts.reduceOp = op
work = group.allreduce([tensor], opts)
if async_op:
return work
else:
work.wait()
return None
# 异步AllReduce使用示例
async_work = dist.all_reduce(tensor, async_op=True)
# ... 做其他计算 ...
async_work.wait() # 等待完成
DistributedDataParallel (DDP) 源码
DDP 初始化
# torch/nn/parallel/distributed.py
class DistributedDataParallel(Module):
"""
分布式数据并行
将模型复制到多个GPU,每个GPU处理不同的数据batch
梯度通过AllReduce同步
"""
def __init__(
self,
module,
device_ids=None,
output_device=None,
dim=0,
broadcast_buffers=True,
process_group=None,
bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
static_graph=False,
):
super(DistributedDataParallel, self).__init__()
# 1. 保存模块引用
self.module = module
# 2. 获取ProcessGroup
if process_group is None:
self.process_group = _get_default_group()
else:
self.process_group = process_group
# 3. 设置设备
if device_ids is None:
device_ids = [torch.cuda.current_device()]
self.device_ids = device_ids
self.output_device = output_device or device_ids[0]
# 4. 同步模型参数 (从rank 0广播)
self._sync_params_and_buffers(authoritative_rank=0)
# 5. 构建Reducer
self._build_reducer(
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
gradient_as_bucket_view=gradient_as_bucket_view,
)
# 6. 注册梯度钩子
self._register_grad_hooks()
def _sync_params_and_buffers(self, authoritative_rank=0):
"""同步所有rank的模型参数"""
# 收集所有参数和buffer
module_states = []
for param in self.module.parameters():
module_states.append(param.detach())
for buffer in self.module.buffers():
module_states.append(buffer.detach())
# 从authoritative_rank广播到所有rank
for state in module_states:
dist.broadcast(state, src=authoritative_rank,
group=self.process_group)
def _build_reducer(self, bucket_cap_mb, find_unused_parameters,
gradient_as_bucket_view):
"""构建Reducer用于梯度同步"""
# 收集需要同步的参数
parameters = list(self.module.parameters())
params_for_reduction = [p for p in parameters if p.requires_grad]
# 构建参数桶
# 将参数按大小分组,减少通信次数
self.reducer = dist.Reducer(
params_for_reduction,
list(range(len(params_for_reduction))), # 参数索引
self.process_group,
[[]], # 预期稀疏梯度
bucket_cap_mb * 1024 * 1024, # 桶大小 (bytes)
find_unused_parameters,
gradient_as_bucket_view,
self.module.parameters, # 参数名称
)
def _register_grad_hooks(self):
"""为每个参数注册梯度钩子"""
for param in self.module.parameters():
if param.requires_grad:
# 注册backward钩子
param.register_hook(self._make_param_hook(param))
def _make_param_hook(self, param):
"""创建参数钩子"""
def hook(grad):
# 通知Reducer该参数的梯度已ready
self.reducer.mark_variable_ready(param)
return grad
return hook
DDP Forward 实现
# torch/nn/parallel/distributed.py (续)
class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs):
"""
DDP forward流程:
1. 同步模型参数 (如果需要)
2. 准备Reducer
3. 执行模型forward
"""
# 1. 如果配置了broadcast_buffers,同步buffer
if self.broadcast_buffers:
self._sync_buffers()
# 2. 准备Reducer (重置状态)
if self.require_backward_grad_sync:
self.reducer.prepare_for_backward([])
# 3. 执行forward
if self.device_ids:
# 移动输入到GPU
inputs = self._move_inputs_to_device(inputs, self.device_ids[0])
kwargs = self._move_kwargs_to_device(kwargs, self.device_ids[0])
# 4. 调用原始模块的forward
output = self.module(*inputs, **kwargs)
# 5. 如果是第一次迭代,记录计算图
if self._is_first_iteration():
self.reducer.prepare_for_backward(
list(_find_tensors(output))
)
return output
def _sync_buffers(self):
"""同步模型buffer (如BatchNorm的running_mean)"""
buffers = [b.data for b in self.module.buffers()]
if buffers:
# 从rank 0广播
dist.broadcast_coalesced(buffers, group=self.process_group)
Reducer 实现 (C++)
// torch/csrc/distributed/c10d/reducer.cpp
class Reducer {
public:
// 参数桶
struct Bucket {
std::vector<at::Tensor> gradients;
at::Tensor flat_buffer; // 展平后的梯度缓冲区
size_t size; // 桶大小 (bytes)
bool ready; // 是否可以开始通信
};
std::vector<Bucket> buckets_;
// 异步AllReduce的Work对象
std::vector<c10::intrusive_ptr<c10d::Work>> pending_works_;
void mark_variable_ready(size_t variable_index) {
/*
当某个参数的梯度计算完成时调用
DDP的核心优化: 边计算边通信
- 不等所有梯度都ready才开始AllReduce
- 一个bucket满了就开始该bucket的AllReduce
- 与后续梯度计算重叠
*/
// 1. 找到该参数所属的bucket
auto& bucket = buckets_[variable_to_bucket_[variable_index]];
// 2. 增加ready计数
bucket.pending--;
// 3. 如果bucket的所有梯度都ready, 启动AllReduce
if (bucket.pending == 0) {
start_bucket_allreduce(bucket);
}
}
void start_bucket_allreduce(Bucket& bucket) {
// 1. 将梯度拷贝到连续buffer (或使用gradient_as_bucket_view优化)
if (!gradient_as_bucket_view_) {
copy_grads_to_bucket(bucket);
}
// 2. 异步AllReduce
auto work = process_group_->allreduce(
{bucket.flat_buffer},
AllreduceOptions()
);
pending_works_.push_back(work);
}
void finalize_backward() {
/*
backward完成后调用
等待所有AllReduce完成
*/
// 等待所有pending的AllReduce
for (auto& work : pending_works_) {
work->wait();
}
pending_works_.clear();
// 如果不是gradient_as_bucket_view, 需要拷贝回参数
if (!gradient_as_bucket_view_) {
copy_bucket_to_grads();
}
// 除以world_size得到平均梯度
// (NCCL的AllReduce默认是SUM, 需要手动除)
for (auto& param : params_) {
param.grad().div_(world_size_);
}
}
};
梯度桶 (Bucket) 机制
┌─────────────────────────────────────────────────────────────────────────┐
│ DDP Bucket 机制 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 问题: 参数多 + 单独AllReduce = 通信次数太多 │
│ │
│ 解决: 将参数分组成桶, 以桶为单位通信 │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 参数排列 │ │
│ │ │ │
│ │ Layer 1 Layer 2 Layer 3 Layer 4 │ │
│ │ [weight, bias] [weight, bias] [weight, bias] [weight] │ │
│ │ ↓ ↓ ↓ ↓ │ │
│ │ [1MB, 0.1MB] [10MB, 1MB] [5MB, 0.5MB] [8MB] │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 桶划分 (25MB/bucket) │ │
│ │ │ │
│ │ Bucket 0 (25MB) Bucket 1 (12.6MB) │ │
│ │ ┌──────────────────┐ ┌──────────────────┐ │ │
│ │ │ Layer 4 weight │ │ Layer 2 weight │ │ │
│ │ │ Layer 3 weight │ │ Layer 2 bias │ │ │
│ │ │ Layer 3 bias │ │ Layer 1 weight │ │ │
│ │ │ │ │ Layer 1 bias │ │ │
│ │ └──────────────────┘ └──────────────────┘ │ │
│ │ │ │
│ │ 注意: 按backward顺序排列, Layer 4最先ready │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 通信与计算重叠 │ │
│ │ │ │
│ │ Backward: [Layer4][Layer3][ Layer2 ][Layer1] │ │
│ │ ↓ ↓ │ │
│ │ AllReduce: [Bucket0 AllReduce] [Bucket1 AllReduce] │ │
│ │ │ │
│ │ Bucket0 包含 Layer4, Layer3 │ │
│ │ 当 Layer3 backward完成, Bucket0就可以开始AllReduce │ │
│ │ 与 Layer2, Layer1 的backward并行执行 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
DDP 通信钩子
# torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
def allreduce_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
默认AllReduce钩子
Args:
process_group: 通信组
bucket: 梯度桶 (包含展平后的梯度tensor)
Returns:
Future对象, 完成后包含归约后的梯度
"""
return (
dist.all_reduce(bucket.buffer(), group=process_group, async_op=True)
.get_future()
.then(lambda fut: fut.value()[0])
)
def fp16_compress_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
FP16压缩钩子
将梯度转换为FP16传输, 减少通信量50%
"""
# 转换为FP16
compressed = bucket.buffer().half()
# AllReduce
future = dist.all_reduce(compressed, group=process_group, async_op=True).get_future()
# 转回FP32
def decompress(fut):
return fut.value()[0].float()
return future.then(decompress)
def powerSGD_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
state: PowerSGDState
) -> torch.futures.Future[torch.Tensor]:
"""
PowerSGD压缩钩子
使用低秩分解压缩梯度
"""
input_tensor = bucket.buffer()
device = input_tensor.device
# 1. 重塑为矩阵
matrix = input_tensor.view(state.matrix_approximation_rank, -1)
# 2. 正交化
if state.use_error_feedback:
matrix.add_(state.error_dict[bucket.index()])
# 3. 低秩分解
p, q = torch.qr(matrix)
# 只传输p和q, 大大减少通信量
# 4. AllReduce p和q
dist.all_reduce(p, group=process_group)
dist.all_reduce(q, group=process_group)
# 5. 重建梯度
reconstructed = torch.mm(p, q.t())
# 6. 误差反馈
if state.use_error_feedback:
state.error_dict[bucket.index()] = matrix - reconstructed
return torch.futures.Future().set_result(reconstructed.view_as(input_tensor))
# 使用通信钩子
model = DDP(model)
model.register_comm_hook(process_group, fp16_compress_hook)
FSDP 源码解析
FSDP 基本原理
┌─────────────────────────────────────────────────────────────────────────┐
│ FSDP vs DDP │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ DDP (DistributedDataParallel): │
│ ══════════════════════════════ │
│ │
│ Rank 0 Rank 1 Rank 2 Rank 3 │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 完整模型 │ │ 完整模型 │ │ 完整模型 │ │ 完整模型 │ │
│ │ 完整梯度 │ │ 完整梯度 │ │ 完整梯度 │ │ 完整梯度 │ │
│ │ 完整优化器│ │ 完整优化器│ │ 完整优化器│ │ 完整优化器│ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │
│ 每个rank显存占用 = 模型 + 梯度 + 优化器状态 = 16x参数量 (Adam FP32) │
│ │
│ FSDP (FullyShardedDataParallel): │
│ ═════════════════════════════════ │
│ │
│ Rank 0 Rank 1 Rank 2 Rank 3 │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 模型分片0 │ │ 模型分片1 │ │ 模型分片2 │ │ 模型分片3 │ │
│ │ 梯度分片0 │ │ 梯度分片1 │ │ 梯度分片2 │ │ 梯度分片3 │ │
│ │ 优化器0 │ │ 优化器1 │ │ 优化器2 │ │ 优化器3 │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │
│ 每个rank显存占用 = (模型 + 梯度 + 优化器) / N = 16x参数量 / N │
│ │
│ Forward时: │
│ 需要使用某层参数时, AllGather收集完整参数, 使用后丢弃 │
│ │
│ Backward时: │
│ AllGather收集完整参数和梯度, 计算后ReduceScatter到各自分片 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
FSDP 核心实现
# torch/distributed/fsdp/fully_sharded_data_parallel.py
class FullyShardedDataParallel(nn.Module):
"""
全分片数据并行
将模型参数、梯度、优化器状态分片到多个GPU
"""
def __init__(
self,
module: nn.Module,
process_group: ProcessGroup = None,
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
cpu_offload: CPUOffload = None,
auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: BackwardPrefetch = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[nn.Module]] = None,
param_init_fn: Optional[Callable] = None,
sync_module_states: bool = False,
):
super().__init__()
self.process_group = process_group or dist.distributed_c10d._get_default_group()
self.rank = dist.get_rank(self.process_group)
self.world_size = dist.get_world_size(self.process_group)
# 分片策略
self.sharding_strategy = sharding_strategy
# CPU卸载
self.cpu_offload = cpu_offload
# 混合精度
self.mixed_precision = mixed_precision
# 1. 自动wrap子模块 (可选)
if auto_wrap_policy is not None:
module = self._auto_wrap(module, auto_wrap_policy)
self._fsdp_wrapped_module = module
# 2. 收集并分片参数
self._shard_parameters()
# 3. 注册forward/backward钩子
self._register_hooks()
def _shard_parameters(self):
"""分片模型参数"""
self._handles = []
for param_name, param in self._fsdp_wrapped_module.named_parameters():
# 1. 创建FlatParameter (展平参数)
flat_param = FlatParameter(param)
# 2. 计算每个rank的分片大小
param_numel = flat_param.numel()
chunk_size = (param_numel + self.world_size - 1) // self.world_size
# 3. 只保留本rank的分片
start_idx = self.rank * chunk_size
end_idx = min(start_idx + chunk_size, param_numel)
# 4. 创建分片视图
sharded_param = flat_param.data[start_idx:end_idx].clone()
# 5. 替换原参数
setattr(self._fsdp_wrapped_module, param_name,
nn.Parameter(sharded_param))
# 6. 创建Handle管理分片
handle = FlatParamHandle(
flat_param=flat_param,
sharded_param=sharded_param,
param_name=param_name,
)
self._handles.append(handle)
def _unshard_params(self, handles=None):
"""取消分片, 收集完整参数"""
if handles is None:
handles = self._handles
for handle in handles:
# AllGather收集所有分片
full_param = torch.empty(
handle.flat_param.shape,
dtype=handle.sharded_param.dtype,
device=handle.sharded_param.device,
)
# 收集所有rank的分片
dist.all_gather_into_tensor(
full_param,
handle.sharded_param,
group=self.process_group,
)
# 更新模型参数为完整参数
handle.flat_param.data = full_param
def _reshard_params(self, handles=None):
"""重新分片, 释放非本地参数"""
if handles is None:
handles = self._handles
for handle in handles:
# 只保留本地分片
handle.flat_param.data = handle.sharded_param
def forward(self, *args, **kwargs):
"""
FSDP forward流程:
1. AllGather收集完整参数
2. 执行forward
3. 释放参数 (重新分片)
"""
# 1. Unshard参数
self._unshard_params()
# 2. Forward
output = self._fsdp_wrapped_module(*args, **kwargs)
# 3. Reshard参数 (如果不需要保留用于backward)
if self.sharding_strategy == ShardingStrategy.FULL_SHARD:
self._reshard_params()
return output
FSDP Backward
# torch/distributed/fsdp/fully_sharded_data_parallel.py (续)
class FullyShardedDataParallel(nn.Module):
def _register_hooks(self):
"""注册forward/backward钩子"""
def _pre_forward_hook(module, args):
"""Forward前: Unshard参数"""
if module._is_root:
self._unshard_params()
def _post_forward_hook(module, args, output):
"""Forward后: Reshard参数"""
if self.sharding_strategy == ShardingStrategy.FULL_SHARD:
self._reshard_params()
def _pre_backward_hook(module, grad_output):
"""Backward前: Unshard参数 (如果已reshard)"""
self._unshard_params()
# 预取下一层的参数 (优化)
if self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
self._prefetch_next_layer()
def _post_backward_hook(param, grad):
"""
Backward后: ReduceScatter梯度
不是AllReduce! 每个rank只需要自己分片的梯度
"""
# 1. ReduceScatter: 每个rank得到梯度的不同分片
grad_sharded = torch.empty(
self._chunk_size,
dtype=grad.dtype,
device=grad.device,
)
dist.reduce_scatter_tensor(
grad_sharded,
grad,
group=self.process_group,
)
# 2. 更新本地梯度
param.grad = grad_sharded
# 3. Reshard参数
self._reshard_params()
# 注册钩子
self.register_forward_pre_hook(_pre_forward_hook)
self.register_forward_hook(_post_forward_hook)
# backward钩子通过autograd hooks注册
class FlatParamHandle:
"""
管理分片参数的Handle
"""
def __init__(self, flat_param, sharded_param, param_name):
self.flat_param = flat_param # 原始完整参数
self.sharded_param = sharded_param # 本rank的分片
self.param_name = param_name
# 梯度累积
self.grad_accumulator = None
# 通信handle (用于异步)
self._all_gather_work = None
self._reduce_scatter_work = None
def all_gather_params(self, process_group, async_op=True):
"""异步AllGather参数"""
full_param = torch.empty(
self.flat_param.shape,
dtype=self.sharded_param.dtype,
device=self.sharded_param.device,
)
work = dist.all_gather_into_tensor(
full_param,
self.sharded_param,
group=process_group,
async_op=async_op,
)
if async_op:
self._all_gather_work = work
self._full_param_buffer = full_param
else:
self.flat_param.data = full_param
def wait_all_gather(self):
"""等待AllGather完成"""
if self._all_gather_work is not None:
self._all_gather_work.wait()
self.flat_param.data = self._full_param_buffer
self._all_gather_work = None
def reduce_scatter_grad(self, process_group, async_op=True):
"""ReduceScatter梯度"""
grad_sharded = torch.empty(
self.sharded_param.shape,
dtype=self.flat_param.grad.dtype,
device=self.flat_param.grad.device,
)
work = dist.reduce_scatter_tensor(
grad_sharded,
self.flat_param.grad,
group=process_group,
async_op=async_op,
)
if async_op:
self._reduce_scatter_work = work
self._grad_sharded_buffer = grad_sharded
else:
self.sharded_param.grad = grad_sharded
def wait_reduce_scatter(self):
"""等待ReduceScatter完成"""
if self._reduce_scatter_work is not None:
self._reduce_scatter_work.wait()
self.sharded_param.grad = self._grad_sharded_buffer
self._reduce_scatter_work = None
源码调试技巧
启用调试日志
import os
import logging
# 启用PyTorch分布式日志
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# 或通过代码设置
import torch.distributed as dist
dist.set_debug_level(dist.DebugLevel.DETAIL)
# NCCL调试
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
# 查看DDP内部状态
model = DDP(model)
print(model._module_copies)
print(model.reducer.buckets)
常用调试工具
# 检查通信状态
def check_dist_state():
print(f"Rank: {dist.get_rank()}")
print(f"World Size: {dist.get_world_size()}")
print(f"Backend: {dist.get_backend()}")
print(f"Is Initialized: {dist.is_initialized()}")
# 检查DDP状态
def check_ddp_state(model):
print(f"Device IDs: {model.device_ids}")
print(f"Output Device: {model.output_device}")
print(f"Bucket Cap: {model.bucket_bytes_cap}")
print(f"Find Unused Params: {model.find_unused_parameters}")
# 检查FSDP状态
def check_fsdp_state(model):
print(f"Sharding Strategy: {model.sharding_strategy}")
print(f"CPU Offload: {model.cpu_offload}")
print(f"Mixed Precision: {model.mixed_precision}")
# 同步检查点
def sync_check(msg):
rank = dist.get_rank()
print(f"[Rank {rank}] {msg}")
dist.barrier()
总结
PyTorch分布式核心组件
┌─────────────────────────────────────────────────────────────────────────┐
│ PyTorch分布式核心组件 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ torch.distributed │
│ ───────────────── │
│ • init_process_group: 初始化分布式环境 │
│ • ProcessGroup: 抽象通信组 │
│ • 集合通信原语: all_reduce, broadcast, all_gather等 │
│ │
│ DistributedDataParallel (DDP) │
│ ───────────────────────────── │
│ • 每rank持有完整模型副本 │
│ • Reducer管理梯度同步 │
│ • Bucket机制减少通信次数 │
│ • 通信钩子支持梯度压缩 │
│ │
│ FullyShardedDataParallel (FSDP) │
│ ──────────────────────────────── │
│ • 参数、梯度、优化器状态全分片 │
│ • Forward时AllGather, Backward时ReduceScatter │
│ • 支持CPU Offload │
│ • 显存效率比DDP高N倍 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
面试高频问题
- DDP和FSDP的核心区别是什么?
- DDP的Bucket机制如何实现通信与计算重叠?
- FSDP为什么使用ReduceScatter而不是AllReduce?
- ProcessGroupNCCL是如何调用NCCL的?
- 如何自定义DDP的通信钩子实现梯度压缩?