HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

05-Triton 编程入门

概述

Triton 是 OpenAI 开发的一种用于编写高效 GPU 程序的语言和编译器。相比 CUDA,Triton 提供了更高级的抽象,让开发者可以专注于算法设计,而将内存管理、并行化等底层细节交给编译器处理。

Triton 架构

设计理念

┌─────────────────────────────────────────────────────────────────────────┐
│                      Triton 设计理念                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  CUDA 编程挑战:                                                         │
│  ├── 需要手动管理共享内存                                                │
│  ├── 需要考虑 bank conflict                                             │
│  ├── 需要显式处理内存合并                                                │
│  ├── 需要管理同步                                                       │
│  └── 需要调优 block/grid 大小                                           │
│                                                                         │
│  Triton 解决方案:                                                       │
│  ├── Block-level 编程模型                                               │
│  │   └── 每个 kernel 实例处理一个 tile                                  │
│  ├── 自动内存管理                                                       │
│  │   └── 编译器决定何时使用共享内存/寄存器                               │
│  ├── 自动并行化                                                         │
│  │   └── 编译器处理 warp 级别优化                                       │
│  └── 自动调优                                                           │
│      └── JIT 编译时可搜索最优配置                                        │
│                                                                         │
│  抽象层次对比:                                                          │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                                                                  │   │
│  │  高级 ──► NumPy/PyTorch ───► 易用,但可能低效                    │   │
│  │    │                                                             │   │
│  │    │     Triton ────────────► 平衡:较易用,高性能               │   │
│  │    │                                                             │   │
│  │  低级 ──► CUDA C++ ─────────► 最大控制,但复杂                   │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

编译流程

┌─────────────────────────────────────────────────────────────────────────┐
│                      Triton 编译流程                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Python + @triton.jit                                                   │
│       │                                                                 │
│       ▼                                                                 │
│  ┌─────────────────────┐                                               │
│  │   Triton-IR (MLIR)  │  高级中间表示                                  │
│  └──────────┬──────────┘                                               │
│             │                                                           │
│             ▼                                                           │
│  ┌─────────────────────┐                                               │
│  │   优化 Pass          │                                               │
│  │   ├── Tile 布局优化  │                                               │
│  │   ├── 内存层次优化   │                                               │
│  │   └── 并行度优化     │                                               │
│  └──────────┬──────────┘                                               │
│             │                                                           │
│             ▼                                                           │
│  ┌─────────────────────┐                                               │
│  │   LLVM IR            │                                               │
│  └──────────┬──────────┘                                               │
│             │                                                           │
│             ▼                                                           │
│  ┌─────────────────────┐                                               │
│  │   PTX / CUBIN        │  可执行的 GPU 代码                            │
│  └─────────────────────┘                                               │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

基础语法

第一个 Triton Kernel

import triton
import triton.language as tl
import torch

# 向量加法
@triton.jit
def add_kernel(
    x_ptr,      # 输入张量 X 的指针
    y_ptr,      # 输入张量 Y 的指针
    output_ptr, # 输出张量的指针
    n_elements, # 元素数量
    BLOCK_SIZE: tl.constexpr,  # 编译时常量
):
    # 获取当前程序实例的 ID (类似 blockIdx)
    pid = tl.program_id(axis=0)

    # 计算该实例负责的元素范围
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # 创建 mask 处理边界
    mask = offsets < n_elements

    # 加载数据
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # 计算
    output = x + y

    # 存储结果
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n_elements = output.numel()

    # 计算 grid 大小
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    # 启动 kernel
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output


# 测试
x = torch.rand(1000000, device='cuda')
y = torch.rand(1000000, device='cuda')
output = add(x, y)

# 验证
assert torch.allclose(output, x + y)

核心概念

import triton
import triton.language as tl

@triton.jit
def concepts_demo(
    ptr,
    n,
    BLOCK_SIZE: tl.constexpr,
):
    """演示 Triton 核心概念"""

    # 1. Program ID - 获取当前实例的索引
    # axis=0 是默认的,可以有多个轴
    pid = tl.program_id(axis=0)  # 类似 CUDA 的 blockIdx.x

    # 2. arange - 创建连续整数序列
    # 这是 Triton 的核心原语
    offsets = tl.arange(0, BLOCK_SIZE)  # [0, 1, 2, ..., BLOCK_SIZE-1]

    # 3. 指针算术
    # ptr + offsets 创建一组指针
    ptrs = ptr + pid * BLOCK_SIZE + offsets

    # 4. Mask - 处理边界条件
    mask = (pid * BLOCK_SIZE + offsets) < n

    # 5. Load - 从全局内存加载
    # mask 确保不会访问越界
    # other 指定 mask 为 False 时的默认值
    data = tl.load(ptrs, mask=mask, other=0.0)

    # 6. 计算 - 支持大多数算术操作
    result = data * 2.0 + 1.0
    result = tl.exp(result)
    result = tl.maximum(result, 0.0)

    # 7. Store - 写回全局内存
    tl.store(ptrs, result, mask=mask)

    # 8. 原子操作
    # tl.atomic_add(ptr, val)
    # tl.atomic_max(ptr, val)
    # tl.atomic_min(ptr, val)

    # 9. Reduce 操作
    sum_val = tl.sum(data, axis=0)
    max_val = tl.max(data, axis=0)

    # 10. constexpr - 编译时常量
    # BLOCK_SIZE 在编译时确定,可用于数组大小等


# 2D Grid 示例
@triton.jit
def kernel_2d(
    ptr,
    M, N,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    # 2D program ID
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)

    # 计算偏移
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    # 2D 索引
    offs = offs_m[:, None] * N + offs_n[None, :]

    # 2D mask
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    data = tl.load(ptr + offs, mask=mask)
    # ...

矩阵乘法实现

基础 GEMM

@triton.jit
def matmul_kernel(
    # 指针
    a_ptr, b_ptr, c_ptr,
    # 矩阵维度
    M, N, K,
    # Strides
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # Block 大小
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    C = A @ B
    A: [M, K], B: [K, N], C: [M, N]
    """
    # 获取 program ID
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)

    # 计算该 block 负责的输出位置
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # 初始化累加器
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # A, B 的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    # K 维度迭代
    for k in range(0, K, BLOCK_K):
        # 加载 A tile
        a_mask = (offs_m[:, None] < M) & ((k + offs_k[None, :]) < K)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)

        # 加载 B tile
        b_mask = ((k + offs_k[:, None]) < K) & (offs_n[None, :] < N)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)

        # 矩阵乘累加
        acc += tl.dot(a, b)

        # 移动指针到下一个 K block
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    # 写回结果
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    tl.store(c_ptrs, acc, mask=c_mask)


def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    M, K = a.shape
    K, N = b.shape

    c = torch.empty((M, N), device=a.device, dtype=a.dtype)

    # 配置
    BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
    )

    return c

优化的 GEMM (使用自动调优)

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K'],  # 根据这些参数选择配置
)
@triton.jit
def matmul_kernel_autotune(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    GROUP_M: tl.constexpr,  # L2 cache 优化
):
    """
    使用分组策略优化 L2 cache 命中率
    """
    pid = tl.program_id(axis=0)

    # 计算 grid 大小
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_M * num_pid_n

    # 分组重排
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # 计算偏移
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # A, B 指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    # 累加器
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # 主循环
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_K)
        b_mask = (offs_k[:, None] < K - k * BLOCK_K) & (offs_n[None, :] < N)

        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)

        acc += tl.dot(a, b)

        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    # 存储
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc, mask=c_mask)


def matmul_autotune(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    M, K = a.shape
    K, N = b.shape

    c = torch.empty((M, N), device=a.device, dtype=torch.float32)

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
    )

    matmul_kernel_autotune[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )

    return c

Softmax 实现

Fused Softmax

@triton.jit
def softmax_kernel(
    input_ptr,
    output_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    按行计算 softmax
    每个 program 处理一行
    """
    # 获取当前行
    row_idx = tl.program_id(0)

    # 计算行的起始位置
    row_start = row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    # 加载一行数据
    input_ptrs = input_ptr + row_start + col_offsets
    row = tl.load(input_ptrs, mask=mask, other=-float('inf'))

    # 计算 max (数值稳定性)
    row_max = tl.max(row, axis=0)

    # 计算 exp(x - max)
    numerator = tl.exp(row - row_max)

    # 计算 sum
    denominator = tl.sum(numerator, axis=0)

    # 归一化
    softmax_output = numerator / denominator

    # 存储结果
    output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
    tl.store(output_ptrs, softmax_output, mask=mask)


def softmax(x: torch.Tensor) -> torch.Tensor:
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)

    # 选择 block size (必须是 2 的幂)
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # 限制 block size
    BLOCK_SIZE = min(BLOCK_SIZE, 4096)

    # 每个 program 处理一行
    grid = (n_rows,)

    softmax_kernel[grid](
        x, output,
        x.stride(0), output.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output


# 支持大行的分块 softmax
@triton.jit
def softmax_kernel_large(
    input_ptr,
    output_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    支持任意大小行的 softmax
    使用 online softmax 算法
    """
    row_idx = tl.program_id(0)
    row_start = row_idx * input_row_stride

    # Online softmax: 两次遍历
    # 第一次:找 max 和计算 exp 的和
    m_i = -float('inf')  # 当前最大值
    l_i = 0.0            # 当前 exp 和

    for block_start in range(0, n_cols, BLOCK_SIZE):
        col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = col_offsets < n_cols

        # 加载块
        input_ptrs = input_ptr + row_start + col_offsets
        x = tl.load(input_ptrs, mask=mask, other=-float('inf'))

        # 更新 max
        m_ij = tl.max(x, axis=0)
        m_new = tl.maximum(m_i, m_ij)

        # 更新 sum: l_new = l_old * exp(m_old - m_new) + sum(exp(x - m_new))
        l_i = l_i * tl.exp(m_i - m_new) + tl.sum(tl.exp(x - m_new), axis=0)
        m_i = m_new

    # 第二次:计算 softmax 并写回
    for block_start in range(0, n_cols, BLOCK_SIZE):
        col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = col_offsets < n_cols

        input_ptrs = input_ptr + row_start + col_offsets
        x = tl.load(input_ptrs, mask=mask, other=-float('inf'))

        # softmax = exp(x - m) / l
        softmax_out = tl.exp(x - m_i) / l_i

        output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
        tl.store(output_ptrs, softmax_out, mask=mask)

LayerNorm 实现

@triton.jit
def layernorm_kernel(
    x_ptr,
    y_ptr,
    w_ptr,  # gamma
    b_ptr,  # beta
    mean_ptr,
    rstd_ptr,
    stride,
    N,
    eps,
    BLOCK_SIZE: tl.constexpr,
):
    """
    y = (x - mean) / sqrt(var + eps) * gamma + beta
    """
    row = tl.program_id(0)
    x_ptr += row * stride
    y_ptr += row * stride

    # 加载数据
    cols = tl.arange(0, BLOCK_SIZE)
    mask = cols < N
    x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)

    # 计算均值
    mean = tl.sum(x, axis=0) / N

    # 计算方差
    xmean = x - mean
    var = tl.sum(xmean * xmean, axis=0) / N

    # 标准化
    rstd = 1.0 / tl.sqrt(var + eps)
    x_hat = xmean * rstd

    # 加载 gamma, beta
    w = tl.load(w_ptr + cols, mask=mask, other=1.0).to(tl.float32)
    b = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32)

    # 缩放和偏移
    y = x_hat * w + b

    # 存储结果
    tl.store(y_ptr + cols, y, mask=mask)

    # 可选:存储 mean 和 rstd (用于反向传播)
    if mean_ptr is not None:
        tl.store(mean_ptr + row, mean)
    if rstd_ptr is not None:
        tl.store(rstd_ptr + row, rstd)


def layernorm(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float = 1e-5
) -> torch.Tensor:
    shape = x.shape
    x = x.view(-1, shape[-1])
    n_rows, n_cols = x.shape

    y = torch.empty_like(x)

    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    BLOCK_SIZE = min(BLOCK_SIZE, 4096)

    grid = (n_rows,)

    layernorm_kernel[grid](
        x, y, weight, bias,
        None, None,  # mean_ptr, rstd_ptr
        x.stride(0),
        n_cols,
        eps,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return y.view(shape)


# 带 Welford 算法的数值稳定版本
@triton.jit
def layernorm_kernel_welford(
    x_ptr,
    y_ptr,
    w_ptr,
    b_ptr,
    stride,
    N,
    eps,
    BLOCK_SIZE: tl.constexpr,
):
    """使用 Welford 算法的数值稳定 LayerNorm"""
    row = tl.program_id(0)
    x_ptr += row * stride
    y_ptr += row * stride

    # Welford online algorithm
    mean = 0.0
    m2 = 0.0
    count = 0.0

    for block_start in range(0, N, BLOCK_SIZE):
        cols = block_start + tl.arange(0, BLOCK_SIZE)
        mask = cols < N

        x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
        block_count = tl.sum(mask.to(tl.float32))

        # Block mean and m2
        block_mean = tl.sum(x, axis=0) / block_count
        block_m2 = tl.sum((x - block_mean) * (x - block_mean), axis=0)

        # Combine with running stats
        new_count = count + block_count
        delta = block_mean - mean
        mean = mean + delta * block_count / new_count
        m2 = m2 + block_m2 + delta * delta * count * block_count / new_count
        count = new_count

    var = m2 / count
    rstd = 1.0 / tl.sqrt(var + eps)

    # 第二次遍历:归一化
    for block_start in range(0, N, BLOCK_SIZE):
        cols = block_start + tl.arange(0, BLOCK_SIZE)
        mask = cols < N

        x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
        w = tl.load(w_ptr + cols, mask=mask, other=1.0).to(tl.float32)
        b = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32)

        y = (x - mean) * rstd * w + b

        tl.store(y_ptr + cols, y, mask=mask)

Flash Attention

@triton.jit
def flash_attention_kernel(
    Q, K, V,
    Out,
    Lse,  # log-sum-exp for backward
    stride_qb, stride_qh, stride_qm, stride_qk,
    stride_kb, stride_kh, stride_kn, stride_kk,
    stride_vb, stride_vh, stride_vn, stride_vk,
    stride_ob, stride_oh, stride_om, stride_ok,
    stride_lb, stride_lh, stride_lm,
    B, H, M, N, K,
    scale,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Flash Attention forward pass
    Q: [B, H, M, K]
    K: [B, H, N, K]
    V: [B, H, N, K]
    Out: [B, H, M, K]
    """
    # 获取 batch 和 head 索引
    batch = tl.program_id(2)
    head = tl.program_id(1)
    q_block = tl.program_id(0)

    # 计算偏移
    q_offset = batch * stride_qb + head * stride_qh
    k_offset = batch * stride_kb + head * stride_kh
    v_offset = batch * stride_vb + head * stride_vh
    o_offset = batch * stride_ob + head * stride_oh

    # Q block 范围
    q_start = q_block * BLOCK_M
    offs_m = q_start + tl.arange(0, BLOCK_M)
    offs_k = tl.arange(0, BLOCK_K)

    # 加载 Q block
    q_ptrs = Q + q_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
    q_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)

    # 初始化
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    o_i = tl.zeros([BLOCK_M, BLOCK_K], dtype=tl.float32)

    # 遍历 K/V blocks
    for kv_block in range(0, tl.cdiv(N, BLOCK_N)):
        kv_start = kv_block * BLOCK_N
        offs_n = kv_start + tl.arange(0, BLOCK_N)

        # 加载 K block
        k_ptrs = K + k_offset + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk
        k_mask = (offs_n[:, None] < N) & (offs_k[None, :] < K)
        k = tl.load(k_ptrs, mask=k_mask, other=0.0)

        # 计算 QK^T
        s = tl.dot(q, tl.trans(k)) * scale

        # Causal mask (可选)
        # if CAUSAL:
        #     s = tl.where(offs_m[:, None] >= offs_n[None, :], s, float('-inf'))

        # Online softmax
        m_ij = tl.max(s, axis=1)
        m_new = tl.maximum(m_i, m_ij)

        # 缩放因子
        alpha = tl.exp(m_i - m_new)
        p = tl.exp(s - m_new[:, None])

        # 更新 l
        l_new = alpha * l_i + tl.sum(p, axis=1)

        # 加载 V block
        v_ptrs = V + v_offset + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
        v_mask = (offs_n[:, None] < N) & (offs_k[None, :] < K)
        v = tl.load(v_ptrs, mask=v_mask, other=0.0)

        # 更新 O
        o_i = o_i * alpha[:, None] + tl.dot(p.to(tl.float16), v)

        # 更新统计量
        m_i = m_new
        l_i = l_new

    # 最终归一化
    o_i = o_i / l_i[:, None]

    # 存储输出
    o_ptrs = Out + o_offset + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
    o_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
    tl.store(o_ptrs, o_i, mask=o_mask)

    # 存储 LSE (用于反向传播)
    lse_ptrs = Lse + batch * stride_lb + head * stride_lh + offs_m * stride_lm
    lse_mask = offs_m < M
    tl.store(lse_ptrs, m_i + tl.log(l_i), mask=lse_mask)


def flash_attention(q, k, v, scale=None):
    B, H, M, K = q.shape
    _, _, N, _ = k.shape

    if scale is None:
        scale = 1.0 / (K ** 0.5)

    out = torch.empty_like(q)
    lse = torch.empty((B, H, M), device=q.device, dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, K

    grid = (triton.cdiv(M, BLOCK_M), H, B)

    flash_attention_kernel[grid](
        q, k, v, out, lse,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        out.stride(0), out.stride(1), out.stride(2), out.stride(3),
        lse.stride(0), lse.stride(1), lse.stride(2),
        B, H, M, N, K,
        scale,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
    )

    return out

性能调优

调优策略

# 使用 @triton.autotune 进行自动调优
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 64}, num_warps=2),
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 512}, num_warps=8),
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=8),
    ],
    key=['n_elements'],  # 根据这些参数选择最优配置
)
@triton.jit
def kernel_with_autotune(
    input_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(input_ptr + offsets, mask=mask)
    y = x * 2  # 简单计算
    tl.store(output_ptr + offsets, y, mask=mask)


# 手动 Benchmark
import triton.testing

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],
        x_vals=[2**i for i in range(12, 28, 1)],
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'PyTorch'],
        styles=[('blue', '-'), ('green', '-')],
        ylabel='GB/s',
        plot_name='Vector Add Performance',
        args={},
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)

    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: x + y)
    elif provider == 'triton':
        ms = triton.testing.do_bench(lambda: add(x, y))

    gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)

# 运行 benchmark
benchmark.run(show_plots=True, print_data=True)

总结

Triton vs CUDA

特性TritonCUDA
学习曲线较平缓陡峭
抽象级别Block-levelThread-level
内存管理自动手动
性能接近最优最优 (手动调优)
开发效率高低
灵活性中等最高

最佳实践

□ 选择合适的 BLOCK_SIZE (通常 64-1024)
□ 使用 @triton.autotune 自动调优
□ 合理使用 mask 处理边界
□ 利用 tl.dot 进行矩阵乘法
□ 使用 constexpr 声明编译时常量
□ Profile 验证性能
Prev
04-算子融合与优化技术