04-算子融合与Kernel优化
概述
算子融合是深度学习编译器最重要的优化技术之一,通过将多个算子合并为单个kernel执行,显著减少内存访问和kernel启动开销。本章深入讲解算子融合的原理、类型、实现方法,以及kernel级别的各种优化技术。
算子融合的价值
┌─────────────────────────────────────────────────────────────────────────────┐
│ 为什么需要算子融合 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 未融合执行 (3个kernel) 融合执行 (1个kernel) │
│ ┌─────────────────────────┐ ┌─────────────────────────┐ │
│ │ Kernel 1: sin(x) │ │ Fused Kernel: │ │
│ │ 读取x从HBM │ │ 读取x,y从HBM (1次) │ │
│ │ 计算sin │ │ 计算sin(x) │ │
│ │ 写回结果到HBM │ → │ 计算cos(y) │ │
│ └─────────────────────────┘ │ 计算add │ │
│ ┌─────────────────────────┐ │ 写回结果到HBM (1次) │ │
│ │ Kernel 2: cos(y) │ └─────────────────────────┘ │
│ │ 读取y从HBM │ │
│ │ 计算cos │ 性能提升: │
│ │ 写回结果到HBM │ - 内存访问: 6次 → 2次 │
│ └─────────────────────────┘ - Kernel启动: 3次 → 1次 │
│ ┌─────────────────────────┐ - 中间结果: 2个 → 0个 │
│ │ Kernel 3: add │ │
│ │ 读取两个中间结果 │ 典型加速比: 2-5x │
│ │ 计算add │ │
│ │ 写回最终结果 │ │
│ └─────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
1. 算子融合类型
1.1 元素级融合 (Element-wise Fusion)
# 最常见的融合类型:连续的element-wise操作
import torch
import triton
import triton.language as tl
# 未融合版本
def unfused_gelu(x):
t1 = x * 0.5
t2 = x * 0.7978845608028654
t3 = x ** 3
t4 = t3 * 0.044715
t5 = t2 + t4
t6 = torch.tanh(t5)
t7 = 1.0 + t6
return t1 * t7
# 融合版本 (Triton实现)
@triton.jit
def fused_gelu_kernel(
x_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
# 单次内存读取
x = tl.load(x_ptr + offset, mask=mask)
# 所有计算在寄存器中完成
t1 = x * 0.5
t2 = x * 0.7978845608028654
t3 = x * x * x
t4 = t3 * 0.044715
t5 = t2 + t4
t6 = tl.math.tanh(t5)
t7 = 1.0 + t6
output = t1 * t7
# 单次内存写入
tl.store(output_ptr + offset, output, mask=mask)
def fused_gelu(x):
output = torch.empty_like(x)
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
fused_gelu_kernel[grid](x, output, n_elements, BLOCK_SIZE=1024)
return output
# 性能对比
def benchmark():
x = torch.randn(1024, 1024, device='cuda')
# 预热
for _ in range(10):
_ = unfused_gelu(x)
_ = fused_gelu(x)
torch.cuda.synchronize()
import time
# 未融合
start = time.time()
for _ in range(100):
_ = unfused_gelu(x)
torch.cuda.synchronize()
unfused_time = time.time() - start
# 融合
start = time.time()
for _ in range(100):
_ = fused_gelu(x)
torch.cuda.synchronize()
fused_time = time.time() - start
print(f"Unfused: {unfused_time*10:.2f}ms")
print(f"Fused: {fused_time*10:.2f}ms")
print(f"Speedup: {unfused_time/fused_time:.2f}x")
1.2 规约融合 (Reduction Fusion)
# 将element-wise操作与reduce操作融合
@triton.jit
def fused_softmax_kernel(
input_ptr, output_ptr,
n_cols,
input_row_stride, output_row_stride,
BLOCK_SIZE: tl.constexpr
):
"""
融合的softmax实现
将 max, sub, exp, sum, div 融合为单个kernel
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# 1. 加载一行数据
input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# 2. 计算max (规约)
row_max = tl.max(row, axis=0)
# 3. 减去max并计算exp (element-wise)
numerator = tl.exp(row - row_max)
# 4. 计算sum (规约)
denominator = tl.sum(numerator, axis=0)
# 5. 归一化 (element-wise)
softmax_output = numerator / denominator
# 6. 存储结果
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
# 更复杂的例子: LayerNorm融合
@triton.jit
def fused_layer_norm_kernel(
x_ptr, weight_ptr, bias_ptr, output_ptr,
N, D,
eps,
BLOCK_SIZE: tl.constexpr
):
"""
融合的LayerNorm实现
mean, var, normalize, scale, shift 全部融合
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < D
# 加载数据
x_ptrs = x_ptr + row_idx * D + col_offsets
x = tl.load(x_ptrs, mask=mask, other=0.0)
# 计算均值
mean = tl.sum(x, axis=0) / D
# 计算方差
x_centered = x - mean
var = tl.sum(x_centered * x_centered, axis=0) / D
# 归一化
x_norm = x_centered / tl.sqrt(var + eps)
# 仿射变换
weight = tl.load(weight_ptr + col_offsets, mask=mask)
bias = tl.load(bias_ptr + col_offsets, mask=mask)
output = x_norm * weight + bias
# 存储
output_ptrs = output_ptr + row_idx * D + col_offsets
tl.store(output_ptrs, output, mask=mask)
1.3 矩阵乘法融合 (MatMul Fusion)
# MatMul + BiasAdd + Activation 融合
@triton.jit
def fused_matmul_bias_relu_kernel(
a_ptr, b_ptr, bias_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""
融合的矩阵乘法 + bias + ReLU
C = ReLU(A @ B + bias)
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 计算块起始位置
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# 初始化累加器
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# 矩阵乘法主循环
for k in range(0, K, BLOCK_K):
# 加载A块
a_ptrs = a_ptr + offs_m[:, None] * stride_am + (k + offs_k[None, :]) * stride_ak
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & ((k + offs_k[None, :]) < K), other=0.0)
# 加载B块
b_ptrs = b_ptr + (k + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn
b = tl.load(b_ptrs, mask=((k + offs_k[:, None]) < K) & (offs_n[None, :] < N), other=0.0)
# 累加
acc += tl.dot(a, b)
# 加载bias并添加
bias = tl.load(bias_ptr + offs_n, mask=offs_n < N)
acc = acc + bias[None, :]
# 应用ReLU (融合的激活)
acc = tl.maximum(acc, 0.0)
# 存储结果
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc, mask=mask)
1.4 注意力融合 (Attention Fusion)
# Flash Attention: 将整个attention计算融合
@triton.jit
def flash_attention_kernel(
Q, K, V, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_ok,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""
Flash Attention融合实现
分块计算attention,避免O(N^2)内存
"""
# 块索引
start_m = tl.program_id(0) * BLOCK_M
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# 指针偏移
Q_block_ptr = Q + off_z * stride_qz + off_h * stride_qh
K_block_ptr = K + off_z * stride_kz + off_h * stride_kh
V_block_ptr = V + off_z * stride_vz + off_h * stride_vh
O_block_ptr = Out + off_z * stride_oz + off_h * stride_oh
# 初始化
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# 加载Q块 (保持在SRAM)
q_ptrs = Q_block_ptr + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX)
# 初始化输出累加器
m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32) # 行最大值
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # 行和
acc = tl.zeros([BLOCK_M, BLOCK_K], dtype=tl.float32) # 输出
# 遍历K,V块
for start_n in range(0, N_CTX, BLOCK_N):
# 加载K块
k_ptrs = K_block_ptr + (start_n + offs_n)[:, None] * stride_kn + offs_k[None, :] * stride_kk
k = tl.load(k_ptrs, mask=(start_n + offs_n)[:, None] < N_CTX)
# 计算QK^T
qk = tl.dot(q, tl.trans(k))
# Softmax缩放
qk = qk * (1.0 / tl.sqrt(float(BLOCK_K)))
# 更新最大值和累加器 (online softmax)
m_ij = tl.max(qk, axis=1)
m_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_new)
beta = tl.exp(m_ij - m_new)
l_new = alpha * l_i + beta * tl.sum(tl.exp(qk - m_ij[:, None]), axis=1)
# 加载V块
v_ptrs = V_block_ptr + (start_n + offs_n)[:, None] * stride_vn + offs_k[None, :] * stride_vk
v = tl.load(v_ptrs, mask=(start_n + offs_n)[:, None] < N_CTX)
# 更新输出
p = tl.exp(qk - m_new[:, None])
acc = alpha[:, None] * acc + tl.dot(p, v)
# 更新状态
m_i = m_new
l_i = l_new
# 归一化并存储
acc = acc / l_i[:, None]
o_ptrs = O_block_ptr + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
tl.store(o_ptrs, acc, mask=offs_m[:, None] < N_CTX)
2. 融合决策算法
2.1 融合图分析
# 融合决策的核心算法
class FusionAnalyzer:
"""
分析计算图,决定哪些操作可以融合
"""
def __init__(self, graph):
self.graph = graph
def find_fusion_groups(self):
"""
找出所有可融合的操作组
"""
fusion_groups = []
visited = set()
for node in self.graph.nodes:
if node in visited:
continue
group = self._grow_fusion_group(node, visited)
if len(group) > 1:
fusion_groups.append(group)
return fusion_groups
def _grow_fusion_group(self, start_node, visited):
"""
从起始节点扩展融合组
使用贪心策略
"""
group = [start_node]
visited.add(start_node)
frontier = list(start_node.users)
while frontier:
candidate = frontier.pop(0)
if candidate in visited:
continue
if self._can_fuse(group, candidate):
group.append(candidate)
visited.add(candidate)
frontier.extend(candidate.users)
return group
def _can_fuse(self, group, candidate):
"""
检查是否可以将candidate融合到group
"""
# 规则1: Element-wise操作可以融合
if self._is_elementwise(candidate):
# 检查形状兼容性
if all(self._shapes_compatible(n, candidate) for n in group):
return True
# 规则2: Broadcast可以融合
if candidate.op_type == 'Broadcast':
return True
# 规则3: 单输出的producer可以融合到consumer
if len(candidate.outputs) == 1:
for user in candidate.users:
if user in group:
return True
# 规则4: Reduce可以融合element-wise前缀
if candidate.op_type == 'Reduce':
return all(self._is_elementwise(n) for n in group)
return False
def _is_elementwise(self, node):
"""检查是否是element-wise操作"""
elementwise_ops = {
'Add', 'Sub', 'Mul', 'Div',
'Exp', 'Log', 'Sin', 'Cos', 'Tanh',
'Relu', 'Sigmoid', 'Gelu',
'Cast', 'Neg', 'Abs'
}
return node.op_type in elementwise_ops
def _shapes_compatible(self, node1, node2):
"""检查形状是否兼容融合"""
shape1 = node1.output_shape
shape2 = node2.output_shape
# 完全相同
if shape1 == shape2:
return True
# 可广播
if self._can_broadcast(shape1, shape2):
return True
return False
class FusionCostModel:
"""
融合代价模型
评估融合是否有益
"""
def __init__(self, target_device):
self.device = target_device
# 设备参数
self.memory_bandwidth = 1000e9 # 1TB/s for H100
self.compute_throughput = 1000e12 # 1PFLOPS for H100
self.kernel_launch_overhead = 5e-6 # 5us
def should_fuse(self, group):
"""
评估是否应该融合
"""
unfused_cost = self._estimate_unfused_cost(group)
fused_cost = self._estimate_fused_cost(group)
return fused_cost < unfused_cost
def _estimate_unfused_cost(self, group):
"""估算未融合的执行时间"""
total_cost = 0
for node in group:
# Kernel启动开销
total_cost += self.kernel_launch_overhead
# 内存访问时间
memory_bytes = self._compute_memory_bytes(node)
total_cost += memory_bytes / self.memory_bandwidth
# 计算时间
flops = self._compute_flops(node)
total_cost += flops / self.compute_throughput
return total_cost
def _estimate_fused_cost(self, group):
"""估算融合后的执行时间"""
# 单次kernel启动
total_cost = self.kernel_launch_overhead
# 只计算输入和输出的内存访问
input_bytes = sum(self._input_bytes(n) for n in group if self._is_input(n, group))
output_bytes = sum(self._output_bytes(n) for n in group if self._is_output(n, group))
total_cost += (input_bytes + output_bytes) / self.memory_bandwidth
# 计算时间 (可能有更好的并行性)
total_flops = sum(self._compute_flops(n) for n in group)
total_cost += total_flops / self.compute_throughput
return total_cost
2.2 循环融合分析
# 循环融合的代数分析
class LoopFusionAnalyzer:
"""
分析循环是否可以融合
"""
def can_fuse_loops(self, loop1, loop2):
"""
检查两个循环是否可以融合
"""
# 条件1: 循环边界相同
if not self._same_bounds(loop1, loop2):
return False
# 条件2: 无循环携带依赖
if self._has_loop_carried_dependence(loop1, loop2):
return False
# 条件3: 数据依赖允许
if not self._data_dependence_allows(loop1, loop2):
return False
return True
def _same_bounds(self, loop1, loop2):
"""检查循环边界是否相同"""
return (loop1.lower_bound == loop2.lower_bound and
loop1.upper_bound == loop2.upper_bound and
loop1.step == loop2.step)
def _has_loop_carried_dependence(self, loop1, loop2):
"""
检查是否存在循环携带依赖
即迭代i的写入是否被迭代j>i的读取依赖
"""
writes1 = self._get_write_accesses(loop1)
reads2 = self._get_read_accesses(loop2)
for write in writes1:
for read in reads2:
if self._may_alias(write, read):
# 检查依赖距离
distance = self._compute_dependence_distance(write, read)
if distance != 0: # 非零距离意味着循环携带依赖
return True
return False
def _data_dependence_allows(self, loop1, loop2):
"""
检查数据依赖是否允许融合
"""
# 分析依赖类型
dep_type = self._analyze_dependence(loop1, loop2)
# RAW (Read After Write) - 允许融合
# WAR (Write After Read) - 允许融合
# WAW (Write After Write) - 需要保持顺序
# RAR (Read After Read) - 总是允许
return dep_type in ['RAW', 'WAR', 'RAR', 'NONE']
class TileFusion:
"""
平铺融合
将多个循环融合并应用平铺
"""
def fuse_and_tile(self, loops, tile_sizes):
"""
融合循环并应用平铺
"""
# 1. 融合循环
fused_loop = self._fuse_loops(loops)
# 2. 应用平铺
tiled_loop = self._tile_loop(fused_loop, tile_sizes)
return tiled_loop
def _fuse_loops(self, loops):
"""
将多个循环融合为一个
"""
# 合并循环体
fused_body = []
for loop in loops:
fused_body.extend(loop.body)
return Loop(
bounds=loops[0].bounds,
body=fused_body
)
def _tile_loop(self, loop, tile_sizes):
"""
对循环应用平铺
"""
# 创建外层循环 (平铺)
outer_loops = []
inner_loops = []
for dim, size in enumerate(tile_sizes):
outer = Loop(
var=f"tile_{dim}",
bounds=(0, loop.bounds[dim], size)
)
inner = Loop(
var=f"inner_{dim}",
bounds=(0, size, 1)
)
outer_loops.append(outer)
inner_loops.append(inner)
# 组合循环
return self._compose_loops(outer_loops + inner_loops, loop.body)
3. Kernel优化技术
3.1 内存访问优化
# 内存访问优化技术
@triton.jit
def optimized_memory_access_kernel(
input_ptr, output_ptr,
N, stride,
BLOCK_SIZE: tl.constexpr
):
"""
优化的内存访问模式
"""
pid = tl.program_id(0)
# 技术1: 合并内存访问 (Coalesced Access)
# 相邻线程访问相邻内存地址
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# 好: 连续访问
data = tl.load(input_ptr + offset) # 合并的128字节事务
# 坏: 跨步访问 (应避免)
# data = tl.load(input_ptr + offset * stride) # 多个小事务
# 技术2: 向量化加载
# 使用更宽的数据类型一次加载更多数据
# float4代替float, 减少指令数
# 技术3: 共享内存缓存
# 对于需要多次访问的数据,先加载到共享内存
@triton.jit
def tiled_matmul_with_shared_memory(
a_ptr, b_ptr, c_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""
使用共享内存的分块矩阵乘法
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 计算块位置
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# 累加器
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# 分块迭代
for k in range(0, K, BLOCK_K):
offs_k = k + tl.arange(0, BLOCK_K)
# 从全局内存加载到共享内存/寄存器
# Triton自动管理共享内存
a = tl.load(a_ptr + offs_m[:, None] * K + offs_k[None, :],
mask=(offs_m[:, None] < M) & (offs_k[None, :] < K))
b = tl.load(b_ptr + offs_k[:, None] * N + offs_n[None, :],
mask=(offs_k[:, None] < K) & (offs_n[None, :] < N))
# 同步 (确保数据加载完成)
# Triton自动插入同步
# 计算
acc += tl.dot(a, b)
# 存储结果
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
3.2 计算优化
# 计算密集型优化
@triton.jit
def optimized_compute_kernel(
x_ptr, output_ptr, N,
BLOCK_SIZE: tl.constexpr
):
"""
计算优化技术演示
"""
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offset, mask=offset < N)
# 技术1: 使用快速数学函数
# 牺牲少量精度换取性能
y = tl.math.fast_expf(x) # 比标准exp快
# 技术2: 强度削减 (Strength Reduction)
# 用便宜的操作替代昂贵的操作
# x * 2 → x + x
# x * 0.5 → x * 0.5 (乘法比除法快)
# x / constant → x * (1/constant)
# 技术3: 利用特殊硬件单元
# Tensor Core: tl.dot()
# SFU: 特殊函数如sin, cos, rsqrt
# 技术4: 指令级并行 (ILP)
# 展开计算增加独立指令
y0 = tl.math.sin(x)
y1 = tl.math.cos(x)
y2 = y0 * y1 # 可以与上面并行
tl.store(output_ptr + offset, y2, mask=offset < N)
# Tensor Core优化
@triton.jit
def tensor_core_gemm(
a_ptr, b_ptr, c_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""
利用Tensor Core的GEMM
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# 使用tl.dot触发Tensor Core
# 要求: BLOCK_M, BLOCK_N, BLOCK_K是16的倍数
# 数据类型: fp16或bf16
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
offs_k = k + tl.arange(0, BLOCK_K)
# 加载fp16数据
a = tl.load(a_ptr + offs_m[:, None] * K + offs_k[None, :]).to(tl.float16)
b = tl.load(b_ptr + offs_k[:, None] * N + offs_n[None, :]).to(tl.float16)
# tl.dot会使用Tensor Core
acc += tl.dot(a, b, out_dtype=tl.float32)
# 存储fp32结果
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, acc)
3.3 并行化优化
# 并行化策略优化
class ParallelizationOptimizer:
"""
优化kernel的并行化策略
"""
def optimize_grid_config(self, problem_size, device_info):
"""
优化grid和block配置
"""
# 设备限制
max_threads_per_block = device_info['max_threads_per_block'] # 1024
max_blocks_per_sm = device_info['max_blocks_per_sm'] # 32
num_sms = device_info['num_sms'] # e.g., 132 for H100
# 目标: 最大化SM占用率
# 计算block大小
problem_elements = problem_size[0] * problem_size[1]
# 启发式: 每个block处理足够的工作
elements_per_block = max(256, problem_elements // (num_sms * 4))
# 确保是warp大小的倍数
block_size = min(max_threads_per_block,
(elements_per_block + 31) // 32 * 32)
# 计算grid大小
num_blocks = (problem_elements + block_size - 1) // block_size
# 2D grid配置 (如果适用)
if len(problem_size) == 2:
grid_x = (problem_size[0] + 15) // 16
grid_y = (problem_size[1] + 15) // 16
block_x = 16
block_y = 16
return ((grid_x, grid_y), (block_x, block_y))
return ((num_blocks,), (block_size,))
def balance_work_distribution(self, num_blocks, num_sms):
"""
平衡工作分配,避免尾效应
"""
# 确保block数量是SM数量的倍数
# 避免最后几个SM空闲
if num_blocks % num_sms != 0:
# 增加block数量到下一个倍数
num_blocks = ((num_blocks + num_sms - 1) // num_sms) * num_sms
return num_blocks
# 动态并行化
@triton.jit
def dynamic_parallel_kernel(
data_ptr, output_ptr,
sizes_ptr, # 每个任务的大小
offsets_ptr, # 每个任务的偏移
num_tasks,
BLOCK_SIZE: tl.constexpr
):
"""
处理不规则并行工作负载
"""
task_id = tl.program_id(0)
if task_id >= num_tasks:
return
# 加载任务信息
size = tl.load(sizes_ptr + task_id)
offset = tl.load(offsets_ptr + task_id)
# 处理任务
for i in range(0, size, BLOCK_SIZE):
idx = tl.arange(0, BLOCK_SIZE)
mask = i + idx < size
data = tl.load(data_ptr + offset + i + idx, mask=mask)
result = process(data)
tl.store(output_ptr + offset + i + idx, result, mask=mask)
4. 自动调优
4.1 搜索空间定义
# 定义kernel调优的搜索空间
class KernelTuningSpace:
"""
Kernel调优搜索空间
"""
def __init__(self):
self.params = {}
def add_param(self, name, values, condition=None):
"""
添加调优参数
"""
self.params[name] = {
'values': values,
'condition': condition
}
def sample(self):
"""
采样一个配置
"""
config = {}
for name, spec in self.params.items():
if spec['condition'] is None or spec['condition'](config):
config[name] = random.choice(spec['values'])
return config
def all_configs(self):
"""
生成所有配置
"""
import itertools
param_names = list(self.params.keys())
param_values = [self.params[n]['values'] for n in param_names]
for values in itertools.product(*param_values):
config = dict(zip(param_names, values))
if self._is_valid(config):
yield config
def _is_valid(self, config):
"""
检查配置是否有效
"""
for name, spec in self.params.items():
if spec['condition'] is not None:
if not spec['condition'](config):
return False
return True
# 示例: GEMM调优空间
def create_gemm_tuning_space(M, N, K):
space = KernelTuningSpace()
# Block大小
space.add_param('BLOCK_M', [32, 64, 128, 256])
space.add_param('BLOCK_N', [32, 64, 128, 256])
space.add_param('BLOCK_K', [16, 32, 64])
# 向量化宽度
space.add_param('num_stages', [2, 3, 4, 5])
# Warp配置
space.add_param('num_warps', [4, 8])
# 条件约束
def shared_memory_constraint(config):
shared_bytes = (
config['BLOCK_M'] * config['BLOCK_K'] +
config['BLOCK_K'] * config['BLOCK_N']
) * 2 # fp16
return shared_bytes <= 48 * 1024 # 48KB限制
space.add_param('_valid', [True],
condition=shared_memory_constraint)
return space
4.2 自动调优实现
# 自动调优器
class AutoTuner:
"""
Kernel自动调优器
"""
def __init__(self, kernel_fn, space, key_fn):
self.kernel_fn = kernel_fn
self.space = space
self.key_fn = key_fn
self.cache = {}
def tune(self, *args, n_trials=100):
"""
调优kernel
"""
key = self.key_fn(*args)
if key in self.cache:
return self.cache[key]
best_config = None
best_time = float('inf')
for trial in range(n_trials):
config = self.space.sample()
try:
time = self._measure(config, *args)
if time < best_time:
best_time = time
best_config = config
print(f"Trial {trial}: {time:.4f}ms, config={config}")
except Exception as e:
# 配置无效,跳过
continue
self.cache[key] = best_config
return best_config
def _measure(self, config, *args, n_warmup=5, n_repeat=20):
"""
测量kernel执行时间
"""
import torch
# 预热
for _ in range(n_warmup):
self.kernel_fn[config](*args)
torch.cuda.synchronize()
# 计时
import time
start = time.time()
for _ in range(n_repeat):
self.kernel_fn[config](*args)
torch.cuda.synchronize()
elapsed = time.time() - start
return elapsed / n_repeat * 1000 # ms
# 使用示例
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
# 更多配置...
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_autotune(
a_ptr, b_ptr, c_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
# kernel实现
pass
5. 面试高频问题
Q1: 什么是算子融合?有什么好处?
答案要点:
- 定义: 将多个算子合并为单个kernel执行
- 好处:
- 减少kernel启动开销
- 减少中间结果的内存读写
- 提高数据局部性
- 更好的指令级并行
- 典型加速: 2-5x取决于计算/访存比
Q2: Flash Attention是如何实现融合的?
答案要点:
- 分块计算: 将QKV矩阵分块处理
- Online Softmax: 边计算边更新softmax统计量
- 内存效率: O(N)内存复杂度而非O(N^2)
- 单次遍历: 避免多次读取QKV
Q3: 如何判断两个算子是否可以融合?
答案要点:
- 形状兼容: 输出形状相同或可广播
- 依赖关系: 无循环携带依赖
- 计算特性: Element-wise更容易融合
- 硬件约束: 寄存器、共享内存足够
Q4: Triton相比CUDA有什么优势?
答案要点:
- 易用性: Python语法,更容易编写
- 自动优化: 自动选择block大小、内存布局
- 可移植性: 支持多种后端
- 快速迭代: 编译速度快,便于调试
Q5: 如何优化GPU kernel的内存访问?
答案要点:
- 合并访问: 相邻线程访问相邻内存
- 共享内存: 缓存重复访问的数据
- 向量化加载: 使用float4等宽类型
- 预取: 异步加载下一轮数据
- 避免Bank冲突: 共享内存访问模式优化
6. 学习资源
官方文档
推荐论文
- "FlashAttention: Fast and Memory-Efficient Exact Attention"
- "Automatic Kernel Generation for Deep Learning"
- "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations"
开源项目
- Triton - GPU编程语言
- FlashAttention - 高效注意力
- xFormers - 高效Transformer组件