02-TorchDynamo与torch.compile
概述
PyTorch 2.0引入的torch.compile是一个革命性的编译API,背后由TorchDynamo和TorchInductor驱动。TorchDynamo通过Python字节码追踪捕获计算图,TorchInductor负责生成高效的GPU/CPU代码。本章深入解析其工作原理和源码实现。
整体架构
┌─────────────────────────────────────────────────────────────────────────────┐
│ torch.compile 架构 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Python 代码 │ │
│ │ @torch.compile │ │
│ │ def fn(x, y): │ │
│ │ return torch.sin(x) + torch.cos(y) │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ TorchDynamo │ │
│ │ ├─ Python字节码拦截 │ │
│ │ ├─ 符号追踪 (Symbolic Tracing) │ │
│ │ ├─ Guard生成 │ │
│ │ └─ Graph Break处理 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ FX Graph │ │
│ │ graph(): │ │
│ │ %x : [num_users=1] = placeholder[target=x] │ │
│ │ %y : [num_users=1] = placeholder[target=y] │ │
│ │ %sin : [num_users=1] = call_function[target=torch.sin] │ │
│ │ %cos : [num_users=1] = call_function[target=torch.cos] │ │
│ │ %add : [num_users=1] = call_function[target=operator.add] │ │
│ │ return add │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Backend (TorchInductor) │ │
│ │ ├─ 图优化 (Fusion, DCE, CSE) │ │
│ │ ├─ 调度 (Scheduling) │ │
│ │ └─ 代码生成 (Triton/C++) │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Compiled Code │ │
│ │ Triton Kernel / C++ Kernel │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
1. TorchDynamo工作原理
1.1 Python字节码追踪
# TorchDynamo的核心: 拦截Python字节码执行
import dis
import torch
def simple_fn(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
# 查看Python字节码
print("Bytecode:")
dis.dis(simple_fn)
"""
输出:
2 0 LOAD_GLOBAL 0 (torch)
2 LOAD_ATTR 1 (sin)
4 LOAD_FAST 0 (x)
6 CALL_FUNCTION 1
8 STORE_FAST 2 (a)
3 10 LOAD_GLOBAL 0 (torch)
12 LOAD_ATTR 2 (cos)
14 LOAD_FAST 1 (y)
16 CALL_FUNCTION 1
18 STORE_FAST 3 (b)
4 20 LOAD_FAST 2 (a)
22 LOAD_FAST 3 (b)
24 BINARY_ADD
26 RETURN_VALUE
"""
# TorchDynamo的工作方式
# 1. 设置帧评估钩子 (PEP 523)
# 2. 拦截每条字节码指令
# 3. 对张量操作进行符号追踪
# 4. 生成FX Graph
1.2 符号追踪机制
# torch/_dynamo/symbolic_convert.py (简化版)
class InstructionTranslator:
"""
字节码指令翻译器
将Python字节码转换为FX图操作
"""
def __init__(self, frame, compiler_fn):
self.frame = frame
self.graph = torch.fx.Graph()
self.symbolic_locals = {} # 符号化的局部变量
self.compiler_fn = compiler_fn
def LOAD_FAST(self, inst):
"""
处理 LOAD_FAST 指令
加载局部变量
"""
name = inst.argval
if name in self.symbolic_locals:
# 已经符号化的变量
self.push(self.symbolic_locals[name])
else:
# 从实际帧获取值
value = self.frame.f_locals[name]
if isinstance(value, torch.Tensor):
# 创建placeholder节点
proxy = self.graph.placeholder(name)
self.symbolic_locals[name] = proxy
self.push(proxy)
else:
self.push(value)
def CALL_FUNCTION(self, inst):
"""
处理 CALL_FUNCTION 指令
函数调用
"""
nargs = inst.argval
args = [self.pop() for _ in range(nargs)][::-1]
fn = self.pop()
# 检查是否是支持的PyTorch操作
if self.is_torch_op(fn):
# 创建call_function节点
result = self.graph.call_function(fn, args)
self.push(result)
else:
# 不支持的操作,触发graph break
raise GraphBreak(f"Unsupported function: {fn}")
def BINARY_ADD(self, inst):
"""
处理 BINARY_ADD 指令
加法操作
"""
b = self.pop()
a = self.pop()
if isinstance(a, torch.fx.Proxy) or isinstance(b, torch.fx.Proxy):
# 创建add节点
import operator
result = self.graph.call_function(operator.add, (a, b))
self.push(result)
else:
# 常量折叠
self.push(a + b)
def run(self):
"""
执行字节码翻译
"""
for inst in self.instructions:
handler = getattr(self, inst.opname, None)
if handler is None:
raise GraphBreak(f"Unsupported opcode: {inst.opname}")
handler(inst)
# 返回FX Graph
return self.graph
class SymbolicValue:
"""
符号值
表示追踪过程中的符号张量
"""
def __init__(self, proxy, example_value):
self.proxy = proxy # FX Proxy节点
self.example_value = example_value # 示例值用于形状推断
@property
def shape(self):
return self.example_value.shape
@property
def dtype(self):
return self.example_value.dtype
1.3 Guard机制
# torch/_dynamo/guards.py (简化版)
class GuardBuilder:
"""
Guard构建器
生成运行时检查,确保编译代码的有效性
"""
def __init__(self):
self.guards = []
def tensor_guard(self, name, tensor):
"""
张量Guard
检查张量属性是否匹配
"""
# 形状检查
self.guards.append(
f"isinstance({name}, torch.Tensor) and "
f"{name}.shape == {tuple(tensor.shape)} and "
f"{name}.dtype == {tensor.dtype} and "
f"{name}.device == {tensor.device} and "
f"{name}.requires_grad == {tensor.requires_grad}"
)
def type_guard(self, name, value):
"""
类型Guard
"""
self.guards.append(f"type({name}) is {type(value).__name__}")
def value_guard(self, name, value):
"""
值Guard (用于常量)
"""
self.guards.append(f"{name} == {repr(value)}")
def build(self):
"""
生成Guard函数
"""
guard_code = " and ".join(self.guards)
return eval(f"lambda locals_: {guard_code}")
class GuardedCode:
"""
带Guard的编译代码
"""
def __init__(self, compiled_fn, guard_fn, graph):
self.compiled_fn = compiled_fn
self.guard_fn = guard_fn
self.graph = graph
def __call__(self, *args, **kwargs):
# 检查Guard
locals_dict = dict(zip(self.graph.input_names, args))
if not self.guard_fn(locals_dict):
# Guard失败,需要重新编译
raise GuardFail("Guard check failed")
return self.compiled_fn(*args, **kwargs)
# Guard使用示例
def example_guard_flow():
"""
Guard工作流程示例
"""
@torch.compile
def fn(x, y):
return x + y
# 第一次调用
x1 = torch.randn(10, 10)
y1 = torch.randn(10, 10)
result1 = fn(x1, y1) # 编译并生成Guard
# 第二次调用 (形状相同)
x2 = torch.randn(10, 10)
y2 = torch.randn(10, 10)
result2 = fn(x2, y2) # Guard通过,使用编译代码
# 第三次调用 (形状不同)
x3 = torch.randn(20, 20)
y3 = torch.randn(20, 20)
result3 = fn(x3, y3) # Guard失败,重新编译
1.4 Graph Break处理
# torch/_dynamo/exc.py
class GraphBreak(Exception):
"""
Graph Break异常
当遇到无法追踪的操作时抛出
"""
pass
# 常见的Graph Break原因
GRAPH_BREAK_REASONS = {
"data_dependent_control_flow": "数据依赖的控制流",
"unsupported_op": "不支持的操作",
"dynamic_shape": "动态形状操作",
"inline_call": "内联调用失败",
"graph_output": "图输出问题",
}
# torch/_dynamo/symbolic_convert.py
class GraphBreakHandler:
"""
Graph Break处理器
"""
def __init__(self, translator):
self.translator = translator
self.sub_graphs = []
def handle_break(self, break_reason):
"""
处理Graph Break
将当前图分割,继续追踪
"""
# 1. 完成当前子图
current_graph = self.translator.finish_subgraph()
self.sub_graphs.append(current_graph)
# 2. 执行导致break的代码(不追踪)
self.execute_break_code()
# 3. 开始新的子图
self.translator.start_new_subgraph()
def execute_break_code(self):
"""
执行导致break的Python代码
"""
# 回退到Python解释器执行
pass
# Graph Break示例
def graph_break_example():
@torch.compile
def fn(x, cond):
a = torch.sin(x) # 可追踪
if cond.item(): # Graph Break! 数据依赖的控制流
b = torch.cos(a)
else:
b = torch.tan(a)
c = b + 1 # 新的子图
return c
# 这会生成多个子图
x = torch.randn(10)
cond = torch.tensor(True)
result = fn(x, cond)
# 避免Graph Break的技巧
def avoid_graph_break():
# 1. 使用torch.where替代if-else
@torch.compile
def good_fn(x, cond):
a = torch.sin(x)
b = torch.where(cond, torch.cos(a), torch.tan(a))
return b + 1
# 2. 使用torch.cond (PyTorch 2.1+)
@torch.compile
def better_fn(x, cond):
a = torch.sin(x)
def true_fn(a):
return torch.cos(a)
def false_fn(a):
return torch.tan(a)
b = torch.cond(cond, true_fn, false_fn, (a,))
return b + 1
2. FX Graph详解
2.1 FX Graph结构
# torch/fx/graph.py (简化版)
class Node:
"""
FX图中的节点
"""
def __init__(self, graph, name, op, target, args, kwargs):
self.graph = graph
self.name = name
self.op = op # 操作类型: placeholder, call_function, call_method, etc.
self.target = target # 目标函数或属性
self.args = args # 位置参数
self.kwargs = kwargs # 关键字参数
self.users = {} # 使用此节点的其他节点
def __repr__(self):
return f"{self.name} = {self.op}[target={self.target}]({self.args})"
class Graph:
"""
FX计算图
"""
def __init__(self):
self.nodes = []
self._node_name_counter = 0
def placeholder(self, name):
"""创建输入占位符"""
node = Node(self, name, 'placeholder', name, (), {})
self.nodes.append(node)
return Proxy(node)
def call_function(self, target, args, kwargs=None):
"""创建函数调用节点"""
name = self._generate_name(target.__name__)
node = Node(self, name, 'call_function', target, args, kwargs or {})
self.nodes.append(node)
return Proxy(node)
def call_method(self, target, args, kwargs=None):
"""创建方法调用节点"""
name = self._generate_name(target)
node = Node(self, name, 'call_method', target, args, kwargs or {})
self.nodes.append(node)
return Proxy(node)
def output(self, result):
"""设置图输出"""
node = Node(self, 'output', 'output', 'output', (result,), {})
self.nodes.append(node)
def _generate_name(self, base):
self._node_name_counter += 1
return f"{base}_{self._node_name_counter}"
def print_tabular(self):
"""打印图结构"""
print("opcode target args kwargs")
print("-" * 60)
for node in self.nodes:
print(f"{node.op:12} {str(node.target):16} {str(node.args):12} {str(node.kwargs)}")
class Proxy:
"""
代理对象
用于构建图的流畅API
"""
def __init__(self, node):
self.node = node
def __add__(self, other):
import operator
return self.node.graph.call_function(operator.add, (self, other))
def __mul__(self, other):
import operator
return self.node.graph.call_function(operator.mul, (self, other))
# FX Graph使用示例
def fx_graph_example():
import torch
import torch.fx as fx
class MyModule(torch.nn.Module):
def forward(self, x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
# 符号追踪
traced = fx.symbolic_trace(MyModule())
print(traced.graph)
# 输出:
# graph():
# %x : [num_users=1] = placeholder[target=x]
# %y : [num_users=1] = placeholder[target=y]
# %sin : [num_users=1] = call_function[target=torch.sin](args = (%x,))
# %cos : [num_users=1] = call_function[target=torch.cos](args = (%y,))
# %add : [num_users=1] = call_function[target=operator.add](args = (%sin, %cos))
# return add
2.2 FX Graph变换
# torch/fx/passes (简化版)
class GraphTransform:
"""
图变换基类
"""
def __call__(self, graph):
return self.transform(graph)
def transform(self, graph):
raise NotImplementedError
class PatternMatcher:
"""
模式匹配器
用于识别可优化的图模式
"""
def __init__(self, pattern):
self.pattern = pattern
def match(self, graph, node):
"""
尝试从给定节点开始匹配模式
"""
# 简化的模式匹配逻辑
pass
class FusionPass(GraphTransform):
"""
算子融合Pass
"""
def transform(self, graph):
# 遍历节点寻找融合机会
for node in graph.nodes:
# 示例: sin + cos 融合
if self._is_fusable_pattern(node):
self._fuse_nodes(graph, node)
return graph
def _is_fusable_pattern(self, node):
"""检查是否可融合"""
# 检查是否是element-wise操作
# 检查是否有相同的输入/输出形状
pass
def _fuse_nodes(self, graph, node):
"""执行融合"""
pass
class DeadCodeElimination(GraphTransform):
"""
死代码消除Pass
"""
def transform(self, graph):
# 从输出节点反向标记活跃节点
live_nodes = set()
self._mark_live(graph.output_node, live_nodes)
# 删除非活跃节点
graph.nodes = [n for n in graph.nodes if n in live_nodes]
return graph
def _mark_live(self, node, live_nodes):
if node in live_nodes:
return
live_nodes.add(node)
for arg in node.args:
if isinstance(arg, Node):
self._mark_live(arg, live_nodes)
class CommonSubexpressionElimination(GraphTransform):
"""
公共子表达式消除Pass
"""
def transform(self, graph):
# 记录已见过的表达式
seen = {}
for node in graph.nodes:
key = (node.op, node.target, tuple(node.args), tuple(node.kwargs.items()))
if key in seen:
# 用已有节点替换
self._replace_uses(graph, node, seen[key])
else:
seen[key] = node
return graph
# 应用Pass示例
def apply_passes_example():
import torch
import torch.fx as fx
class MyModule(torch.nn.Module):
def forward(self, x):
a = torch.sin(x)
b = torch.sin(x) # 重复计算
c = a + b
d = torch.cos(x) # 死代码 (未使用)
return c
traced = fx.symbolic_trace(MyModule())
# 应用优化Pass
passes = [
CommonSubexpressionElimination(),
DeadCodeElimination(),
]
graph = traced.graph
for pass_fn in passes:
graph = pass_fn(graph)
print(graph)
3. TorchInductor代码生成
3.1 Inductor架构
┌─────────────────────────────────────────────────────────────────────────────┐
│ TorchInductor 架构 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ FX Graph │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Decomposition │ │
│ │ 将高层算子分解为基础算子 │ │
│ │ e.g., LayerNorm → mean, sub, var, rsqrt, mul, add │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Lowering │ │
│ │ 将FX节点转换为IR节点 │ │
│ │ 构建操作依赖关系 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Scheduling │ │
│ │ 决定哪些操作融合到一个kernel │ │
│ │ 优化内存访问模式 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Code Generation │ │
│ │ GPU: 生成Triton代码 │ │
│ │ CPU: 生成C++代码 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Compiled Kernel │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
3.2 Lowering过程
# torch/_inductor/lowering.py (简化版)
class IRNode:
"""
Inductor IR节点基类
"""
pass
class Pointwise(IRNode):
"""
逐点操作
"""
def __init__(self, dtype, inner_fn, ranges):
self.dtype = dtype
self.inner_fn = inner_fn # 计算函数
self.ranges = ranges # 循环范围
def __repr__(self):
return f"Pointwise(dtype={self.dtype}, ranges={self.ranges})"
class Reduction(IRNode):
"""
规约操作
"""
def __init__(self, dtype, inner_fn, ranges, reduction_ranges, reduction_type):
self.dtype = dtype
self.inner_fn = inner_fn
self.ranges = ranges
self.reduction_ranges = reduction_ranges
self.reduction_type = reduction_type # sum, mean, max, etc.
class Buffer(IRNode):
"""
内存缓冲区
"""
def __init__(self, name, layout):
self.name = name
self.layout = layout
# Lowering函数
def lower_sin(x):
"""
将torch.sin lowering为Pointwise操作
"""
def inner_fn(index):
return f"tl.sin({x.loader.make_load(index)})"
return Pointwise(
dtype=x.dtype,
inner_fn=inner_fn,
ranges=x.ranges,
)
def lower_add(a, b):
"""
将加法lowering为Pointwise操作
"""
def inner_fn(index):
return f"({a.loader.make_load(index)} + {b.loader.make_load(index)})"
return Pointwise(
dtype=a.dtype,
inner_fn=inner_fn,
ranges=a.ranges,
)
def lower_sum(x, dims):
"""
将求和lowering为Reduction操作
"""
def inner_fn(index):
return x.loader.make_load(index)
return Reduction(
dtype=x.dtype,
inner_fn=inner_fn,
ranges=compute_output_ranges(x.ranges, dims),
reduction_ranges=dims,
reduction_type='sum',
)
# Lowering注册表
LOWERING_REGISTRY = {
torch.sin: lower_sin,
torch.cos: lambda x: lower_pointwise_unary(x, 'tl.cos'),
torch.add: lower_add,
torch.sum: lower_sum,
# ... 更多操作
}
3.3 调度与融合
# torch/_inductor/scheduler.py (简化版)
class Scheduler:
"""
调度器
决定IR节点的执行顺序和融合策略
"""
def __init__(self, nodes):
self.nodes = nodes
self.fused_groups = []
def schedule(self):
"""
主调度函数
"""
# 1. 构建依赖图
dep_graph = self._build_dependency_graph()
# 2. 识别融合机会
fusion_groups = self._find_fusion_groups(dep_graph)
# 3. 创建调度组
for group in fusion_groups:
scheduled_group = ScheduledGroup(group)
self.fused_groups.append(scheduled_group)
return self.fused_groups
def _build_dependency_graph(self):
"""
构建节点依赖关系图
"""
dep_graph = {}
for node in self.nodes:
deps = self._get_dependencies(node)
dep_graph[node] = deps
return dep_graph
def _find_fusion_groups(self, dep_graph):
"""
识别可融合的节点组
"""
groups = []
visited = set()
for node in self.nodes:
if node in visited:
continue
# 尝试构建融合组
group = self._grow_fusion_group(node, dep_graph, visited)
if group:
groups.append(group)
return groups
def _grow_fusion_group(self, start_node, dep_graph, visited):
"""
从起始节点扩展融合组
"""
group = [start_node]
visited.add(start_node)
# 遍历用户节点,检查是否可融合
for user in self._get_users(start_node):
if self._can_fuse(start_node, user):
group.append(user)
visited.add(user)
# 递归扩展
self._grow_fusion_group(user, dep_graph, visited)
return group
def _can_fuse(self, producer, consumer):
"""
检查两个节点是否可融合
"""
# 条件1: 都是Pointwise操作
if not (isinstance(producer, Pointwise) and isinstance(consumer, Pointwise)):
return False
# 条件2: 形状兼容
if producer.ranges != consumer.ranges:
return False
# 条件3: 内存依赖允许
# ...
return True
class ScheduledGroup:
"""
调度组
一组将被融合为单个kernel的操作
"""
def __init__(self, nodes):
self.nodes = nodes
self.inputs = self._compute_inputs()
self.outputs = self._compute_outputs()
def _compute_inputs(self):
"""计算组的输入"""
inputs = set()
internal = set(self.nodes)
for node in self.nodes:
for dep in node.dependencies:
if dep not in internal:
inputs.add(dep)
return inputs
def _compute_outputs(self):
"""计算组的输出"""
outputs = []
internal = set(self.nodes)
for node in self.nodes:
for user in node.users:
if user not in internal:
outputs.append(node)
break
return outputs
3.4 Triton代码生成
# torch/_inductor/triton_ops.py (简化版)
class TritonKernelGenerator:
"""
Triton内核代码生成器
"""
def __init__(self, scheduled_group):
self.group = scheduled_group
self.code = []
def generate(self):
"""
生成Triton kernel代码
"""
# 生成kernel签名
self._generate_signature()
# 生成索引计算
self._generate_index()
# 生成加载代码
self._generate_loads()
# 生成计算代码
self._generate_compute()
# 生成存储代码
self._generate_stores()
return "\n".join(self.code)
def _generate_signature(self):
"""生成函数签名"""
params = []
for inp in self.group.inputs:
params.append(f"{inp.name}_ptr")
for out in self.group.outputs:
params.append(f"{out.name}_ptr")
params.append("numel")
self.code.append("@triton.jit")
self.code.append(f"def kernel({', '.join(params)}, BLOCK_SIZE: tl.constexpr):")
def _generate_index(self):
"""生成索引计算"""
self.code.append(" pid = tl.program_id(0)")
self.code.append(" offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
self.code.append(" mask = offset < numel")
def _generate_loads(self):
"""生成数据加载"""
for inp in self.group.inputs:
self.code.append(f" {inp.name} = tl.load({inp.name}_ptr + offset, mask=mask)")
def _generate_compute(self):
"""生成计算代码"""
for node in self.group.nodes:
compute_code = self._lower_node_to_triton(node)
self.code.append(f" {compute_code}")
def _generate_stores(self):
"""生成数据存储"""
for out in self.group.outputs:
self.code.append(f" tl.store({out.name}_ptr + offset, {out.name}, mask=mask)")
def _lower_node_to_triton(self, node):
"""将IR节点转换为Triton代码"""
if isinstance(node, Pointwise):
return node.inner_fn("offset")
# ... 更多节点类型
# 生成的Triton代码示例
def generated_triton_example():
"""
示例: sin(x) + cos(y) 生成的Triton代码
"""
code = """
@triton.jit
def fused_kernel(
x_ptr, y_ptr, output_ptr,
numel,
BLOCK_SIZE: tl.constexpr
):
# 索引计算
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < numel
# 加载数据
x = tl.load(x_ptr + offset, mask=mask)
y = tl.load(y_ptr + offset, mask=mask)
# 计算 (融合的操作)
tmp0 = tl.sin(x)
tmp1 = tl.cos(y)
tmp2 = tmp0 + tmp1
# 存储结果
tl.store(output_ptr + offset, tmp2, mask=mask)
def call_kernel(x, y):
output = torch.empty_like(x)
numel = x.numel()
BLOCK_SIZE = 1024
grid = (triton.cdiv(numel, BLOCK_SIZE),)
fused_kernel[grid](
x, y, output,
numel,
BLOCK_SIZE=BLOCK_SIZE
)
return output
"""
return code
4. torch.compile使用详解
4.1 基本用法
import torch
# 方式1: 装饰器
@torch.compile
def fn(x, y):
return torch.sin(x) + torch.cos(y)
# 方式2: 函数调用
def fn2(x, y):
return torch.sin(x) + torch.cos(y)
compiled_fn = torch.compile(fn2)
# 方式3: 编译nn.Module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x @ self.weight)
model = MyModule()
compiled_model = torch.compile(model)
# 方式4: 训练时编译
model = torch.compile(model)
optimizer = torch.optim.Adam(model.parameters())
for batch in dataloader:
loss = model(batch)
loss.backward() # 反向传播也会被编译
optimizer.step()
4.2 编译选项
import torch
# mode选项
@torch.compile(mode="default") # 默认,平衡编译时间和性能
def fn1(x): return x + 1
@torch.compile(mode="reduce-overhead") # 减少开销,适合小模型
def fn2(x): return x + 1
@torch.compile(mode="max-autotune") # 最大调优,更长编译时间
def fn3(x): return x + 1
# backend选项
@torch.compile(backend="inductor") # 默认,使用TorchInductor
def fn4(x): return x + 1
@torch.compile(backend="eager") # 不编译,用于调试
def fn5(x): return x + 1
@torch.compile(backend="aot_eager") # AOT但eager执行
def fn6(x): return x + 1
# fullgraph选项
@torch.compile(fullgraph=True) # 要求完整图,遇到graph break会报错
def fn7(x):
return torch.sin(x)
# dynamic选项
@torch.compile(dynamic=True) # 支持动态形状
def fn8(x):
return x.sum()
# 组合选项
@torch.compile(
mode="max-autotune",
backend="inductor",
fullgraph=True,
dynamic=False,
)
def optimized_fn(x, y):
return torch.matmul(x, y)
4.3 调试技巧
import torch
import torch._dynamo as dynamo
# 1. 查看编译日志
torch._dynamo.config.verbose = True
torch._dynamo.config.log_level = logging.DEBUG
# 2. 导出FX Graph
@torch.compile(backend="aot_eager")
def fn(x):
return torch.sin(x) + 1
# 使用explain查看编译信息
explanation = torch._dynamo.explain(fn)
print(explanation)
# 3. 查看生成的代码
torch._inductor.config.debug = True
@torch.compile
def fn(x, y):
return torch.sin(x) + torch.cos(y)
x = torch.randn(1000, device='cuda')
y = torch.randn(1000, device='cuda')
result = fn(x, y) # 编译并运行
# 生成的代码会保存在 /tmp/torchinductor_xxx/ 目录
# 4. 性能分析
with torch.profiler.profile() as prof:
for _ in range(10):
result = fn(x, y)
print(prof.key_averages().table())
# 5. 禁用编译 (调试用)
torch._dynamo.config.suppress_errors = True # 编译失败时回退到eager
# 6. 重置编译缓存
torch._dynamo.reset()
# 7. 检测graph break
@torch.compile(fullgraph=True)
def fn_with_break(x, cond):
if cond.item(): # 这会导致graph break
return torch.sin(x)
return torch.cos(x)
# 使用fullgraph=True会在graph break时报错
# 8. 使用TORCH_COMPILE_DEBUG环境变量
# TORCH_COMPILE_DEBUG=1 python script.py
4.4 常见问题和解决方案
# 问题1: Graph Break导致的性能下降
# 坏的写法 (会导致graph break)
@torch.compile
def bad_fn(x, threshold):
if x.sum().item() > threshold: # .item() 导致 graph break
return x * 2
return x
# 好的写法
@torch.compile
def good_fn(x, threshold):
mask = x.sum() > threshold
return torch.where(mask, x * 2, x)
# 问题2: 动态形状
# 方案1: 使用dynamic=True
@torch.compile(dynamic=True)
def dynamic_fn(x):
return x.sum(dim=-1)
# 方案2: 标记动态维度
from torch._dynamo import mark_dynamic
x = torch.randn(batch_size, seq_len, hidden_dim)
mark_dynamic(x, 0) # batch维度是动态的
# 问题3: 自定义操作
# 注册自定义操作以支持编译
@torch.library.custom_op("mylib::custom_relu", mutates_args=())
def custom_relu(x: torch.Tensor) -> torch.Tensor:
return torch.relu(x)
@custom_relu.register_fake
def custom_relu_fake(x):
return torch.empty_like(x)
# 现在可以在编译函数中使用
@torch.compile
def fn_with_custom_op(x):
return torch.ops.mylib.custom_relu(x)
# 问题4: 数据依赖的控制流
# 使用torch.cond
@torch.compile
def fn_with_cond(x, pred):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return torch.cond(pred, true_fn, false_fn, (x,))
# 问题5: 编译时间过长
# 使用缓存
torch._inductor.config.fx_graph_cache = True
# 或者预编译
model = torch.compile(model)
# 预热
for _ in range(3):
_ = model(dummy_input)
# 保存编译结果
torch.save(model.state_dict(), "compiled_model.pt")
5. 性能优化最佳实践
5.1 优化建议
import torch
# 1. 使用正确的mode
# 训练时: 默认mode
model = torch.compile(model)
# 推理时: reduce-overhead (适合小batch) 或 max-autotune (适合大batch)
model = torch.compile(model, mode="max-autotune")
# 2. 避免graph break
# 使用torch.where, torch.cond等
# 避免.item(), print(), Python控制流
# 3. 利用算子融合
@torch.compile
def fused_gelu(x):
# 这些操作会被自动融合
return x * 0.5 * (1.0 + torch.tanh(
0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))
))
# 4. 正确设置动态形状
@torch.compile(dynamic=True)
def dynamic_model(x):
# 支持不同序列长度
return self.transformer(x)
# 5. 使用torch.cuda.amp与compile结合
model = torch.compile(model)
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(input)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 6. 批量编译多个函数
functions_to_compile = [fn1, fn2, fn3]
compiled_functions = [torch.compile(fn) for fn in functions_to_compile]
# 7. 使用compile与FSDP结合
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model)
model = torch.compile(model)
5.2 性能对比
import torch
import time
def benchmark(fn, x, warmup=10, repeat=100):
"""性能基准测试"""
# 预热
for _ in range(warmup):
_ = fn(x)
torch.cuda.synchronize()
# 计时
start = time.time()
for _ in range(repeat):
_ = fn(x)
torch.cuda.synchronize()
elapsed = time.time() - start
return elapsed / repeat * 1000 # ms
# 测试函数
def attention(q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
weights = torch.softmax(scores, dim=-1)
return torch.matmul(weights, v)
# 创建输入
batch, heads, seq, dim = 16, 8, 1024, 64
q = torch.randn(batch, heads, seq, dim, device='cuda')
k = torch.randn(batch, heads, seq, dim, device='cuda')
v = torch.randn(batch, heads, seq, dim, device='cuda')
# Eager模式
eager_time = benchmark(lambda x: attention(x[0], x[1], x[2]), (q, k, v))
# Compile模式
compiled_attention = torch.compile(attention)
compile_time = benchmark(lambda x: compiled_attention(x[0], x[1], x[2]), (q, k, v))
# max-autotune模式
autotune_attention = torch.compile(attention, mode="max-autotune")
autotune_time = benchmark(lambda x: autotune_attention(x[0], x[1], x[2]), (q, k, v))
print(f"Eager: {eager_time:.2f} ms")
print(f"Compile: {compile_time:.2f} ms (speedup: {eager_time/compile_time:.2f}x)")
print(f"Autotune: {autotune_time:.2f} ms (speedup: {eager_time/autotune_time:.2f}x)")
6. 面试高频问题
Q1: TorchDynamo是如何捕获计算图的?
答案要点:
- PEP 523钩子: 使用Python帧评估钩子拦截字节码执行
- 符号追踪: 将Python操作转换为符号表示
- Guard生成: 生成运行时检查确保编译代码有效
- Graph Break: 遇到不支持操作时分割图
Q2: 什么情况会导致Graph Break?
答案要点:
- 数据依赖的控制流 (if x.item() > 0)
- 不支持的Python操作 (print, 某些内置函数)
- 动态属性访问
- 某些第三方库调用
Q3: torch.compile的不同mode有什么区别?
答案要点:
- default: 平衡编译时间和运行时性能
- reduce-overhead: 减少kernel启动开销,适合小batch
- max-autotune: 最大化性能,编译时间更长,适合部署
Q4: TorchInductor如何实现算子融合?
答案要点:
- Lowering: 将FX图转换为IR节点
- 调度分析: 分析节点依赖关系
- 融合判断: 检查形状兼容性、内存访问模式
- 代码生成: 生成融合的Triton/C++ kernel
Q5: 如何调试torch.compile的性能问题?
答案要点:
- 使用
torch._dynamo.explain()查看编译信息 - 设置
TORCH_COMPILE_DEBUG=1查看详细日志 - 检查graph break,使用
fullgraph=True定位 - 使用torch.profiler分析kernel执行时间
- 查看生成的代码 (在/tmp/torchinductor_xxx/)
7. 学习资源
官方文档
源码阅读
- torch/_dynamo/ - TorchDynamo实现
- torch/_inductor/ - TorchInductor实现
- torch/fx/ - FX Graph实现