02-高性能 Kernel 开发实战
概述
本文通过实现深度学习中最核心的算子(GEMM、Reduce、Softmax、LayerNorm),深入讲解高性能 CUDA Kernel 开发的核心技术和优化策略。
GEMM (通用矩阵乘法)
朴素实现
// 最基础的矩阵乘法
// C[M, N] = A[M, K] * B[K, N]
__global__ void gemm_naive(
const float* A,
const float* B,
float* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; k++) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}
// 问题分析:
// 1. 每个线程独立加载 A 的一行和 B 的一列
// 2. 大量重复的全局内存访问
// 3. 计算访存比极低
Tiled GEMM (共享内存优化)
#define TILE_SIZE 32
__global__ void gemm_tiled(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K
) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * TILE_SIZE + ty;
int col = bx * TILE_SIZE + tx;
float sum = 0.0f;
// 分块迭代
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
// 加载 A tile
if (row < M && t * TILE_SIZE + tx < K) {
As[ty][tx] = A[row * K + t * TILE_SIZE + tx];
} else {
As[ty][tx] = 0.0f;
}
// 加载 B tile
if (t * TILE_SIZE + ty < K && col < N) {
Bs[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
} else {
Bs[ty][tx] = 0.0f;
}
__syncthreads();
// 计算部分和
#pragma unroll
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[ty][k] * Bs[k][tx];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = sum;
}
}
寄存器分块 + 向量化 (高级优化)
/*
* 高性能 GEMM 优化策略:
* 1. Thread Block Tile: Block 级别的数据复用
* 2. Warp Tile: Warp 级别的数据复用
* 3. Thread Tile: 每个线程计算多个输出元素
* 4. 双缓冲: 隐藏内存延迟
* 5. 向量化加载: 提高内存带宽利用率
*/
// Block Tile: 128x128
// Warp Tile: 64x32
// Thread Tile: 8x8
// 每个 Block: 128 threads (4 warps)
#define BM 128
#define BN 128
#define BK 8
#define TM 8
#define TN 8
__global__ void gemm_optimized(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K
) {
// 共享内存
__shared__ float As[BK][BM]; // 转置存储避免 bank conflict
__shared__ float Bs[BK][BN];
// 线程在 block 内的位置
int tx = threadIdx.x;
int ty = threadIdx.y;
int tid = ty * blockDim.x + tx;
// Block 在 grid 中的位置
int bx = blockIdx.x;
int by = blockIdx.y;
// 每个线程负责的寄存器 tile
float reg_C[TM][TN] = {0.0f};
float reg_A[TM];
float reg_B[TN];
// 计算加载位置
int load_a_row = tid / (BK / 4); // 每次加载 float4
int load_a_col = (tid % (BK / 4)) * 4;
int load_b_row = tid / (BN / 4);
int load_b_col = (tid % (BN / 4)) * 4;
// 计算输出位置
int thread_row = (tid / (BN / TN)) * TM;
int thread_col = (tid % (BN / TN)) * TN;
// 主循环
for (int bk = 0; bk < K; bk += BK) {
// 向量化加载 A 到共享内存
int a_row = by * BM + load_a_row;
int a_col = bk + load_a_col;
if (a_row < M && a_col + 3 < K) {
float4 tmp = reinterpret_cast<const float4*>(
&A[a_row * K + a_col])[0];
As[load_a_col + 0][load_a_row] = tmp.x;
As[load_a_col + 1][load_a_row] = tmp.y;
As[load_a_col + 2][load_a_row] = tmp.z;
As[load_a_col + 3][load_a_row] = tmp.w;
}
// 向量化加载 B 到共享内存
int b_row = bk + load_b_row;
int b_col = bx * BN + load_b_col;
if (b_row < K && b_col + 3 < N) {
float4 tmp = reinterpret_cast<const float4*>(
&B[b_row * N + b_col])[0];
Bs[load_b_row][load_b_col + 0] = tmp.x;
Bs[load_b_row][load_b_col + 1] = tmp.y;
Bs[load_b_row][load_b_col + 2] = tmp.z;
Bs[load_b_row][load_b_col + 3] = tmp.w;
}
__syncthreads();
// 计算 thread tile
#pragma unroll
for (int k = 0; k < BK; k++) {
// 加载到寄存器
#pragma unroll
for (int m = 0; m < TM; m++) {
reg_A[m] = As[k][thread_row + m];
}
#pragma unroll
for (int n = 0; n < TN; n++) {
reg_B[n] = Bs[k][thread_col + n];
}
// 外积累加
#pragma unroll
for (int m = 0; m < TM; m++) {
#pragma unroll
for (int n = 0; n < TN; n++) {
reg_C[m][n] += reg_A[m] * reg_B[n];
}
}
}
__syncthreads();
}
// 写回结果
int c_row = by * BM + thread_row;
int c_col = bx * BN + thread_col;
#pragma unroll
for (int m = 0; m < TM; m++) {
#pragma unroll
for (int n = 0; n < TN; n += 4) {
if (c_row + m < M && c_col + n + 3 < N) {
float4 tmp = make_float4(
reg_C[m][n + 0],
reg_C[m][n + 1],
reg_C[m][n + 2],
reg_C[m][n + 3]
);
reinterpret_cast<float4*>(
&C[(c_row + m) * N + c_col + n])[0] = tmp;
}
}
}
}
双缓冲优化
/*
* 双缓冲技术:
* 使用两组共享内存缓冲区,在计算当前 tile 时预取下一个 tile
* 隐藏全局内存访问延迟
*/
#define BM 128
#define BN 128
#define BK 8
__global__ void gemm_double_buffer(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K
) {
// 双缓冲共享内存
__shared__ float As[2][BK][BM];
__shared__ float Bs[2][BK][BN];
int tx = threadIdx.x;
int ty = threadIdx.y;
int tid = ty * blockDim.x + tx;
int bx = blockIdx.x;
int by = blockIdx.y;
float reg_C[8][8] = {0.0f};
int write_stage = 0;
int read_stage = 0;
// 预取第一个 tile
load_tile_async(A, B, As[write_stage], Bs[write_stage], by, bx, 0, tid, M, N, K);
__syncthreads();
// 主循环
for (int bk = BK; bk < K + BK; bk += BK) {
// 切换缓冲区
write_stage = 1 - write_stage;
read_stage = 1 - read_stage;
// 异步预取下一个 tile
if (bk < K) {
load_tile_async(A, B, As[write_stage], Bs[write_stage],
by, bx, bk, tid, M, N, K);
}
// 计算当前 tile
compute_tile(As[read_stage], Bs[read_stage], reg_C);
__syncthreads();
}
// 写回结果
store_result(C, reg_C, by, bx, tid, M, N);
}
// 使用 CUDA 11+ 的 cp.async 异步拷贝
__device__ void load_tile_async_v2(
const float* __restrict__ src,
float* __restrict__ dst,
int size
) {
// 使用异步拷贝指令
#if __CUDA_ARCH__ >= 800
asm volatile(
"cp.async.ca.shared.global [%0], [%1], %2;"
:
: "r"(dst), "l"(src), "n"(16) // 16 bytes = 4 floats
);
#else
// Fallback
*reinterpret_cast<float4*>(dst) = *reinterpret_cast<const float4*>(src);
#endif
}
// 等待异步拷贝完成
__device__ void wait_async_copy() {
#if __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_all;");
#endif
}
Reduce 操作
Warp Shuffle Reduce
// Warp 内归约 - 使用 shuffle 指令
__device__ __forceinline__ float warp_reduce_sum(float val) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
__device__ __forceinline__ float warp_reduce_max(float val) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
// Block 内归约
template<int BLOCK_SIZE>
__device__ float block_reduce_sum(float val) {
static __shared__ float shared[32]; // 32 warps max
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
// Warp 内归约
val = warp_reduce_sum(val);
// Warp leader 写入共享内存
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
// 第一个 warp 做最终归约
val = (threadIdx.x < BLOCK_SIZE / 32) ? shared[lane] : 0.0f;
if (wid == 0) {
val = warp_reduce_sum(val);
}
return val;
}
完整的 Reduce Kernel
// 向量归约 sum
template<int BLOCK_SIZE>
__global__ void reduce_sum_kernel(
const float* __restrict__ input,
float* __restrict__ output,
int n
) {
__shared__ float sdata[BLOCK_SIZE];
int tid = threadIdx.x;
int idx = blockIdx.x * BLOCK_SIZE * 2 + threadIdx.x;
int grid_size = BLOCK_SIZE * 2 * gridDim.x;
// Grid-stride loop,每个线程累加多个元素
float sum = 0.0f;
while (idx < n) {
sum += input[idx];
if (idx + BLOCK_SIZE < n) {
sum += input[idx + BLOCK_SIZE];
}
idx += grid_size;
}
sdata[tid] = sum;
__syncthreads();
// Block 内归约
if (BLOCK_SIZE >= 512) {
if (tid < 256) sdata[tid] += sdata[tid + 256];
__syncthreads();
}
if (BLOCK_SIZE >= 256) {
if (tid < 128) sdata[tid] += sdata[tid + 128];
__syncthreads();
}
if (BLOCK_SIZE >= 128) {
if (tid < 64) sdata[tid] += sdata[tid + 64];
__syncthreads();
}
// 最后一个 warp 无需同步
if (tid < 32) {
volatile float* smem = sdata;
if (BLOCK_SIZE >= 64) smem[tid] += smem[tid + 32];
if (BLOCK_SIZE >= 32) smem[tid] += smem[tid + 16];
if (BLOCK_SIZE >= 16) smem[tid] += smem[tid + 8];
if (BLOCK_SIZE >= 8) smem[tid] += smem[tid + 4];
if (BLOCK_SIZE >= 4) smem[tid] += smem[tid + 2];
if (BLOCK_SIZE >= 2) smem[tid] += smem[tid + 1];
}
if (tid == 0) {
output[blockIdx.x] = sdata[0];
}
}
// 使用 Warp Shuffle 的优化版本
template<int BLOCK_SIZE>
__global__ void reduce_sum_shuffle(
const float* __restrict__ input,
float* __restrict__ output,
int n
) {
float sum = 0.0f;
// Grid-stride loop
for (int i = blockIdx.x * BLOCK_SIZE + threadIdx.x;
i < n;
i += BLOCK_SIZE * gridDim.x) {
sum += input[i];
}
// Block reduce
sum = block_reduce_sum<BLOCK_SIZE>(sum);
if (threadIdx.x == 0) {
atomicAdd(output, sum); // 原子加到最终结果
}
}
多维 Reduce (Softmax 预处理)
// 对矩阵每一行求最大值和求和
// input: [batch, seq_len]
template<int BLOCK_SIZE>
__global__ void row_reduce_max_sum(
const float* __restrict__ input,
float* __restrict__ row_max,
float* __restrict__ row_sum,
int rows, int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* row_ptr = input + row * cols;
// 每个线程处理多个元素
float local_max = -INFINITY;
float local_sum = 0.0f;
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
float val = row_ptr[i];
local_max = fmaxf(local_max, val);
}
// Block reduce max
__shared__ float smem_max[32];
float max_val = block_reduce_max<BLOCK_SIZE>(local_max, smem_max);
// Broadcast max to all threads
__shared__ float shared_max;
if (threadIdx.x == 0) {
shared_max = max_val;
row_max[row] = max_val;
}
__syncthreads();
max_val = shared_max;
// 计算 exp 并求和
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
float val = row_ptr[i];
local_sum += expf(val - max_val);
}
// Block reduce sum
__shared__ float smem_sum[32];
float sum_val = block_reduce_sum<BLOCK_SIZE>(local_sum, smem_sum);
if (threadIdx.x == 0) {
row_sum[row] = sum_val;
}
}
Softmax 实现
基础 Softmax
// Softmax: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x - max(x)))
template<int BLOCK_SIZE>
__global__ void softmax_kernel(
const float* __restrict__ input,
float* __restrict__ output,
int rows, int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* in_row = input + row * cols;
float* out_row = output + row * cols;
// Step 1: Find max
float local_max = -INFINITY;
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
local_max = fmaxf(local_max, in_row[i]);
}
float max_val = block_reduce_max<BLOCK_SIZE>(local_max);
__shared__ float s_max;
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
max_val = s_max;
// Step 2: Compute exp and sum
float local_sum = 0.0f;
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
local_sum += expf(in_row[i] - max_val);
}
float sum_val = block_reduce_sum<BLOCK_SIZE>(local_sum);
__shared__ float s_sum;
if (threadIdx.x == 0) s_sum = sum_val;
__syncthreads();
sum_val = s_sum;
// Step 3: Normalize
float inv_sum = 1.0f / sum_val;
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
out_row[i] = expf(in_row[i] - max_val) * inv_sum;
}
}
Online Softmax (单次遍历)
/*
* Online Softmax 算法:
* 只需一次遍历数据,同时计算 max 和 sum
* 核心思想:当发现新的最大值时,调整之前的累加和
*
* 递推公式:
* m_new = max(m_old, x_i)
* d_new = d_old * exp(m_old - m_new) + exp(x_i - m_new)
*/
struct SoftmaxState {
float max_val;
float sum;
};
__device__ __forceinline__ SoftmaxState softmax_state_combine(
SoftmaxState a, SoftmaxState b
) {
float max_new = fmaxf(a.max_val, b.max_val);
float sum_new = a.sum * expf(a.max_val - max_new) +
b.sum * expf(b.max_val - max_new);
return {max_new, sum_new};
}
__device__ __forceinline__ SoftmaxState warp_reduce_softmax(SoftmaxState state) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
SoftmaxState other;
other.max_val = __shfl_down_sync(0xffffffff, state.max_val, offset);
other.sum = __shfl_down_sync(0xffffffff, state.sum, offset);
state = softmax_state_combine(state, other);
}
return state;
}
template<int BLOCK_SIZE>
__global__ void softmax_online(
const float* __restrict__ input,
float* __restrict__ output,
int rows, int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* in_row = input + row * cols;
float* out_row = output + row * cols;
// Online reduction
SoftmaxState state = {-INFINITY, 0.0f};
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
float x = in_row[i];
SoftmaxState new_state = {x, 1.0f};
state = softmax_state_combine(state, new_state);
}
// Block reduce
__shared__ float s_max[32];
__shared__ float s_sum[32];
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
state = warp_reduce_softmax(state);
if (lane == 0) {
s_max[wid] = state.max_val;
s_sum[wid] = state.sum;
}
__syncthreads();
if (threadIdx.x < BLOCK_SIZE / 32) {
state.max_val = s_max[lane];
state.sum = s_sum[lane];
} else {
state.max_val = -INFINITY;
state.sum = 0.0f;
}
if (wid == 0) {
state = warp_reduce_softmax(state);
}
__shared__ float final_max, final_sum;
if (threadIdx.x == 0) {
final_max = state.max_val;
final_sum = state.sum;
}
__syncthreads();
// Normalize
float inv_sum = 1.0f / final_sum;
for (int i = threadIdx.x; i < cols; i += BLOCK_SIZE) {
out_row[i] = expf(in_row[i] - final_max) * inv_sum;
}
}
Fused Softmax (与其他操作融合)
// Fused Attention Score Softmax
// input: attention scores [batch, heads, seq, seq]
// output: softmax(input) * V
template<int BLOCK_SIZE, int HEAD_DIM>
__global__ void fused_softmax_v_kernel(
const float* __restrict__ scores, // [B, H, S, S]
const float* __restrict__ V, // [B, H, S, D]
float* __restrict__ output, // [B, H, S, D]
int batch_size, int num_heads, int seq_len
) {
int batch_head = blockIdx.x;
int query_pos = blockIdx.y;
int b = batch_head / num_heads;
int h = batch_head % num_heads;
if (b >= batch_size || query_pos >= seq_len) return;
// Score row: [seq_len]
const float* score_row = scores +
(b * num_heads * seq_len + h * seq_len + query_pos) * seq_len;
// Online softmax
SoftmaxState state = {-INFINITY, 0.0f};
for (int k = threadIdx.x; k < seq_len; k += BLOCK_SIZE) {
float s = score_row[k];
SoftmaxState new_state = {s, 1.0f};
state = softmax_state_combine(state, new_state);
}
// Block reduce
state = block_reduce_softmax<BLOCK_SIZE>(state);
__shared__ float s_max, s_sum;
if (threadIdx.x == 0) {
s_max = state.max_val;
s_sum = state.sum;
}
__syncthreads();
float inv_sum = 1.0f / s_sum;
// Compute weighted sum: output = softmax(scores) @ V
// Each thread handles one element of output
float* out_row = output +
(b * num_heads * seq_len + h * seq_len + query_pos) * HEAD_DIM;
for (int d = threadIdx.x; d < HEAD_DIM; d += BLOCK_SIZE) {
float acc = 0.0f;
for (int k = 0; k < seq_len; k++) {
float attn = expf(score_row[k] - s_max) * inv_sum;
float v_val = V[(b * num_heads * seq_len + h * seq_len + k) * HEAD_DIM + d];
acc += attn * v_val;
}
out_row[d] = acc;
}
}
LayerNorm 实现
基础 LayerNorm
// LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta
template<int BLOCK_SIZE>
__global__ void layer_norm_kernel(
const float* __restrict__ input,
const float* __restrict__ gamma,
const float* __restrict__ beta,
float* __restrict__ output,
int batch_size, int hidden_size,
float eps
) {
int batch_idx = blockIdx.x;
if (batch_idx >= batch_size) return;
const float* in_row = input + batch_idx * hidden_size;
float* out_row = output + batch_idx * hidden_size;
// Step 1: Compute mean
float local_sum = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
local_sum += in_row[i];
}
float mean = block_reduce_sum<BLOCK_SIZE>(local_sum) / hidden_size;
__shared__ float s_mean;
if (threadIdx.x == 0) s_mean = mean;
__syncthreads();
mean = s_mean;
// Step 2: Compute variance
float local_var_sum = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
float diff = in_row[i] - mean;
local_var_sum += diff * diff;
}
float variance = block_reduce_sum<BLOCK_SIZE>(local_var_sum) / hidden_size;
__shared__ float s_var;
if (threadIdx.x == 0) s_var = variance;
__syncthreads();
variance = s_var;
// Step 3: Normalize
float inv_std = rsqrtf(variance + eps);
for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
float normalized = (in_row[i] - mean) * inv_std;
out_row[i] = normalized * gamma[i] + beta[i];
}
}
Welford 算法 (数值稳定)
/*
* Welford's online algorithm for computing mean and variance
* 数值更稳定,避免大数相减的精度问题
*/
struct WelfordState {
float mean;
float m2; // Sum of squared differences
float count;
};
__device__ __forceinline__ WelfordState welford_combine(
WelfordState a, WelfordState b
) {
float count = a.count + b.count;
float delta = b.mean - a.mean;
float mean = a.mean + delta * b.count / count;
float m2 = a.m2 + b.m2 + delta * delta * a.count * b.count / count;
return {mean, m2, count};
}
__device__ __forceinline__ WelfordState warp_reduce_welford(WelfordState state) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
WelfordState other;
other.mean = __shfl_down_sync(0xffffffff, state.mean, offset);
other.m2 = __shfl_down_sync(0xffffffff, state.m2, offset);
other.count = __shfl_down_sync(0xffffffff, state.count, offset);
state = welford_combine(state, other);
}
return state;
}
template<int BLOCK_SIZE>
__global__ void layer_norm_welford(
const float* __restrict__ input,
const float* __restrict__ gamma,
const float* __restrict__ beta,
float* __restrict__ output,
float* __restrict__ mean_out, // 可选:输出均值
float* __restrict__ rstd_out, // 可选:输出标准差倒数
int batch_size, int hidden_size,
float eps
) {
int batch_idx = blockIdx.x;
if (batch_idx >= batch_size) return;
const float* in_row = input + batch_idx * hidden_size;
float* out_row = output + batch_idx * hidden_size;
// Welford reduction
WelfordState state = {0.0f, 0.0f, 0.0f};
for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
float x = in_row[i];
WelfordState new_state = {x, 0.0f, 1.0f};
state = welford_combine(state, new_state);
}
// Block reduce using Welford
__shared__ float s_mean[32], s_m2[32], s_count[32];
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
state = warp_reduce_welford(state);
if (lane == 0) {
s_mean[wid] = state.mean;
s_m2[wid] = state.m2;
s_count[wid] = state.count;
}
__syncthreads();
if (threadIdx.x < BLOCK_SIZE / 32) {
state.mean = s_mean[lane];
state.m2 = s_m2[lane];
state.count = s_count[lane];
} else {
state = {0.0f, 0.0f, 0.0f};
}
if (wid == 0) {
state = warp_reduce_welford(state);
}
__shared__ float final_mean, final_rstd;
if (threadIdx.x == 0) {
float variance = state.m2 / state.count;
final_mean = state.mean;
final_rstd = rsqrtf(variance + eps);
if (mean_out) mean_out[batch_idx] = final_mean;
if (rstd_out) rstd_out[batch_idx] = final_rstd;
}
__syncthreads();
// Normalize
float mean = final_mean;
float rstd = final_rstd;
for (int i = threadIdx.x; i < hidden_size; i += BLOCK_SIZE) {
float normalized = (in_row[i] - mean) * rstd;
out_row[i] = normalized * gamma[i] + beta[i];
}
}
向量化 LayerNorm
// 使用 float4 向量化访问
template<int BLOCK_SIZE>
__global__ void layer_norm_vectorized(
const float4* __restrict__ input, // hidden_size / 4
const float4* __restrict__ gamma,
const float4* __restrict__ beta,
float4* __restrict__ output,
int batch_size, int hidden_size, // 原始 hidden_size
float eps
) {
int batch_idx = blockIdx.x;
if (batch_idx >= batch_size) return;
int vec_hidden = hidden_size / 4;
const float4* in_row = input + batch_idx * vec_hidden;
float4* out_row = output + batch_idx * vec_hidden;
// Compute mean
float local_sum = 0.0f;
for (int i = threadIdx.x; i < vec_hidden; i += BLOCK_SIZE) {
float4 val = in_row[i];
local_sum += val.x + val.y + val.z + val.w;
}
float mean = block_reduce_sum<BLOCK_SIZE>(local_sum) / hidden_size;
__shared__ float s_mean;
if (threadIdx.x == 0) s_mean = mean;
__syncthreads();
mean = s_mean;
// Compute variance
float local_var = 0.0f;
for (int i = threadIdx.x; i < vec_hidden; i += BLOCK_SIZE) {
float4 val = in_row[i];
float d0 = val.x - mean, d1 = val.y - mean;
float d2 = val.z - mean, d3 = val.w - mean;
local_var += d0*d0 + d1*d1 + d2*d2 + d3*d3;
}
float var = block_reduce_sum<BLOCK_SIZE>(local_var) / hidden_size;
__shared__ float s_rstd;
if (threadIdx.x == 0) s_rstd = rsqrtf(var + eps);
__syncthreads();
float rstd = s_rstd;
// Normalize with vectorized access
for (int i = threadIdx.x; i < vec_hidden; i += BLOCK_SIZE) {
float4 val = in_row[i];
float4 g = gamma[i];
float4 b = beta[i];
float4 out;
out.x = (val.x - mean) * rstd * g.x + b.x;
out.y = (val.y - mean) * rstd * g.y + b.y;
out.z = (val.z - mean) * rstd * g.z + b.z;
out.w = (val.w - mean) * rstd * g.w + b.w;
out_row[i] = out;
}
}
性能优化总结
优化策略对比
| 优化技术 | 适用场景 | 收益 |
|---|---|---|
| 共享内存 Tiling | 数据复用高的算子 | 减少全局内存访问 |
| 寄存器分块 | 计算密集型 | 最大化计算吞吐 |
| 向量化访问 | 内存密集型 | 提高带宽利用率 |
| Warp Shuffle | Reduce 类操作 | 减少同步开销 |
| 双缓冲 | 内存延迟敏感 | 隐藏访存延迟 |
| 循环展开 | 小循环体 | 减少指令开销 |
Kernel 开发检查清单
□ 内存访问是否合并?
□ 共享内存 Bank Conflict?
□ 是否有 Warp Divergence?
□ 寄存器压力是否过大?
□ 占用率是否足够?
□ 是否可以向量化?
□ 是否可以融合其他操作?