05-Triton 编程入门
概述
Triton 是 OpenAI 开发的一种用于编写高效 GPU 程序的语言和编译器。相比 CUDA,Triton 提供了更高级的抽象,让开发者可以专注于算法设计,而将内存管理、并行化等底层细节交给编译器处理。
Triton 架构
设计理念
┌─────────────────────────────────────────────────────────────────────────┐
│ Triton 设计理念 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ CUDA 编程挑战: │
│ ├── 需要手动管理共享内存 │
│ ├── 需要考虑 bank conflict │
│ ├── 需要显式处理内存合并 │
│ ├── 需要管理同步 │
│ └── 需要调优 block/grid 大小 │
│ │
│ Triton 解决方案: │
│ ├── Block-level 编程模型 │
│ │ └── 每个 kernel 实例处理一个 tile │
│ ├── 自动内存管理 │
│ │ └── 编译器决定何时使用共享内存/寄存器 │
│ ├── 自动并行化 │
│ │ └── 编译器处理 warp 级别优化 │
│ └── 自动调优 │
│ └── JIT 编译时可搜索最优配置 │
│ │
│ 抽象层次对比: │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 高级 ──► NumPy/PyTorch ───► 易用,但可能低效 │ │
│ │ │ │ │
│ │ │ Triton ────────────► 平衡:较易用,高性能 │ │
│ │ │ │ │
│ │ 低级 ──► CUDA C++ ─────────► 最大控制,但复杂 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
编译流程
┌─────────────────────────────────────────────────────────────────────────┐
│ Triton 编译流程 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Python + @triton.jit │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Triton-IR (MLIR) │ 高级中间表示 │
│ └──────────┬──────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ 优化 Pass │ │
│ │ ├── Tile 布局优化 │ │
│ │ ├── 内存层次优化 │ │
│ │ └── 并行度优化 │ │
│ └──────────┬──────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ LLVM IR │ │
│ └──────────┬──────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ PTX / CUBIN │ 可执行的 GPU 代码 │
│ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
基础语法
第一个 Triton Kernel
import triton
import triton.language as tl
import torch
# 向量加法
@triton.jit
def add_kernel(
x_ptr, # 输入张量 X 的指针
y_ptr, # 输入张量 Y 的指针
output_ptr, # 输出张量的指针
n_elements, # 元素数量
BLOCK_SIZE: tl.constexpr, # 编译时常量
):
# 获取当前程序实例的 ID (类似 blockIdx)
pid = tl.program_id(axis=0)
# 计算该实例负责的元素范围
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 创建 mask 处理边界
mask = offsets < n_elements
# 加载数据
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# 计算
output = x + y
# 存储结果
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
# 计算 grid 大小
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# 启动 kernel
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
# 测试
x = torch.rand(1000000, device='cuda')
y = torch.rand(1000000, device='cuda')
output = add(x, y)
# 验证
assert torch.allclose(output, x + y)
核心概念
import triton
import triton.language as tl
@triton.jit
def concepts_demo(
ptr,
n,
BLOCK_SIZE: tl.constexpr,
):
"""演示 Triton 核心概念"""
# 1. Program ID - 获取当前实例的索引
# axis=0 是默认的,可以有多个轴
pid = tl.program_id(axis=0) # 类似 CUDA 的 blockIdx.x
# 2. arange - 创建连续整数序列
# 这是 Triton 的核心原语
offsets = tl.arange(0, BLOCK_SIZE) # [0, 1, 2, ..., BLOCK_SIZE-1]
# 3. 指针算术
# ptr + offsets 创建一组指针
ptrs = ptr + pid * BLOCK_SIZE + offsets
# 4. Mask - 处理边界条件
mask = (pid * BLOCK_SIZE + offsets) < n
# 5. Load - 从全局内存加载
# mask 确保不会访问越界
# other 指定 mask 为 False 时的默认值
data = tl.load(ptrs, mask=mask, other=0.0)
# 6. 计算 - 支持大多数算术操作
result = data * 2.0 + 1.0
result = tl.exp(result)
result = tl.maximum(result, 0.0)
# 7. Store - 写回全局内存
tl.store(ptrs, result, mask=mask)
# 8. 原子操作
# tl.atomic_add(ptr, val)
# tl.atomic_max(ptr, val)
# tl.atomic_min(ptr, val)
# 9. Reduce 操作
sum_val = tl.sum(data, axis=0)
max_val = tl.max(data, axis=0)
# 10. constexpr - 编译时常量
# BLOCK_SIZE 在编译时确定,可用于数组大小等
# 2D Grid 示例
@triton.jit
def kernel_2d(
ptr,
M, N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# 2D program ID
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
# 计算偏移
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# 2D 索引
offs = offs_m[:, None] * N + offs_n[None, :]
# 2D mask
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
data = tl.load(ptr + offs, mask=mask)
# ...
矩阵乘法实现
基础 GEMM
@triton.jit
def matmul_kernel(
# 指针
a_ptr, b_ptr, c_ptr,
# 矩阵维度
M, N, K,
# Strides
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Block 大小
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
C = A @ B
A: [M, K], B: [K, N], C: [M, N]
"""
# 获取 program ID
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
# 计算该 block 负责的输出位置
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# 初始化累加器
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# A, B 的指针
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
# K 维度迭代
for k in range(0, K, BLOCK_K):
# 加载 A tile
a_mask = (offs_m[:, None] < M) & ((k + offs_k[None, :]) < K)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
# 加载 B tile
b_mask = ((k + offs_k[:, None]) < K) & (offs_n[None, :] < N)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
# 矩阵乘累加
acc += tl.dot(a, b)
# 移动指针到下一个 K block
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# 写回结果
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc, mask=c_mask)
def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 配置
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return c
优化的 GEMM (使用自动调优)
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K'], # 根据这些参数选择配置
)
@triton.jit
def matmul_kernel_autotune(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, # L2 cache 优化
):
"""
使用分组策略优化 L2 cache 命中率
"""
pid = tl.program_id(axis=0)
# 计算 grid 大小
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
# 分组重排
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# 计算偏移
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# A, B 指针
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
# 累加器
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# 主循环
for k in range(0, tl.cdiv(K, BLOCK_K)):
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_K)
b_mask = (offs_k[:, None] < K - k * BLOCK_K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# 存储
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc, mask=c_mask)
def matmul_autotune(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
)
matmul_kernel_autotune[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
Softmax 实现
Fused Softmax
@triton.jit
def softmax_kernel(
input_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
按行计算 softmax
每个 program 处理一行
"""
# 获取当前行
row_idx = tl.program_id(0)
# 计算行的起始位置
row_start = row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# 加载一行数据
input_ptrs = input_ptr + row_start + col_offsets
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# 计算 max (数值稳定性)
row_max = tl.max(row, axis=0)
# 计算 exp(x - max)
numerator = tl.exp(row - row_max)
# 计算 sum
denominator = tl.sum(numerator, axis=0)
# 归一化
softmax_output = numerator / denominator
# 存储结果
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
def softmax(x: torch.Tensor) -> torch.Tensor:
n_rows, n_cols = x.shape
output = torch.empty_like(x)
# 选择 block size (必须是 2 的幂)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# 限制 block size
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# 每个 program 处理一行
grid = (n_rows,)
softmax_kernel[grid](
x, output,
x.stride(0), output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
# 支持大行的分块 softmax
@triton.jit
def softmax_kernel_large(
input_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
支持任意大小行的 softmax
使用 online softmax 算法
"""
row_idx = tl.program_id(0)
row_start = row_idx * input_row_stride
# Online softmax: 两次遍历
# 第一次:找 max 和计算 exp 的和
m_i = -float('inf') # 当前最大值
l_i = 0.0 # 当前 exp 和
for block_start in range(0, n_cols, BLOCK_SIZE):
col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# 加载块
input_ptrs = input_ptr + row_start + col_offsets
x = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# 更新 max
m_ij = tl.max(x, axis=0)
m_new = tl.maximum(m_i, m_ij)
# 更新 sum: l_new = l_old * exp(m_old - m_new) + sum(exp(x - m_new))
l_i = l_i * tl.exp(m_i - m_new) + tl.sum(tl.exp(x - m_new), axis=0)
m_i = m_new
# 第二次:计算 softmax 并写回
for block_start in range(0, n_cols, BLOCK_SIZE):
col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
input_ptrs = input_ptr + row_start + col_offsets
x = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# softmax = exp(x - m) / l
softmax_out = tl.exp(x - m_i) / l_i
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
tl.store(output_ptrs, softmax_out, mask=mask)
LayerNorm 实现
@triton.jit
def layernorm_kernel(
x_ptr,
y_ptr,
w_ptr, # gamma
b_ptr, # beta
mean_ptr,
rstd_ptr,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
y = (x - mean) / sqrt(var + eps) * gamma + beta
"""
row = tl.program_id(0)
x_ptr += row * stride
y_ptr += row * stride
# 加载数据
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# 计算均值
mean = tl.sum(x, axis=0) / N
# 计算方差
xmean = x - mean
var = tl.sum(xmean * xmean, axis=0) / N
# 标准化
rstd = 1.0 / tl.sqrt(var + eps)
x_hat = xmean * rstd
# 加载 gamma, beta
w = tl.load(w_ptr + cols, mask=mask, other=1.0).to(tl.float32)
b = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# 缩放和偏移
y = x_hat * w + b
# 存储结果
tl.store(y_ptr + cols, y, mask=mask)
# 可选:存储 mean 和 rstd (用于反向传播)
if mean_ptr is not None:
tl.store(mean_ptr + row, mean)
if rstd_ptr is not None:
tl.store(rstd_ptr + row, rstd)
def layernorm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-5
) -> torch.Tensor:
shape = x.shape
x = x.view(-1, shape[-1])
n_rows, n_cols = x.shape
y = torch.empty_like(x)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
grid = (n_rows,)
layernorm_kernel[grid](
x, y, weight, bias,
None, None, # mean_ptr, rstd_ptr
x.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y.view(shape)
# 带 Welford 算法的数值稳定版本
@triton.jit
def layernorm_kernel_welford(
x_ptr,
y_ptr,
w_ptr,
b_ptr,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""使用 Welford 算法的数值稳定 LayerNorm"""
row = tl.program_id(0)
x_ptr += row * stride
y_ptr += row * stride
# Welford online algorithm
mean = 0.0
m2 = 0.0
count = 0.0
for block_start in range(0, N, BLOCK_SIZE):
cols = block_start + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
block_count = tl.sum(mask.to(tl.float32))
# Block mean and m2
block_mean = tl.sum(x, axis=0) / block_count
block_m2 = tl.sum((x - block_mean) * (x - block_mean), axis=0)
# Combine with running stats
new_count = count + block_count
delta = block_mean - mean
mean = mean + delta * block_count / new_count
m2 = m2 + block_m2 + delta * delta * count * block_count / new_count
count = new_count
var = m2 / count
rstd = 1.0 / tl.sqrt(var + eps)
# 第二次遍历:归一化
for block_start in range(0, N, BLOCK_SIZE):
cols = block_start + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(w_ptr + cols, mask=mask, other=1.0).to(tl.float32)
b = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y = (x - mean) * rstd * w + b
tl.store(y_ptr + cols, y, mask=mask)
Flash Attention
@triton.jit
def flash_attention_kernel(
Q, K, V,
Out,
Lse, # log-sum-exp for backward
stride_qb, stride_qh, stride_qm, stride_qk,
stride_kb, stride_kh, stride_kn, stride_kk,
stride_vb, stride_vh, stride_vn, stride_vk,
stride_ob, stride_oh, stride_om, stride_ok,
stride_lb, stride_lh, stride_lm,
B, H, M, N, K,
scale,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Flash Attention forward pass
Q: [B, H, M, K]
K: [B, H, N, K]
V: [B, H, N, K]
Out: [B, H, M, K]
"""
# 获取 batch 和 head 索引
batch = tl.program_id(2)
head = tl.program_id(1)
q_block = tl.program_id(0)
# 计算偏移
q_offset = batch * stride_qb + head * stride_qh
k_offset = batch * stride_kb + head * stride_kh
v_offset = batch * stride_vb + head * stride_vh
o_offset = batch * stride_ob + head * stride_oh
# Q block 范围
q_start = q_block * BLOCK_M
offs_m = q_start + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_K)
# 加载 Q block
q_ptrs = Q + q_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
q_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
# 初始化
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
o_i = tl.zeros([BLOCK_M, BLOCK_K], dtype=tl.float32)
# 遍历 K/V blocks
for kv_block in range(0, tl.cdiv(N, BLOCK_N)):
kv_start = kv_block * BLOCK_N
offs_n = kv_start + tl.arange(0, BLOCK_N)
# 加载 K block
k_ptrs = K + k_offset + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk
k_mask = (offs_n[:, None] < N) & (offs_k[None, :] < K)
k = tl.load(k_ptrs, mask=k_mask, other=0.0)
# 计算 QK^T
s = tl.dot(q, tl.trans(k)) * scale
# Causal mask (可选)
# if CAUSAL:
# s = tl.where(offs_m[:, None] >= offs_n[None, :], s, float('-inf'))
# Online softmax
m_ij = tl.max(s, axis=1)
m_new = tl.maximum(m_i, m_ij)
# 缩放因子
alpha = tl.exp(m_i - m_new)
p = tl.exp(s - m_new[:, None])
# 更新 l
l_new = alpha * l_i + tl.sum(p, axis=1)
# 加载 V block
v_ptrs = V + v_offset + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
v_mask = (offs_n[:, None] < N) & (offs_k[None, :] < K)
v = tl.load(v_ptrs, mask=v_mask, other=0.0)
# 更新 O
o_i = o_i * alpha[:, None] + tl.dot(p.to(tl.float16), v)
# 更新统计量
m_i = m_new
l_i = l_new
# 最终归一化
o_i = o_i / l_i[:, None]
# 存储输出
o_ptrs = Out + o_offset + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
o_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
tl.store(o_ptrs, o_i, mask=o_mask)
# 存储 LSE (用于反向传播)
lse_ptrs = Lse + batch * stride_lb + head * stride_lh + offs_m * stride_lm
lse_mask = offs_m < M
tl.store(lse_ptrs, m_i + tl.log(l_i), mask=lse_mask)
def flash_attention(q, k, v, scale=None):
B, H, M, K = q.shape
_, _, N, _ = k.shape
if scale is None:
scale = 1.0 / (K ** 0.5)
out = torch.empty_like(q)
lse = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, K
grid = (triton.cdiv(M, BLOCK_M), H, B)
flash_attention_kernel[grid](
q, k, v, out, lse,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
lse.stride(0), lse.stride(1), lse.stride(2),
B, H, M, N, K,
scale,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return out
性能调优
调优策略
# 使用 @triton.autotune 进行自动调优
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}, num_warps=2),
triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
triton.Config({'BLOCK_SIZE': 256}, num_warps=4),
triton.Config({'BLOCK_SIZE': 512}, num_warps=8),
triton.Config({'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['n_elements'], # 根据这些参数选择最优配置
)
@triton.jit
def kernel_with_autotune(
input_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask)
y = x * 2 # 简单计算
tl.store(output_ptr + offsets, y, mask=mask)
# 手动 Benchmark
import triton.testing
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[2**i for i in range(12, 28, 1)],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'PyTorch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='GB/s',
plot_name='Vector Add Performance',
args={},
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: x + y)
elif provider == 'triton':
ms = triton.testing.do_bench(lambda: add(x, y))
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
# 运行 benchmark
benchmark.run(show_plots=True, print_data=True)
总结
Triton vs CUDA
| 特性 | Triton | CUDA |
|---|---|---|
| 学习曲线 | 较平缓 | 陡峭 |
| 抽象级别 | Block-level | Thread-level |
| 内存管理 | 自动 | 手动 |
| 性能 | 接近最优 | 最优 (手动调优) |
| 开发效率 | 高 | 低 |
| 灵活性 | 中等 | 最高 |
最佳实践
□ 选择合适的 BLOCK_SIZE (通常 64-1024)
□ 使用 @triton.autotune 自动调优
□ 合理使用 mask 处理边界
□ 利用 tl.dot 进行矩阵乘法
□ 使用 constexpr 声明编译时常量
□ Profile 验证性能