04-算子融合与优化技术
概述
算子融合(Operator Fusion)是深度学习性能优化的核心技术之一。通过将多个算子合并为一个 Kernel,可以减少内存访问、降低 Kernel 启动开销,显著提升计算效率。
算子融合原理
为什么需要算子融合
┌─────────────────────────────────────────────────────────────────────────┐
│ 算子融合动机分析 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 未融合的算子链:MatMul → Add Bias → GELU │
│ │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ Input │ → │ MatMul │ → │ Global │ → │ Add │ → │ Global │ │
│ │ (HBM) │ │ Kernel │ │ Memory │ │ Bias │ │ Memory │ │
│ └────────┘ └────────┘ └────────┘ └────────┘ └────────┘ │
│ ↓ ↓ │
│ ┌────────┐ ┌────────┐ │
│ │ GELU │ │ Output │ │
│ │ Kernel │ │ (HBM) │ │
│ └────────┘ └────────┘ │
│ │
│ 问题: │
│ ├── 3 次 Kernel 启动 (每次 ~5-10 μs) │
│ ├── 2 次中间结果写入 HBM │
│ ├── 2 次中间结果读取 HBM │
│ └── 带宽瓶颈:HBM 带宽有限 (A100: 2TB/s) │
│ │
│ 融合后:MatMul + Add Bias + GELU │
│ │
│ ┌────────┐ ┌────────────────────────┐ ┌────────┐ │
│ │ Input │ → │ Fused Kernel │ → │ Output │ │
│ │ (HBM) │ │ MatMul+Bias+GELU │ │ (HBM) │ │
│ └────────┘ │ (寄存器中完成Bias+GELU) │ └────────┘ │
│ └────────────────────────┘ │
│ │
│ 收益: │
│ ├── 1 次 Kernel 启动 │
│ ├── 0 次中间结果写入/读取 HBM │
│ └── 性能提升:1.5x - 3x (取决于算子特点) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
融合类型分类
┌─────────────────────────────────────────────────────────────────────────┐
│ 算子融合类型 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. Element-wise 融合 │
│ ├── 最简单,多个逐元素操作合并 │
│ ├── 例:Add + ReLU, Mul + Add + Sigmoid │
│ └── 收益:减少内存访问次数 │
│ │
│ 2. Epilogue 融合 │
│ ├── 将后处理操作融合到主算子中 │
│ ├── 例:GEMM + Bias + Activation │
│ └── 收益:避免中间结果写回 HBM │
│ │
│ 3. Reduce 融合 │
│ ├── 将 reduce 与其他操作融合 │
│ ├── 例:Softmax (max + exp + sum + div) │
│ └── 收益:减少全局内存往返 │
│ │
│ 4. 跨层融合 │
│ ├── 融合多个神经网络层 │
│ ├── 例:Conv + BN + ReLU, Attention Block │
│ └── 收益:最大化数据复用 │
│ │
│ 5. 动态融合 (JIT) │
│ ├── 运行时根据计算图动态生成融合 Kernel │
│ ├── 例:TorchScript, XLA, TVM │
│ └── 收益:灵活适应不同模型结构 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Element-wise 融合
基础融合实现
// 未融合版本
__global__ void add_kernel(float* c, const float* a, const float* b, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) c[idx] = a[idx] + b[idx];
}
__global__ void relu_kernel(float* y, const float* x, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) y[idx] = fmaxf(x[idx], 0.0f);
}
// 融合版本:Add + ReLU
__global__ void fused_add_relu(
float* __restrict__ output,
const float* __restrict__ a,
const float* __restrict__ b,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float val = a[idx] + b[idx];
output[idx] = fmaxf(val, 0.0f);
}
}
// 更复杂的融合:Mul + Add + Sigmoid
__global__ void fused_mul_add_sigmoid(
float* __restrict__ output,
const float* __restrict__ x,
const float* __restrict__ w,
const float* __restrict__ b,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float val = x[idx] * w[idx] + b[idx];
output[idx] = 1.0f / (1.0f + expf(-val));
}
}
向量化融合
// 使用 float4 向量化的融合 Kernel
__global__ void fused_add_relu_vec4(
float4* __restrict__ output,
const float4* __restrict__ a,
const float4* __restrict__ b,
int n // n 是 float4 的数量
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float4 va = a[idx];
float4 vb = b[idx];
float4 result;
result.x = fmaxf(va.x + vb.x, 0.0f);
result.y = fmaxf(va.y + vb.y, 0.0f);
result.z = fmaxf(va.z + vb.z, 0.0f);
result.w = fmaxf(va.w + vb.w, 0.0f);
output[idx] = result;
}
}
// LayerNorm + Dropout + Add (残差连接) 融合
__global__ void fused_layernorm_dropout_add(
float* __restrict__ output,
const float* __restrict__ input,
const float* __restrict__ residual,
const float* __restrict__ gamma,
const float* __restrict__ beta,
float* __restrict__ dropout_mask,
int batch_size,
int hidden_size,
float dropout_prob,
float eps,
unsigned long long seed
) {
int batch_idx = blockIdx.x;
if (batch_idx >= batch_size) return;
const float* in_row = input + batch_idx * hidden_size;
const float* res_row = residual + batch_idx * hidden_size;
float* out_row = output + batch_idx * hidden_size;
float* mask_row = dropout_mask + batch_idx * hidden_size;
// Step 1: Compute mean
float local_sum = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
local_sum += in_row[i];
}
float mean = block_reduce_sum<256>(local_sum) / hidden_size;
__shared__ float s_mean;
if (threadIdx.x == 0) s_mean = mean;
__syncthreads();
mean = s_mean;
// Step 2: Compute variance
float local_var = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float diff = in_row[i] - mean;
local_var += diff * diff;
}
float var = block_reduce_sum<256>(local_var) / hidden_size;
__shared__ float s_rstd;
if (threadIdx.x == 0) s_rstd = rsqrtf(var + eps);
__syncthreads();
float rstd = s_rstd;
// Step 3: Normalize + Dropout + Add residual
// 使用 curand 生成 dropout mask
curandStatePhilox4_32_10_t state;
curand_init(seed, batch_idx * hidden_size + threadIdx.x, 0, &state);
float scale = 1.0f / (1.0f - dropout_prob);
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
// LayerNorm
float normalized = (in_row[i] - mean) * rstd;
float ln_out = normalized * gamma[i] + beta[i];
// Dropout
float rand = curand_uniform(&state);
float mask = (rand > dropout_prob) ? 1.0f : 0.0f;
mask_row[i] = mask;
float dropout_out = ln_out * mask * scale;
// Add residual
out_row[i] = dropout_out + res_row[i];
}
}
GEMM Epilogue 融合
GEMM + Bias + Activation
// GEMM 结果的后处理融合
// D = activation(alpha * A @ B + beta * C + bias)
template<typename Activation>
__global__ void gemm_epilogue_fused(
const float* __restrict__ gemm_output, // GEMM 原始输出
const float* __restrict__ bias,
float* __restrict__ output,
int M, int N,
float alpha,
float beta
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
int idx = row * N + col;
float val = alpha * gemm_output[idx] + bias[col];
output[idx] = Activation::apply(val);
}
}
// 激活函数仿函数
struct ReLU {
__device__ __forceinline__ static float apply(float x) {
return fmaxf(x, 0.0f);
}
};
struct GELU {
__device__ __forceinline__ static float apply(float x) {
// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
const float c = 0.7978845608f; // sqrt(2/π)
const float k = 0.044715f;
float x3 = x * x * x;
return 0.5f * x * (1.0f + tanhf(c * (x + k * x3)));
}
};
struct SiLU {
__device__ __forceinline__ static float apply(float x) {
return x / (1.0f + expf(-x));
}
};
// 在 CUTLASS 中定义 Epilogue
#include <cutlass/epilogue/thread/linear_combination_relu.h>
#include <cutlass/epilogue/thread/linear_combination_gelu.h>
using EpilogueWithReLU = cutlass::epilogue::thread::LinearCombinationRelu<
float, // Element type
128 / cutlass::sizeof_bits<float>::value, // Elements per access
float, // Accumulator type
float // Compute type
>;
using EpilogueWithGELU = cutlass::epilogue::thread::LinearCombinationGELU<
float,
128 / cutlass::sizeof_bits<float>::value,
float,
float
>;
完整的融合 GEMM Kernel
/*
* 融合 GEMM + Bias + GELU 的完整实现
* 使用 Tensor Core
*/
#define BM 128
#define BN 128
#define BK 32
#define TM 8
#define TN 8
__global__ void fused_gemm_bias_gelu(
const half* __restrict__ A,
const half* __restrict__ B,
const half* __restrict__ bias,
half* __restrict__ C,
int M, int N, int K
) {
__shared__ half As[BK][BM];
__shared__ half Bs[BK][BN];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int tid = ty * blockDim.x + tx;
// WMMA fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag[2];
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2];
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag[2][2];
// 初始化累加器
#pragma unroll
for (int i = 0; i < 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
wmma::fill_fragment(c_frag[i][j], 0.0f);
}
}
int warp_id = tid / 32;
int warp_row = warp_id / 2;
int warp_col = warp_id % 2;
// K 维度迭代
for (int bk = 0; bk < K; bk += BK) {
// 加载 A, B 到共享内存
load_shared_a(A, As, by, bk, tid, M, K);
load_shared_b(B, Bs, bx, bk, tid, K, N);
__syncthreads();
// WMMA 计算
#pragma unroll
for (int k = 0; k < BK; k += 16) {
#pragma unroll
for (int m = 0; m < 2; m++) {
wmma::load_matrix_sync(a_frag[m],
&As[k][warp_row * 64 + m * 16], BM);
}
#pragma unroll
for (int n = 0; n < 2; n++) {
wmma::load_matrix_sync(b_frag[n],
&Bs[k][warp_col * 64 + n * 16], BN);
}
#pragma unroll
for (int m = 0; m < 2; m++) {
#pragma unroll
for (int n = 0; n < 2; n++) {
wmma::mma_sync(c_frag[m][n], a_frag[m], b_frag[n], c_frag[m][n]);
}
}
}
__syncthreads();
}
// Epilogue: Bias + GELU
int c_row_base = by * BM + warp_row * 64;
int c_col_base = bx * BN + warp_col * 64;
#pragma unroll
for (int m = 0; m < 2; m++) {
#pragma unroll
for (int n = 0; n < 2; n++) {
// 应用 bias 和 GELU 到每个 fragment 元素
#pragma unroll
for (int i = 0; i < c_frag[m][n].num_elements; i++) {
// 计算该元素的列索引
int elem_col = c_col_base + n * 16 + (i % 16);
if (elem_col < N) {
float val = c_frag[m][n].x[i];
// Add bias
val += __half2float(bias[elem_col]);
// GELU
val = gelu_forward(val);
c_frag[m][n].x[i] = val;
}
}
// 转换为 half 并存储
wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_half_frag;
#pragma unroll
for (int i = 0; i < c_frag[m][n].num_elements; i++) {
c_half_frag.x[i] = __float2half(c_frag[m][n].x[i]);
}
int c_row = c_row_base + m * 16;
int c_col = c_col_base + n * 16;
if (c_row < M && c_col < N) {
wmma::store_matrix_sync(&C[c_row * N + c_col], c_half_frag, N,
wmma::mem_row_major);
}
}
}
}
__device__ __forceinline__ float gelu_forward(float x) {
const float c = 0.7978845608f;
const float k = 0.044715f;
return 0.5f * x * (1.0f + tanhf(c * (x + k * x * x * x)));
}
Attention 融合 (Flash Attention)
Flash Attention 原理
┌─────────────────────────────────────────────────────────────────────────┐
│ Flash Attention 原理 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 标准 Attention: │
│ O = softmax(QK^T / √d) V │
│ │
│ 标准实现的内存占用: │
│ ├── S = QK^T: O(N²) 中间矩阵 │
│ ├── P = softmax(S): O(N²) 中间矩阵 │
│ └── O = PV: 最终输出 │
│ 总计:O(N²) 显存,对长序列不可行 │
│ │
│ Flash Attention 核心思想: │
│ ├── 分块计算:将 Q, K, V 分成小块 │
│ ├── Online Softmax:边计算边更新 softmax 统计量 │
│ └── 不存储中间 S, P 矩阵 │
│ │
│ 分块策略: │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Q [N, d] K [N, d] V [N, d] O [N, d] │ │
│ │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ │
│ │ │Q_1│ │K_1│ │V_1│ │O_1│ │ │
│ │ ├───┤ ───> ├───┤ ───> ├───┤ ───> ├───┤ │ │
│ │ │Q_2│ │K_2│ │V_2│ │O_2│ │ │
│ │ ├───┤ ├───┤ ├───┤ ├───┤ │ │
│ │ │...│ │...│ │...│ │...│ │ │
│ │ └───┘ └───┘ └───┘ └───┘ │ │
│ │ │ │
│ │ 每次只加载一块 Q_i,遍历所有 K_j, V_j 块 │ │
│ │ 在 SRAM 中完成计算,避免 HBM 往返 │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
│ Online Softmax 更新: │
│ 对于 Q_i,遍历 K_j 时: │
│ ├── m_new = max(m_old, max(S_ij)) │
│ ├── l_new = l_old * exp(m_old - m_new) + sum(exp(S_ij - m_new)) │
│ └── O_new = O_old * (l_old/l_new) * exp(m_old-m_new) + softmax(S_ij)*V│
│ │
└─────────────────────────────────────────────────────────────────────────┘
Flash Attention 简化实现
/*
* Flash Attention 简化实现
* 仅用于说明原理,实际实现更复杂
*/
#define BR 64 // Q block size
#define BC 64 // K/V block size
#define d 64 // head dimension
__global__ void flash_attention_kernel(
const float* __restrict__ Q, // [N, d]
const float* __restrict__ K, // [N, d]
const float* __restrict__ V, // [N, d]
float* __restrict__ O, // [N, d]
float* __restrict__ L, // [N] softmax 归一化因子
float* __restrict__ M, // [N] 行最大值
int N,
float scale
) {
// 共享内存
__shared__ float Qi[BR][d]; // Q 块
__shared__ float Kj[BC][d]; // K 块
__shared__ float Vj[BC][d]; // V 块
__shared__ float Sij[BR][BC]; // 注意力分数块
int batch_head = blockIdx.x;
int q_block = blockIdx.y;
int q_start = q_block * BR;
// 加载 Q 块到共享内存
load_block(Q + batch_head * N * d, Qi, q_start, N, d);
__syncthreads();
// 初始化输出
float Oi[d] = {0.0f};
float mi = -INFINITY;
float li = 0.0f;
int thread_row = threadIdx.x / BC;
int thread_col = threadIdx.x % BC;
// 遍历 K, V 块
for (int j = 0; j < (N + BC - 1) / BC; j++) {
int kv_start = j * BC;
// 加载 K, V 块
load_block(K + batch_head * N * d, Kj, kv_start, N, d);
load_block(V + batch_head * N * d, Vj, kv_start, N, d);
__syncthreads();
// 计算 S_ij = Q_i @ K_j^T * scale
if (thread_row < BR && thread_col < BC) {
float sum = 0.0f;
#pragma unroll
for (int k = 0; k < d; k++) {
sum += Qi[thread_row][k] * Kj[thread_col][k];
}
Sij[thread_row][thread_col] = sum * scale;
}
__syncthreads();
// 对于该线程负责的 Q 行,计算 online softmax
if (threadIdx.x < BR) {
int i = threadIdx.x;
// 找该行最大值
float row_max = -INFINITY;
#pragma unroll
for (int jj = 0; jj < BC && kv_start + jj < N; jj++) {
row_max = fmaxf(row_max, Sij[i][jj]);
}
// Online softmax 更新
float mi_new = fmaxf(mi, row_max);
float li_new = li * expf(mi - mi_new);
// 更新 O
float scale_old = expf(mi - mi_new);
#pragma unroll
for (int k = 0; k < d; k++) {
Oi[k] *= scale_old;
}
// 累加当前块的贡献
#pragma unroll
for (int jj = 0; jj < BC && kv_start + jj < N; jj++) {
float pij = expf(Sij[i][jj] - mi_new);
li_new += pij;
#pragma unroll
for (int k = 0; k < d; k++) {
Oi[k] += pij * Vj[jj][k];
}
}
mi = mi_new;
li = li_new;
}
__syncthreads();
}
// 归一化并写回
if (threadIdx.x < BR && q_start + threadIdx.x < N) {
int row = q_start + threadIdx.x;
float inv_l = 1.0f / li;
#pragma unroll
for (int k = 0; k < d; k++) {
O[(batch_head * N + row) * d + k] = Oi[k] * inv_l;
}
L[batch_head * N + row] = li;
M[batch_head * N + row] = mi;
}
}
Flash Attention 2 优化
/*
* Flash Attention 2 优化点:
* 1. 更好的并行策略:在 sequence 维度并行
* 2. 减少非 matmul 操作
* 3. 更好的 warp 专业化
*/
template<int BR, int BC, int d>
__global__ void flash_attention_2_kernel(
const half* __restrict__ Q,
const half* __restrict__ K,
const half* __restrict__ V,
half* __restrict__ O,
int N,
float scale
) {
// Thread block 处理一个 Q 块
// 使用 Tensor Core
__shared__ half Qi[BR][d];
__shared__ half Kj[BC][d];
__shared__ half Vj[BC][d];
// WMMA fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> q_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> k_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> v_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> s_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> o_frag;
wmma::fill_fragment(o_frag, 0.0f);
int warp_id = threadIdx.x / 32;
int q_block = blockIdx.y;
int batch_head = blockIdx.x;
float mi = -INFINITY;
float li = 0.0f;
// 加载 Q 块 (只加载一次)
load_q_block(Q + batch_head * N * d, Qi, q_block * BR, N, d);
__syncthreads();
// 遍历 K/V 块
for (int kv_block = 0; kv_block < (N + BC - 1) / BC; kv_block++) {
// 加载 K, V
load_kv_block(K, V, Kj, Vj, batch_head, kv_block * BC, N, d);
__syncthreads();
// 使用 Tensor Core 计算 S = Q @ K^T
wmma::fill_fragment(s_frag, 0.0f);
for (int k = 0; k < d; k += 16) {
wmma::load_matrix_sync(q_frag, &Qi[warp_id * 16][k], d);
wmma::load_matrix_sync(k_frag, &Kj[0][k], d);
wmma::mma_sync(s_frag, q_frag, k_frag, s_frag);
}
// Scale
#pragma unroll
for (int i = 0; i < s_frag.num_elements; i++) {
s_frag.x[i] *= scale;
}
// Causal mask (如果需要)
apply_causal_mask(s_frag, q_block * BR + warp_id * 16, kv_block * BC);
// Online softmax
float new_max = reduce_max(s_frag);
float mi_new = fmaxf(mi, new_max);
// Rescale previous output
float scale_factor = expf(mi - mi_new);
#pragma unroll
for (int i = 0; i < o_frag.num_elements; i++) {
o_frag.x[i] *= scale_factor;
}
// Compute softmax and accumulate
float row_sum = 0.0f;
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> p_frag;
#pragma unroll
for (int i = 0; i < s_frag.num_elements; i++) {
float p = expf(s_frag.x[i] - mi_new);
row_sum += p;
p_frag.x[i] = __float2half(p);
}
// O += P @ V
wmma::load_matrix_sync(v_frag, &Vj[0][0], d);
wmma::mma_sync(o_frag, p_frag, v_frag, o_frag);
// Update stats
li = li * scale_factor + row_sum;
mi = mi_new;
__syncthreads();
}
// Normalize and store
float inv_li = 1.0f / li;
#pragma unroll
for (int i = 0; i < o_frag.num_elements; i++) {
o_frag.x[i] *= inv_li;
}
wmma::store_matrix_sync(
&O[(batch_head * N + q_block * BR + warp_id * 16) * d],
o_frag, d, wmma::mem_row_major
);
}
动态融合与 JIT 编译
运行时 Kernel 生成
# 使用 Triton 进行动态融合
import triton
import triton.language as tl
@triton.jit
def fused_elementwise_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
OP1: tl.constexpr, # 运行时指定操作
OP2: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# 动态选择操作
if OP1 == 0: # Add
result = x + y
elif OP1 == 1: # Mul
result = x * y
elif OP1 == 2: # Sub
result = x - y
if OP2 == 0: # ReLU
result = tl.maximum(result, 0)
elif OP2 == 1: # GELU
result = 0.5 * result * (1 + tl.math.tanh(
0.7978845608 * (result + 0.044715 * result * result * result)))
elif OP2 == 2: # Sigmoid
result = 1 / (1 + tl.exp(-result))
tl.store(output_ptr + offsets, result, mask=mask)
# 使用
def fused_op(x, y, op1='add', op2='relu'):
op1_map = {'add': 0, 'mul': 1, 'sub': 2}
op2_map = {'relu': 0, 'gelu': 1, 'sigmoid': 2}
output = torch.empty_like(x)
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
fused_elementwise_kernel[grid](
x, y, output, n_elements,
BLOCK_SIZE=1024,
OP1=op1_map[op1],
OP2=op2_map[op2],
)
return output
TorchScript Fusion
import torch
from torch import nn
import torch.jit
# TorchScript 会自动融合兼容的操作
class FusedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, hidden_size)
def forward(self, x):
# 这些操作会被 TorchScript 融合
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
return x
# 编译为 TorchScript
model = FusedMLP(768, 3072).cuda()
scripted = torch.jit.script(model)
# 查看融合情况
print(scripted.graph)
# 使用 torch.compile (PyTorch 2.0+) 进行更激进的融合
@torch.compile(mode="reduce-overhead")
def optimized_forward(model, x):
return model(x)
融合优化最佳实践
融合决策指南
┌─────────────────────────────────────────────────────────────────────────┐
│ 融合决策指南 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 应该融合的情况: │
│ ├── Element-wise 操作链 │
│ ├── 算子间有数据依赖且中间结果不需要保存 │
│ ├── 减少内存带宽瓶颈的算子 │
│ └── Kernel 启动开销显著的小算子 │
│ │
│ 不应该融合的情况: │
│ ├── 融合后 Kernel 过于复杂,导致寄存器溢出 │
│ ├── 融合后占用率过低 │
│ ├── 中间结果需要被多个后续算子使用 │
│ └── 算子已经足够高效(如 cuBLAS GEMM) │
│ │
│ 性能评估: │
│ ├── 使用 Nsight 分析内存带宽和计算利用率 │
│ ├── 对比融合前后的端到端性能 │
│ └── 注意 edge case 的正确性 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
常见融合模式
| 融合模式 | 算子组合 | 收益 |
|---|---|---|
| Bias + Activation | GEMM + Bias + ReLU/GELU | 1.5-2x |
| LayerNorm 融合 | Mean + Var + Normalize + Scale | 2-3x |
| Softmax 融合 | Max + Exp + Sum + Div | 2-4x |
| Attention | QK^T + Softmax + V | 2-4x (Flash Attention) |
| Residual | Add + LayerNorm | 1.5x |
总结
算子融合核心要点
□ 识别内存密集型算子链
□ 使用共享内存/寄存器存储中间结果
□ 利用 Online 算法避免多次遍历
□ 考虑 Tensor Core 的 epilogue 融合
□ 平衡融合粒度与 Kernel 复杂度
□ 使用 profiler 验证融合效果