02-AllReduce 算法实现
概述
AllReduce 是分布式训练中最核心的集合通信操作,用于在所有参与者之间同步梯度。本文深入分析各种 AllReduce 算法的原理、实现和性能特点。
AllReduce 基础
操作定义
┌─────────────────────────────────────────────────────────────────────────┐
│ AllReduce 操作 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 输入:每个节点有一份数据 x_i │
│ 输出:每个节点得到 reduce(x_0, x_1, ..., x_{n-1}) │
│ │
│ 示例 (4个节点,Sum 操作): │
│ │
│ Before: After: │
│ Node 0: [1, 2, 3, 4] Node 0: [10, 20, 30, 40] │
│ Node 1: [2, 4, 6, 8] Node 1: [10, 20, 30, 40] │
│ Node 2: [3, 6, 9, 12] Node 2: [10, 20, 30, 40] │
│ Node 3: [4, 8, 12, 16] Node 3: [10, 20, 30, 40] │
│ │
│ 支持的归约操作: │
│ ├── Sum: 求和 │
│ ├── Prod: 乘积 │
│ ├── Max: 最大值 │
│ ├── Min: 最小值 │
│ └── Avg: 平均值 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
复杂度分析
┌─────────────────────────────────────────────────────────────────────────┐
│ AllReduce 算法复杂度对比 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 设: N = 节点数, M = 数据大小, α = 启动延迟, β = 传输时间/字节 │
│ │
│ 算法 │ 延迟 │ 带宽 │ 总时间 │
│ ───────────────┼─────────────────┼───────────────┼─────────────────── │
│ Naive │ O(N) │ O(N) │ (N-1)(α + Mβ) │
│ Ring │ O(N) │ O(1) │ 2(N-1)α + 2M(N-1)/Nβ│
│ Tree │ O(log N) │ O(log N) │ 2log(N)α + 2Mβ │
│ Recursive HD │ O(log N) │ O(log N) │ log(N)α + 2Mβ │
│ Bucket │ O(N) │ O(1) │ 2(N-1)α + Mβ │
│ │
│ 选择指南: │
│ ├── 小消息 + 高延迟网络 → Tree / Recursive Halving-Doubling │
│ ├── 大消息 + 高带宽网络 → Ring │
│ └── 混合场景 → 分层算法 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Ring AllReduce
算法原理
┌─────────────────────────────────────────────────────────────────────────┐
│ Ring AllReduce 算法 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 阶段 1: Reduce-Scatter (N-1 轮) │
│ │
│ Step 0: Step 1: │
│ ┌───────────────────────┐ ┌───────────────────────┐ │
│ │ 0:[A0,B0,C0,D0] │ │ 0:[A0,B0,C0,D0+D3] │ │
│ │ ↓ D0 │ │ ↓ C0+C3 │ │
│ │ 1:[A1,B1,C1,D1] │ │ 1:[A1,B1,C1,D1+D0] │ │
│ │ ↓ A1 │ │ ↓ D1+D0 │ │
│ │ 2:[A2,B2,C2,D2] │ │ 2:[A2,B2,C2+C1,D2] │ │
│ │ ↓ B2 │ │ ↓ A2+A1 │ │
│ │ 3:[A3,B3,C3,D3] │ │ 3:[A3,B3+B2,C3,D3] │ │
│ │ ↓ C3 │ │ ↓ B3+B2 │ │
│ └───────────────────────┘ └───────────────────────┘ │
│ │
│ Step 2: After Reduce-Scatter: │
│ ┌───────────────────────┐ ┌───────────────────────┐ │
│ │ 0:[A*,B0,C0,D0] │ │ 0:[A*, -, -, -] │ │
│ │ 1:[-,B*,C1,D1] │ │ 1:[ -,B*, -, -] │ │
│ │ 2:[-,-,C*,D2] │ │ 2:[ -, -,C*, -] │ │
│ │ 3:[-,-,-,D*] │ │ 3:[ -, -, -,D*] │ │
│ └───────────────────────┘ └───────────────────────┘ │
│ │
│ * 表示完全归约的结果 │
│ │
│ 阶段 2: AllGather (N-1 轮) │
│ │
│ Step 0: Final Result: │
│ ┌───────────────────────┐ ┌───────────────────────┐ │
│ │ 0:[A*,B*,-, -] │ │ 0:[A*,B*,C*,D*] │ │
│ │ 1:[-,B*,C*,-] │ │ 1:[A*,B*,C*,D*] │ │
│ │ 2:[-,-,C*,D*] │ │ 2:[A*,B*,C*,D*] │ │
│ │ 3:[A*,-,-,D*] │ │ 3:[A*,B*,C*,D*] │ │
│ └───────────────────────┘ └───────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
完整实现
#include <mpi.h>
#include <cuda_runtime.h>
#include <vector>
// Ring AllReduce 实现
template<typename T>
class RingAllReduce {
public:
RingAllReduce(int rank, int worldSize, size_t count)
: rank_(rank), worldSize_(worldSize), count_(count) {
// 计算每个块的大小
chunkSize_ = (count + worldSize - 1) / worldSize;
// 分配发送/接收缓冲区
cudaMalloc(&sendBuff_, chunkSize_ * sizeof(T));
cudaMalloc(&recvBuff_, chunkSize_ * sizeof(T));
// 计算前驱和后继
prevRank_ = (rank - 1 + worldSize) % worldSize;
nextRank_ = (rank + 1) % worldSize;
}
~RingAllReduce() {
cudaFree(sendBuff_);
cudaFree(recvBuff_);
}
void execute(T* data, cudaStream_t stream) {
// Phase 1: Reduce-Scatter
reduceScatter(data, stream);
// Phase 2: AllGather
allGather(data, stream);
}
private:
void reduceScatter(T* data, cudaStream_t stream) {
for (int step = 0; step < worldSize_ - 1; step++) {
// 计算发送和接收的块索引
int sendChunk = (rank_ - step + worldSize_) % worldSize_;
int recvChunk = (rank_ - step - 1 + worldSize_) % worldSize_;
size_t sendOffset = sendChunk * chunkSize_;
size_t recvOffset = recvChunk * chunkSize_;
size_t thisChunkSize = getChunkSize(sendChunk);
// 复制数据到发送缓冲区
cudaMemcpyAsync(sendBuff_, data + sendOffset,
thisChunkSize * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
// 异步发送
MPI_Request sendReq, recvReq;
MPI_Isend(sendBuff_, thisChunkSize * sizeof(T), MPI_BYTE,
nextRank_, 0, MPI_COMM_WORLD, &sendReq);
// 异步接收
MPI_Irecv(recvBuff_, getChunkSize(recvChunk) * sizeof(T), MPI_BYTE,
prevRank_, 0, MPI_COMM_WORLD, &recvReq);
// 等待接收完成
MPI_Wait(&recvReq, MPI_STATUS_IGNORE);
// 归约:data[recvChunk] += recvBuff
reduceKernel<<<gridSize_, blockSize_, 0, stream>>>(
data + recvOffset, recvBuff_, getChunkSize(recvChunk));
// 等待发送完成
MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
cudaStreamSynchronize(stream);
}
}
void allGather(T* data, cudaStream_t stream) {
for (int step = 0; step < worldSize_ - 1; step++) {
// 计算发送和接收的块索引
int sendChunk = (rank_ - step + 1 + worldSize_) % worldSize_;
int recvChunk = (rank_ - step + worldSize_) % worldSize_;
size_t sendOffset = sendChunk * chunkSize_;
size_t recvOffset = recvChunk * chunkSize_;
// 复制数据到发送缓冲区
cudaMemcpyAsync(sendBuff_, data + sendOffset,
getChunkSize(sendChunk) * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
MPI_Request sendReq, recvReq;
MPI_Isend(sendBuff_, getChunkSize(sendChunk) * sizeof(T), MPI_BYTE,
nextRank_, 1, MPI_COMM_WORLD, &sendReq);
MPI_Irecv(recvBuff_, getChunkSize(recvChunk) * sizeof(T), MPI_BYTE,
prevRank_, 1, MPI_COMM_WORLD, &recvReq);
MPI_Wait(&recvReq, MPI_STATUS_IGNORE);
// 直接复制(不需要归约)
cudaMemcpyAsync(data + recvOffset, recvBuff_,
getChunkSize(recvChunk) * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
cudaStreamSynchronize(stream);
}
}
size_t getChunkSize(int chunk) {
if (chunk == worldSize_ - 1) {
return count_ - chunk * chunkSize_;
}
return chunkSize_;
}
private:
int rank_, worldSize_;
int prevRank_, nextRank_;
size_t count_, chunkSize_;
T *sendBuff_, *recvBuff_;
int gridSize_ = 256, blockSize_ = 256;
};
// 归约 Kernel
template<typename T>
__global__ void reduceKernel(T* dst, const T* src, size_t count) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < count) {
dst[idx] += src[idx];
}
}
Tree AllReduce
算法原理
┌─────────────────────────────────────────────────────────────────────────┐
│ Tree AllReduce 算法 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 二叉树结构 (8 个节点): │
│ │
│ 0 (root) │
│ / \ │
│ 1 2 │
│ / \ / \ │
│ 3 4 5 6 │
│ / │
│ 7 │
│ │
│ 阶段 1: Reduce (叶子 → 根) │
│ │
│ Step 1: 7→3, 4→1, 5→2, 6→2 │
│ Step 2: 3→1, 2→0 │
│ Step 3: 1→0 │
│ │
│ 每个非叶子节点:result = local + sum(children) │
│ │
│ 阶段 2: Broadcast (根 → 叶子) │
│ │
│ Step 1: 0→1, 0→2 │
│ Step 2: 1→3, 1→4, 2→5, 2→6 │
│ Step 3: 3→7 │
│ │
│ 优点:延迟 O(log N) │
│ 缺点:带宽利用率较低,根节点成为瓶颈 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
实现
// Tree AllReduce 实现
template<typename T>
class TreeAllReduce {
public:
TreeAllReduce(int rank, int worldSize, size_t count)
: rank_(rank), worldSize_(worldSize), count_(count) {
buildTree();
cudaMalloc(&buffer_, count * sizeof(T));
}
void execute(T* data, cudaStream_t stream) {
// Phase 1: Reduce (up)
reduce(data, stream);
// Phase 2: Broadcast (down)
broadcast(data, stream);
}
private:
void buildTree() {
// 构建二叉树
// 对于 rank r:
// - 父节点: (r - 1) / 2 (if r > 0)
// - 左子节点: 2r + 1
// - 右子节点: 2r + 2
if (rank_ == 0) {
parent_ = -1;
} else {
parent_ = (rank_ - 1) / 2;
}
int leftChild = 2 * rank_ + 1;
int rightChild = 2 * rank_ + 2;
if (leftChild < worldSize_) {
children_.push_back(leftChild);
}
if (rightChild < worldSize_) {
children_.push_back(rightChild);
}
// 计算树的深度
depth_ = 0;
int nodes = 1;
while (nodes < worldSize_) {
depth_++;
nodes = nodes * 2 + 1;
}
}
void reduce(T* data, cudaStream_t stream) {
// 从叶子节点开始,自底向上
// 叶子节点没有子节点,直接发送给父节点
// 接收所有子节点的数据并归约
for (int child : children_) {
MPI_Recv(buffer_, count_ * sizeof(T), MPI_BYTE,
child, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
// 归约
reduceKernel<<<256, 256, 0, stream>>>(data, buffer_, count_);
cudaStreamSynchronize(stream);
}
// 发送给父节点(非根节点)
if (parent_ >= 0) {
cudaMemcpy(buffer_, data, count_ * sizeof(T), cudaMemcpyDeviceToHost);
MPI_Send(buffer_, count_ * sizeof(T), MPI_BYTE,
parent_, 0, MPI_COMM_WORLD);
}
}
void broadcast(T* data, cudaStream_t stream) {
// 从根节点开始,自顶向下
// 接收父节点的数据(非根节点)
if (parent_ >= 0) {
MPI_Recv(buffer_, count_ * sizeof(T), MPI_BYTE,
parent_, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
cudaMemcpy(data, buffer_, count_ * sizeof(T), cudaMemcpyHostToDevice);
}
// 发送给所有子节点
cudaMemcpy(buffer_, data, count_ * sizeof(T), cudaMemcpyDeviceToHost);
for (int child : children_) {
MPI_Send(buffer_, count_ * sizeof(T), MPI_BYTE,
child, 1, MPI_COMM_WORLD);
}
}
private:
int rank_, worldSize_;
size_t count_;
int parent_;
std::vector<int> children_;
int depth_;
T* buffer_;
};
Recursive Halving-Doubling
算法原理
┌─────────────────────────────────────────────────────────────────────────┐
│ Recursive Halving-Doubling AllReduce │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 适用于节点数为 2^k 的情况 │
│ │
│ 阶段 1: Reduce-Scatter (Recursive Halving) │
│ │
│ Step 1: 距离 = 4 │
│ 0 ←→ 4: 交换并归约半边数据 │
│ 1 ←→ 5 │
│ 2 ←→ 6 │
│ 3 ←→ 7 │
│ │
│ 0: [A0+A4, B0+B4, -, -] 4: [-, -, C0+C4, D0+D4] │
│ 1: [A1+A5, B1+B5, -, -] 5: [-, -, C1+C5, D1+D5] │
│ 2: [A2+A6, B2+B6, -, -] 6: [-, -, C2+C6, D2+D6] │
│ 3: [A3+A7, B3+B7, -, -] 7: [-, -, C3+C7, D3+D7] │
│ │
│ Step 2: 距离 = 2 │
│ 0 ←→ 2: 0 得到 [A*, -, -, -] │
│ 1 ←→ 3: 1 得到 [-, B*, -, -] │
│ 4 ←→ 6: 4 得到 [-, -, C*, -] │
│ 5 ←→ 7: 5 得到 [-, -, -, D*] │
│ │
│ 阶段 2: AllGather (Recursive Doubling) │
│ │
│ Step 1: 距离 = 1 │
│ 0 ←→ 1: 交换结果 │
│ ... │
│ │
│ Step 2: 距离 = 2 │
│ 0 ←→ 2, 1 ←→ 3, ... │
│ │
│ Step 3: 距离 = 4 │
│ 0 ←→ 4, 1 ←→ 5, ... │
│ │
│ 优点:通信轮次 O(log N),带宽效率高 │
│ 缺点:实现复杂,节点数需要是 2^k │
│ │
└─────────────────────────────────────────────────────────────────────────┘
实现
template<typename T>
class RecursiveHalvingDoublingAllReduce {
public:
RecursiveHalvingDoublingAllReduce(int rank, int worldSize, size_t count)
: rank_(rank), worldSize_(worldSize), count_(count) {
// 检查是否是 2 的幂
assert((worldSize & (worldSize - 1)) == 0);
cudaMalloc(&sendBuff_, count * sizeof(T));
cudaMalloc(&recvBuff_, count * sizeof(T));
}
void execute(T* data, cudaStream_t stream) {
// Phase 1: Reduce-Scatter with Recursive Halving
recursiveHalving(data, stream);
// Phase 2: AllGather with Recursive Doubling
recursiveDoubling(data, stream);
}
private:
void recursiveHalving(T* data, cudaStream_t stream) {
size_t currentCount = count_;
// log2(worldSize) 轮
for (int mask = worldSize_ >> 1; mask > 0; mask >>= 1) {
int peer = rank_ ^ mask;
// 确定发送/接收哪半边
size_t halfCount = currentCount / 2;
size_t sendOffset, recvOffset;
if (rank_ < peer) {
// 发送后半部分,接收前半部分的归约结果
sendOffset = halfCount;
recvOffset = 0;
} else {
// 发送前半部分,接收后半部分的归约结果
sendOffset = 0;
recvOffset = halfCount;
}
// 发送
cudaMemcpyAsync(sendBuff_, data + sendOffset,
halfCount * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
cudaStreamSynchronize(stream);
MPI_Request sendReq, recvReq;
MPI_Isend(sendBuff_, halfCount * sizeof(T), MPI_BYTE,
peer, 0, MPI_COMM_WORLD, &sendReq);
MPI_Irecv(recvBuff_, halfCount * sizeof(T), MPI_BYTE,
peer, 0, MPI_COMM_WORLD, &recvReq);
MPI_Wait(&recvReq, MPI_STATUS_IGNORE);
// 归约到接收区域
reduceKernel<<<256, 256, 0, stream>>>(
data + recvOffset, recvBuff_, halfCount);
cudaStreamSynchronize(stream);
MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
// 更新有效数据区域
if (rank_ < peer) {
// 有效数据在前半部分
currentCount = halfCount;
} else {
// 有效数据在后半部分,移动到开头
cudaMemcpyAsync(data, data + halfCount,
halfCount * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
currentCount = halfCount;
}
}
}
void recursiveDoubling(T* data, cudaStream_t stream) {
size_t currentCount = count_ / worldSize_;
// log2(worldSize) 轮
for (int mask = 1; mask < worldSize_; mask <<= 1) {
int peer = rank_ ^ mask;
// 计算发送/接收位置
size_t sendOffset = 0;
size_t recvOffset = currentCount;
if (rank_ > peer) {
std::swap(sendOffset, recvOffset);
}
// 交换数据
cudaMemcpyAsync(sendBuff_, data + sendOffset,
currentCount * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
cudaStreamSynchronize(stream);
MPI_Request sendReq, recvReq;
MPI_Isend(sendBuff_, currentCount * sizeof(T), MPI_BYTE,
peer, 1, MPI_COMM_WORLD, &sendReq);
MPI_Irecv(recvBuff_, currentCount * sizeof(T), MPI_BYTE,
peer, 1, MPI_COMM_WORLD, &recvReq);
MPI_Wait(&recvReq, MPI_STATUS_IGNORE);
cudaMemcpyAsync(data + recvOffset, recvBuff_,
currentCount * sizeof(T),
cudaMemcpyDeviceToDevice, stream);
cudaStreamSynchronize(stream);
MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
currentCount *= 2;
}
}
private:
int rank_, worldSize_;
size_t count_;
T *sendBuff_, *recvBuff_;
};
分层 AllReduce (Hierarchical)
算法原理
┌─────────────────────────────────────────────────────────────────────────┐
│ 分层 AllReduce 算法 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 适用于多节点多卡环境 │
│ │
│ Node 0 Node 1 Node 2 Node 3 │
│ ┌───────┐ ┌───────┐ ┌───────┐ ┌───────┐ │
│ │GPU0 1 │ │GPU0 1 │ │GPU0 1 │ │GPU0 1 │ │
│ │ 2 3 │ │ 2 3 │ │ 2 3 │ │ 2 3 │ │
│ └───────┘ └───────┘ └───────┘ └───────┘ │
│ │
│ 阶段 1: 节点内 AllReduce (使用 NVLink/PCIe, Ring) │
│ - 每个节点内 4 个 GPU 做 Ring AllReduce │
│ - 带宽: NVLink 600GB/s >> 网络 100GB/s │
│ │
│ 阶段 2: 节点间 AllReduce (使用网络, Ring/Tree) │
│ - 每个节点选择一个 GPU (通常 GPU0) 作为代表 │
│ - 4 个节点的代表 GPU 做 AllReduce │
│ │
│ 阶段 3: 节点内 Broadcast (使用 NVLink/PCIe) │
│ - 每个节点内,代表 GPU 广播结果给其他 GPU │
│ │
│ 优化变体: │
│ ├── Two-level Ring: 节点内 Ring + 节点间 Ring │
│ ├── NCCL Hierarchical: 自动检测拓扑选择最优策略 │
│ └── Bucket: 节点间使用 Reduce-Scatter + AllGather │
│ │
└─────────────────────────────────────────────────────────────────────────┘
实现
template<typename T>
class HierarchicalAllReduce {
public:
HierarchicalAllReduce(
int globalRank,
int worldSize,
int localRank,
int localSize,
size_t count
) : globalRank_(globalRank),
worldSize_(worldSize),
localRank_(localRank),
localSize_(localSize),
count_(count) {
nodeId_ = globalRank / localSize;
numNodes_ = worldSize / localSize;
// 创建节点内通信域
MPI_Comm_split(MPI_COMM_WORLD, nodeId_, localRank, &localComm_);
// 创建节点间通信域(仅 local rank 0 参与)
int interColor = (localRank == 0) ? 0 : MPI_UNDEFINED;
MPI_Comm_split(MPI_COMM_WORLD, interColor, nodeId_, &interComm_);
cudaMalloc(&buffer_, count * sizeof(T));
}
void execute(T* data, cudaStream_t stream) {
// Phase 1: 节点内 AllReduce
intraNodeAllReduce(data, stream);
// Phase 2: 节点间 AllReduce (仅 local rank 0)
if (localRank_ == 0) {
interNodeAllReduce(data, stream);
}
// Phase 3: 节点内 Broadcast
intraNodeBroadcast(data, stream);
}
private:
void intraNodeAllReduce(T* data, cudaStream_t stream) {
// 使用 NCCL 进行节点内 AllReduce
// 利用 NVLink 高带宽
ncclAllReduce(data, data, count_, getNcclDataType<T>(),
ncclSum, localNcclComm_, stream);
cudaStreamSynchronize(stream);
}
void interNodeAllReduce(T* data, cudaStream_t stream) {
// 节点间使用 MPI 或网络 NCCL
cudaMemcpy(buffer_, data, count_ * sizeof(T), cudaMemcpyDeviceToHost);
// 使用 Ring AllReduce
T* sendBuff = buffer_;
T* recvBuff = new T[count_];
int prevNode = (nodeId_ - 1 + numNodes_) % numNodes_;
int nextNode = (nodeId_ + 1) % numNodes_;
size_t chunkSize = count_ / numNodes_;
// Reduce-Scatter
for (int step = 0; step < numNodes_ - 1; step++) {
int sendChunk = (nodeId_ - step + numNodes_) % numNodes_;
int recvChunk = (nodeId_ - step - 1 + numNodes_) % numNodes_;
MPI_Request sendReq, recvReq;
MPI_Isend(sendBuff + sendChunk * chunkSize,
chunkSize * sizeof(T), MPI_BYTE,
nextNode * localSize_, 0, MPI_COMM_WORLD, &sendReq);
MPI_Irecv(recvBuff, chunkSize * sizeof(T), MPI_BYTE,
prevNode * localSize_, 0, MPI_COMM_WORLD, &recvReq);
MPI_Wait(&recvReq, MPI_STATUS_IGNORE);
// 归约
for (size_t i = 0; i < chunkSize; i++) {
sendBuff[recvChunk * chunkSize + i] += recvBuff[i];
}
MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
}
// AllGather (类似)
// ...
cudaMemcpy(data, buffer_, count_ * sizeof(T), cudaMemcpyHostToDevice);
delete[] recvBuff;
}
void intraNodeBroadcast(T* data, cudaStream_t stream) {
// 使用 NCCL Broadcast
ncclBroadcast(data, data, count_, getNcclDataType<T>(),
0, localNcclComm_, stream);
cudaStreamSynchronize(stream);
}
private:
int globalRank_, worldSize_;
int localRank_, localSize_;
int nodeId_, numNodes_;
size_t count_;
MPI_Comm localComm_, interComm_;
ncclComm_t localNcclComm_;
T* buffer_;
};
性能分析与选择
算法选择策略
def select_allreduce_algorithm(
num_nodes: int,
gpus_per_node: int,
message_size: int,
network_bandwidth: float, # GB/s
nvlink_bandwidth: float, # GB/s
network_latency: float, # us
):
"""
选择最优 AllReduce 算法
参数:
num_nodes: 节点数
gpus_per_node: 每节点 GPU 数
message_size: 消息大小 (bytes)
network_bandwidth: 网络带宽
nvlink_bandwidth: NVLink 带宽
network_latency: 网络延迟
"""
total_gpus = num_nodes * gpus_per_node
# 计算各算法的预估时间
alpha = network_latency * 1e-6 # 转换为秒
beta = 1 / (network_bandwidth * 1e9) # 秒/字节
# Ring AllReduce
ring_time = 2 * (total_gpus - 1) * alpha + \
2 * message_size * (total_gpus - 1) / total_gpus * beta
# Tree AllReduce
import math
tree_depth = math.ceil(math.log2(total_gpus))
tree_time = 2 * tree_depth * alpha + 2 * message_size * beta
# Hierarchical AllReduce
hier_intra = 2 * (gpus_per_node - 1) * (1e-6) + \
2 * message_size * (gpus_per_node - 1) / gpus_per_node / (nvlink_bandwidth * 1e9)
hier_inter = 2 * (num_nodes - 1) * alpha + \
2 * message_size * (num_nodes - 1) / num_nodes * beta
hier_time = hier_intra + hier_inter
# 选择最优
times = {
'ring': ring_time,
'tree': tree_time,
'hierarchical': hier_time
}
best = min(times, key=times.get)
# 特殊情况处理
if message_size < 1024 * 1024: # < 1MB
# 小消息优先考虑延迟
if tree_time < ring_time * 0.8:
return 'tree'
if num_nodes > 1 and gpus_per_node > 1:
# 多节点多卡优先考虑分层
if hier_time < ring_time * 1.2:
return 'hierarchical'
return best
总结
算法对比
| 算法 | 延迟 | 带宽效率 | 适用场景 |
|---|---|---|---|
| Ring | O(N) | 最优 | 大消息、高带宽 |
| Tree | O(log N) | 较低 | 小消息、低延迟 |
| Recursive HD | O(log N) | 高 | 2^k 节点 |
| Hierarchical | 分层 | 高 | 多节点多卡 |
实践建议
□ 大消息 (> 1MB) 使用 Ring
□ 小消息 (< 1MB) 使用 Tree
□ 多节点多卡使用 Hierarchical
□ 利用 NCCL 自动选择算法
□ 考虑通信与计算重叠