HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

02-高性能 Kernel 开发实战

概述

本文通过实现深度学习中最核心的算子(GEMM、Reduce、Softmax、LayerNorm),深入讲解高性能 CUDA Kernel 开发的核心技术和优化策略。

GEMM (通用矩阵乘法)

朴素实现

// 最基础的矩阵乘法
// C[M, N] = A[M, K] * B[K, N]
__global__ void gemm_naive(
    const float* A,
    const float* B,
    float* C,
    int M, int N, int K
) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; k++) {
            sum += A[row * K + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}

// 问题分析:
// 1. 每个线程独立加载 A 的一行和 B 的一列
// 2. 大量重复的全局内存访问
// 3. 计算访存比极低

Tiled GEMM (共享内存优化)

#define TILE_SIZE 32

__global__ void gemm_tiled(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K
) {
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];

    int bx = blockIdx.x, by = blockIdx.y;
    int tx = threadIdx.x, ty = threadIdx.y;

    int row = by * TILE_SIZE + ty;
    int col = bx * TILE_SIZE + tx;

    float sum = 0.0f;

    // 分块迭代
    for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
        // 加载 A tile
        if (row < M && t * TILE_SIZE + tx < K) {
            As[ty][tx] = A[row * K + t * TILE_SIZE + tx];
        } else {
            As[ty][tx] = 0.0f;
        }

        // 加载 B tile
        if (t * TILE_SIZE + ty < K && col < N) {
            Bs[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
        } else {
            Bs[ty][tx] = 0.0f;
        }

        __syncthreads();

        // 计算部分和
        #pragma unroll
        for (int k = 0; k < TILE_SIZE; k++) {
            sum += As[ty][k] * Bs[k][tx];
        }

        __syncthreads();
    }

    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

寄存器分块 + 向量化 (高级优化)

/*
 * 高性能 GEMM 优化策略:
 * 1. Thread Block Tile: Block 级别的数据复用
 * 2. Warp Tile: Warp 级别的数据复用
 * 3. Thread Tile: 每个线程计算多个输出元素
 * 4. 双缓冲: 隐藏内存延迟
 * 5. 向量化加载: 提高内存带宽利用率
 */

// Block Tile: 128x128
// Warp Tile: 64x32
// Thread Tile: 8x8
// 每个 Block: 128 threads (4 warps)

#define BM 128
#define BN 128
#define BK 8
#define TM 8
#define TN 8

__global__ void gemm_optimized(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K
) {
    // 共享内存
    __shared__ float As[BK][BM];  // 转置存储避免 bank conflict
    __shared__ float Bs[BK][BN];

    // 线程在 block 内的位置
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int tid = ty * blockDim.x + tx;

    // Block 在 grid 中的位置
    int bx = blockIdx.x;
    int by = blockIdx.y;

    // 每个线程负责的寄存器 tile
    float reg_C[TM][TN] = {0.0f};
    float reg_A[TM];
    float reg_B[TN];

    // 计算加载位置
    int load_a_row = tid / (BK / 4);  // 每次加载 float4
    int load_a_col = (tid % (BK / 4)) * 4;
    int load_b_row = tid / (BN / 4);
    int load_b_col = (tid % (BN / 4)) * 4;

    // 计算输出位置
    int thread_row = (tid / (BN / TN)) * TM;
    int thread_col = (tid % (BN / TN)) * TN;

    // 主循环
    for (int bk = 0; bk < K; bk += BK) {
        // 向量化加载 A 到共享内存
        int a_row = by * BM + load_a_row;
        int a_col = bk + load_a_col;
        if (a_row < M && a_col + 3 < K) {
            float4 tmp = reinterpret_cast<const float4*>(
                &A[a_row * K + a_col])[0];
            As[load_a_col + 0][load_a_row] = tmp.x;
            As[load_a_col + 1][load_a_row] = tmp.y;
            As[load_a_col + 2][load_a_row] = tmp.z;
            As[load_a_col + 3][load_a_row] = tmp.w;
        }

        // 向量化加载 B 到共享内存
        int b_row = bk + load_b_row;
        int b_col = bx * BN + load_b_col;
        if (b_row < K && b_col + 3 < N) {
            float4 tmp = reinterpret_cast<const float4*>(
                &B[b_row * N + b_col])[0];
            Bs[load_b_row][load_b_col + 0] = tmp.x;
            Bs[load_b_row][load_b_col + 1] = tmp.y;
            Bs[load_b_row][load_b_col + 2] = tmp.z;
            Bs[load_b_row][load_b_col + 3] = tmp.w;
        }

        __syncthreads();

        // 计算 thread tile
        #pragma unroll
        for (int k = 0; k < BK; k++) {
            // 加载到寄存器
            #pragma unroll
            for (int m = 0; m < TM; m++) {
                reg_A[m] = As[k][thread_row + m];
            }
            #pragma unroll
            for (int n = 0; n < TN; n++) {
                reg_B[n] = Bs[k][thread_col + n];
            }

            // 外积累加
            #pragma unroll
            for (int m = 0; m < TM; m++) {
                #pragma unroll
                for (int n = 0; n < TN; n++) {
                    reg_C[m][n] += reg_A[m] * reg_B[n];
                }
            }
        }

        __syncthreads();
    }

    // 写回结果
    int c_row = by * BM + thread_row;
    int c_col = bx * BN + thread_col;

    #pragma unroll
    for (int m = 0; m < TM; m++) {
        #pragma unroll
        for (int n = 0; n < TN; n += 4) {
            if (c_row + m < M && c_col + n + 3 < N) {
                float4 tmp = make_float4(
                    reg_C[m][n + 0],
                    reg_C[m][n + 1],
                    reg_C[m][n + 2],
                    reg_C[m][n + 3]
                );
                reinterpret_cast<float4*>(
                    &C[(c_row + m) * N + c_col + n])[0] = tmp;
            }
        }
    }
}

双缓冲优化

/*
 * 双缓冲技术:
 * 使用两组共享内存缓冲区,在计算当前 tile 时预取下一个 tile
 * 隐藏全局内存访问延迟
 */

#define BM 128
#define BN 128
#define BK 8

__global__ void gemm_double_buffer(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K
) {
    // 双缓冲共享内存
    __shared__ float As[2][BK][BM];
    __shared__ float Bs[2][BK][BN];

    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int tid = ty * blockDim.x + tx;
    int bx = blockIdx.x;
    int by = blockIdx.y;

    float reg_C[8][8] = {0.0f};

    int write_stage = 0;
    int read_stage = 0;

    // 预取第一个 tile
    load_tile_async(A, B, As[write_stage], Bs[write_stage], by, bx, 0, tid, M, N, K);

    __syncthreads();

    // 主循环
    for (int bk = BK; bk < K + BK; bk += BK) {
        // 切换缓冲区
        write_stage = 1 - write_stage;
        read_stage = 1 - read_stage;

        // 异步预取下一个 tile
        if (bk < K) {
            load_tile_async(A, B, As[write_stage], Bs[write_stage],
                           by, bx, bk, tid, M, N, K);
        }

        // 计算当前 tile
        compute_tile(As[read_stage], Bs[read_stage], reg_C);

        __syncthreads();
    }

    // 写回结果
    store_result(C, reg_C, by, bx, tid, M, N);
}

// 使用 CUDA 11+ 的 cp.async 异步拷贝
__device__ void load_tile_async_v2(
    const float* __restrict__ src,
    float* __restrict__ dst,
    int size
) {
    // 使用异步拷贝指令
    #if __CUDA_ARCH__ >= 800
    asm volatile(
        "cp.async.ca.shared.global [%0], [%1], %2;"
        :
        : "r"(dst), "l"(src), "n"(16)  // 16 bytes = 4 floats
    );
    #else
    // Fallback
    *reinterpret_cast<float4*>(dst) = *reinterpret_cast<const float4*>(src);
    #endif
}

// 等待异步拷贝完成
__device__ void wait_async_copy() {
    #if __CUDA_ARCH__ >= 800
    asm volatile("cp.async.wait_all;");
    #endif
}

Reduce 操作

Warp Shuffle Reduce

// Warp 内归约 - 使用 shuffle 指令
__device__ __forceinline__ float warp_reduce_sum(float val) {
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

__device__ __forceinline__ float warp_reduce_max(float val) {
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
    }
    return val;
}

// Block 内归约
template<int BLOCK_SIZE>
__device__ float block_reduce_sum(float val) {
    static __shared__ float shared[32];  // 32 warps max

    int lane = threadIdx.x % 32;
    int wid = threadIdx.x / 32;

    // Warp 内归约
    val = warp_reduce_sum(val);

    // Warp leader 写入共享内存
    if (lane == 0) {
        shared[wid] = val;
    }
    __syncthreads();

    // 第一个 warp 做最终归约
    val = (threadIdx.x < BLOCK_SIZE / 32) ? shared[lane] : 0.0f;
    if (wid == 0) {
        val = warp_reduce_sum(val);
    }

    return val;
}

完整的 Reduce Kernel

// 向量归约 sum
template<int BLOCK_SIZE>
__global__ void reduce_sum_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    int n
) {
    __shared__ float sdata[BLOCK_SIZE];

    int tid = threadIdx.x;
    int idx = blockIdx.x * BLOCK_SIZE * 2 + threadIdx.x;
    int grid_size = BLOCK_SIZE * 2 * gridDim.x;

    // Grid-stride loop,每个线程累加多个元素
    float sum = 0.0f;
    while (idx < n) {
        sum += input[idx];
        if (idx + BLOCK_SIZE < n) {
            sum += input[idx + BLOCK_SIZE];
        }
        idx += grid_size;
    }

    sdata[tid] = sum;
    __syncthreads();

    // Block 内归约
    if (BLOCK_SIZE >= 512) {
        if (tid < 256) sdata[tid] += sdata[tid + 256];
        __syncthreads();
    }
    if (BLOCK_SIZE >= 256) {
        if (tid < 128) sdata[tid] += sdata[tid + 128];
        __syncthreads();
    }
    if (BLOCK_SIZE >= 128) {
        if (tid < 64) sdata[tid] += sdata[tid + 64];
        __syncthreads();
    }

    // 最后一个 warp 无需同步
    if (tid < 32) {
        volatile float* smem = sdata;
        if (BLOCK_SIZE >= 64) smem[tid] += smem[tid + 32];
        if (BLOCK_SIZE >= 32) smem[tid] += smem[tid + 16];
        if (BLOCK_SIZE >= 16) smem[tid] += smem[tid + 8];
        if (BLOCK_SIZE >= 8) smem[tid] += smem[tid + 4];
        if (BLOCK_SIZE >= 4) smem[tid] += smem[tid + 2];
        if (BLOCK_SIZE >= 2) smem[tid] += smem[tid + 1];
    }

    if (tid == 0) {
        output[blockIdx.x] = sdata[0];
    }
}

// 使用 Warp Shuffle 的优化版本
template<int BLOCK_SIZE>
__global__ void reduce_sum_shuffle(
    const float* __restrict__ input,
    float* __restrict__ output,
    int n
) {
    float sum = 0.0f;

    // Grid-stride loop
    for (int i = blockIdx.x * BLOCK_SIZE + threadIdx.x;
         i < n;
         i += BLOCK_SIZE * gridDim.x) {
        sum += input[i];
    }

    // Block reduce
    sum = block_reduce_sum<BLOCK_SIZE>(sum);

    if (threadIdx.x == 0) {
        atomicAdd(output, sum);  // 原子加到最终结果
    }
}

多维 Reduce (Softmax 预处理)

// 对矩阵每一行求最大值和求和
// input: [batch, seq_len]
template<int BLOCK_SIZE>
__global__ void row_reduce_max_sum(
    const float* __restrict__ input,
    float* __restrict__ row_max,
    float* __restrict__ row_sum,
    int rows, int cols
) {
    int row = blockIdx.x;

    if (row >= rows) return;

    const float* row_ptr = input + row * cols;

    // 每个线程处理多个元素
    float local_max = -INFINITY;
    float local_sum = 0.0f;

    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        float val = row_ptr[i];
        local_max = fmaxf(local_max, val);
    }

    // Block reduce max
    __shared__ float smem_max[32];
    float max_val = block_reduce_max<BLOCK_SIZE>(local_max, smem_max);

    // Broadcast max to all threads
    __shared__ float shared_max;
    if (threadIdx.x == 0) {
        shared_max = max_val;
        row_max[row] = max_val;
    }
    __syncthreads();
    max_val = shared_max;

    // 计算 exp 并求和
    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        float val = row_ptr[i];
        local_sum += expf(val - max_val);
    }

    // Block reduce sum
    __shared__ float smem_sum[32];
    float sum_val = block_reduce_sum<BLOCK_SIZE>(local_sum, smem_sum);

    if (threadIdx.x == 0) {
        row_sum[row] = sum_val;
    }
}

Softmax 实现

基础 Softmax

// Softmax: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x - max(x)))
template<int BLOCK_SIZE>
__global__ void softmax_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    int rows, int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* in_row = input + row * cols;
    float* out_row = output + row * cols;

    // Step 1: Find max
    float local_max = -INFINITY;
    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        local_max = fmaxf(local_max, in_row[i]);
    }
    float max_val = block_reduce_max<BLOCK_SIZE>(local_max);

    __shared__ float s_max;
    if (threadIdx.x == 0) s_max = max_val;
    __syncthreads();
    max_val = s_max;

    // Step 2: Compute exp and sum
    float local_sum = 0.0f;
    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        local_sum += expf(in_row[i] - max_val);
    }
    float sum_val = block_reduce_sum<BLOCK_SIZE>(local_sum);

    __shared__ float s_sum;
    if (threadIdx.x == 0) s_sum = sum_val;
    __syncthreads();
    sum_val = s_sum;

    // Step 3: Normalize
    float inv_sum = 1.0f / sum_val;
    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        out_row[i] = expf(in_row[i] - max_val) * inv_sum;
    }
}

Online Softmax (单次遍历)

/*
 * Online Softmax 算法:
 * 只需一次遍历数据,同时计算 max 和 sum
 * 核心思想:当发现新的最大值时,调整之前的累加和
 *
 * 递推公式:
 * m_new = max(m_old, x_i)
 * d_new = d_old * exp(m_old - m_new) + exp(x_i - m_new)
 */

struct SoftmaxState {
    float max_val;
    float sum;
};

__device__ __forceinline__ SoftmaxState softmax_state_combine(
    SoftmaxState a, SoftmaxState b
) {
    float max_new = fmaxf(a.max_val, b.max_val);
    float sum_new = a.sum * expf(a.max_val - max_new) +
                    b.sum * expf(b.max_val - max_new);
    return {max_new, sum_new};
}

__device__ __forceinline__ SoftmaxState warp_reduce_softmax(SoftmaxState state) {
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        SoftmaxState other;
        other.max_val = __shfl_down_sync(0xffffffff, state.max_val, offset);
        other.sum = __shfl_down_sync(0xffffffff, state.sum, offset);
        state = softmax_state_combine(state, other);
    }
    return state;
}

template<int BLOCK_SIZE>
__global__ void softmax_online(
    const float* __restrict__ input,
    float* __restrict__ output,
    int rows, int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* in_row = input + row * cols;
    float* out_row = output + row * cols;

    // Online reduction
    SoftmaxState state = {-INFINITY, 0.0f};

    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        float x = in_row[i];
        SoftmaxState new_state = {x, 1.0f};
        state = softmax_state_combine(state, new_state);
    }

    // Block reduce
    __shared__ float s_max[32];
    __shared__ float s_sum[32];

    int lane = threadIdx.x % 32;
    int wid = threadIdx.x / 32;

    state = warp_reduce_softmax(state);

    if (lane == 0) {
        s_max[wid] = state.max_val;
        s_sum[wid] = state.sum;
    }
    __syncthreads();

    if (threadIdx.x < BLOCK_SIZE / 32) {
        state.max_val = s_max[lane];
        state.sum = s_sum[lane];
    } else {
        state.max_val = -INFINITY;
        state.sum = 0.0f;
    }

    if (wid == 0) {
        state = warp_reduce_softmax(state);
    }

    __shared__ float final_max, final_sum;
    if (threadIdx.x == 0) {
        final_max = state.max_val;
        final_sum = state.sum;
    }
    __syncthreads();

    // Normalize
    float inv_sum = 1.0f / final_sum;
    for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
        out_row[i] = expf(in_row[i] - final_max) * inv_sum;
    }
}

Fused Softmax (与其他操作融合)

// Fused Attention Score Softmax
// input: attention scores [batch, heads, seq, seq]
// output: softmax(input) * V
template<int BLOCK_SIZE, int HEAD_DIM>
__global__ void fused_softmax_v_kernel(
    const float* __restrict__ scores,  // [B, H, S, S]
    const float* __restrict__ V,       // [B, H, S, D]
    float* __restrict__ output,        // [B, H, S, D]
    int batch_size, int num_heads, int seq_len
) {
    int batch_head = blockIdx.x;
    int query_pos = blockIdx.y;

    int b = batch_head / num_heads;
    int h = batch_head % num_heads;

    if (b >= batch_size || query_pos >= seq_len) return;

    // Score row: [seq_len]
    const float* score_row = scores +
        (b * num_heads * seq_len + h * seq_len + query_pos) * seq_len;

    // Online softmax
    SoftmaxState state = {-INFINITY, 0.0f};

    for (int k = threadIdx.x; k < seq_len; k += BLOCK_SIZE) {
        float s = score_row[k];
        SoftmaxState new_state = {s, 1.0f};
        state = softmax_state_combine(state, new_state);
    }

    // Block reduce
    state = block_reduce_softmax<BLOCK_SIZE>(state);

    __shared__ float s_max, s_sum;
    if (threadIdx.x == 0) {
        s_max = state.max_val;
        s_sum = state.sum;
    }
    __syncthreads();

    float inv_sum = 1.0f / s_sum;

    // Compute weighted sum: output = softmax(scores) @ V
    // Each thread handles one element of output
    float* out_row = output +
        (b * num_heads * seq_len + h * seq_len + query_pos) * HEAD_DIM;

    for (int d = threadIdx.x; d < HEAD_DIM; d += BLOCK_SIZE) {
        float acc = 0.0f;

        for (int k = 0; k < seq_len; k++) {
            float attn = expf(score_row[k] - s_max) * inv_sum;
            float v_val = V[(b * num_heads * seq_len + h * seq_len + k) * HEAD_DIM + d];
            acc += attn * v_val;
        }

        out_row[d] = acc;
    }
}

LayerNorm 实现

基础 LayerNorm

// LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta
template<int BLOCK_SIZE>
__global__ void layer_norm_kernel(
    const float* __restrict__ input,
    const float* __restrict__ gamma,
    const float* __restrict__ beta,
    float* __restrict__ output,
    int batch_size, int hidden_size,
    float eps
) {
    int batch_idx = blockIdx.x;
    if (batch_idx >= batch_size) return;

    const float* in_row = input + batch_idx * hidden_size;
    float* out_row = output + batch_idx * hidden_size;

    // Step 1: Compute mean
    float local_sum = 0.0f;
    for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
        local_sum += in_row[i];
    }
    float mean = block_reduce_sum<BLOCK_SIZE>(local_sum) / hidden_size;

    __shared__ float s_mean;
    if (threadIdx.x == 0) s_mean = mean;
    __syncthreads();
    mean = s_mean;

    // Step 2: Compute variance
    float local_var_sum = 0.0f;
    for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
        float diff = in_row[i] - mean;
        local_var_sum += diff * diff;
    }
    float variance = block_reduce_sum<BLOCK_SIZE>(local_var_sum) / hidden_size;

    __shared__ float s_var;
    if (threadIdx.x == 0) s_var = variance;
    __syncthreads();
    variance = s_var;

    // Step 3: Normalize
    float inv_std = rsqrtf(variance + eps);

    for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
        float normalized = (in_row[i] - mean) * inv_std;
        out_row[i] = normalized * gamma[i] + beta[i];
    }
}

Welford 算法 (数值稳定)

/*
 * Welford's online algorithm for computing mean and variance
 * 数值更稳定,避免大数相减的精度问题
 */

struct WelfordState {
    float mean;
    float m2;    // Sum of squared differences
    float count;
};

__device__ __forceinline__ WelfordState welford_combine(
    WelfordState a, WelfordState b
) {
    float count = a.count + b.count;
    float delta = b.mean - a.mean;
    float mean = a.mean + delta * b.count / count;
    float m2 = a.m2 + b.m2 + delta * delta * a.count * b.count / count;
    return {mean, m2, count};
}

__device__ __forceinline__ WelfordState warp_reduce_welford(WelfordState state) {
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        WelfordState other;
        other.mean = __shfl_down_sync(0xffffffff, state.mean, offset);
        other.m2 = __shfl_down_sync(0xffffffff, state.m2, offset);
        other.count = __shfl_down_sync(0xffffffff, state.count, offset);
        state = welford_combine(state, other);
    }
    return state;
}

template<int BLOCK_SIZE>
__global__ void layer_norm_welford(
    const float* __restrict__ input,
    const float* __restrict__ gamma,
    const float* __restrict__ beta,
    float* __restrict__ output,
    float* __restrict__ mean_out,      // 可选:输出均值
    float* __restrict__ rstd_out,      // 可选:输出标准差倒数
    int batch_size, int hidden_size,
    float eps
) {
    int batch_idx = blockIdx.x;
    if (batch_idx >= batch_size) return;

    const float* in_row = input + batch_idx * hidden_size;
    float* out_row = output + batch_idx * hidden_size;

    // Welford reduction
    WelfordState state = {0.0f, 0.0f, 0.0f};

    for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
        float x = in_row[i];
        WelfordState new_state = {x, 0.0f, 1.0f};
        state = welford_combine(state, new_state);
    }

    // Block reduce using Welford
    __shared__ float s_mean[32], s_m2[32], s_count[32];
    int lane = threadIdx.x % 32;
    int wid = threadIdx.x / 32;

    state = warp_reduce_welford(state);

    if (lane == 0) {
        s_mean[wid] = state.mean;
        s_m2[wid] = state.m2;
        s_count[wid] = state.count;
    }
    __syncthreads();

    if (threadIdx.x < BLOCK_SIZE / 32) {
        state.mean = s_mean[lane];
        state.m2 = s_m2[lane];
        state.count = s_count[lane];
    } else {
        state = {0.0f, 0.0f, 0.0f};
    }

    if (wid == 0) {
        state = warp_reduce_welford(state);
    }

    __shared__ float final_mean, final_rstd;
    if (threadIdx.x == 0) {
        float variance = state.m2 / state.count;
        final_mean = state.mean;
        final_rstd = rsqrtf(variance + eps);

        if (mean_out) mean_out[batch_idx] = final_mean;
        if (rstd_out) rstd_out[batch_idx] = final_rstd;
    }
    __syncthreads();

    // Normalize
    float mean = final_mean;
    float rstd = final_rstd;

    for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
        float normalized = (in_row[i] - mean) * rstd;
        out_row[i] = normalized * gamma[i] + beta[i];
    }
}

向量化 LayerNorm

// 使用 float4 向量化访问
template<int BLOCK_SIZE>
__global__ void layer_norm_vectorized(
    const float4* __restrict__ input,   // hidden_size / 4
    const float4* __restrict__ gamma,
    const float4* __restrict__ beta,
    float4* __restrict__ output,
    int batch_size, int hidden_size,    // 原始 hidden_size
    float eps
) {
    int batch_idx = blockIdx.x;
    if (batch_idx >= batch_size) return;

    int vec_hidden = hidden_size / 4;
    const float4* in_row = input + batch_idx * vec_hidden;
    float4* out_row = output + batch_idx * vec_hidden;

    // Compute mean
    float local_sum = 0.0f;
    for (int i = threadIdx.x; i < vec_hidden; i += BLOCK_SIZE) {
        float4 val = in_row[i];
        local_sum += val.x + val.y + val.z + val.w;
    }
    float mean = block_reduce_sum<BLOCK_SIZE>(local_sum) / hidden_size;

    __shared__ float s_mean;
    if (threadIdx.x == 0) s_mean = mean;
    __syncthreads();
    mean = s_mean;

    // Compute variance
    float local_var = 0.0f;
    for (int i = threadIdx.x; i < vec_hidden; i += BLOCK_SIZE) {
        float4 val = in_row[i];
        float d0 = val.x - mean, d1 = val.y - mean;
        float d2 = val.z - mean, d3 = val.w - mean;
        local_var += d0*d0 + d1*d1 + d2*d2 + d3*d3;
    }
    float var = block_reduce_sum<BLOCK_SIZE>(local_var) / hidden_size;

    __shared__ float s_rstd;
    if (threadIdx.x == 0) s_rstd = rsqrtf(var + eps);
    __syncthreads();
    float rstd = s_rstd;

    // Normalize with vectorized access
    for (int i = threadIdx.x; i < vec_hidden; i += BLOCK_SIZE) {
        float4 val = in_row[i];
        float4 g = gamma[i];
        float4 b = beta[i];

        float4 out;
        out.x = (val.x - mean) * rstd * g.x + b.x;
        out.y = (val.y - mean) * rstd * g.y + b.y;
        out.z = (val.z - mean) * rstd * g.z + b.z;
        out.w = (val.w - mean) * rstd * g.w + b.w;

        out_row[i] = out;
    }
}

性能优化总结

优化策略对比

优化技术适用场景收益
共享内存 Tiling数据复用高的算子减少全局内存访问
寄存器分块计算密集型最大化计算吞吐
向量化访问内存密集型提高带宽利用率
Warp ShuffleReduce 类操作减少同步开销
双缓冲内存延迟敏感隐藏访存延迟
循环展开小循环体减少指令开销

Kernel 开发检查清单

□ 内存访问是否合并?
□ 共享内存 Bank Conflict?
□ 是否有 Warp Divergence?
□ 寄存器压力是否过大?
□ 占用率是否足够?
□ 是否可以向量化?
□ 是否可以融合其他操作?
Prev
01-CUDA编程模型与内存层次
Next
03-Tensor Core 与矩阵运算