HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

02-TorchDynamo与torch.compile

概述

PyTorch 2.0引入的torch.compile是一个革命性的编译API,背后由TorchDynamo和TorchInductor驱动。TorchDynamo通过Python字节码追踪捕获计算图,TorchInductor负责生成高效的GPU/CPU代码。本章深入解析其工作原理和源码实现。

整体架构

┌─────────────────────────────────────────────────────────────────────────────┐
│                      torch.compile 架构                                      │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Python 代码                                     │     │
│  │   @torch.compile                                                    │     │
│  │   def fn(x, y):                                                     │     │
│  │       return torch.sin(x) + torch.cos(y)                           │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     TorchDynamo                                     │     │
│  │   ├─ Python字节码拦截                                               │     │
│  │   ├─ 符号追踪 (Symbolic Tracing)                                    │     │
│  │   ├─ Guard生成                                                      │     │
│  │   └─ Graph Break处理                                                │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     FX Graph                                        │     │
│  │   graph():                                                          │     │
│  │       %x : [num_users=1] = placeholder[target=x]                   │     │
│  │       %y : [num_users=1] = placeholder[target=y]                   │     │
│  │       %sin : [num_users=1] = call_function[target=torch.sin]       │     │
│  │       %cos : [num_users=1] = call_function[target=torch.cos]       │     │
│  │       %add : [num_users=1] = call_function[target=operator.add]    │     │
│  │       return add                                                    │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Backend (TorchInductor)                         │     │
│  │   ├─ 图优化 (Fusion, DCE, CSE)                                      │     │
│  │   ├─ 调度 (Scheduling)                                              │     │
│  │   └─ 代码生成 (Triton/C++)                                          │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Compiled Code                                   │     │
│  │   Triton Kernel / C++ Kernel                                       │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

1. TorchDynamo工作原理

1.1 Python字节码追踪

# TorchDynamo的核心: 拦截Python字节码执行

import dis
import torch

def simple_fn(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

# 查看Python字节码
print("Bytecode:")
dis.dis(simple_fn)
"""
输出:
  2           0 LOAD_GLOBAL              0 (torch)
              2 LOAD_ATTR                1 (sin)
              4 LOAD_FAST                0 (x)
              6 CALL_FUNCTION            1
              8 STORE_FAST               2 (a)

  3          10 LOAD_GLOBAL              0 (torch)
             12 LOAD_ATTR                2 (cos)
             14 LOAD_FAST                1 (y)
             16 CALL_FUNCTION            1
             18 STORE_FAST               3 (b)

  4          20 LOAD_FAST                2 (a)
             22 LOAD_FAST                3 (b)
             24 BINARY_ADD
             26 RETURN_VALUE
"""

# TorchDynamo的工作方式
# 1. 设置帧评估钩子 (PEP 523)
# 2. 拦截每条字节码指令
# 3. 对张量操作进行符号追踪
# 4. 生成FX Graph

1.2 符号追踪机制

# torch/_dynamo/symbolic_convert.py (简化版)

class InstructionTranslator:
    """
    字节码指令翻译器
    将Python字节码转换为FX图操作
    """

    def __init__(self, frame, compiler_fn):
        self.frame = frame
        self.graph = torch.fx.Graph()
        self.symbolic_locals = {}  # 符号化的局部变量
        self.compiler_fn = compiler_fn

    def LOAD_FAST(self, inst):
        """
        处理 LOAD_FAST 指令
        加载局部变量
        """
        name = inst.argval
        if name in self.symbolic_locals:
            # 已经符号化的变量
            self.push(self.symbolic_locals[name])
        else:
            # 从实际帧获取值
            value = self.frame.f_locals[name]
            if isinstance(value, torch.Tensor):
                # 创建placeholder节点
                proxy = self.graph.placeholder(name)
                self.symbolic_locals[name] = proxy
                self.push(proxy)
            else:
                self.push(value)

    def CALL_FUNCTION(self, inst):
        """
        处理 CALL_FUNCTION 指令
        函数调用
        """
        nargs = inst.argval
        args = [self.pop() for _ in range(nargs)][::-1]
        fn = self.pop()

        # 检查是否是支持的PyTorch操作
        if self.is_torch_op(fn):
            # 创建call_function节点
            result = self.graph.call_function(fn, args)
            self.push(result)
        else:
            # 不支持的操作,触发graph break
            raise GraphBreak(f"Unsupported function: {fn}")

    def BINARY_ADD(self, inst):
        """
        处理 BINARY_ADD 指令
        加法操作
        """
        b = self.pop()
        a = self.pop()

        if isinstance(a, torch.fx.Proxy) or isinstance(b, torch.fx.Proxy):
            # 创建add节点
            import operator
            result = self.graph.call_function(operator.add, (a, b))
            self.push(result)
        else:
            # 常量折叠
            self.push(a + b)

    def run(self):
        """
        执行字节码翻译
        """
        for inst in self.instructions:
            handler = getattr(self, inst.opname, None)
            if handler is None:
                raise GraphBreak(f"Unsupported opcode: {inst.opname}")
            handler(inst)

        # 返回FX Graph
        return self.graph


class SymbolicValue:
    """
    符号值
    表示追踪过程中的符号张量
    """

    def __init__(self, proxy, example_value):
        self.proxy = proxy  # FX Proxy节点
        self.example_value = example_value  # 示例值用于形状推断

    @property
    def shape(self):
        return self.example_value.shape

    @property
    def dtype(self):
        return self.example_value.dtype

1.3 Guard机制

# torch/_dynamo/guards.py (简化版)

class GuardBuilder:
    """
    Guard构建器
    生成运行时检查,确保编译代码的有效性
    """

    def __init__(self):
        self.guards = []

    def tensor_guard(self, name, tensor):
        """
        张量Guard
        检查张量属性是否匹配
        """
        # 形状检查
        self.guards.append(
            f"isinstance({name}, torch.Tensor) and "
            f"{name}.shape == {tuple(tensor.shape)} and "
            f"{name}.dtype == {tensor.dtype} and "
            f"{name}.device == {tensor.device} and "
            f"{name}.requires_grad == {tensor.requires_grad}"
        )

    def type_guard(self, name, value):
        """
        类型Guard
        """
        self.guards.append(f"type({name}) is {type(value).__name__}")

    def value_guard(self, name, value):
        """
        值Guard (用于常量)
        """
        self.guards.append(f"{name} == {repr(value)}")

    def build(self):
        """
        生成Guard函数
        """
        guard_code = " and ".join(self.guards)
        return eval(f"lambda locals_: {guard_code}")


class GuardedCode:
    """
    带Guard的编译代码
    """

    def __init__(self, compiled_fn, guard_fn, graph):
        self.compiled_fn = compiled_fn
        self.guard_fn = guard_fn
        self.graph = graph

    def __call__(self, *args, **kwargs):
        # 检查Guard
        locals_dict = dict(zip(self.graph.input_names, args))
        if not self.guard_fn(locals_dict):
            # Guard失败,需要重新编译
            raise GuardFail("Guard check failed")

        return self.compiled_fn(*args, **kwargs)


# Guard使用示例
def example_guard_flow():
    """
    Guard工作流程示例
    """
    @torch.compile
    def fn(x, y):
        return x + y

    # 第一次调用
    x1 = torch.randn(10, 10)
    y1 = torch.randn(10, 10)
    result1 = fn(x1, y1)  # 编译并生成Guard

    # 第二次调用 (形状相同)
    x2 = torch.randn(10, 10)
    y2 = torch.randn(10, 10)
    result2 = fn(x2, y2)  # Guard通过,使用编译代码

    # 第三次调用 (形状不同)
    x3 = torch.randn(20, 20)
    y3 = torch.randn(20, 20)
    result3 = fn(x3, y3)  # Guard失败,重新编译

1.4 Graph Break处理

# torch/_dynamo/exc.py

class GraphBreak(Exception):
    """
    Graph Break异常
    当遇到无法追踪的操作时抛出
    """
    pass


# 常见的Graph Break原因
GRAPH_BREAK_REASONS = {
    "data_dependent_control_flow": "数据依赖的控制流",
    "unsupported_op": "不支持的操作",
    "dynamic_shape": "动态形状操作",
    "inline_call": "内联调用失败",
    "graph_output": "图输出问题",
}


# torch/_dynamo/symbolic_convert.py

class GraphBreakHandler:
    """
    Graph Break处理器
    """

    def __init__(self, translator):
        self.translator = translator
        self.sub_graphs = []

    def handle_break(self, break_reason):
        """
        处理Graph Break
        将当前图分割,继续追踪
        """
        # 1. 完成当前子图
        current_graph = self.translator.finish_subgraph()
        self.sub_graphs.append(current_graph)

        # 2. 执行导致break的代码(不追踪)
        self.execute_break_code()

        # 3. 开始新的子图
        self.translator.start_new_subgraph()

    def execute_break_code(self):
        """
        执行导致break的Python代码
        """
        # 回退到Python解释器执行
        pass


# Graph Break示例
def graph_break_example():
    @torch.compile
    def fn(x, cond):
        a = torch.sin(x)  # 可追踪

        if cond.item():  # Graph Break! 数据依赖的控制流
            b = torch.cos(a)
        else:
            b = torch.tan(a)

        c = b + 1  # 新的子图
        return c

    # 这会生成多个子图
    x = torch.randn(10)
    cond = torch.tensor(True)
    result = fn(x, cond)


# 避免Graph Break的技巧
def avoid_graph_break():
    # 1. 使用torch.where替代if-else
    @torch.compile
    def good_fn(x, cond):
        a = torch.sin(x)
        b = torch.where(cond, torch.cos(a), torch.tan(a))
        return b + 1

    # 2. 使用torch.cond (PyTorch 2.1+)
    @torch.compile
    def better_fn(x, cond):
        a = torch.sin(x)

        def true_fn(a):
            return torch.cos(a)

        def false_fn(a):
            return torch.tan(a)

        b = torch.cond(cond, true_fn, false_fn, (a,))
        return b + 1

2. FX Graph详解

2.1 FX Graph结构

# torch/fx/graph.py (简化版)

class Node:
    """
    FX图中的节点
    """

    def __init__(self, graph, name, op, target, args, kwargs):
        self.graph = graph
        self.name = name
        self.op = op  # 操作类型: placeholder, call_function, call_method, etc.
        self.target = target  # 目标函数或属性
        self.args = args  # 位置参数
        self.kwargs = kwargs  # 关键字参数
        self.users = {}  # 使用此节点的其他节点

    def __repr__(self):
        return f"{self.name} = {self.op}[target={self.target}]({self.args})"


class Graph:
    """
    FX计算图
    """

    def __init__(self):
        self.nodes = []
        self._node_name_counter = 0

    def placeholder(self, name):
        """创建输入占位符"""
        node = Node(self, name, 'placeholder', name, (), {})
        self.nodes.append(node)
        return Proxy(node)

    def call_function(self, target, args, kwargs=None):
        """创建函数调用节点"""
        name = self._generate_name(target.__name__)
        node = Node(self, name, 'call_function', target, args, kwargs or {})
        self.nodes.append(node)
        return Proxy(node)

    def call_method(self, target, args, kwargs=None):
        """创建方法调用节点"""
        name = self._generate_name(target)
        node = Node(self, name, 'call_method', target, args, kwargs or {})
        self.nodes.append(node)
        return Proxy(node)

    def output(self, result):
        """设置图输出"""
        node = Node(self, 'output', 'output', 'output', (result,), {})
        self.nodes.append(node)

    def _generate_name(self, base):
        self._node_name_counter += 1
        return f"{base}_{self._node_name_counter}"

    def print_tabular(self):
        """打印图结构"""
        print("opcode       target           args         kwargs")
        print("-" * 60)
        for node in self.nodes:
            print(f"{node.op:12} {str(node.target):16} {str(node.args):12} {str(node.kwargs)}")


class Proxy:
    """
    代理对象
    用于构建图的流畅API
    """

    def __init__(self, node):
        self.node = node

    def __add__(self, other):
        import operator
        return self.node.graph.call_function(operator.add, (self, other))

    def __mul__(self, other):
        import operator
        return self.node.graph.call_function(operator.mul, (self, other))


# FX Graph使用示例
def fx_graph_example():
    import torch
    import torch.fx as fx

    class MyModule(torch.nn.Module):
        def forward(self, x, y):
            a = torch.sin(x)
            b = torch.cos(y)
            return a + b

    # 符号追踪
    traced = fx.symbolic_trace(MyModule())
    print(traced.graph)

    # 输出:
    # graph():
    #     %x : [num_users=1] = placeholder[target=x]
    #     %y : [num_users=1] = placeholder[target=y]
    #     %sin : [num_users=1] = call_function[target=torch.sin](args = (%x,))
    #     %cos : [num_users=1] = call_function[target=torch.cos](args = (%y,))
    #     %add : [num_users=1] = call_function[target=operator.add](args = (%sin, %cos))
    #     return add

2.2 FX Graph变换

# torch/fx/passes (简化版)

class GraphTransform:
    """
    图变换基类
    """

    def __call__(self, graph):
        return self.transform(graph)

    def transform(self, graph):
        raise NotImplementedError


class PatternMatcher:
    """
    模式匹配器
    用于识别可优化的图模式
    """

    def __init__(self, pattern):
        self.pattern = pattern

    def match(self, graph, node):
        """
        尝试从给定节点开始匹配模式
        """
        # 简化的模式匹配逻辑
        pass


class FusionPass(GraphTransform):
    """
    算子融合Pass
    """

    def transform(self, graph):
        # 遍历节点寻找融合机会
        for node in graph.nodes:
            # 示例: sin + cos 融合
            if self._is_fusable_pattern(node):
                self._fuse_nodes(graph, node)

        return graph

    def _is_fusable_pattern(self, node):
        """检查是否可融合"""
        # 检查是否是element-wise操作
        # 检查是否有相同的输入/输出形状
        pass

    def _fuse_nodes(self, graph, node):
        """执行融合"""
        pass


class DeadCodeElimination(GraphTransform):
    """
    死代码消除Pass
    """

    def transform(self, graph):
        # 从输出节点反向标记活跃节点
        live_nodes = set()
        self._mark_live(graph.output_node, live_nodes)

        # 删除非活跃节点
        graph.nodes = [n for n in graph.nodes if n in live_nodes]

        return graph

    def _mark_live(self, node, live_nodes):
        if node in live_nodes:
            return
        live_nodes.add(node)
        for arg in node.args:
            if isinstance(arg, Node):
                self._mark_live(arg, live_nodes)


class CommonSubexpressionElimination(GraphTransform):
    """
    公共子表达式消除Pass
    """

    def transform(self, graph):
        # 记录已见过的表达式
        seen = {}

        for node in graph.nodes:
            key = (node.op, node.target, tuple(node.args), tuple(node.kwargs.items()))

            if key in seen:
                # 用已有节点替换
                self._replace_uses(graph, node, seen[key])
            else:
                seen[key] = node

        return graph


# 应用Pass示例
def apply_passes_example():
    import torch
    import torch.fx as fx

    class MyModule(torch.nn.Module):
        def forward(self, x):
            a = torch.sin(x)
            b = torch.sin(x)  # 重复计算
            c = a + b
            d = torch.cos(x)  # 死代码 (未使用)
            return c

    traced = fx.symbolic_trace(MyModule())

    # 应用优化Pass
    passes = [
        CommonSubexpressionElimination(),
        DeadCodeElimination(),
    ]

    graph = traced.graph
    for pass_fn in passes:
        graph = pass_fn(graph)

    print(graph)

3. TorchInductor代码生成

3.1 Inductor架构

┌─────────────────────────────────────────────────────────────────────────────┐
│                      TorchInductor 架构                                      │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  FX Graph                                                                    │
│      │                                                                       │
│      ▼                                                                       │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Decomposition                                   │     │
│  │   将高层算子分解为基础算子                                           │     │
│  │   e.g., LayerNorm → mean, sub, var, rsqrt, mul, add                │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│      │                                                                       │
│      ▼                                                                       │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Lowering                                        │     │
│  │   将FX节点转换为IR节点                                               │     │
│  │   构建操作依赖关系                                                   │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│      │                                                                       │
│      ▼                                                                       │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Scheduling                                      │     │
│  │   决定哪些操作融合到一个kernel                                       │     │
│  │   优化内存访问模式                                                   │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│      │                                                                       │
│      ▼                                                                       │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Code Generation                                 │     │
│  │   GPU: 生成Triton代码                                               │     │
│  │   CPU: 生成C++代码                                                  │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│      │                                                                       │
│      ▼                                                                       │
│  Compiled Kernel                                                             │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

3.2 Lowering过程

# torch/_inductor/lowering.py (简化版)

class IRNode:
    """
    Inductor IR节点基类
    """
    pass


class Pointwise(IRNode):
    """
    逐点操作
    """

    def __init__(self, dtype, inner_fn, ranges):
        self.dtype = dtype
        self.inner_fn = inner_fn  # 计算函数
        self.ranges = ranges  # 循环范围

    def __repr__(self):
        return f"Pointwise(dtype={self.dtype}, ranges={self.ranges})"


class Reduction(IRNode):
    """
    规约操作
    """

    def __init__(self, dtype, inner_fn, ranges, reduction_ranges, reduction_type):
        self.dtype = dtype
        self.inner_fn = inner_fn
        self.ranges = ranges
        self.reduction_ranges = reduction_ranges
        self.reduction_type = reduction_type  # sum, mean, max, etc.


class Buffer(IRNode):
    """
    内存缓冲区
    """

    def __init__(self, name, layout):
        self.name = name
        self.layout = layout


# Lowering函数
def lower_sin(x):
    """
    将torch.sin lowering为Pointwise操作
    """
    def inner_fn(index):
        return f"tl.sin({x.loader.make_load(index)})"

    return Pointwise(
        dtype=x.dtype,
        inner_fn=inner_fn,
        ranges=x.ranges,
    )


def lower_add(a, b):
    """
    将加法lowering为Pointwise操作
    """
    def inner_fn(index):
        return f"({a.loader.make_load(index)} + {b.loader.make_load(index)})"

    return Pointwise(
        dtype=a.dtype,
        inner_fn=inner_fn,
        ranges=a.ranges,
    )


def lower_sum(x, dims):
    """
    将求和lowering为Reduction操作
    """
    def inner_fn(index):
        return x.loader.make_load(index)

    return Reduction(
        dtype=x.dtype,
        inner_fn=inner_fn,
        ranges=compute_output_ranges(x.ranges, dims),
        reduction_ranges=dims,
        reduction_type='sum',
    )


# Lowering注册表
LOWERING_REGISTRY = {
    torch.sin: lower_sin,
    torch.cos: lambda x: lower_pointwise_unary(x, 'tl.cos'),
    torch.add: lower_add,
    torch.sum: lower_sum,
    # ... 更多操作
}

3.3 调度与融合

# torch/_inductor/scheduler.py (简化版)

class Scheduler:
    """
    调度器
    决定IR节点的执行顺序和融合策略
    """

    def __init__(self, nodes):
        self.nodes = nodes
        self.fused_groups = []

    def schedule(self):
        """
        主调度函数
        """
        # 1. 构建依赖图
        dep_graph = self._build_dependency_graph()

        # 2. 识别融合机会
        fusion_groups = self._find_fusion_groups(dep_graph)

        # 3. 创建调度组
        for group in fusion_groups:
            scheduled_group = ScheduledGroup(group)
            self.fused_groups.append(scheduled_group)

        return self.fused_groups

    def _build_dependency_graph(self):
        """
        构建节点依赖关系图
        """
        dep_graph = {}
        for node in self.nodes:
            deps = self._get_dependencies(node)
            dep_graph[node] = deps
        return dep_graph

    def _find_fusion_groups(self, dep_graph):
        """
        识别可融合的节点组
        """
        groups = []
        visited = set()

        for node in self.nodes:
            if node in visited:
                continue

            # 尝试构建融合组
            group = self._grow_fusion_group(node, dep_graph, visited)
            if group:
                groups.append(group)

        return groups

    def _grow_fusion_group(self, start_node, dep_graph, visited):
        """
        从起始节点扩展融合组
        """
        group = [start_node]
        visited.add(start_node)

        # 遍历用户节点,检查是否可融合
        for user in self._get_users(start_node):
            if self._can_fuse(start_node, user):
                group.append(user)
                visited.add(user)
                # 递归扩展
                self._grow_fusion_group(user, dep_graph, visited)

        return group

    def _can_fuse(self, producer, consumer):
        """
        检查两个节点是否可融合
        """
        # 条件1: 都是Pointwise操作
        if not (isinstance(producer, Pointwise) and isinstance(consumer, Pointwise)):
            return False

        # 条件2: 形状兼容
        if producer.ranges != consumer.ranges:
            return False

        # 条件3: 内存依赖允许
        # ...

        return True


class ScheduledGroup:
    """
    调度组
    一组将被融合为单个kernel的操作
    """

    def __init__(self, nodes):
        self.nodes = nodes
        self.inputs = self._compute_inputs()
        self.outputs = self._compute_outputs()

    def _compute_inputs(self):
        """计算组的输入"""
        inputs = set()
        internal = set(self.nodes)

        for node in self.nodes:
            for dep in node.dependencies:
                if dep not in internal:
                    inputs.add(dep)

        return inputs

    def _compute_outputs(self):
        """计算组的输出"""
        outputs = []
        internal = set(self.nodes)

        for node in self.nodes:
            for user in node.users:
                if user not in internal:
                    outputs.append(node)
                    break

        return outputs

3.4 Triton代码生成

# torch/_inductor/triton_ops.py (简化版)

class TritonKernelGenerator:
    """
    Triton内核代码生成器
    """

    def __init__(self, scheduled_group):
        self.group = scheduled_group
        self.code = []

    def generate(self):
        """
        生成Triton kernel代码
        """
        # 生成kernel签名
        self._generate_signature()

        # 生成索引计算
        self._generate_index()

        # 生成加载代码
        self._generate_loads()

        # 生成计算代码
        self._generate_compute()

        # 生成存储代码
        self._generate_stores()

        return "\n".join(self.code)

    def _generate_signature(self):
        """生成函数签名"""
        params = []
        for inp in self.group.inputs:
            params.append(f"{inp.name}_ptr")
        for out in self.group.outputs:
            params.append(f"{out.name}_ptr")
        params.append("numel")

        self.code.append("@triton.jit")
        self.code.append(f"def kernel({', '.join(params)}, BLOCK_SIZE: tl.constexpr):")

    def _generate_index(self):
        """生成索引计算"""
        self.code.append("    pid = tl.program_id(0)")
        self.code.append("    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
        self.code.append("    mask = offset < numel")

    def _generate_loads(self):
        """生成数据加载"""
        for inp in self.group.inputs:
            self.code.append(f"    {inp.name} = tl.load({inp.name}_ptr + offset, mask=mask)")

    def _generate_compute(self):
        """生成计算代码"""
        for node in self.group.nodes:
            compute_code = self._lower_node_to_triton(node)
            self.code.append(f"    {compute_code}")

    def _generate_stores(self):
        """生成数据存储"""
        for out in self.group.outputs:
            self.code.append(f"    tl.store({out.name}_ptr + offset, {out.name}, mask=mask)")

    def _lower_node_to_triton(self, node):
        """将IR节点转换为Triton代码"""
        if isinstance(node, Pointwise):
            return node.inner_fn("offset")
        # ... 更多节点类型


# 生成的Triton代码示例
def generated_triton_example():
    """
    示例: sin(x) + cos(y) 生成的Triton代码
    """
    code = """
@triton.jit
def fused_kernel(
    x_ptr, y_ptr, output_ptr,
    numel,
    BLOCK_SIZE: tl.constexpr
):
    # 索引计算
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offset < numel

    # 加载数据
    x = tl.load(x_ptr + offset, mask=mask)
    y = tl.load(y_ptr + offset, mask=mask)

    # 计算 (融合的操作)
    tmp0 = tl.sin(x)
    tmp1 = tl.cos(y)
    tmp2 = tmp0 + tmp1

    # 存储结果
    tl.store(output_ptr + offset, tmp2, mask=mask)


def call_kernel(x, y):
    output = torch.empty_like(x)
    numel = x.numel()
    BLOCK_SIZE = 1024

    grid = (triton.cdiv(numel, BLOCK_SIZE),)
    fused_kernel[grid](
        x, y, output,
        numel,
        BLOCK_SIZE=BLOCK_SIZE
    )

    return output
"""
    return code

4. torch.compile使用详解

4.1 基本用法

import torch

# 方式1: 装饰器
@torch.compile
def fn(x, y):
    return torch.sin(x) + torch.cos(y)

# 方式2: 函数调用
def fn2(x, y):
    return torch.sin(x) + torch.cos(y)

compiled_fn = torch.compile(fn2)

# 方式3: 编译nn.Module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x @ self.weight)

model = MyModule()
compiled_model = torch.compile(model)

# 方式4: 训练时编译
model = torch.compile(model)
optimizer = torch.optim.Adam(model.parameters())

for batch in dataloader:
    loss = model(batch)
    loss.backward()  # 反向传播也会被编译
    optimizer.step()

4.2 编译选项

import torch

# mode选项
@torch.compile(mode="default")  # 默认,平衡编译时间和性能
def fn1(x): return x + 1

@torch.compile(mode="reduce-overhead")  # 减少开销,适合小模型
def fn2(x): return x + 1

@torch.compile(mode="max-autotune")  # 最大调优,更长编译时间
def fn3(x): return x + 1


# backend选项
@torch.compile(backend="inductor")  # 默认,使用TorchInductor
def fn4(x): return x + 1

@torch.compile(backend="eager")  # 不编译,用于调试
def fn5(x): return x + 1

@torch.compile(backend="aot_eager")  # AOT但eager执行
def fn6(x): return x + 1


# fullgraph选项
@torch.compile(fullgraph=True)  # 要求完整图,遇到graph break会报错
def fn7(x):
    return torch.sin(x)


# dynamic选项
@torch.compile(dynamic=True)  # 支持动态形状
def fn8(x):
    return x.sum()


# 组合选项
@torch.compile(
    mode="max-autotune",
    backend="inductor",
    fullgraph=True,
    dynamic=False,
)
def optimized_fn(x, y):
    return torch.matmul(x, y)

4.3 调试技巧

import torch
import torch._dynamo as dynamo

# 1. 查看编译日志
torch._dynamo.config.verbose = True
torch._dynamo.config.log_level = logging.DEBUG

# 2. 导出FX Graph
@torch.compile(backend="aot_eager")
def fn(x):
    return torch.sin(x) + 1

# 使用explain查看编译信息
explanation = torch._dynamo.explain(fn)
print(explanation)

# 3. 查看生成的代码
torch._inductor.config.debug = True

@torch.compile
def fn(x, y):
    return torch.sin(x) + torch.cos(y)

x = torch.randn(1000, device='cuda')
y = torch.randn(1000, device='cuda')
result = fn(x, y)  # 编译并运行

# 生成的代码会保存在 /tmp/torchinductor_xxx/ 目录

# 4. 性能分析
with torch.profiler.profile() as prof:
    for _ in range(10):
        result = fn(x, y)

print(prof.key_averages().table())

# 5. 禁用编译 (调试用)
torch._dynamo.config.suppress_errors = True  # 编译失败时回退到eager


# 6. 重置编译缓存
torch._dynamo.reset()


# 7. 检测graph break
@torch.compile(fullgraph=True)
def fn_with_break(x, cond):
    if cond.item():  # 这会导致graph break
        return torch.sin(x)
    return torch.cos(x)

# 使用fullgraph=True会在graph break时报错


# 8. 使用TORCH_COMPILE_DEBUG环境变量
# TORCH_COMPILE_DEBUG=1 python script.py

4.4 常见问题和解决方案

# 问题1: Graph Break导致的性能下降

# 坏的写法 (会导致graph break)
@torch.compile
def bad_fn(x, threshold):
    if x.sum().item() > threshold:  # .item() 导致 graph break
        return x * 2
    return x

# 好的写法
@torch.compile
def good_fn(x, threshold):
    mask = x.sum() > threshold
    return torch.where(mask, x * 2, x)


# 问题2: 动态形状

# 方案1: 使用dynamic=True
@torch.compile(dynamic=True)
def dynamic_fn(x):
    return x.sum(dim=-1)

# 方案2: 标记动态维度
from torch._dynamo import mark_dynamic
x = torch.randn(batch_size, seq_len, hidden_dim)
mark_dynamic(x, 0)  # batch维度是动态的


# 问题3: 自定义操作

# 注册自定义操作以支持编译
@torch.library.custom_op("mylib::custom_relu", mutates_args=())
def custom_relu(x: torch.Tensor) -> torch.Tensor:
    return torch.relu(x)

@custom_relu.register_fake
def custom_relu_fake(x):
    return torch.empty_like(x)

# 现在可以在编译函数中使用
@torch.compile
def fn_with_custom_op(x):
    return torch.ops.mylib.custom_relu(x)


# 问题4: 数据依赖的控制流

# 使用torch.cond
@torch.compile
def fn_with_cond(x, pred):
    def true_fn(x):
        return torch.sin(x)

    def false_fn(x):
        return torch.cos(x)

    return torch.cond(pred, true_fn, false_fn, (x,))


# 问题5: 编译时间过长

# 使用缓存
torch._inductor.config.fx_graph_cache = True

# 或者预编译
model = torch.compile(model)
# 预热
for _ in range(3):
    _ = model(dummy_input)
# 保存编译结果
torch.save(model.state_dict(), "compiled_model.pt")

5. 性能优化最佳实践

5.1 优化建议

import torch

# 1. 使用正确的mode
# 训练时: 默认mode
model = torch.compile(model)

# 推理时: reduce-overhead (适合小batch) 或 max-autotune (适合大batch)
model = torch.compile(model, mode="max-autotune")

# 2. 避免graph break
# 使用torch.where, torch.cond等
# 避免.item(), print(), Python控制流

# 3. 利用算子融合
@torch.compile
def fused_gelu(x):
    # 这些操作会被自动融合
    return x * 0.5 * (1.0 + torch.tanh(
        0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))
    ))

# 4. 正确设置动态形状
@torch.compile(dynamic=True)
def dynamic_model(x):
    # 支持不同序列长度
    return self.transformer(x)

# 5. 使用torch.cuda.amp与compile结合
model = torch.compile(model)
scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast():
    output = model(input)
    loss = criterion(output, target)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()


# 6. 批量编译多个函数
functions_to_compile = [fn1, fn2, fn3]
compiled_functions = [torch.compile(fn) for fn in functions_to_compile]

# 7. 使用compile与FSDP结合
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = FSDP(model)
model = torch.compile(model)

5.2 性能对比

import torch
import time

def benchmark(fn, x, warmup=10, repeat=100):
    """性能基准测试"""
    # 预热
    for _ in range(warmup):
        _ = fn(x)
    torch.cuda.synchronize()

    # 计时
    start = time.time()
    for _ in range(repeat):
        _ = fn(x)
    torch.cuda.synchronize()
    elapsed = time.time() - start

    return elapsed / repeat * 1000  # ms


# 测试函数
def attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
    weights = torch.softmax(scores, dim=-1)
    return torch.matmul(weights, v)


# 创建输入
batch, heads, seq, dim = 16, 8, 1024, 64
q = torch.randn(batch, heads, seq, dim, device='cuda')
k = torch.randn(batch, heads, seq, dim, device='cuda')
v = torch.randn(batch, heads, seq, dim, device='cuda')

# Eager模式
eager_time = benchmark(lambda x: attention(x[0], x[1], x[2]), (q, k, v))

# Compile模式
compiled_attention = torch.compile(attention)
compile_time = benchmark(lambda x: compiled_attention(x[0], x[1], x[2]), (q, k, v))

# max-autotune模式
autotune_attention = torch.compile(attention, mode="max-autotune")
autotune_time = benchmark(lambda x: autotune_attention(x[0], x[1], x[2]), (q, k, v))

print(f"Eager: {eager_time:.2f} ms")
print(f"Compile: {compile_time:.2f} ms (speedup: {eager_time/compile_time:.2f}x)")
print(f"Autotune: {autotune_time:.2f} ms (speedup: {eager_time/autotune_time:.2f}x)")

6. 面试高频问题

Q1: TorchDynamo是如何捕获计算图的?

答案要点:

  1. PEP 523钩子: 使用Python帧评估钩子拦截字节码执行
  2. 符号追踪: 将Python操作转换为符号表示
  3. Guard生成: 生成运行时检查确保编译代码有效
  4. Graph Break: 遇到不支持操作时分割图

Q2: 什么情况会导致Graph Break?

答案要点:

  1. 数据依赖的控制流 (if x.item() > 0)
  2. 不支持的Python操作 (print, 某些内置函数)
  3. 动态属性访问
  4. 某些第三方库调用

Q3: torch.compile的不同mode有什么区别?

答案要点:

  • default: 平衡编译时间和运行时性能
  • reduce-overhead: 减少kernel启动开销,适合小batch
  • max-autotune: 最大化性能,编译时间更长,适合部署

Q4: TorchInductor如何实现算子融合?

答案要点:

  1. Lowering: 将FX图转换为IR节点
  2. 调度分析: 分析节点依赖关系
  3. 融合判断: 检查形状兼容性、内存访问模式
  4. 代码生成: 生成融合的Triton/C++ kernel

Q5: 如何调试torch.compile的性能问题?

答案要点:

  1. 使用torch._dynamo.explain()查看编译信息
  2. 设置TORCH_COMPILE_DEBUG=1查看详细日志
  3. 检查graph break,使用fullgraph=True定位
  4. 使用torch.profiler分析kernel执行时间
  5. 查看生成的代码 (在/tmp/torchinductor_xxx/)

7. 学习资源

官方文档

  • PyTorch Compiler
  • TorchDynamo
  • TorchInductor

源码阅读

  • torch/_dynamo/ - TorchDynamo实现
  • torch/_inductor/ - TorchInductor实现
  • torch/fx/ - FX Graph实现

推荐资源

  • PyTorch 2.0 Tutorial
  • Understanding torch.compile
Prev
01-深度学习编译器概述
Next
03-XLA编译器深度解析