HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

13-编译优化与图优化

章节概述

本章深入讲解深度学习编译器的核心技术,包括编译器架构、图优化、算子融合、代码生成和自动调度,帮助AI Infra工程师理解如何将高层模型表示转换为高效的底层代码。

知识体系

┌─────────────────────────────────────────────────────────────────────────────┐
│                      编译优化与图优化知识体系                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                    编译器架构                                       │     │
│  │                                                                     │     │
│  │   前端           │   中间表示      │   优化Pass     │   后端        │     │
│  │   ├─ 图导入     │   ├─ High IR   │   ├─ 图级      │   ├─ CUDA    │     │
│  │   ├─ 类型推断   │   ├─ Low IR    │   ├─ 算子级    │   ├─ CPU     │     │
│  │   └─ 形状推断   │   └─ 调度IR    │   └─ 内存优化  │   └─ NPU     │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                │                                             │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                    TorchDynamo + Inductor                           │     │
│  │                                                                     │     │
│  │   字节码追踪     │   FX Graph      │   代码生成     │   自动优化    │     │
│  │   ├─ 符号执行   │   ├─ Node      │   ├─ Triton   │   ├─ 融合     │     │
│  │   ├─ Guard机制  │   ├─ 变换      │   ├─ C++      │   ├─ 调度     │     │
│  │   └─ Graph Break│   └─ Pass      │   └─ OpenAI   │   └─ 内存     │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                │                                             │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                    XLA编译器                                        │     │
│  │                                                                     │     │
│  │   HLO IR        │   优化Pass      │   GPU后端      │   JAX集成     │     │
│  │   ├─ 指令      │   ├─ 代数简化   │   ├─ IrEmitter │   ├─ JIT     │     │
│  │   ├─ 计算      │   ├─ 融合       │   ├─ Thunk    │   ├─ Jaxpr   │     │
│  │   └─ 模块      │   └─ 布局分配   │   └─ LLVM     │   └─ 控制流   │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                │                                             │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                    算子融合与Kernel优化                             │     │
│  │                                                                     │     │
│  │   融合类型       │   融合决策      │   内存优化     │   计算优化    │     │
│  │   ├─ 元素级     │   ├─ 依赖分析  │   ├─ 合并访问  │   ├─ 向量化   │     │
│  │   ├─ 规约       │   ├─ 代价模型  │   ├─ 共享内存  │   ├─ Tensor Core│    │
│  │   └─ 注意力     │   └─ 循环分析  │   └─ 预取      │   └─ 指令级   │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                │                                             │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                    自动调度与代码生成                               │     │
│  │                                                                     │     │
│  │   调度原语       │   搜索算法      │   代价模型     │   代码生成    │     │
│  │   ├─ split     │   ├─ 随机搜索  │   ├─ 实际测量  │   ├─ CUDA    │     │
│  │   ├─ reorder   │   ├─ 遗传算法  │   ├─ ML预测   │   ├─ LLVM    │     │
│  │   └─ bind      │   └─ 强化学习  │   └─ XGBoost  │   └─ Triton  │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

章节内容

01-深度学习编译器概述

为什么需要编译器

  • 传统方法的局限性
  • 编译器的价值:自动优化、跨平台

编译器架构

  • 前端、中间表示、优化、后端
  • IR设计原则

主流编译器对比

  • TorchDynamo vs XLA vs TVM vs MLIR
  • 各自特点和适用场景

编译器优化技术概览

  • 图级优化
  • 算子级优化
  • 自动调优

02-TorchDynamo与torch.compile

TorchDynamo工作原理

  • Python字节码追踪
  • 符号追踪机制
  • Guard生成
  • Graph Break处理

FX Graph详解

  • Node结构
  • Graph变换
  • 优化Pass实现

TorchInductor代码生成

  • Lowering过程
  • 调度与融合
  • Triton代码生成

torch.compile使用

  • 基本用法
  • 编译选项
  • 调试技巧
  • 性能优化

03-XLA编译器深度解析

HLO中间表示

  • HloInstruction, HloComputation, HloModule
  • HLO文本表示
  • HLO构建API

HLO优化Pass

  • 代数简化 (Algebraic Simplifier)
  • 算子融合 (Fusion)
  • 布局分配 (Layout Assignment)
  • 内存分配 (Buffer Assignment)

GPU代码生成

  • IrEmitter
  • LLVM IR生成
  • Thunk执行

JAX与XLA集成

  • JAX编译流程
  • 控制流原语
  • 自定义XLA调用

04-算子融合与Kernel优化

算子融合类型

  • 元素级融合 (Element-wise Fusion)
  • 规约融合 (Reduction Fusion)
  • 矩阵乘法融合 (MatMul Fusion)
  • 注意力融合 (Flash Attention)

融合决策算法

  • 依赖分析
  • 代价模型
  • 循环融合分析

Kernel优化技术

  • 内存访问优化
  • 计算优化
  • 并行化优化

自动调优

  • 搜索空间定义
  • 调优实现

05-自动调度与代码生成

调度原语

  • 循环变换: split, reorder, tile, fuse
  • 并行化: parallel, vectorize, unroll
  • 内存: cache_read, cache_write, compute_at
  • 硬件映射: bind

搜索算法

  • 随机搜索
  • 模拟退火
  • 遗传算法
  • 基于代价模型的搜索
  • 强化学习

代码生成

  • CUDA代码生成
  • LLVM IR代码生成
  • Triton代码生成

核心代码示例

torch.compile使用

import torch

# 基本用法
@torch.compile
def fn(x, y):
    return torch.sin(x) + torch.cos(y)

# 高级选项
@torch.compile(mode="max-autotune", fullgraph=True)
def optimized_fn(x, y):
    return torch.matmul(x, y)

# 查看编译信息
explanation = torch._dynamo.explain(fn)
print(explanation)

Triton Kernel融合

import triton
import triton.language as tl

@triton.jit
def fused_gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offset < n_elements

    # 单次内存读取
    x = tl.load(x_ptr + offset, mask=mask)

    # 所有计算在寄存器中
    t = x * 0.7978845608028654 * (1.0 + 0.044715 * x * x)
    output = x * 0.5 * (1.0 + tl.math.tanh(t))

    # 单次内存写入
    tl.store(output_ptr + offset, output, mask=mask)

XLA HLO示例

HloModule softmax_example

ENTRY softmax {
  %input = f32[16,1024] parameter(0)

  // max
  %max = f32[16] reduce(%input, %neg_inf), dimensions={1}, to_apply=%max_fn

  // exp(x - max)
  %broadcast_max = f32[16,1024] broadcast(%max), dimensions={0}
  %sub = f32[16,1024] subtract(%input, %broadcast_max)
  %exp = f32[16,1024] exponential(%sub)

  // sum
  %sum = f32[16] reduce(%exp, %zero), dimensions={1}, to_apply=%add_fn

  // normalize
  %broadcast_sum = f32[16,1024] broadcast(%sum), dimensions={0}
  ROOT %softmax = f32[16,1024] divide(%exp, %broadcast_sum)
}

性能对比参考

编译器加速比

场景torch.compileXLATVM
ResNet-50推理1.5x1.4x1.6x
BERT训练1.3x1.4x-
GPT-2生成2.0x1.8x-
Custom ops2-5x2-3x3-10x

算子融合效果

融合模式内存减少加速比
GELU (7 ops → 1)6x2-3x
LayerNorm (5 ops → 1)4x2x
Softmax (5 ops → 1)4x2x
Flash AttentionN² → N2-4x

学习路径

Week 1-2: 编译器基础
├── 阅读 01-深度学习编译器概述
├── 理解编译器架构和IR设计
├── 对比不同编译器方案
└── 实践torch.compile基本用法

Week 3-4: TorchDynamo深入
├── 阅读 02-TorchDynamo与torch.compile
├── 理解字节码追踪机制
├── 学习FX Graph变换
├── 实践Inductor代码生成
└── 调试和优化技巧

Week 5-6: XLA编译器
├── 阅读 03-XLA编译器深度解析
├── 理解HLO IR设计
├── 分析优化Pass实现
├── 学习JAX编程
└── 性能分析与调优

Week 7-8: 算子融合
├── 阅读 04-算子融合与Kernel优化
├── 实现简单的融合kernel
├── 学习Triton编程
├── 分析Flash Attention实现
└── Kernel性能优化实践

Week 9-10: 自动调度
├── 阅读 05-自动调度与代码生成
├── 理解调度原语
├── 实现搜索算法
├── 学习代码生成技术
└── 端到端优化实践

推荐资源

官方文档

  • PyTorch Compiler
  • XLA Documentation
  • TVM Documentation
  • Triton Documentation
  • MLIR Documentation

开源项目

  • PyTorch - torch/_dynamo/, torch/_inductor/
  • TVM - 通用深度学习编译器
  • Triton - GPU编程语言
  • FlashAttention - 高效注意力实现

经典论文

  • "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning"
  • "XLA: TensorFlow, Compiled"
  • "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations"
  • "FlashAttention: Fast and Memory-Efficient Exact Attention"
  • "Ansor: Generating High-Performance Tensor Programs for Deep Learning"

下一步学习

完成本章后,建议回顾和综合应用:

  1. 10-CUDA编程与算子开发 - 理解底层GPU编程
  2. 11-通信与网络底层 - 理解分布式通信
  3. 12-框架源码解析 - 理解框架如何使用编译器
Next
01-深度学习编译器概述