HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

03-XLA编译器深度解析

概述

XLA (Accelerated Linear Algebra) 是Google开发的深度学习编译器,为TensorFlow和JAX提供图编译优化。本章深入解析XLA的架构设计、HLO中间表示、优化Pass以及代码生成机制,帮助读者理解静态图编译器的工作原理。

XLA整体架构

┌─────────────────────────────────────────────────────────────────────────────┐
│                          XLA 架构                                           │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     前端 (Frontend)                                 │     │
│  │   TensorFlow Graph │ JAX Jaxpr │ PyTorch/XLA │ StableHLO           │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     HLO (High Level Optimizer)                      │     │
│  │   ├─ HLO IR: 高层中间表示                                           │     │
│  │   ├─ HLO Module: 计算图容器                                         │     │
│  │   └─ HLO Instruction: 计算节点                                      │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     HLO Optimization Passes                         │     │
│  │   ├─ 代数简化 (Algebraic Simplifier)                                │     │
│  │   ├─ 算子融合 (Fusion)                                              │     │
│  │   ├─ 布局分配 (Layout Assignment)                                   │     │
│  │   ├─ 内存调度 (Buffer Assignment)                                   │     │
│  │   └─ 更多优化...                                                    │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     后端 (Backend)                                  │     │
│  │   CPU (LLVM) │ GPU (CUDA/ROCm) │ TPU │ 其他加速器                   │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                  │                                           │
│                                  ▼                                           │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     可执行代码                                      │     │
│  │   HloModule → Thunks/Executables                                   │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

1. HLO中间表示

1.1 HLO核心概念

// tensorflow/compiler/xla/hlo/ir/hlo_instruction.h (简化版)

// HLO操作码
enum class HloOpcode {
  // 元素级操作
  kAdd,
  kSubtract,
  kMultiply,
  kDivide,
  kExp,
  kLog,
  kSin,
  kCos,
  kTanh,

  // 规约操作
  kReduce,
  kReduceWindow,

  // 矩阵操作
  kDot,
  kConvolution,

  // 数据移动
  kBroadcast,
  kTranspose,
  kReshape,
  kSlice,
  kConcatenate,

  // 控制流
  kWhile,
  kConditional,
  kCall,

  // 通信
  kAllReduce,
  kAllGather,
  kReduceScatter,

  // 其他
  kParameter,
  kConstant,
  kTuple,
  kGetTupleElement,
  // ...
};


class HloInstruction {
 public:
  // 操作码
  HloOpcode opcode() const { return opcode_; }

  // 形状信息
  const Shape& shape() const { return shape_; }

  // 操作数
  const std::vector<HloInstruction*>& operands() const { return operands_; }

  // 用户 (使用此指令的其他指令)
  const std::vector<HloInstruction*>& users() const { return users_; }

  // 名称
  const std::string& name() const { return name_; }

  // 元数据
  const OpMetadata& metadata() const { return metadata_; }

 private:
  HloOpcode opcode_;
  Shape shape_;
  std::vector<HloInstruction*> operands_;
  std::vector<HloInstruction*> users_;
  std::string name_;
  OpMetadata metadata_;
};


class HloComputation {
 public:
  // 计算名称
  const std::string& name() const { return name_; }

  // 指令列表
  const std::vector<std::unique_ptr<HloInstruction>>& instructions() const {
    return instructions_;
  }

  // 根指令 (输出)
  HloInstruction* root_instruction() const { return root_instruction_; }

  // 参数
  const std::vector<HloInstruction*>& parameter_instructions() const {
    return param_instructions_;
  }

 private:
  std::string name_;
  std::vector<std::unique_ptr<HloInstruction>> instructions_;
  HloInstruction* root_instruction_;
  std::vector<HloInstruction*> param_instructions_;
};


class HloModule {
 public:
  // 入口计算
  HloComputation* entry_computation() const { return entry_computation_; }

  // 所有计算 (包括子计算)
  const std::vector<std::unique_ptr<HloComputation>>& computations() const {
    return computations_;
  }

  // 配置
  const HloModuleConfig& config() const { return config_; }

 private:
  HloComputation* entry_computation_;
  std::vector<std::unique_ptr<HloComputation>> computations_;
  HloModuleConfig config_;
};

1.2 HLO文本表示

// HLO文本格式示例

HloModule simple_example

// 入口计算
ENTRY main {
  // 参数定义
  %p0 = f32[128,256] parameter(0), metadata={op_name="input_a"}
  %p1 = f32[256,512] parameter(1), metadata={op_name="input_b"}
  %p2 = f32[512] parameter(2), metadata={op_name="bias"}

  // 矩阵乘法
  %dot = f32[128,512] dot(%p0, %p1),
    lhs_contracting_dims={1},
    rhs_contracting_dims={0}

  // 广播bias
  %broadcast = f32[128,512] broadcast(%p2), dimensions={1}

  // 加法
  %add = f32[128,512] add(%dot, %broadcast)

  // ReLU (使用maximum)
  %zero = f32[] constant(0)
  %broadcast_zero = f32[128,512] broadcast(%zero), dimensions={}
  %relu = f32[128,512] maximum(%add, %broadcast_zero)

  ROOT %output = f32[128,512] tuple(%relu)
}


// 更复杂的示例: 带有reduce的softmax
HloModule softmax_example

%max_computation {
  %p0 = f32[] parameter(0)
  %p1 = f32[] parameter(1)
  ROOT %max = f32[] maximum(%p0, %p1)
}

%add_computation {
  %p0 = f32[] parameter(0)
  %p1 = f32[] parameter(1)
  ROOT %add = f32[] add(%p0, %p1)
}

ENTRY softmax {
  %input = f32[16,1024] parameter(0)

  // Step 1: 计算max
  %neg_inf = f32[] constant(-inf)
  %max = f32[16] reduce(%input, %neg_inf), dimensions={1},
    to_apply=%max_computation

  // Step 2: 减去max
  %broadcast_max = f32[16,1024] broadcast(%max), dimensions={0}
  %sub = f32[16,1024] subtract(%input, %broadcast_max)

  // Step 3: exp
  %exp = f32[16,1024] exponential(%sub)

  // Step 4: 求和
  %zero = f32[] constant(0)
  %sum = f32[16] reduce(%exp, %zero), dimensions={1},
    to_apply=%add_computation

  // Step 5: 除法
  %broadcast_sum = f32[16,1024] broadcast(%sum), dimensions={0}
  ROOT %softmax = f32[16,1024] divide(%exp, %broadcast_sum)
}

1.3 HLO构建API

// 使用C++ API构建HLO

#include "tensorflow/compiler/xla/client/xla_builder.h"

XlaOp BuildMatMulBiasRelu(XlaBuilder* builder,
                          XlaOp input,
                          XlaOp weight,
                          XlaOp bias) {
  // MatMul
  XlaOp dot = Dot(input, weight);

  // BiasAdd
  XlaOp biased = Add(dot, bias, /*broadcast_dimensions=*/{1});

  // ReLU
  XlaOp zero = ConstantR0<float>(builder, 0.0f);
  return Max(biased, zero);
}


// Python API (JAX)
import jax
import jax.numpy as jnp

@jax.jit
def matmul_bias_relu(x, weight, bias):
    """JAX会将此函数编译为HLO"""
    out = jnp.dot(x, weight)
    out = out + bias
    out = jnp.maximum(out, 0)
    return out

# 查看生成的HLO
from jax._src import xla_bridge
hlo_text = jax.make_jaxpr(matmul_bias_relu)(
    jnp.zeros((128, 256)),
    jnp.zeros((256, 512)),
    jnp.zeros((512,))
)
print(hlo_text)

2. HLO优化Pass

2.1 优化Pass框架

// tensorflow/compiler/xla/service/hlo_pass_interface.h

class HloPassInterface {
 public:
  virtual ~HloPassInterface() = default;

  // Pass名称
  virtual std::string_view name() const = 0;

  // 运行Pass
  virtual StatusOr<bool> Run(
      HloModule* module,
      const absl::flat_hash_set<absl::string_view>& execution_threads) = 0;
};


// 模块级Pass基类
class HloModulePass : public HloPassInterface {
 protected:
  // 子类实现此方法
  virtual StatusOr<bool> RunOnModule(
      HloModule* module,
      const absl::flat_hash_set<absl::string_view>& execution_threads) = 0;
};


// Pass管道
class HloPassPipeline : public HloPassInterface {
 public:
  void AddPass(std::unique_ptr<HloPassInterface> pass) {
    passes_.push_back(std::move(pass));
  }

  StatusOr<bool> Run(HloModule* module,
                     const absl::flat_hash_set<absl::string_view>& threads) override {
    bool changed = false;
    for (auto& pass : passes_) {
      TF_ASSIGN_OR_RETURN(bool pass_changed, pass->Run(module, threads));
      changed |= pass_changed;
    }
    return changed;
  }

 private:
  std::vector<std::unique_ptr<HloPassInterface>> passes_;
};

2.2 代数简化 (Algebraic Simplifier)

// tensorflow/compiler/xla/service/algebraic_simplifier.cc (简化版)

class AlgebraicSimplifier : public HloModulePass {
 public:
  std::string_view name() const override { return "algebraic-simplifier"; }

 protected:
  StatusOr<bool> RunOnModule(HloModule* module,
                             const absl::flat_hash_set<absl::string_view>&) override {
    bool changed = false;

    for (HloComputation* computation : module->computations()) {
      for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
        TF_ASSIGN_OR_RETURN(bool inst_changed, SimplifyInstruction(inst));
        changed |= inst_changed;
      }
    }

    return changed;
  }

 private:
  StatusOr<bool> SimplifyInstruction(HloInstruction* inst) {
    switch (inst->opcode()) {
      case HloOpcode::kAdd:
        return SimplifyAdd(inst);
      case HloOpcode::kMultiply:
        return SimplifyMultiply(inst);
      case HloOpcode::kDivide:
        return SimplifyDivide(inst);
      case HloOpcode::kReshape:
        return SimplifyReshape(inst);
      // ... 更多简化规则
      default:
        return false;
    }
  }

  StatusOr<bool> SimplifyAdd(HloInstruction* add) {
    // 规则: x + 0 = x
    if (IsConstantZero(add->operand(1))) {
      return ReplaceWithOperand(add, 0);
    }
    if (IsConstantZero(add->operand(0))) {
      return ReplaceWithOperand(add, 1);
    }

    // 规则: x + x = 2 * x
    if (add->operand(0) == add->operand(1)) {
      auto two = MakeConstant(2.0f, add->shape());
      auto mul = MakeMultiply(two, add->operand(0));
      return ReplaceInstruction(add, mul);
    }

    return false;
  }

  StatusOr<bool> SimplifyMultiply(HloInstruction* mul) {
    // 规则: x * 0 = 0
    if (IsConstantZero(mul->operand(0)) || IsConstantZero(mul->operand(1))) {
      return ReplaceWithZero(mul);
    }

    // 规则: x * 1 = x
    if (IsConstantOne(mul->operand(1))) {
      return ReplaceWithOperand(mul, 0);
    }
    if (IsConstantOne(mul->operand(0))) {
      return ReplaceWithOperand(mul, 1);
    }

    return false;
  }

  StatusOr<bool> SimplifyReshape(HloInstruction* reshape) {
    // 规则: reshape(reshape(x)) = reshape(x)
    if (reshape->operand(0)->opcode() == HloOpcode::kReshape) {
      auto inner_reshape = reshape->operand(0);
      auto new_reshape = MakeReshape(inner_reshape->operand(0), reshape->shape());
      return ReplaceInstruction(reshape, new_reshape);
    }

    // 规则: reshape(x) where shape unchanged = x
    if (ShapeUtil::Equal(reshape->shape(), reshape->operand(0)->shape())) {
      return ReplaceWithOperand(reshape, 0);
    }

    return false;
  }
};

2.3 算子融合 (Fusion)

// tensorflow/compiler/xla/service/gpu/gpu_fusible.cc (简化版)

class GpuInstructionFusion : public HloModulePass {
 public:
  std::string_view name() const override { return "gpu-instruction-fusion"; }

 protected:
  StatusOr<bool> RunOnModule(HloModule* module,
                             const absl::flat_hash_set<absl::string_view>&) override {
    bool changed = false;

    for (HloComputation* computation : module->computations()) {
      // 反向遍历,优先融合靠近输出的指令
      for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
        if (!ShouldFuse(inst)) continue;

        // 尝试融合producer
        for (HloInstruction* operand : inst->operands()) {
          if (CanFuseProducer(inst, operand)) {
            TF_ASSIGN_OR_RETURN(
                HloInstruction* fused,
                Fuse(inst, operand));
            changed = true;
          }
        }
      }
    }

    return changed;
  }

 private:
  bool ShouldFuse(HloInstruction* inst) {
    // 只融合某些类型的操作
    return inst->opcode() == HloOpcode::kFusion ||
           IsElementwise(inst) ||
           inst->opcode() == HloOpcode::kReduce;
  }

  bool CanFuseProducer(HloInstruction* consumer, HloInstruction* producer) {
    // 检查是否可融合

    // 1. producer只有一个consumer
    if (producer->user_count() != 1) return false;

    // 2. 都是element-wise操作
    if (!IsElementwise(producer) || !IsElementwise(consumer)) {
      // 特殊处理reduce
      if (consumer->opcode() == HloOpcode::kReduce &&
          IsElementwise(producer)) {
        return true;
      }
      return false;
    }

    // 3. 形状兼容
    if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) {
      return false;
    }

    return true;
  }

  StatusOr<HloInstruction*> Fuse(HloInstruction* consumer,
                                  HloInstruction* producer) {
    HloComputation* computation = consumer->parent();

    // 创建或更新fusion指令
    HloInstruction* fusion;
    if (consumer->opcode() == HloOpcode::kFusion) {
      // 扩展现有fusion
      fusion = consumer;
      fusion->FuseInstruction(producer);
    } else {
      // 创建新fusion
      fusion = computation->AddInstruction(
          HloInstruction::CreateFusion(
              consumer->shape(),
              HloInstruction::FusionKind::kLoop,  // 循环融合
              {producer},
              computation->AddEmbeddedComputation(...)));
      TF_RETURN_IF_ERROR(consumer->ReplaceAllUsesWith(fusion));
    }

    return fusion;
  }
};


// HLO融合示例
/*
融合前:
%p0 = f32[128,256] parameter(0)
%p1 = f32[128,256] parameter(1)
%add = f32[128,256] add(%p0, %p1)
%exp = f32[128,256] exponential(%add)
%mul = f32[128,256] multiply(%exp, %p0)

融合后:
%fusion = f32[128,256] fusion(%p0, %p1), kind=kLoop,
  calls=
  {
    %param0 = f32[128,256] parameter(0)
    %param1 = f32[128,256] parameter(1)
    %add = f32[128,256] add(%param0, %param1)
    %exp = f32[128,256] exponential(%add)
    ROOT %mul = f32[128,256] multiply(%exp, %param0)
  }
*/

2.4 布局分配 (Layout Assignment)

// tensorflow/compiler/xla/service/layout_assignment.cc (简化版)

// 数据布局表示
// NCHW: batch, channel, height, width
// NHWC: batch, height, width, channel

class GpuLayoutAssignment : public HloModulePass {
 public:
  std::string_view name() const override { return "gpu-layout-assignment"; }

 protected:
  StatusOr<bool> RunOnModule(HloModule* module,
                             const absl::flat_hash_set<absl::string_view>&) override {
    bool changed = false;

    for (HloComputation* computation : module->computations()) {
      for (HloInstruction* inst : computation->instructions()) {
        // 为每个操作选择最优布局
        TF_ASSIGN_OR_RETURN(Layout optimal_layout,
                            ChooseOptimalLayout(inst));

        if (NeedsLayoutChange(inst, optimal_layout)) {
          InsertLayoutTransform(inst, optimal_layout);
          changed = true;
        }
      }
    }

    return changed;
  }

 private:
  Layout ChooseOptimalLayout(HloInstruction* inst) {
    switch (inst->opcode()) {
      case HloOpcode::kConvolution:
        // GPU卷积优先使用NHWC + Tensor Core
        return Layout::NHWC();

      case HloOpcode::kDot:
        // 矩阵乘法: 行主序
        return Layout::RowMajor();

      default:
        // 默认: 保持输入布局
        return inst->operand(0)->shape().layout();
    }
  }

  void InsertLayoutTransform(HloInstruction* inst, const Layout& target) {
    HloComputation* computation = inst->parent();

    // 在操作前后插入transpose/reshape
    for (int i = 0; i < inst->operand_count(); ++i) {
      HloInstruction* operand = inst->mutable_operand(i);

      if (operand->shape().layout() != target) {
        // 插入布局转换
        HloInstruction* transform = computation->AddInstruction(
            HloInstruction::CreateTranspose(
                ShapeUtil::MakeShapeWithLayout(...),
                operand,
                GetPermutation(operand->shape().layout(), target)));

        TF_CHECK_OK(inst->ReplaceOperandWith(i, transform));
      }
    }
  }
};

2.5 内存分配 (Buffer Assignment)

// tensorflow/compiler/xla/service/buffer_assignment.cc (简化版)

class BufferAssignment {
 public:
  // 为每个HLO值分配buffer
  static StatusOr<std::unique_ptr<BufferAssignment>> Run(
      const HloModule* module,
      std::unique_ptr<HloOrdering> hlo_ordering,
      BufferValue::SizeFunction buffer_size,
      LogicalBuffer::AlignmentFunction alignment) {

    auto assignment = std::make_unique<BufferAssignment>(module);

    // 1. 收集所有需要buffer的值
    std::vector<const LogicalBuffer*> buffers;
    for (const HloComputation* comp : module->computations()) {
      for (const HloInstruction* inst : comp->instructions()) {
        for (const LogicalBuffer* buffer : GetBuffers(inst)) {
          buffers.push_back(buffer);
        }
      }
    }

    // 2. 活跃性分析
    LivenessAnalysis liveness = ComputeLiveness(module, *hlo_ordering);

    // 3. 分配buffer,尽可能重用
    for (const LogicalBuffer* buffer : buffers) {
      // 查找可重用的buffer
      BufferAllocation* reusable = FindReusableAllocation(
          assignment.get(), buffer, liveness);

      if (reusable != nullptr) {
        reusable->AddAssignment(buffer);
      } else {
        // 创建新buffer
        auto* new_allocation = assignment->NewAllocation(
            buffer, buffer_size(*buffer), alignment(*buffer));
      }
    }

    return assignment;
  }

 private:
  BufferAllocation* FindReusableAllocation(
      BufferAssignment* assignment,
      const LogicalBuffer* buffer,
      const LivenessAnalysis& liveness) {

    for (BufferAllocation& alloc : assignment->allocations_) {
      // 检查是否可重用
      bool can_reuse = true;
      for (const LogicalBuffer* assigned : alloc.assigned_buffers()) {
        if (liveness.MayInterfere(*assigned, *buffer)) {
          can_reuse = false;
          break;
        }
      }

      if (can_reuse && alloc.size() >= GetSize(buffer)) {
        return &alloc;
      }
    }

    return nullptr;
  }
};

3. GPU代码生成

3.1 GPU后端架构

┌─────────────────────────────────────────────────────────────────────────────┐
│                      XLA GPU 后端                                            │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  优化后的 HLO Module                                                         │
│         │                                                                    │
│         ▼                                                                    │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     IrEmitter                                       │     │
│  │   将HLO指令转换为LLVM IR                                            │     │
│  │   ├─ Element-wise → PTX内联                                        │     │
│  │   ├─ Reduction → 多阶段规约kernel                                   │     │
│  │   ├─ Dot → cuBLAS调用或自定义kernel                                │     │
│  │   └─ Fusion → 融合kernel                                           │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│         │                                                                    │
│         ▼                                                                    │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     LLVM IR                                         │     │
│  │   LLVM优化Pass                                                      │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│         │                                                                    │
│         ▼                                                                    │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     NVPTX 后端                                      │     │
│  │   LLVM IR → PTX                                                    │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│         │                                                                    │
│         ▼                                                                    │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     PTXAS                                           │     │
│  │   PTX → SASS (GPU机器码)                                           │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│         │                                                                    │
│         ▼                                                                    │
│  ┌────────────────────────────────────────────────────────────────────┐     │
│  │                     Thunk                                           │     │
│  │   运行时执行单元                                                    │     │
│  │   ├─ KernelThunk: GPU kernel启动                                   │     │
│  │   ├─ GemmThunk: cuBLAS调用                                         │     │
│  │   └─ CopyThunk: 内存拷贝                                           │     │
│  └────────────────────────────────────────────────────────────────────┘     │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

3.2 IrEmitter代码生成

// tensorflow/compiler/xla/service/gpu/ir_emitter.cc (简化版)

class IrEmitter {
 public:
  // 为HLO计算生成LLVM IR
  Status EmitComputation(const HloComputation* computation) {
    for (const HloInstruction* inst : computation->MakeInstructionPostOrder()) {
      TF_RETURN_IF_ERROR(HandleInstruction(inst));
    }
    return OkStatus();
  }

 private:
  Status HandleInstruction(const HloInstruction* inst) {
    switch (inst->opcode()) {
      case HloOpcode::kAdd:
      case HloOpcode::kMultiply:
      case HloOpcode::kSubtract:
        return HandleElementwiseBinary(inst);

      case HloOpcode::kExp:
      case HloOpcode::kLog:
      case HloOpcode::kSin:
        return HandleElementwiseUnary(inst);

      case HloOpcode::kDot:
        return HandleDot(inst);

      case HloOpcode::kReduce:
        return HandleReduce(inst);

      case HloOpcode::kFusion:
        return HandleFusion(inst);

      // ... 更多操作
      default:
        return Unimplemented("Opcode: %s", HloOpcodeString(inst->opcode()));
    }
  }

  Status HandleElementwiseBinary(const HloInstruction* inst) {
    // 生成逐元素二元操作的kernel

    llvm::Function* kernel = CreateKernelFunction(inst);
    llvm::IRBuilder<> builder(kernel->getEntryBlock());

    // 计算线程索引
    llvm::Value* tid = EmitThreadId(&builder);
    llvm::Value* bid = EmitBlockId(&builder);
    llvm::Value* idx = builder.CreateAdd(
        builder.CreateMul(bid, GetBlockSize()),
        tid);

    // 边界检查
    llvm::Value* in_bounds = builder.CreateICmpULT(idx, GetNumElements(inst));
    llvm::BasicBlock* compute_block = CreateBlock("compute");
    llvm::BasicBlock* exit_block = CreateBlock("exit");
    builder.CreateCondBr(in_bounds, compute_block, exit_block);

    // 计算块
    builder.SetInsertPoint(compute_block);

    // 加载操作数
    llvm::Value* lhs = EmitLoad(&builder, inst->operand(0), idx);
    llvm::Value* rhs = EmitLoad(&builder, inst->operand(1), idx);

    // 执行操作
    llvm::Value* result;
    switch (inst->opcode()) {
      case HloOpcode::kAdd:
        result = builder.CreateFAdd(lhs, rhs);
        break;
      case HloOpcode::kMultiply:
        result = builder.CreateFMul(lhs, rhs);
        break;
      // ...
    }

    // 存储结果
    EmitStore(&builder, inst, result, idx);
    builder.CreateBr(exit_block);

    builder.SetInsertPoint(exit_block);
    builder.CreateRetVoid();

    return OkStatus();
  }

  Status HandleFusion(const HloInstruction* fusion) {
    // 生成融合kernel

    const HloComputation* fused_computation = fusion->fused_instructions_computation();
    llvm::Function* kernel = CreateKernelFunction(fusion);
    llvm::IRBuilder<> builder(kernel->getEntryBlock());

    // 计算索引
    llvm::Value* idx = EmitLinearIndex(&builder);

    // 按拓扑顺序生成融合指令
    absl::flat_hash_map<const HloInstruction*, llvm::Value*> value_map;

    for (const HloInstruction* inst : fused_computation->MakeInstructionPostOrder()) {
      if (inst->opcode() == HloOpcode::kParameter) {
        // 参数: 从全局内存加载
        value_map[inst] = EmitLoad(&builder, fusion->operand(inst->parameter_number()), idx);
      } else {
        // 计算: 使用之前的值
        std::vector<llvm::Value*> operand_values;
        for (const HloInstruction* operand : inst->operands()) {
          operand_values.push_back(value_map[operand]);
        }

        llvm::Value* result = EmitPrimitiveOp(&builder, inst, operand_values);
        value_map[inst] = result;
      }
    }

    // 存储根指令的结果
    HloInstruction* root = fused_computation->root_instruction();
    EmitStore(&builder, fusion, value_map[root], idx);

    return OkStatus();
  }
};

3.3 Thunk执行

// tensorflow/compiler/xla/service/gpu/thunk.h (简化版)

class Thunk {
 public:
  enum class Kind {
    kKernel,
    kGemm,
    kCopy,
    kConvolution,
    kAllReduce,
    // ...
  };

  virtual ~Thunk() = default;

  // 执行thunk
  virtual Status ExecuteOnStream(const ExecuteParams& params) = 0;

  Kind kind() const { return kind_; }

 protected:
  explicit Thunk(Kind kind) : kind_(kind) {}

 private:
  Kind kind_;
};


class KernelThunk : public Thunk {
 public:
  KernelThunk(std::vector<BufferAllocation::Slice> args,
              std::string kernel_name,
              LaunchDimensions launch_dims)
      : Thunk(Kind::kKernel),
        args_(std::move(args)),
        kernel_name_(std::move(kernel_name)),
        launch_dims_(launch_dims) {}

  Status ExecuteOnStream(const ExecuteParams& params) override {
    // 获取kernel函数
    se::KernelBase* kernel = GetKernel(kernel_name_, params.stream->parent());

    // 准备参数
    se::KernelArgsArray</* max args */> kernel_args;
    for (const auto& slice : args_) {
      void* ptr = params.buffer_allocations->GetDeviceAddress(slice).opaque();
      kernel_args.add_argument(ptr);
    }

    // 启动kernel
    return params.stream->ThenLaunch(
        launch_dims_.thread_counts_per_block(),
        launch_dims_.block_counts(),
        *kernel,
        kernel_args);
  }

 private:
  std::vector<BufferAllocation::Slice> args_;
  std::string kernel_name_;
  LaunchDimensions launch_dims_;
};


class GemmThunk : public Thunk {
 public:
  // GEMM操作: C = alpha * A @ B + beta * C

  Status ExecuteOnStream(const ExecuteParams& params) override {
    // 获取buffer指针
    void* lhs_ptr = params.buffer_allocations->GetDeviceAddress(lhs_buffer_).opaque();
    void* rhs_ptr = params.buffer_allocations->GetDeviceAddress(rhs_buffer_).opaque();
    void* output_ptr = params.buffer_allocations->GetDeviceAddress(output_buffer_).opaque();

    // 调用cuBLAS
    se::blas::BlasSupport* blas = params.stream->parent()->AsBlas();
    return blas->DoBlasGemm(
        params.stream,
        se::blas::Transpose::kNoTranspose,
        se::blas::Transpose::kNoTranspose,
        m_, n_, k_,
        alpha_,
        lhs_ptr, lhs_type_, lda_,
        rhs_ptr, rhs_type_, ldb_,
        beta_,
        output_ptr, output_type_, ldc_);
  }

 private:
  BufferAllocation::Slice lhs_buffer_, rhs_buffer_, output_buffer_;
  int64_t m_, n_, k_;
  float alpha_, beta_;
  // ...
};

4. JAX与XLA集成

4.1 JAX编译流程

import jax
import jax.numpy as jnp
from jax import lax

# JAX函数到XLA的编译流程

@jax.jit
def attention(query, key, value):
    """
    注意力计算
    JAX会将此函数编译为XLA HLO
    """
    # Q @ K^T
    scores = jnp.einsum('bhqd,bhkd->bhqk', query, key)

    # Scale
    d_k = query.shape[-1]
    scores = scores / jnp.sqrt(d_k)

    # Softmax
    weights = jax.nn.softmax(scores, axis=-1)

    # Weighted sum
    output = jnp.einsum('bhqk,bhkd->bhqd', weights, value)

    return output


# 查看生成的Jaxpr (JAX的IR)
print("Jaxpr:")
print(jax.make_jaxpr(attention)(
    jnp.zeros((2, 8, 1024, 64)),
    jnp.zeros((2, 8, 1024, 64)),
    jnp.zeros((2, 8, 1024, 64))
))


# 查看生成的HLO
def get_hlo(fn, *args):
    """获取XLA HLO文本"""
    lowered = jax.jit(fn).lower(*args)
    compiled = lowered.compile()
    return compiled.as_text()

print("\nHLO:")
print(get_hlo(
    attention,
    jnp.zeros((2, 8, 1024, 64)),
    jnp.zeros((2, 8, 1024, 64)),
    jnp.zeros((2, 8, 1024, 64))
))

4.2 JAX控制流原语

import jax
import jax.numpy as jnp
from jax import lax

# JAX控制流会被编译为XLA控制流指令

# 1. lax.cond - 条件分支
def cond_example(pred, x):
    return lax.cond(
        pred,
        lambda x: jnp.sin(x),   # true分支
        lambda x: jnp.cos(x),   # false分支
        x
    )


# 2. lax.while_loop - while循环
def while_example(x):
    def cond_fn(state):
        i, x = state
        return i < 10

    def body_fn(state):
        i, x = state
        return (i + 1, x * 2)

    init_state = (0, x)
    final_state = lax.while_loop(cond_fn, body_fn, init_state)
    return final_state[1]


# 3. lax.fori_loop - for循环
def fori_example(x):
    def body_fn(i, x):
        return x + jnp.sin(x * i)

    return lax.fori_loop(0, 10, body_fn, x)


# 4. lax.scan - 序列处理
def scan_example(xs):
    def step(carry, x):
        new_carry = carry + x
        output = carry * x
        return new_carry, output

    final_carry, outputs = lax.scan(step, jnp.array(0.0), xs)
    return outputs


# 这些都会被编译为XLA的While/Conditional指令
for fn in [cond_example, while_example, fori_example, scan_example]:
    print(f"\n{fn.__name__}:")
    if fn == cond_example:
        print(jax.make_jaxpr(fn)(True, jnp.array(1.0)))
    elif fn == scan_example:
        print(jax.make_jaxpr(fn)(jnp.arange(10.0)))
    else:
        print(jax.make_jaxpr(fn)(jnp.array(1.0)))

4.3 JAX自定义XLA调用

import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
import jax.numpy as jnp

# 自定义JAX原语
custom_add_p = core.Primitive("custom_add")

def custom_add(x, y):
    """自定义加法操作"""
    return custom_add_p.bind(x, y)

# 抽象评估规则
@custom_add_p.def_abstract_eval
def custom_add_abstract_eval(x, y):
    assert x.shape == y.shape
    return core.ShapedArray(x.shape, x.dtype)

# MLIR lowering规则
def custom_add_lowering(ctx, x, y):
    """将自定义操作lowering为HLO"""
    return mlir.mhlo.AddOp(x, y).results

mlir.register_lowering(custom_add_p, custom_add_lowering)

# 使用自定义操作
@jax.jit
def use_custom_add(x, y):
    return custom_add(x, y)

# 测试
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
result = use_custom_add(x, y)
print(result)  # [5. 7. 9.]

5. XLA性能调优

5.1 XLA配置选项

import os
import jax

# XLA配置选项

# 1. 启用XLA详细日志
os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump'

# 2. 控制优化级别
os.environ['XLA_FLAGS'] = '--xla_backend_optimization_level=3'

# 3. 启用特定优化
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_cudnn_frontend=true '  # 使用cuDNN前端
    '--xla_gpu_enable_triton_gemm=true '      # 使用Triton GEMM
    '--xla_gpu_autotune_level=4'              # 自动调优级别
)

# 4. 内存优化
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_memory_limit_slop_factor=1.2 '
    '--xla_gpu_enable_async_all_reduce=true'
)

# 5. 调试选项
os.environ['XLA_FLAGS'] = (
    '--xla_dump_hlo_as_text=true '
    '--xla_dump_hlo_as_proto=true '
    '--xla_dump_hlo_pass_re=.*'  # dump所有pass
)


# JAX配置
jax.config.update("jax_enable_x64", True)  # 启用64位精度
jax.config.update("jax_debug_nans", True)  # 检测NaN

5.2 性能分析

import jax
import jax.numpy as jnp
from jax.profiler import trace, StepTraceAnnotation

# 1. 使用JAX profiler
def profile_example():
    @jax.jit
    def fn(x):
        for _ in range(10):
            x = jnp.sin(x) + jnp.cos(x)
        return x

    x = jnp.ones((1000, 1000))

    # Profile
    with trace("/tmp/jax_trace"):
        for i in range(100):
            with StepTraceAnnotation("step", step_num=i):
                x = fn(x)
                x.block_until_ready()


# 2. 查看编译时间
def analyze_compilation():
    @jax.jit
    def fn(x):
        return jnp.sin(x) + jnp.cos(x)

    x = jnp.ones((1000, 1000))

    # 第一次调用会触发编译
    import time
    start = time.time()
    _ = fn(x).block_until_ready()
    compile_time = time.time() - start
    print(f"First call (with compile): {compile_time:.3f}s")

    # 后续调用使用缓存
    start = time.time()
    for _ in range(100):
        _ = fn(x).block_until_ready()
    run_time = (time.time() - start) / 100
    print(f"Subsequent calls: {run_time*1000:.3f}ms")


# 3. 分析HLO
def analyze_hlo():
    @jax.jit
    def fn(x, y):
        a = jnp.sin(x)
        b = jnp.cos(y)
        c = a + b
        d = c * x
        return d

    x = jnp.ones((1000, 1000))
    y = jnp.ones((1000, 1000))

    # 获取lowered representation
    lowered = jax.jit(fn).lower(x, y)

    # 查看HLO
    print("Optimized HLO:")
    print(lowered.compile().as_text())

    # 查看cost analysis
    print("\nCost analysis:")
    print(lowered.compile().cost_analysis())

6. 面试高频问题

Q1: XLA的HLO IR有什么特点?

答案要点:

  1. 严格类型: 所有操作数都有精确的形状和类型信息
  2. SSA形式: 每个值只被赋值一次
  3. 显式控制流: While、Conditional等控制流显式表示
  4. 设备无关: 同一IR可以编译到不同后端
  5. 丰富的元数据: 包含优化提示和调试信息

Q2: XLA如何实现算子融合?

答案要点:

  1. 循环融合: Element-wise操作融合到同一kernel
  2. 输入融合: 生产者融合到消费者
  3. 输出融合: 多个输出共享计算
  4. 融合决策: 基于内存访问模式和计算密度

Q3: JAX是如何与XLA集成的?

答案要点:

  1. 追踪: JAX追踪Python函数生成Jaxpr
  2. Lowering: Jaxpr转换为StableHLO/MHLO
  3. 编译: XLA编译HLO生成可执行代码
  4. 缓存: 编译结果按输入签名缓存

Q4: XLA相比于eager执行有什么优势?

答案要点:

  1. 算子融合: 减少kernel启动开销和内存访问
  2. 内存优化: 更好的buffer复用
  3. 布局优化: 自动选择最优数据布局
  4. 全局优化: 跨操作的优化如CSE、DCE

Q5: XLA的Buffer Assignment如何优化内存使用?

答案要点:

  1. 活跃性分析: 确定每个值的生命周期
  2. Buffer复用: 不活跃的buffer可以被复用
  3. 布局协调: 减少不必要的数据拷贝
  4. 临时Buffer池: 预分配和复用临时内存

7. 学习资源

官方文档

  • XLA Overview
  • XLA Architecture
  • JAX Documentation

源码阅读

  • tensorflow/compiler/xla/ - XLA主代码
  • tensorflow/compiler/xla/service/ - 优化Pass
  • tensorflow/compiler/xla/service/gpu/ - GPU后端
  • jax/_src/ - JAX实现

推荐论文

  • "XLA: TensorFlow, Compiled"
  • "JAX: Composable Transformations of Python+NumPy Programs"
Prev
02-TorchDynamo与torch.compile
Next
04-算子融合与Kernel优化