13-编译优化与图优化
章节概述
本章深入讲解深度学习编译器的核心技术,包括编译器架构、图优化、算子融合、代码生成和自动调度,帮助AI Infra工程师理解如何将高层模型表示转换为高效的底层代码。
知识体系
┌─────────────────────────────────────────────────────────────────────────────┐
│ 编译优化与图优化知识体系 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ 编译器架构 │ │
│ │ │ │
│ │ 前端 │ 中间表示 │ 优化Pass │ 后端 │ │
│ │ ├─ 图导入 │ ├─ High IR │ ├─ 图级 │ ├─ CUDA │ │
│ │ ├─ 类型推断 │ ├─ Low IR │ ├─ 算子级 │ ├─ CPU │ │
│ │ └─ 形状推断 │ └─ 调度IR │ └─ 内存优化 │ └─ NPU │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ TorchDynamo + Inductor │ │
│ │ │ │
│ │ 字节码追踪 │ FX Graph │ 代码生成 │ 自动优化 │ │
│ │ ├─ 符号执行 │ ├─ Node │ ├─ Triton │ ├─ 融合 │ │
│ │ ├─ Guard机制 │ ├─ 变换 │ ├─ C++ │ ├─ 调度 │ │
│ │ └─ Graph Break│ └─ Pass │ └─ OpenAI │ └─ 内存 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ XLA编译器 │ │
│ │ │ │
│ │ HLO IR │ 优化Pass │ GPU后端 │ JAX集成 │ │
│ │ ├─ 指令 │ ├─ 代数简化 │ ├─ IrEmitter │ ├─ JIT │ │
│ │ ├─ 计算 │ ├─ 融合 │ ├─ Thunk │ ├─ Jaxpr │ │
│ │ └─ 模块 │ └─ 布局分配 │ └─ LLVM │ └─ 控制流 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ 算子融合与Kernel优化 │ │
│ │ │ │
│ │ 融合类型 │ 融合决策 │ 内存优化 │ 计算优化 │ │
│ │ ├─ 元素级 │ ├─ 依赖分析 │ ├─ 合并访问 │ ├─ 向量化 │ │
│ │ ├─ 规约 │ ├─ 代价模型 │ ├─ 共享内存 │ ├─ Tensor Core│ │
│ │ └─ 注意力 │ └─ 循环分析 │ └─ 预取 │ └─ 指令级 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ 自动调度与代码生成 │ │
│ │ │ │
│ │ 调度原语 │ 搜索算法 │ 代价模型 │ 代码生成 │ │
│ │ ├─ split │ ├─ 随机搜索 │ ├─ 实际测量 │ ├─ CUDA │ │
│ │ ├─ reorder │ ├─ 遗传算法 │ ├─ ML预测 │ ├─ LLVM │ │
│ │ └─ bind │ └─ 强化学习 │ └─ XGBoost │ └─ Triton │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
章节内容
01-深度学习编译器概述
为什么需要编译器
- 传统方法的局限性
- 编译器的价值:自动优化、跨平台
编译器架构
- 前端、中间表示、优化、后端
- IR设计原则
主流编译器对比
- TorchDynamo vs XLA vs TVM vs MLIR
- 各自特点和适用场景
编译器优化技术概览
- 图级优化
- 算子级优化
- 自动调优
02-TorchDynamo与torch.compile
TorchDynamo工作原理
- Python字节码追踪
- 符号追踪机制
- Guard生成
- Graph Break处理
FX Graph详解
- Node结构
- Graph变换
- 优化Pass实现
TorchInductor代码生成
- Lowering过程
- 调度与融合
- Triton代码生成
torch.compile使用
- 基本用法
- 编译选项
- 调试技巧
- 性能优化
03-XLA编译器深度解析
HLO中间表示
- HloInstruction, HloComputation, HloModule
- HLO文本表示
- HLO构建API
HLO优化Pass
- 代数简化 (Algebraic Simplifier)
- 算子融合 (Fusion)
- 布局分配 (Layout Assignment)
- 内存分配 (Buffer Assignment)
GPU代码生成
- IrEmitter
- LLVM IR生成
- Thunk执行
JAX与XLA集成
- JAX编译流程
- 控制流原语
- 自定义XLA调用
04-算子融合与Kernel优化
算子融合类型
- 元素级融合 (Element-wise Fusion)
- 规约融合 (Reduction Fusion)
- 矩阵乘法融合 (MatMul Fusion)
- 注意力融合 (Flash Attention)
融合决策算法
- 依赖分析
- 代价模型
- 循环融合分析
Kernel优化技术
- 内存访问优化
- 计算优化
- 并行化优化
自动调优
- 搜索空间定义
- 调优实现
05-自动调度与代码生成
调度原语
- 循环变换: split, reorder, tile, fuse
- 并行化: parallel, vectorize, unroll
- 内存: cache_read, cache_write, compute_at
- 硬件映射: bind
搜索算法
- 随机搜索
- 模拟退火
- 遗传算法
- 基于代价模型的搜索
- 强化学习
代码生成
- CUDA代码生成
- LLVM IR代码生成
- Triton代码生成
核心代码示例
torch.compile使用
import torch
# 基本用法
@torch.compile
def fn(x, y):
return torch.sin(x) + torch.cos(y)
# 高级选项
@torch.compile(mode="max-autotune", fullgraph=True)
def optimized_fn(x, y):
return torch.matmul(x, y)
# 查看编译信息
explanation = torch._dynamo.explain(fn)
print(explanation)
Triton Kernel融合
import triton
import triton.language as tl
@triton.jit
def fused_gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements
# 单次内存读取
x = tl.load(x_ptr + offset, mask=mask)
# 所有计算在寄存器中
t = x * 0.7978845608028654 * (1.0 + 0.044715 * x * x)
output = x * 0.5 * (1.0 + tl.math.tanh(t))
# 单次内存写入
tl.store(output_ptr + offset, output, mask=mask)
XLA HLO示例
HloModule softmax_example
ENTRY softmax {
%input = f32[16,1024] parameter(0)
// max
%max = f32[16] reduce(%input, %neg_inf), dimensions={1}, to_apply=%max_fn
// exp(x - max)
%broadcast_max = f32[16,1024] broadcast(%max), dimensions={0}
%sub = f32[16,1024] subtract(%input, %broadcast_max)
%exp = f32[16,1024] exponential(%sub)
// sum
%sum = f32[16] reduce(%exp, %zero), dimensions={1}, to_apply=%add_fn
// normalize
%broadcast_sum = f32[16,1024] broadcast(%sum), dimensions={0}
ROOT %softmax = f32[16,1024] divide(%exp, %broadcast_sum)
}
性能对比参考
编译器加速比
| 场景 | torch.compile | XLA | TVM |
|---|---|---|---|
| ResNet-50推理 | 1.5x | 1.4x | 1.6x |
| BERT训练 | 1.3x | 1.4x | - |
| GPT-2生成 | 2.0x | 1.8x | - |
| Custom ops | 2-5x | 2-3x | 3-10x |
算子融合效果
| 融合模式 | 内存减少 | 加速比 |
|---|---|---|
| GELU (7 ops → 1) | 6x | 2-3x |
| LayerNorm (5 ops → 1) | 4x | 2x |
| Softmax (5 ops → 1) | 4x | 2x |
| Flash Attention | N² → N | 2-4x |
学习路径
Week 1-2: 编译器基础
├── 阅读 01-深度学习编译器概述
├── 理解编译器架构和IR设计
├── 对比不同编译器方案
└── 实践torch.compile基本用法
Week 3-4: TorchDynamo深入
├── 阅读 02-TorchDynamo与torch.compile
├── 理解字节码追踪机制
├── 学习FX Graph变换
├── 实践Inductor代码生成
└── 调试和优化技巧
Week 5-6: XLA编译器
├── 阅读 03-XLA编译器深度解析
├── 理解HLO IR设计
├── 分析优化Pass实现
├── 学习JAX编程
└── 性能分析与调优
Week 7-8: 算子融合
├── 阅读 04-算子融合与Kernel优化
├── 实现简单的融合kernel
├── 学习Triton编程
├── 分析Flash Attention实现
└── Kernel性能优化实践
Week 9-10: 自动调度
├── 阅读 05-自动调度与代码生成
├── 理解调度原语
├── 实现搜索算法
├── 学习代码生成技术
└── 端到端优化实践
推荐资源
官方文档
开源项目
- PyTorch - torch/_dynamo/, torch/_inductor/
- TVM - 通用深度学习编译器
- Triton - GPU编程语言
- FlashAttention - 高效注意力实现
经典论文
- "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning"
- "XLA: TensorFlow, Compiled"
- "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations"
- "FlashAttention: Fast and Memory-Efficient Exact Attention"
- "Ansor: Generating High-Performance Tensor Programs for Deep Learning"
下一步学习
完成本章后,建议回顾和综合应用:
- 10-CUDA编程与算子开发 - 理解底层GPU编程
- 11-通信与网络底层 - 理解分布式通信
- 12-框架源码解析 - 理解框架如何使用编译器