HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

02-AllReduce 算法实现

概述

AllReduce 是分布式训练中最核心的集合通信操作,用于在所有参与者之间同步梯度。本文深入分析各种 AllReduce 算法的原理、实现和性能特点。

AllReduce 基础

操作定义

┌─────────────────────────────────────────────────────────────────────────┐
│                      AllReduce 操作                                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  输入:每个节点有一份数据 x_i                                             │
│  输出:每个节点得到 reduce(x_0, x_1, ..., x_{n-1})                       │
│                                                                         │
│  示例 (4个节点,Sum 操作):                                               │
│                                                                         │
│  Before:                     After:                                     │
│  Node 0: [1, 2, 3, 4]       Node 0: [10, 20, 30, 40]                   │
│  Node 1: [2, 4, 6, 8]       Node 1: [10, 20, 30, 40]                   │
│  Node 2: [3, 6, 9, 12]      Node 2: [10, 20, 30, 40]                   │
│  Node 3: [4, 8, 12, 16]     Node 3: [10, 20, 30, 40]                   │
│                                                                         │
│  支持的归约操作:                                                        │
│  ├── Sum: 求和                                                          │
│  ├── Prod: 乘积                                                         │
│  ├── Max: 最大值                                                        │
│  ├── Min: 最小值                                                        │
│  └── Avg: 平均值                                                        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

复杂度分析

┌─────────────────────────────────────────────────────────────────────────┐
│                    AllReduce 算法复杂度对比                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  设: N = 节点数, M = 数据大小, α = 启动延迟, β = 传输时间/字节           │
│                                                                         │
│  算法            │ 延迟            │ 带宽           │ 总时间              │
│  ───────────────┼─────────────────┼───────────────┼─────────────────── │
│  Naive          │ O(N)            │ O(N)          │ (N-1)(α + Mβ)      │
│  Ring           │ O(N)            │ O(1)          │ 2(N-1)α + 2M(N-1)/Nβ│
│  Tree           │ O(log N)        │ O(log N)      │ 2log(N)α + 2Mβ     │
│  Recursive HD   │ O(log N)        │ O(log N)      │ log(N)α + 2Mβ      │
│  Bucket         │ O(N)            │ O(1)          │ 2(N-1)α + Mβ       │
│                                                                         │
│  选择指南:                                                              │
│  ├── 小消息 + 高延迟网络 → Tree / Recursive Halving-Doubling           │
│  ├── 大消息 + 高带宽网络 → Ring                                         │
│  └── 混合场景 → 分层算法                                                │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

Ring AllReduce

算法原理

┌─────────────────────────────────────────────────────────────────────────┐
│                    Ring AllReduce 算法                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  阶段 1: Reduce-Scatter (N-1 轮)                                        │
│                                                                         │
│  Step 0:                          Step 1:                              │
│  ┌───────────────────────┐       ┌───────────────────────┐            │
│  │ 0:[A0,B0,C0,D0]       │       │ 0:[A0,B0,C0,D0+D3]    │            │
│  │        ↓ D0           │       │        ↓ C0+C3        │            │
│  │ 1:[A1,B1,C1,D1]       │       │ 1:[A1,B1,C1,D1+D0]    │            │
│  │        ↓ A1           │       │        ↓ D1+D0        │            │
│  │ 2:[A2,B2,C2,D2]       │       │ 2:[A2,B2,C2+C1,D2]    │            │
│  │        ↓ B2           │       │        ↓ A2+A1        │            │
│  │ 3:[A3,B3,C3,D3]       │       │ 3:[A3,B3+B2,C3,D3]    │            │
│  │        ↓ C3           │       │        ↓ B3+B2        │            │
│  └───────────────────────┘       └───────────────────────┘            │
│                                                                         │
│  Step 2:                          After Reduce-Scatter:                │
│  ┌───────────────────────┐       ┌───────────────────────┐            │
│  │ 0:[A*,B0,C0,D0]       │       │ 0:[A*, -, -, -]       │            │
│  │ 1:[-,B*,C1,D1]        │       │ 1:[ -,B*, -, -]       │            │
│  │ 2:[-,-,C*,D2]         │       │ 2:[ -, -,C*, -]       │            │
│  │ 3:[-,-,-,D*]          │       │ 3:[ -, -, -,D*]       │            │
│  └───────────────────────┘       └───────────────────────┘            │
│                                                                         │
│  * 表示完全归约的结果                                                    │
│                                                                         │
│  阶段 2: AllGather (N-1 轮)                                             │
│                                                                         │
│  Step 0:                          Final Result:                        │
│  ┌───────────────────────┐       ┌───────────────────────┐            │
│  │ 0:[A*,B*,-, -]        │       │ 0:[A*,B*,C*,D*]       │            │
│  │ 1:[-,B*,C*,-]         │       │ 1:[A*,B*,C*,D*]       │            │
│  │ 2:[-,-,C*,D*]         │       │ 2:[A*,B*,C*,D*]       │            │
│  │ 3:[A*,-,-,D*]         │       │ 3:[A*,B*,C*,D*]       │            │
│  └───────────────────────┘       └───────────────────────┘            │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

完整实现

#include <mpi.h>
#include <cuda_runtime.h>
#include <vector>

// Ring AllReduce 实现
template<typename T>
class RingAllReduce {
public:
    RingAllReduce(int rank, int worldSize, size_t count)
        : rank_(rank), worldSize_(worldSize), count_(count) {
        // 计算每个块的大小
        chunkSize_ = (count + worldSize - 1) / worldSize;

        // 分配发送/接收缓冲区
        cudaMalloc(&sendBuff_, chunkSize_ * sizeof(T));
        cudaMalloc(&recvBuff_, chunkSize_ * sizeof(T));

        // 计算前驱和后继
        prevRank_ = (rank - 1 + worldSize) % worldSize;
        nextRank_ = (rank + 1) % worldSize;
    }

    ~RingAllReduce() {
        cudaFree(sendBuff_);
        cudaFree(recvBuff_);
    }

    void execute(T* data, cudaStream_t stream) {
        // Phase 1: Reduce-Scatter
        reduceScatter(data, stream);

        // Phase 2: AllGather
        allGather(data, stream);
    }

private:
    void reduceScatter(T* data, cudaStream_t stream) {
        for (int step = 0; step < worldSize_ - 1; step++) {
            // 计算发送和接收的块索引
            int sendChunk = (rank_ - step + worldSize_) % worldSize_;
            int recvChunk = (rank_ - step - 1 + worldSize_) % worldSize_;

            size_t sendOffset = sendChunk * chunkSize_;
            size_t recvOffset = recvChunk * chunkSize_;
            size_t thisChunkSize = getChunkSize(sendChunk);

            // 复制数据到发送缓冲区
            cudaMemcpyAsync(sendBuff_, data + sendOffset,
                           thisChunkSize * sizeof(T),
                           cudaMemcpyDeviceToDevice, stream);

            // 异步发送
            MPI_Request sendReq, recvReq;
            MPI_Isend(sendBuff_, thisChunkSize * sizeof(T), MPI_BYTE,
                     nextRank_, 0, MPI_COMM_WORLD, &sendReq);

            // 异步接收
            MPI_Irecv(recvBuff_, getChunkSize(recvChunk) * sizeof(T), MPI_BYTE,
                     prevRank_, 0, MPI_COMM_WORLD, &recvReq);

            // 等待接收完成
            MPI_Wait(&recvReq, MPI_STATUS_IGNORE);

            // 归约:data[recvChunk] += recvBuff
            reduceKernel<<<gridSize_, blockSize_, 0, stream>>>(
                data + recvOffset, recvBuff_, getChunkSize(recvChunk));

            // 等待发送完成
            MPI_Wait(&sendReq, MPI_STATUS_IGNORE);

            cudaStreamSynchronize(stream);
        }
    }

    void allGather(T* data, cudaStream_t stream) {
        for (int step = 0; step < worldSize_ - 1; step++) {
            // 计算发送和接收的块索引
            int sendChunk = (rank_ - step + 1 + worldSize_) % worldSize_;
            int recvChunk = (rank_ - step + worldSize_) % worldSize_;

            size_t sendOffset = sendChunk * chunkSize_;
            size_t recvOffset = recvChunk * chunkSize_;

            // 复制数据到发送缓冲区
            cudaMemcpyAsync(sendBuff_, data + sendOffset,
                           getChunkSize(sendChunk) * sizeof(T),
                           cudaMemcpyDeviceToDevice, stream);

            MPI_Request sendReq, recvReq;
            MPI_Isend(sendBuff_, getChunkSize(sendChunk) * sizeof(T), MPI_BYTE,
                     nextRank_, 1, MPI_COMM_WORLD, &sendReq);

            MPI_Irecv(recvBuff_, getChunkSize(recvChunk) * sizeof(T), MPI_BYTE,
                     prevRank_, 1, MPI_COMM_WORLD, &recvReq);

            MPI_Wait(&recvReq, MPI_STATUS_IGNORE);

            // 直接复制(不需要归约)
            cudaMemcpyAsync(data + recvOffset, recvBuff_,
                           getChunkSize(recvChunk) * sizeof(T),
                           cudaMemcpyDeviceToDevice, stream);

            MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
            cudaStreamSynchronize(stream);
        }
    }

    size_t getChunkSize(int chunk) {
        if (chunk == worldSize_ - 1) {
            return count_ - chunk * chunkSize_;
        }
        return chunkSize_;
    }

private:
    int rank_, worldSize_;
    int prevRank_, nextRank_;
    size_t count_, chunkSize_;
    T *sendBuff_, *recvBuff_;
    int gridSize_ = 256, blockSize_ = 256;
};

// 归约 Kernel
template<typename T>
__global__ void reduceKernel(T* dst, const T* src, size_t count) {
    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < count) {
        dst[idx] += src[idx];
    }
}

Tree AllReduce

算法原理

┌─────────────────────────────────────────────────────────────────────────┐
│                    Tree AllReduce 算法                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  二叉树结构 (8 个节点):                                                  │
│                                                                         │
│                        0 (root)                                         │
│                       / \                                               │
│                      1   2                                              │
│                     / \ / \                                             │
│                    3  4 5  6                                            │
│                   /                                                     │
│                  7                                                      │
│                                                                         │
│  阶段 1: Reduce (叶子 → 根)                                             │
│                                                                         │
│  Step 1: 7→3, 4→1, 5→2, 6→2                                            │
│  Step 2: 3→1, 2→0                                                       │
│  Step 3: 1→0                                                            │
│                                                                         │
│  每个非叶子节点:result = local + sum(children)                          │
│                                                                         │
│  阶段 2: Broadcast (根 → 叶子)                                          │
│                                                                         │
│  Step 1: 0→1, 0→2                                                       │
│  Step 2: 1→3, 1→4, 2→5, 2→6                                            │
│  Step 3: 3→7                                                            │
│                                                                         │
│  优点:延迟 O(log N)                                                    │
│  缺点:带宽利用率较低,根节点成为瓶颈                                     │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

实现

// Tree AllReduce 实现
template<typename T>
class TreeAllReduce {
public:
    TreeAllReduce(int rank, int worldSize, size_t count)
        : rank_(rank), worldSize_(worldSize), count_(count) {
        buildTree();
        cudaMalloc(&buffer_, count * sizeof(T));
    }

    void execute(T* data, cudaStream_t stream) {
        // Phase 1: Reduce (up)
        reduce(data, stream);

        // Phase 2: Broadcast (down)
        broadcast(data, stream);
    }

private:
    void buildTree() {
        // 构建二叉树
        // 对于 rank r:
        // - 父节点: (r - 1) / 2 (if r > 0)
        // - 左子节点: 2r + 1
        // - 右子节点: 2r + 2

        if (rank_ == 0) {
            parent_ = -1;
        } else {
            parent_ = (rank_ - 1) / 2;
        }

        int leftChild = 2 * rank_ + 1;
        int rightChild = 2 * rank_ + 2;

        if (leftChild < worldSize_) {
            children_.push_back(leftChild);
        }
        if (rightChild < worldSize_) {
            children_.push_back(rightChild);
        }

        // 计算树的深度
        depth_ = 0;
        int nodes = 1;
        while (nodes < worldSize_) {
            depth_++;
            nodes = nodes * 2 + 1;
        }
    }

    void reduce(T* data, cudaStream_t stream) {
        // 从叶子节点开始,自底向上
        // 叶子节点没有子节点,直接发送给父节点

        // 接收所有子节点的数据并归约
        for (int child : children_) {
            MPI_Recv(buffer_, count_ * sizeof(T), MPI_BYTE,
                    child, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);

            // 归约
            reduceKernel<<<256, 256, 0, stream>>>(data, buffer_, count_);
            cudaStreamSynchronize(stream);
        }

        // 发送给父节点(非根节点)
        if (parent_ >= 0) {
            cudaMemcpy(buffer_, data, count_ * sizeof(T), cudaMemcpyDeviceToHost);
            MPI_Send(buffer_, count_ * sizeof(T), MPI_BYTE,
                    parent_, 0, MPI_COMM_WORLD);
        }
    }

    void broadcast(T* data, cudaStream_t stream) {
        // 从根节点开始,自顶向下

        // 接收父节点的数据(非根节点)
        if (parent_ >= 0) {
            MPI_Recv(buffer_, count_ * sizeof(T), MPI_BYTE,
                    parent_, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
            cudaMemcpy(data, buffer_, count_ * sizeof(T), cudaMemcpyHostToDevice);
        }

        // 发送给所有子节点
        cudaMemcpy(buffer_, data, count_ * sizeof(T), cudaMemcpyDeviceToHost);
        for (int child : children_) {
            MPI_Send(buffer_, count_ * sizeof(T), MPI_BYTE,
                    child, 1, MPI_COMM_WORLD);
        }
    }

private:
    int rank_, worldSize_;
    size_t count_;
    int parent_;
    std::vector<int> children_;
    int depth_;
    T* buffer_;
};

Recursive Halving-Doubling

算法原理

┌─────────────────────────────────────────────────────────────────────────┐
│              Recursive Halving-Doubling AllReduce                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  适用于节点数为 2^k 的情况                                               │
│                                                                         │
│  阶段 1: Reduce-Scatter (Recursive Halving)                             │
│                                                                         │
│  Step 1: 距离 = 4                                                       │
│  0 ←→ 4: 交换并归约半边数据                                             │
│  1 ←→ 5                                                                 │
│  2 ←→ 6                                                                 │
│  3 ←→ 7                                                                 │
│                                                                         │
│  0: [A0+A4, B0+B4, -, -]    4: [-, -, C0+C4, D0+D4]                     │
│  1: [A1+A5, B1+B5, -, -]    5: [-, -, C1+C5, D1+D5]                     │
│  2: [A2+A6, B2+B6, -, -]    6: [-, -, C2+C6, D2+D6]                     │
│  3: [A3+A7, B3+B7, -, -]    7: [-, -, C3+C7, D3+D7]                     │
│                                                                         │
│  Step 2: 距离 = 2                                                       │
│  0 ←→ 2: 0 得到 [A*, -, -, -]                                          │
│  1 ←→ 3: 1 得到 [-, B*, -, -]                                          │
│  4 ←→ 6: 4 得到 [-, -, C*, -]                                          │
│  5 ←→ 7: 5 得到 [-, -, -, D*]                                          │
│                                                                         │
│  阶段 2: AllGather (Recursive Doubling)                                 │
│                                                                         │
│  Step 1: 距离 = 1                                                       │
│  0 ←→ 1: 交换结果                                                       │
│  ...                                                                    │
│                                                                         │
│  Step 2: 距离 = 2                                                       │
│  0 ←→ 2, 1 ←→ 3, ...                                                   │
│                                                                         │
│  Step 3: 距离 = 4                                                       │
│  0 ←→ 4, 1 ←→ 5, ...                                                   │
│                                                                         │
│  优点:通信轮次 O(log N),带宽效率高                                     │
│  缺点:实现复杂,节点数需要是 2^k                                        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

实现

template<typename T>
class RecursiveHalvingDoublingAllReduce {
public:
    RecursiveHalvingDoublingAllReduce(int rank, int worldSize, size_t count)
        : rank_(rank), worldSize_(worldSize), count_(count) {
        // 检查是否是 2 的幂
        assert((worldSize & (worldSize - 1)) == 0);

        cudaMalloc(&sendBuff_, count * sizeof(T));
        cudaMalloc(&recvBuff_, count * sizeof(T));
    }

    void execute(T* data, cudaStream_t stream) {
        // Phase 1: Reduce-Scatter with Recursive Halving
        recursiveHalving(data, stream);

        // Phase 2: AllGather with Recursive Doubling
        recursiveDoubling(data, stream);
    }

private:
    void recursiveHalving(T* data, cudaStream_t stream) {
        size_t currentCount = count_;

        // log2(worldSize) 轮
        for (int mask = worldSize_ >> 1; mask > 0; mask >>= 1) {
            int peer = rank_ ^ mask;

            // 确定发送/接收哪半边
            size_t halfCount = currentCount / 2;
            size_t sendOffset, recvOffset;

            if (rank_ < peer) {
                // 发送后半部分,接收前半部分的归约结果
                sendOffset = halfCount;
                recvOffset = 0;
            } else {
                // 发送前半部分,接收后半部分的归约结果
                sendOffset = 0;
                recvOffset = halfCount;
            }

            // 发送
            cudaMemcpyAsync(sendBuff_, data + sendOffset,
                           halfCount * sizeof(T),
                           cudaMemcpyDeviceToDevice, stream);
            cudaStreamSynchronize(stream);

            MPI_Request sendReq, recvReq;
            MPI_Isend(sendBuff_, halfCount * sizeof(T), MPI_BYTE,
                     peer, 0, MPI_COMM_WORLD, &sendReq);
            MPI_Irecv(recvBuff_, halfCount * sizeof(T), MPI_BYTE,
                     peer, 0, MPI_COMM_WORLD, &recvReq);

            MPI_Wait(&recvReq, MPI_STATUS_IGNORE);

            // 归约到接收区域
            reduceKernel<<<256, 256, 0, stream>>>(
                data + recvOffset, recvBuff_, halfCount);
            cudaStreamSynchronize(stream);

            MPI_Wait(&sendReq, MPI_STATUS_IGNORE);

            // 更新有效数据区域
            if (rank_ < peer) {
                // 有效数据在前半部分
                currentCount = halfCount;
            } else {
                // 有效数据在后半部分,移动到开头
                cudaMemcpyAsync(data, data + halfCount,
                               halfCount * sizeof(T),
                               cudaMemcpyDeviceToDevice, stream);
                currentCount = halfCount;
            }
        }
    }

    void recursiveDoubling(T* data, cudaStream_t stream) {
        size_t currentCount = count_ / worldSize_;

        // log2(worldSize) 轮
        for (int mask = 1; mask < worldSize_; mask <<= 1) {
            int peer = rank_ ^ mask;

            // 计算发送/接收位置
            size_t sendOffset = 0;
            size_t recvOffset = currentCount;

            if (rank_ > peer) {
                std::swap(sendOffset, recvOffset);
            }

            // 交换数据
            cudaMemcpyAsync(sendBuff_, data + sendOffset,
                           currentCount * sizeof(T),
                           cudaMemcpyDeviceToDevice, stream);
            cudaStreamSynchronize(stream);

            MPI_Request sendReq, recvReq;
            MPI_Isend(sendBuff_, currentCount * sizeof(T), MPI_BYTE,
                     peer, 1, MPI_COMM_WORLD, &sendReq);
            MPI_Irecv(recvBuff_, currentCount * sizeof(T), MPI_BYTE,
                     peer, 1, MPI_COMM_WORLD, &recvReq);

            MPI_Wait(&recvReq, MPI_STATUS_IGNORE);

            cudaMemcpyAsync(data + recvOffset, recvBuff_,
                           currentCount * sizeof(T),
                           cudaMemcpyDeviceToDevice, stream);
            cudaStreamSynchronize(stream);

            MPI_Wait(&sendReq, MPI_STATUS_IGNORE);

            currentCount *= 2;
        }
    }

private:
    int rank_, worldSize_;
    size_t count_;
    T *sendBuff_, *recvBuff_;
};

分层 AllReduce (Hierarchical)

算法原理

┌─────────────────────────────────────────────────────────────────────────┐
│                    分层 AllReduce 算法                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  适用于多节点多卡环境                                                    │
│                                                                         │
│  Node 0           Node 1           Node 2           Node 3             │
│  ┌───────┐       ┌───────┐       ┌───────┐       ┌───────┐            │
│  │GPU0 1 │       │GPU0 1 │       │GPU0 1 │       │GPU0 1 │            │
│  │  2  3 │       │  2  3 │       │  2  3 │       │  2  3 │            │
│  └───────┘       └───────┘       └───────┘       └───────┘            │
│                                                                         │
│  阶段 1: 节点内 AllReduce (使用 NVLink/PCIe, Ring)                      │
│  - 每个节点内 4 个 GPU 做 Ring AllReduce                                │
│  - 带宽: NVLink 600GB/s >> 网络 100GB/s                                │
│                                                                         │
│  阶段 2: 节点间 AllReduce (使用网络, Ring/Tree)                         │
│  - 每个节点选择一个 GPU (通常 GPU0) 作为代表                            │
│  - 4 个节点的代表 GPU 做 AllReduce                                      │
│                                                                         │
│  阶段 3: 节点内 Broadcast (使用 NVLink/PCIe)                            │
│  - 每个节点内,代表 GPU 广播结果给其他 GPU                              │
│                                                                         │
│  优化变体:                                                              │
│  ├── Two-level Ring: 节点内 Ring + 节点间 Ring                         │
│  ├── NCCL Hierarchical: 自动检测拓扑选择最优策略                        │
│  └── Bucket: 节点间使用 Reduce-Scatter + AllGather                     │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

实现

template<typename T>
class HierarchicalAllReduce {
public:
    HierarchicalAllReduce(
        int globalRank,
        int worldSize,
        int localRank,
        int localSize,
        size_t count
    ) : globalRank_(globalRank),
        worldSize_(worldSize),
        localRank_(localRank),
        localSize_(localSize),
        count_(count) {

        nodeId_ = globalRank / localSize;
        numNodes_ = worldSize / localSize;

        // 创建节点内通信域
        MPI_Comm_split(MPI_COMM_WORLD, nodeId_, localRank, &localComm_);

        // 创建节点间通信域(仅 local rank 0 参与)
        int interColor = (localRank == 0) ? 0 : MPI_UNDEFINED;
        MPI_Comm_split(MPI_COMM_WORLD, interColor, nodeId_, &interComm_);

        cudaMalloc(&buffer_, count * sizeof(T));
    }

    void execute(T* data, cudaStream_t stream) {
        // Phase 1: 节点内 AllReduce
        intraNodeAllReduce(data, stream);

        // Phase 2: 节点间 AllReduce (仅 local rank 0)
        if (localRank_ == 0) {
            interNodeAllReduce(data, stream);
        }

        // Phase 3: 节点内 Broadcast
        intraNodeBroadcast(data, stream);
    }

private:
    void intraNodeAllReduce(T* data, cudaStream_t stream) {
        // 使用 NCCL 进行节点内 AllReduce
        // 利用 NVLink 高带宽
        ncclAllReduce(data, data, count_, getNcclDataType<T>(),
                     ncclSum, localNcclComm_, stream);
        cudaStreamSynchronize(stream);
    }

    void interNodeAllReduce(T* data, cudaStream_t stream) {
        // 节点间使用 MPI 或网络 NCCL
        cudaMemcpy(buffer_, data, count_ * sizeof(T), cudaMemcpyDeviceToHost);

        // 使用 Ring AllReduce
        T* sendBuff = buffer_;
        T* recvBuff = new T[count_];

        int prevNode = (nodeId_ - 1 + numNodes_) % numNodes_;
        int nextNode = (nodeId_ + 1) % numNodes_;
        size_t chunkSize = count_ / numNodes_;

        // Reduce-Scatter
        for (int step = 0; step < numNodes_ - 1; step++) {
            int sendChunk = (nodeId_ - step + numNodes_) % numNodes_;
            int recvChunk = (nodeId_ - step - 1 + numNodes_) % numNodes_;

            MPI_Request sendReq, recvReq;
            MPI_Isend(sendBuff + sendChunk * chunkSize,
                     chunkSize * sizeof(T), MPI_BYTE,
                     nextNode * localSize_, 0, MPI_COMM_WORLD, &sendReq);
            MPI_Irecv(recvBuff, chunkSize * sizeof(T), MPI_BYTE,
                     prevNode * localSize_, 0, MPI_COMM_WORLD, &recvReq);

            MPI_Wait(&recvReq, MPI_STATUS_IGNORE);

            // 归约
            for (size_t i = 0; i < chunkSize; i++) {
                sendBuff[recvChunk * chunkSize + i] += recvBuff[i];
            }

            MPI_Wait(&sendReq, MPI_STATUS_IGNORE);
        }

        // AllGather (类似)
        // ...

        cudaMemcpy(data, buffer_, count_ * sizeof(T), cudaMemcpyHostToDevice);
        delete[] recvBuff;
    }

    void intraNodeBroadcast(T* data, cudaStream_t stream) {
        // 使用 NCCL Broadcast
        ncclBroadcast(data, data, count_, getNcclDataType<T>(),
                     0, localNcclComm_, stream);
        cudaStreamSynchronize(stream);
    }

private:
    int globalRank_, worldSize_;
    int localRank_, localSize_;
    int nodeId_, numNodes_;
    size_t count_;
    MPI_Comm localComm_, interComm_;
    ncclComm_t localNcclComm_;
    T* buffer_;
};

性能分析与选择

算法选择策略

def select_allreduce_algorithm(
    num_nodes: int,
    gpus_per_node: int,
    message_size: int,
    network_bandwidth: float,  # GB/s
    nvlink_bandwidth: float,   # GB/s
    network_latency: float,    # us
):
    """
    选择最优 AllReduce 算法

    参数:
        num_nodes: 节点数
        gpus_per_node: 每节点 GPU 数
        message_size: 消息大小 (bytes)
        network_bandwidth: 网络带宽
        nvlink_bandwidth: NVLink 带宽
        network_latency: 网络延迟
    """
    total_gpus = num_nodes * gpus_per_node

    # 计算各算法的预估时间
    alpha = network_latency * 1e-6  # 转换为秒
    beta = 1 / (network_bandwidth * 1e9)  # 秒/字节

    # Ring AllReduce
    ring_time = 2 * (total_gpus - 1) * alpha + \
                2 * message_size * (total_gpus - 1) / total_gpus * beta

    # Tree AllReduce
    import math
    tree_depth = math.ceil(math.log2(total_gpus))
    tree_time = 2 * tree_depth * alpha + 2 * message_size * beta

    # Hierarchical AllReduce
    hier_intra = 2 * (gpus_per_node - 1) * (1e-6) + \
                 2 * message_size * (gpus_per_node - 1) / gpus_per_node / (nvlink_bandwidth * 1e9)
    hier_inter = 2 * (num_nodes - 1) * alpha + \
                 2 * message_size * (num_nodes - 1) / num_nodes * beta
    hier_time = hier_intra + hier_inter

    # 选择最优
    times = {
        'ring': ring_time,
        'tree': tree_time,
        'hierarchical': hier_time
    }

    best = min(times, key=times.get)

    # 特殊情况处理
    if message_size < 1024 * 1024:  # < 1MB
        # 小消息优先考虑延迟
        if tree_time < ring_time * 0.8:
            return 'tree'

    if num_nodes > 1 and gpus_per_node > 1:
        # 多节点多卡优先考虑分层
        if hier_time < ring_time * 1.2:
            return 'hierarchical'

    return best

总结

算法对比

算法延迟带宽效率适用场景
RingO(N)最优大消息、高带宽
TreeO(log N)较低小消息、低延迟
Recursive HDO(log N)高2^k 节点
Hierarchical分层高多节点多卡

实践建议

□ 大消息 (> 1MB) 使用 Ring
□ 小消息 (< 1MB) 使用 Tree
□ 多节点多卡使用 Hierarchical
□ 利用 NCCL 自动选择算法
□ 考虑通信与计算重叠
Prev
01-NCCL 源码深度解析
Next
03-RDMA与InfiniBand原理