HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

04-算子融合与优化技术

概述

算子融合(Operator Fusion)是深度学习性能优化的核心技术之一。通过将多个算子合并为一个 Kernel,可以减少内存访问、降低 Kernel 启动开销,显著提升计算效率。

算子融合原理

为什么需要算子融合

┌─────────────────────────────────────────────────────────────────────────┐
│                      算子融合动机分析                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  未融合的算子链:MatMul → Add Bias → GELU                               │
│                                                                         │
│  ┌────────┐    ┌────────┐    ┌────────┐    ┌────────┐    ┌────────┐   │
│  │ Input  │ →  │ MatMul │ →  │ Global │ →  │  Add   │ →  │ Global │   │
│  │ (HBM)  │    │ Kernel │    │ Memory │    │  Bias  │    │ Memory │   │
│  └────────┘    └────────┘    └────────┘    └────────┘    └────────┘   │
│                                    ↓                           ↓       │
│                              ┌────────┐                 ┌────────┐     │
│                              │  GELU  │                 │ Output │     │
│                              │ Kernel │                 │ (HBM)  │     │
│                              └────────┘                 └────────┘     │
│                                                                         │
│  问题:                                                                  │
│  ├── 3 次 Kernel 启动 (每次 ~5-10 μs)                                   │
│  ├── 2 次中间结果写入 HBM                                               │
│  ├── 2 次中间结果读取 HBM                                               │
│  └── 带宽瓶颈:HBM 带宽有限 (A100: 2TB/s)                               │
│                                                                         │
│  融合后:MatMul + Add Bias + GELU                                       │
│                                                                         │
│  ┌────────┐    ┌────────────────────────┐    ┌────────┐               │
│  │ Input  │ →  │    Fused Kernel        │ →  │ Output │               │
│  │ (HBM)  │    │ MatMul+Bias+GELU       │    │ (HBM)  │               │
│  └────────┘    │ (寄存器中完成Bias+GELU) │    └────────┘               │
│                └────────────────────────┘                               │
│                                                                         │
│  收益:                                                                  │
│  ├── 1 次 Kernel 启动                                                   │
│  ├── 0 次中间结果写入/读取 HBM                                          │
│  └── 性能提升:1.5x - 3x (取决于算子特点)                                │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

融合类型分类

┌─────────────────────────────────────────────────────────────────────────┐
│                      算子融合类型                                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  1. Element-wise 融合                                                   │
│     ├── 最简单,多个逐元素操作合并                                       │
│     ├── 例:Add + ReLU, Mul + Add + Sigmoid                            │
│     └── 收益:减少内存访问次数                                           │
│                                                                         │
│  2. Epilogue 融合                                                       │
│     ├── 将后处理操作融合到主算子中                                       │
│     ├── 例:GEMM + Bias + Activation                                   │
│     └── 收益:避免中间结果写回 HBM                                       │
│                                                                         │
│  3. Reduce 融合                                                         │
│     ├── 将 reduce 与其他操作融合                                        │
│     ├── 例:Softmax (max + exp + sum + div)                            │
│     └── 收益:减少全局内存往返                                           │
│                                                                         │
│  4. 跨层融合                                                            │
│     ├── 融合多个神经网络层                                              │
│     ├── 例:Conv + BN + ReLU, Attention Block                          │
│     └── 收益:最大化数据复用                                             │
│                                                                         │
│  5. 动态融合 (JIT)                                                      │
│     ├── 运行时根据计算图动态生成融合 Kernel                              │
│     ├── 例:TorchScript, XLA, TVM                                      │
│     └── 收益:灵活适应不同模型结构                                       │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

Element-wise 融合

基础融合实现

// 未融合版本
__global__ void add_kernel(float* c, const float* a, const float* b, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) c[idx] = a[idx] + b[idx];
}

__global__ void relu_kernel(float* y, const float* x, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) y[idx] = fmaxf(x[idx], 0.0f);
}

// 融合版本:Add + ReLU
__global__ void fused_add_relu(
    float* __restrict__ output,
    const float* __restrict__ a,
    const float* __restrict__ b,
    int n
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        float val = a[idx] + b[idx];
        output[idx] = fmaxf(val, 0.0f);
    }
}

// 更复杂的融合:Mul + Add + Sigmoid
__global__ void fused_mul_add_sigmoid(
    float* __restrict__ output,
    const float* __restrict__ x,
    const float* __restrict__ w,
    const float* __restrict__ b,
    int n
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        float val = x[idx] * w[idx] + b[idx];
        output[idx] = 1.0f / (1.0f + expf(-val));
    }
}

向量化融合

// 使用 float4 向量化的融合 Kernel
__global__ void fused_add_relu_vec4(
    float4* __restrict__ output,
    const float4* __restrict__ a,
    const float4* __restrict__ b,
    int n  // n 是 float4 的数量
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < n) {
        float4 va = a[idx];
        float4 vb = b[idx];

        float4 result;
        result.x = fmaxf(va.x + vb.x, 0.0f);
        result.y = fmaxf(va.y + vb.y, 0.0f);
        result.z = fmaxf(va.z + vb.z, 0.0f);
        result.w = fmaxf(va.w + vb.w, 0.0f);

        output[idx] = result;
    }
}

// LayerNorm + Dropout + Add (残差连接) 融合
__global__ void fused_layernorm_dropout_add(
    float* __restrict__ output,
    const float* __restrict__ input,
    const float* __restrict__ residual,
    const float* __restrict__ gamma,
    const float* __restrict__ beta,
    float* __restrict__ dropout_mask,
    int batch_size,
    int hidden_size,
    float dropout_prob,
    float eps,
    unsigned long long seed
) {
    int batch_idx = blockIdx.x;
    if (batch_idx >= batch_size) return;

    const float* in_row = input + batch_idx * hidden_size;
    const float* res_row = residual + batch_idx * hidden_size;
    float* out_row = output + batch_idx * hidden_size;
    float* mask_row = dropout_mask + batch_idx * hidden_size;

    // Step 1: Compute mean
    float local_sum = 0.0f;
    for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
        local_sum += in_row[i];
    }
    float mean = block_reduce_sum<256>(local_sum) / hidden_size;

    __shared__ float s_mean;
    if (threadIdx.x == 0) s_mean = mean;
    __syncthreads();
    mean = s_mean;

    // Step 2: Compute variance
    float local_var = 0.0f;
    for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
        float diff = in_row[i] - mean;
        local_var += diff * diff;
    }
    float var = block_reduce_sum<256>(local_var) / hidden_size;

    __shared__ float s_rstd;
    if (threadIdx.x == 0) s_rstd = rsqrtf(var + eps);
    __syncthreads();
    float rstd = s_rstd;

    // Step 3: Normalize + Dropout + Add residual
    // 使用 curand 生成 dropout mask
    curandStatePhilox4_32_10_t state;
    curand_init(seed, batch_idx * hidden_size + threadIdx.x, 0, &state);

    float scale = 1.0f / (1.0f - dropout_prob);

    for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
        // LayerNorm
        float normalized = (in_row[i] - mean) * rstd;
        float ln_out = normalized * gamma[i] + beta[i];

        // Dropout
        float rand = curand_uniform(&state);
        float mask = (rand > dropout_prob) ? 1.0f : 0.0f;
        mask_row[i] = mask;
        float dropout_out = ln_out * mask * scale;

        // Add residual
        out_row[i] = dropout_out + res_row[i];
    }
}

GEMM Epilogue 融合

GEMM + Bias + Activation

// GEMM 结果的后处理融合
// D = activation(alpha * A @ B + beta * C + bias)

template<typename Activation>
__global__ void gemm_epilogue_fused(
    const float* __restrict__ gemm_output,  // GEMM 原始输出
    const float* __restrict__ bias,
    float* __restrict__ output,
    int M, int N,
    float alpha,
    float beta
) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < M && col < N) {
        int idx = row * N + col;
        float val = alpha * gemm_output[idx] + bias[col];
        output[idx] = Activation::apply(val);
    }
}

// 激活函数仿函数
struct ReLU {
    __device__ __forceinline__ static float apply(float x) {
        return fmaxf(x, 0.0f);
    }
};

struct GELU {
    __device__ __forceinline__ static float apply(float x) {
        // GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
        const float c = 0.7978845608f;  // sqrt(2/π)
        const float k = 0.044715f;
        float x3 = x * x * x;
        return 0.5f * x * (1.0f + tanhf(c * (x + k * x3)));
    }
};

struct SiLU {
    __device__ __forceinline__ static float apply(float x) {
        return x / (1.0f + expf(-x));
    }
};

// 在 CUTLASS 中定义 Epilogue
#include <cutlass/epilogue/thread/linear_combination_relu.h>
#include <cutlass/epilogue/thread/linear_combination_gelu.h>

using EpilogueWithReLU = cutlass::epilogue::thread::LinearCombinationRelu<
    float,                                  // Element type
    128 / cutlass::sizeof_bits<float>::value,  // Elements per access
    float,                                  // Accumulator type
    float                                   // Compute type
>;

using EpilogueWithGELU = cutlass::epilogue::thread::LinearCombinationGELU<
    float,
    128 / cutlass::sizeof_bits<float>::value,
    float,
    float
>;

完整的融合 GEMM Kernel

/*
 * 融合 GEMM + Bias + GELU 的完整实现
 * 使用 Tensor Core
 */

#define BM 128
#define BN 128
#define BK 32
#define TM 8
#define TN 8

__global__ void fused_gemm_bias_gelu(
    const half* __restrict__ A,
    const half* __restrict__ B,
    const half* __restrict__ bias,
    half* __restrict__ C,
    int M, int N, int K
) {
    __shared__ half As[BK][BM];
    __shared__ half Bs[BK][BN];

    int bx = blockIdx.x, by = blockIdx.y;
    int tx = threadIdx.x, ty = threadIdx.y;
    int tid = ty * blockDim.x + tx;

    // WMMA fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag[2];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2];
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag[2][2];

    // 初始化累加器
    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            wmma::fill_fragment(c_frag[i][j], 0.0f);
        }
    }

    int warp_id = tid / 32;
    int warp_row = warp_id / 2;
    int warp_col = warp_id % 2;

    // K 维度迭代
    for (int bk = 0; bk < K; bk += BK) {
        // 加载 A, B 到共享内存
        load_shared_a(A, As, by, bk, tid, M, K);
        load_shared_b(B, Bs, bx, bk, tid, K, N);
        __syncthreads();

        // WMMA 计算
        #pragma unroll
        for (int k = 0; k < BK; k += 16) {
            #pragma unroll
            for (int m = 0; m < 2; m++) {
                wmma::load_matrix_sync(a_frag[m],
                    &As[k][warp_row * 64 + m * 16], BM);
            }
            #pragma unroll
            for (int n = 0; n < 2; n++) {
                wmma::load_matrix_sync(b_frag[n],
                    &Bs[k][warp_col * 64 + n * 16], BN);
            }

            #pragma unroll
            for (int m = 0; m < 2; m++) {
                #pragma unroll
                for (int n = 0; n < 2; n++) {
                    wmma::mma_sync(c_frag[m][n], a_frag[m], b_frag[n], c_frag[m][n]);
                }
            }
        }
        __syncthreads();
    }

    // Epilogue: Bias + GELU
    int c_row_base = by * BM + warp_row * 64;
    int c_col_base = bx * BN + warp_col * 64;

    #pragma unroll
    for (int m = 0; m < 2; m++) {
        #pragma unroll
        for (int n = 0; n < 2; n++) {
            // 应用 bias 和 GELU 到每个 fragment 元素
            #pragma unroll
            for (int i = 0; i < c_frag[m][n].num_elements; i++) {
                // 计算该元素的列索引
                int elem_col = c_col_base + n * 16 + (i % 16);
                if (elem_col < N) {
                    float val = c_frag[m][n].x[i];
                    // Add bias
                    val += __half2float(bias[elem_col]);
                    // GELU
                    val = gelu_forward(val);
                    c_frag[m][n].x[i] = val;
                }
            }

            // 转换为 half 并存储
            wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_half_frag;
            #pragma unroll
            for (int i = 0; i < c_frag[m][n].num_elements; i++) {
                c_half_frag.x[i] = __float2half(c_frag[m][n].x[i]);
            }

            int c_row = c_row_base + m * 16;
            int c_col = c_col_base + n * 16;
            if (c_row < M && c_col < N) {
                wmma::store_matrix_sync(&C[c_row * N + c_col], c_half_frag, N,
                    wmma::mem_row_major);
            }
        }
    }
}

__device__ __forceinline__ float gelu_forward(float x) {
    const float c = 0.7978845608f;
    const float k = 0.044715f;
    return 0.5f * x * (1.0f + tanhf(c * (x + k * x * x * x)));
}

Attention 融合 (Flash Attention)

Flash Attention 原理

┌─────────────────────────────────────────────────────────────────────────┐
│                      Flash Attention 原理                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  标准 Attention:                                                        │
│  O = softmax(QK^T / √d) V                                              │
│                                                                         │
│  标准实现的内存占用:                                                    │
│  ├── S = QK^T: O(N²) 中间矩阵                                          │
│  ├── P = softmax(S): O(N²) 中间矩阵                                    │
│  └── O = PV: 最终输出                                                   │
│  总计:O(N²) 显存,对长序列不可行                                        │
│                                                                         │
│  Flash Attention 核心思想:                                             │
│  ├── 分块计算:将 Q, K, V 分成小块                                      │
│  ├── Online Softmax:边计算边更新 softmax 统计量                        │
│  └── 不存储中间 S, P 矩阵                                               │
│                                                                         │
│  分块策略:                                                              │
│  ┌─────────────────────────────────────────────────────────────┐       │
│  │   Q [N, d]     K [N, d]     V [N, d]     O [N, d]           │       │
│  │   ┌───┐        ┌───┐        ┌───┐        ┌───┐              │       │
│  │   │Q_1│        │K_1│        │V_1│        │O_1│              │       │
│  │   ├───┤  ───>  ├───┤  ───>  ├───┤  ───>  ├───┤              │       │
│  │   │Q_2│        │K_2│        │V_2│        │O_2│              │       │
│  │   ├───┤        ├───┤        ├───┤        ├───┤              │       │
│  │   │...│        │...│        │...│        │...│              │       │
│  │   └───┘        └───┘        └───┘        └───┘              │       │
│  │                                                              │       │
│  │   每次只加载一块 Q_i,遍历所有 K_j, V_j 块                    │       │
│  │   在 SRAM 中完成计算,避免 HBM 往返                           │       │
│  └─────────────────────────────────────────────────────────────┘       │
│                                                                         │
│  Online Softmax 更新:                                                  │
│  对于 Q_i,遍历 K_j 时:                                                │
│  ├── m_new = max(m_old, max(S_ij))                                     │
│  ├── l_new = l_old * exp(m_old - m_new) + sum(exp(S_ij - m_new))      │
│  └── O_new = O_old * (l_old/l_new) * exp(m_old-m_new) + softmax(S_ij)*V│
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

Flash Attention 简化实现

/*
 * Flash Attention 简化实现
 * 仅用于说明原理,实际实现更复杂
 */

#define BR 64   // Q block size
#define BC 64   // K/V block size
#define d  64   // head dimension

__global__ void flash_attention_kernel(
    const float* __restrict__ Q,  // [N, d]
    const float* __restrict__ K,  // [N, d]
    const float* __restrict__ V,  // [N, d]
    float* __restrict__ O,        // [N, d]
    float* __restrict__ L,        // [N] softmax 归一化因子
    float* __restrict__ M,        // [N] 行最大值
    int N,
    float scale
) {
    // 共享内存
    __shared__ float Qi[BR][d];   // Q 块
    __shared__ float Kj[BC][d];   // K 块
    __shared__ float Vj[BC][d];   // V 块
    __shared__ float Sij[BR][BC]; // 注意力分数块

    int batch_head = blockIdx.x;
    int q_block = blockIdx.y;

    int q_start = q_block * BR;

    // 加载 Q 块到共享内存
    load_block(Q + batch_head * N * d, Qi, q_start, N, d);
    __syncthreads();

    // 初始化输出
    float Oi[d] = {0.0f};
    float mi = -INFINITY;
    float li = 0.0f;

    int thread_row = threadIdx.x / BC;
    int thread_col = threadIdx.x % BC;

    // 遍历 K, V 块
    for (int j = 0; j < (N + BC - 1) / BC; j++) {
        int kv_start = j * BC;

        // 加载 K, V 块
        load_block(K + batch_head * N * d, Kj, kv_start, N, d);
        load_block(V + batch_head * N * d, Vj, kv_start, N, d);
        __syncthreads();

        // 计算 S_ij = Q_i @ K_j^T * scale
        if (thread_row < BR && thread_col < BC) {
            float sum = 0.0f;
            #pragma unroll
            for (int k = 0; k < d; k++) {
                sum += Qi[thread_row][k] * Kj[thread_col][k];
            }
            Sij[thread_row][thread_col] = sum * scale;
        }
        __syncthreads();

        // 对于该线程负责的 Q 行,计算 online softmax
        if (threadIdx.x < BR) {
            int i = threadIdx.x;

            // 找该行最大值
            float row_max = -INFINITY;
            #pragma unroll
            for (int jj = 0; jj < BC && kv_start + jj < N; jj++) {
                row_max = fmaxf(row_max, Sij[i][jj]);
            }

            // Online softmax 更新
            float mi_new = fmaxf(mi, row_max);
            float li_new = li * expf(mi - mi_new);

            // 更新 O
            float scale_old = expf(mi - mi_new);
            #pragma unroll
            for (int k = 0; k < d; k++) {
                Oi[k] *= scale_old;
            }

            // 累加当前块的贡献
            #pragma unroll
            for (int jj = 0; jj < BC && kv_start + jj < N; jj++) {
                float pij = expf(Sij[i][jj] - mi_new);
                li_new += pij;

                #pragma unroll
                for (int k = 0; k < d; k++) {
                    Oi[k] += pij * Vj[jj][k];
                }
            }

            mi = mi_new;
            li = li_new;
        }
        __syncthreads();
    }

    // 归一化并写回
    if (threadIdx.x < BR && q_start + threadIdx.x < N) {
        int row = q_start + threadIdx.x;
        float inv_l = 1.0f / li;

        #pragma unroll
        for (int k = 0; k < d; k++) {
            O[(batch_head * N + row) * d + k] = Oi[k] * inv_l;
        }

        L[batch_head * N + row] = li;
        M[batch_head * N + row] = mi;
    }
}

Flash Attention 2 优化

/*
 * Flash Attention 2 优化点:
 * 1. 更好的并行策略:在 sequence 维度并行
 * 2. 减少非 matmul 操作
 * 3. 更好的 warp 专业化
 */

template<int BR, int BC, int d>
__global__ void flash_attention_2_kernel(
    const half* __restrict__ Q,
    const half* __restrict__ K,
    const half* __restrict__ V,
    half* __restrict__ O,
    int N,
    float scale
) {
    // Thread block 处理一个 Q 块
    // 使用 Tensor Core

    __shared__ half Qi[BR][d];
    __shared__ half Kj[BC][d];
    __shared__ half Vj[BC][d];

    // WMMA fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> q_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> k_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> v_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> s_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> o_frag;

    wmma::fill_fragment(o_frag, 0.0f);

    int warp_id = threadIdx.x / 32;
    int q_block = blockIdx.y;
    int batch_head = blockIdx.x;

    float mi = -INFINITY;
    float li = 0.0f;

    // 加载 Q 块 (只加载一次)
    load_q_block(Q + batch_head * N * d, Qi, q_block * BR, N, d);
    __syncthreads();

    // 遍历 K/V 块
    for (int kv_block = 0; kv_block < (N + BC - 1) / BC; kv_block++) {
        // 加载 K, V
        load_kv_block(K, V, Kj, Vj, batch_head, kv_block * BC, N, d);
        __syncthreads();

        // 使用 Tensor Core 计算 S = Q @ K^T
        wmma::fill_fragment(s_frag, 0.0f);

        for (int k = 0; k < d; k += 16) {
            wmma::load_matrix_sync(q_frag, &Qi[warp_id * 16][k], d);
            wmma::load_matrix_sync(k_frag, &Kj[0][k], d);
            wmma::mma_sync(s_frag, q_frag, k_frag, s_frag);
        }

        // Scale
        #pragma unroll
        for (int i = 0; i < s_frag.num_elements; i++) {
            s_frag.x[i] *= scale;
        }

        // Causal mask (如果需要)
        apply_causal_mask(s_frag, q_block * BR + warp_id * 16, kv_block * BC);

        // Online softmax
        float new_max = reduce_max(s_frag);
        float mi_new = fmaxf(mi, new_max);

        // Rescale previous output
        float scale_factor = expf(mi - mi_new);
        #pragma unroll
        for (int i = 0; i < o_frag.num_elements; i++) {
            o_frag.x[i] *= scale_factor;
        }

        // Compute softmax and accumulate
        float row_sum = 0.0f;
        wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> p_frag;

        #pragma unroll
        for (int i = 0; i < s_frag.num_elements; i++) {
            float p = expf(s_frag.x[i] - mi_new);
            row_sum += p;
            p_frag.x[i] = __float2half(p);
        }

        // O += P @ V
        wmma::load_matrix_sync(v_frag, &Vj[0][0], d);
        wmma::mma_sync(o_frag, p_frag, v_frag, o_frag);

        // Update stats
        li = li * scale_factor + row_sum;
        mi = mi_new;

        __syncthreads();
    }

    // Normalize and store
    float inv_li = 1.0f / li;
    #pragma unroll
    for (int i = 0; i < o_frag.num_elements; i++) {
        o_frag.x[i] *= inv_li;
    }

    wmma::store_matrix_sync(
        &O[(batch_head * N + q_block * BR + warp_id * 16) * d],
        o_frag, d, wmma::mem_row_major
    );
}

动态融合与 JIT 编译

运行时 Kernel 生成

# 使用 Triton 进行动态融合
import triton
import triton.language as tl

@triton.jit
def fused_elementwise_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
    OP1: tl.constexpr,  # 运行时指定操作
    OP2: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # 动态选择操作
    if OP1 == 0:  # Add
        result = x + y
    elif OP1 == 1:  # Mul
        result = x * y
    elif OP1 == 2:  # Sub
        result = x - y

    if OP2 == 0:  # ReLU
        result = tl.maximum(result, 0)
    elif OP2 == 1:  # GELU
        result = 0.5 * result * (1 + tl.math.tanh(
            0.7978845608 * (result + 0.044715 * result * result * result)))
    elif OP2 == 2:  # Sigmoid
        result = 1 / (1 + tl.exp(-result))

    tl.store(output_ptr + offsets, result, mask=mask)

# 使用
def fused_op(x, y, op1='add', op2='relu'):
    op1_map = {'add': 0, 'mul': 1, 'sub': 2}
    op2_map = {'relu': 0, 'gelu': 1, 'sigmoid': 2}

    output = torch.empty_like(x)
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    fused_elementwise_kernel[grid](
        x, y, output, n_elements,
        BLOCK_SIZE=1024,
        OP1=op1_map[op1],
        OP2=op2_map[op2],
    )
    return output

TorchScript Fusion

import torch
from torch import nn
import torch.jit

# TorchScript 会自动融合兼容的操作
class FusedMLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, intermediate_size)
        self.fc2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x):
        # 这些操作会被 TorchScript 融合
        x = self.fc1(x)
        x = torch.nn.functional.gelu(x)
        x = self.fc2(x)
        return x

# 编译为 TorchScript
model = FusedMLP(768, 3072).cuda()
scripted = torch.jit.script(model)

# 查看融合情况
print(scripted.graph)

# 使用 torch.compile (PyTorch 2.0+) 进行更激进的融合
@torch.compile(mode="reduce-overhead")
def optimized_forward(model, x):
    return model(x)

融合优化最佳实践

融合决策指南

┌─────────────────────────────────────────────────────────────────────────┐
│                      融合决策指南                                        │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  应该融合的情况:                                                        │
│  ├── Element-wise 操作链                                                │
│  ├── 算子间有数据依赖且中间结果不需要保存                                 │
│  ├── 减少内存带宽瓶颈的算子                                              │
│  └── Kernel 启动开销显著的小算子                                         │
│                                                                         │
│  不应该融合的情况:                                                      │
│  ├── 融合后 Kernel 过于复杂,导致寄存器溢出                              │
│  ├── 融合后占用率过低                                                   │
│  ├── 中间结果需要被多个后续算子使用                                      │
│  └── 算子已经足够高效(如 cuBLAS GEMM)                                  │
│                                                                         │
│  性能评估:                                                              │
│  ├── 使用 Nsight 分析内存带宽和计算利用率                                │
│  ├── 对比融合前后的端到端性能                                            │
│  └── 注意 edge case 的正确性                                            │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

常见融合模式

融合模式算子组合收益
Bias + ActivationGEMM + Bias + ReLU/GELU1.5-2x
LayerNorm 融合Mean + Var + Normalize + Scale2-3x
Softmax 融合Max + Exp + Sum + Div2-4x
AttentionQK^T + Softmax + V2-4x (Flash Attention)
ResidualAdd + LayerNorm1.5x

总结

算子融合核心要点

□ 识别内存密集型算子链
□ 使用共享内存/寄存器存储中间结果
□ 利用 Online 算法避免多次遍历
□ 考虑 Tensor Core 的 epilogue 融合
□ 平衡融合粒度与 Kernel 复杂度
□ 使用 profiler 验证融合效果
Prev
03-Tensor Core 与矩阵运算
Next
05-Triton 编程入门