03-XLA编译器深度解析
概述
XLA (Accelerated Linear Algebra) 是Google开发的深度学习编译器,为TensorFlow和JAX提供图编译优化。本章深入解析XLA的架构设计、HLO中间表示、优化Pass以及代码生成机制,帮助读者理解静态图编译器的工作原理。
XLA整体架构
┌─────────────────────────────────────────────────────────────────────────────┐
│ XLA 架构 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ 前端 (Frontend) │ │
│ │ TensorFlow Graph │ JAX Jaxpr │ PyTorch/XLA │ StableHLO │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ HLO (High Level Optimizer) │ │
│ │ ├─ HLO IR: 高层中间表示 │ │
│ │ ├─ HLO Module: 计算图容器 │ │
│ │ └─ HLO Instruction: 计算节点 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ HLO Optimization Passes │ │
│ │ ├─ 代数简化 (Algebraic Simplifier) │ │
│ │ ├─ 算子融合 (Fusion) │ │
│ │ ├─ 布局分配 (Layout Assignment) │ │
│ │ ├─ 内存调度 (Buffer Assignment) │ │
│ │ └─ 更多优化... │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ 后端 (Backend) │ │
│ │ CPU (LLVM) │ GPU (CUDA/ROCm) │ TPU │ 其他加速器 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ 可执行代码 │ │
│ │ HloModule → Thunks/Executables │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
1. HLO中间表示
1.1 HLO核心概念
// tensorflow/compiler/xla/hlo/ir/hlo_instruction.h (简化版)
// HLO操作码
enum class HloOpcode {
// 元素级操作
kAdd,
kSubtract,
kMultiply,
kDivide,
kExp,
kLog,
kSin,
kCos,
kTanh,
// 规约操作
kReduce,
kReduceWindow,
// 矩阵操作
kDot,
kConvolution,
// 数据移动
kBroadcast,
kTranspose,
kReshape,
kSlice,
kConcatenate,
// 控制流
kWhile,
kConditional,
kCall,
// 通信
kAllReduce,
kAllGather,
kReduceScatter,
// 其他
kParameter,
kConstant,
kTuple,
kGetTupleElement,
// ...
};
class HloInstruction {
public:
// 操作码
HloOpcode opcode() const { return opcode_; }
// 形状信息
const Shape& shape() const { return shape_; }
// 操作数
const std::vector<HloInstruction*>& operands() const { return operands_; }
// 用户 (使用此指令的其他指令)
const std::vector<HloInstruction*>& users() const { return users_; }
// 名称
const std::string& name() const { return name_; }
// 元数据
const OpMetadata& metadata() const { return metadata_; }
private:
HloOpcode opcode_;
Shape shape_;
std::vector<HloInstruction*> operands_;
std::vector<HloInstruction*> users_;
std::string name_;
OpMetadata metadata_;
};
class HloComputation {
public:
// 计算名称
const std::string& name() const { return name_; }
// 指令列表
const std::vector<std::unique_ptr<HloInstruction>>& instructions() const {
return instructions_;
}
// 根指令 (输出)
HloInstruction* root_instruction() const { return root_instruction_; }
// 参数
const std::vector<HloInstruction*>& parameter_instructions() const {
return param_instructions_;
}
private:
std::string name_;
std::vector<std::unique_ptr<HloInstruction>> instructions_;
HloInstruction* root_instruction_;
std::vector<HloInstruction*> param_instructions_;
};
class HloModule {
public:
// 入口计算
HloComputation* entry_computation() const { return entry_computation_; }
// 所有计算 (包括子计算)
const std::vector<std::unique_ptr<HloComputation>>& computations() const {
return computations_;
}
// 配置
const HloModuleConfig& config() const { return config_; }
private:
HloComputation* entry_computation_;
std::vector<std::unique_ptr<HloComputation>> computations_;
HloModuleConfig config_;
};
1.2 HLO文本表示
// HLO文本格式示例
HloModule simple_example
// 入口计算
ENTRY main {
// 参数定义
%p0 = f32[128,256] parameter(0), metadata={op_name="input_a"}
%p1 = f32[256,512] parameter(1), metadata={op_name="input_b"}
%p2 = f32[512] parameter(2), metadata={op_name="bias"}
// 矩阵乘法
%dot = f32[128,512] dot(%p0, %p1),
lhs_contracting_dims={1},
rhs_contracting_dims={0}
// 广播bias
%broadcast = f32[128,512] broadcast(%p2), dimensions={1}
// 加法
%add = f32[128,512] add(%dot, %broadcast)
// ReLU (使用maximum)
%zero = f32[] constant(0)
%broadcast_zero = f32[128,512] broadcast(%zero), dimensions={}
%relu = f32[128,512] maximum(%add, %broadcast_zero)
ROOT %output = f32[128,512] tuple(%relu)
}
// 更复杂的示例: 带有reduce的softmax
HloModule softmax_example
%max_computation {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
ROOT %max = f32[] maximum(%p0, %p1)
}
%add_computation {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
ROOT %add = f32[] add(%p0, %p1)
}
ENTRY softmax {
%input = f32[16,1024] parameter(0)
// Step 1: 计算max
%neg_inf = f32[] constant(-inf)
%max = f32[16] reduce(%input, %neg_inf), dimensions={1},
to_apply=%max_computation
// Step 2: 减去max
%broadcast_max = f32[16,1024] broadcast(%max), dimensions={0}
%sub = f32[16,1024] subtract(%input, %broadcast_max)
// Step 3: exp
%exp = f32[16,1024] exponential(%sub)
// Step 4: 求和
%zero = f32[] constant(0)
%sum = f32[16] reduce(%exp, %zero), dimensions={1},
to_apply=%add_computation
// Step 5: 除法
%broadcast_sum = f32[16,1024] broadcast(%sum), dimensions={0}
ROOT %softmax = f32[16,1024] divide(%exp, %broadcast_sum)
}
1.3 HLO构建API
// 使用C++ API构建HLO
#include "tensorflow/compiler/xla/client/xla_builder.h"
XlaOp BuildMatMulBiasRelu(XlaBuilder* builder,
XlaOp input,
XlaOp weight,
XlaOp bias) {
// MatMul
XlaOp dot = Dot(input, weight);
// BiasAdd
XlaOp biased = Add(dot, bias, /*broadcast_dimensions=*/{1});
// ReLU
XlaOp zero = ConstantR0<float>(builder, 0.0f);
return Max(biased, zero);
}
// Python API (JAX)
import jax
import jax.numpy as jnp
@jax.jit
def matmul_bias_relu(x, weight, bias):
"""JAX会将此函数编译为HLO"""
out = jnp.dot(x, weight)
out = out + bias
out = jnp.maximum(out, 0)
return out
# 查看生成的HLO
from jax._src import xla_bridge
hlo_text = jax.make_jaxpr(matmul_bias_relu)(
jnp.zeros((128, 256)),
jnp.zeros((256, 512)),
jnp.zeros((512,))
)
print(hlo_text)
2. HLO优化Pass
2.1 优化Pass框架
// tensorflow/compiler/xla/service/hlo_pass_interface.h
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
// Pass名称
virtual std::string_view name() const = 0;
// 运行Pass
virtual StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) = 0;
};
// 模块级Pass基类
class HloModulePass : public HloPassInterface {
protected:
// 子类实现此方法
virtual StatusOr<bool> RunOnModule(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) = 0;
};
// Pass管道
class HloPassPipeline : public HloPassInterface {
public:
void AddPass(std::unique_ptr<HloPassInterface> pass) {
passes_.push_back(std::move(pass));
}
StatusOr<bool> Run(HloModule* module,
const absl::flat_hash_set<absl::string_view>& threads) override {
bool changed = false;
for (auto& pass : passes_) {
TF_ASSIGN_OR_RETURN(bool pass_changed, pass->Run(module, threads));
changed |= pass_changed;
}
return changed;
}
private:
std::vector<std::unique_ptr<HloPassInterface>> passes_;
};
2.2 代数简化 (Algebraic Simplifier)
// tensorflow/compiler/xla/service/algebraic_simplifier.cc (简化版)
class AlgebraicSimplifier : public HloModulePass {
public:
std::string_view name() const override { return "algebraic-simplifier"; }
protected:
StatusOr<bool> RunOnModule(HloModule* module,
const absl::flat_hash_set<absl::string_view>&) override {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool inst_changed, SimplifyInstruction(inst));
changed |= inst_changed;
}
}
return changed;
}
private:
StatusOr<bool> SimplifyInstruction(HloInstruction* inst) {
switch (inst->opcode()) {
case HloOpcode::kAdd:
return SimplifyAdd(inst);
case HloOpcode::kMultiply:
return SimplifyMultiply(inst);
case HloOpcode::kDivide:
return SimplifyDivide(inst);
case HloOpcode::kReshape:
return SimplifyReshape(inst);
// ... 更多简化规则
default:
return false;
}
}
StatusOr<bool> SimplifyAdd(HloInstruction* add) {
// 规则: x + 0 = x
if (IsConstantZero(add->operand(1))) {
return ReplaceWithOperand(add, 0);
}
if (IsConstantZero(add->operand(0))) {
return ReplaceWithOperand(add, 1);
}
// 规则: x + x = 2 * x
if (add->operand(0) == add->operand(1)) {
auto two = MakeConstant(2.0f, add->shape());
auto mul = MakeMultiply(two, add->operand(0));
return ReplaceInstruction(add, mul);
}
return false;
}
StatusOr<bool> SimplifyMultiply(HloInstruction* mul) {
// 规则: x * 0 = 0
if (IsConstantZero(mul->operand(0)) || IsConstantZero(mul->operand(1))) {
return ReplaceWithZero(mul);
}
// 规则: x * 1 = x
if (IsConstantOne(mul->operand(1))) {
return ReplaceWithOperand(mul, 0);
}
if (IsConstantOne(mul->operand(0))) {
return ReplaceWithOperand(mul, 1);
}
return false;
}
StatusOr<bool> SimplifyReshape(HloInstruction* reshape) {
// 规则: reshape(reshape(x)) = reshape(x)
if (reshape->operand(0)->opcode() == HloOpcode::kReshape) {
auto inner_reshape = reshape->operand(0);
auto new_reshape = MakeReshape(inner_reshape->operand(0), reshape->shape());
return ReplaceInstruction(reshape, new_reshape);
}
// 规则: reshape(x) where shape unchanged = x
if (ShapeUtil::Equal(reshape->shape(), reshape->operand(0)->shape())) {
return ReplaceWithOperand(reshape, 0);
}
return false;
}
};
2.3 算子融合 (Fusion)
// tensorflow/compiler/xla/service/gpu/gpu_fusible.cc (简化版)
class GpuInstructionFusion : public HloModulePass {
public:
std::string_view name() const override { return "gpu-instruction-fusion"; }
protected:
StatusOr<bool> RunOnModule(HloModule* module,
const absl::flat_hash_set<absl::string_view>&) override {
bool changed = false;
for (HloComputation* computation : module->computations()) {
// 反向遍历,优先融合靠近输出的指令
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
if (!ShouldFuse(inst)) continue;
// 尝试融合producer
for (HloInstruction* operand : inst->operands()) {
if (CanFuseProducer(inst, operand)) {
TF_ASSIGN_OR_RETURN(
HloInstruction* fused,
Fuse(inst, operand));
changed = true;
}
}
}
}
return changed;
}
private:
bool ShouldFuse(HloInstruction* inst) {
// 只融合某些类型的操作
return inst->opcode() == HloOpcode::kFusion ||
IsElementwise(inst) ||
inst->opcode() == HloOpcode::kReduce;
}
bool CanFuseProducer(HloInstruction* consumer, HloInstruction* producer) {
// 检查是否可融合
// 1. producer只有一个consumer
if (producer->user_count() != 1) return false;
// 2. 都是element-wise操作
if (!IsElementwise(producer) || !IsElementwise(consumer)) {
// 特殊处理reduce
if (consumer->opcode() == HloOpcode::kReduce &&
IsElementwise(producer)) {
return true;
}
return false;
}
// 3. 形状兼容
if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) {
return false;
}
return true;
}
StatusOr<HloInstruction*> Fuse(HloInstruction* consumer,
HloInstruction* producer) {
HloComputation* computation = consumer->parent();
// 创建或更新fusion指令
HloInstruction* fusion;
if (consumer->opcode() == HloOpcode::kFusion) {
// 扩展现有fusion
fusion = consumer;
fusion->FuseInstruction(producer);
} else {
// 创建新fusion
fusion = computation->AddInstruction(
HloInstruction::CreateFusion(
consumer->shape(),
HloInstruction::FusionKind::kLoop, // 循环融合
{producer},
computation->AddEmbeddedComputation(...)));
TF_RETURN_IF_ERROR(consumer->ReplaceAllUsesWith(fusion));
}
return fusion;
}
};
// HLO融合示例
/*
融合前:
%p0 = f32[128,256] parameter(0)
%p1 = f32[128,256] parameter(1)
%add = f32[128,256] add(%p0, %p1)
%exp = f32[128,256] exponential(%add)
%mul = f32[128,256] multiply(%exp, %p0)
融合后:
%fusion = f32[128,256] fusion(%p0, %p1), kind=kLoop,
calls=
{
%param0 = f32[128,256] parameter(0)
%param1 = f32[128,256] parameter(1)
%add = f32[128,256] add(%param0, %param1)
%exp = f32[128,256] exponential(%add)
ROOT %mul = f32[128,256] multiply(%exp, %param0)
}
*/
2.4 布局分配 (Layout Assignment)
// tensorflow/compiler/xla/service/layout_assignment.cc (简化版)
// 数据布局表示
// NCHW: batch, channel, height, width
// NHWC: batch, height, width, channel
class GpuLayoutAssignment : public HloModulePass {
public:
std::string_view name() const override { return "gpu-layout-assignment"; }
protected:
StatusOr<bool> RunOnModule(HloModule* module,
const absl::flat_hash_set<absl::string_view>&) override {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* inst : computation->instructions()) {
// 为每个操作选择最优布局
TF_ASSIGN_OR_RETURN(Layout optimal_layout,
ChooseOptimalLayout(inst));
if (NeedsLayoutChange(inst, optimal_layout)) {
InsertLayoutTransform(inst, optimal_layout);
changed = true;
}
}
}
return changed;
}
private:
Layout ChooseOptimalLayout(HloInstruction* inst) {
switch (inst->opcode()) {
case HloOpcode::kConvolution:
// GPU卷积优先使用NHWC + Tensor Core
return Layout::NHWC();
case HloOpcode::kDot:
// 矩阵乘法: 行主序
return Layout::RowMajor();
default:
// 默认: 保持输入布局
return inst->operand(0)->shape().layout();
}
}
void InsertLayoutTransform(HloInstruction* inst, const Layout& target) {
HloComputation* computation = inst->parent();
// 在操作前后插入transpose/reshape
for (int i = 0; i < inst->operand_count(); ++i) {
HloInstruction* operand = inst->mutable_operand(i);
if (operand->shape().layout() != target) {
// 插入布局转换
HloInstruction* transform = computation->AddInstruction(
HloInstruction::CreateTranspose(
ShapeUtil::MakeShapeWithLayout(...),
operand,
GetPermutation(operand->shape().layout(), target)));
TF_CHECK_OK(inst->ReplaceOperandWith(i, transform));
}
}
}
};
2.5 内存分配 (Buffer Assignment)
// tensorflow/compiler/xla/service/buffer_assignment.cc (简化版)
class BufferAssignment {
public:
// 为每个HLO值分配buffer
static StatusOr<std::unique_ptr<BufferAssignment>> Run(
const HloModule* module,
std::unique_ptr<HloOrdering> hlo_ordering,
BufferValue::SizeFunction buffer_size,
LogicalBuffer::AlignmentFunction alignment) {
auto assignment = std::make_unique<BufferAssignment>(module);
// 1. 收集所有需要buffer的值
std::vector<const LogicalBuffer*> buffers;
for (const HloComputation* comp : module->computations()) {
for (const HloInstruction* inst : comp->instructions()) {
for (const LogicalBuffer* buffer : GetBuffers(inst)) {
buffers.push_back(buffer);
}
}
}
// 2. 活跃性分析
LivenessAnalysis liveness = ComputeLiveness(module, *hlo_ordering);
// 3. 分配buffer,尽可能重用
for (const LogicalBuffer* buffer : buffers) {
// 查找可重用的buffer
BufferAllocation* reusable = FindReusableAllocation(
assignment.get(), buffer, liveness);
if (reusable != nullptr) {
reusable->AddAssignment(buffer);
} else {
// 创建新buffer
auto* new_allocation = assignment->NewAllocation(
buffer, buffer_size(*buffer), alignment(*buffer));
}
}
return assignment;
}
private:
BufferAllocation* FindReusableAllocation(
BufferAssignment* assignment,
const LogicalBuffer* buffer,
const LivenessAnalysis& liveness) {
for (BufferAllocation& alloc : assignment->allocations_) {
// 检查是否可重用
bool can_reuse = true;
for (const LogicalBuffer* assigned : alloc.assigned_buffers()) {
if (liveness.MayInterfere(*assigned, *buffer)) {
can_reuse = false;
break;
}
}
if (can_reuse && alloc.size() >= GetSize(buffer)) {
return &alloc;
}
}
return nullptr;
}
};
3. GPU代码生成
3.1 GPU后端架构
┌─────────────────────────────────────────────────────────────────────────────┐
│ XLA GPU 后端 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 优化后的 HLO Module │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ IrEmitter │ │
│ │ 将HLO指令转换为LLVM IR │ │
│ │ ├─ Element-wise → PTX内联 │ │
│ │ ├─ Reduction → 多阶段规约kernel │ │
│ │ ├─ Dot → cuBLAS调用或自定义kernel │ │
│ │ └─ Fusion → 融合kernel │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ LLVM IR │ │
│ │ LLVM优化Pass │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ NVPTX 后端 │ │
│ │ LLVM IR → PTX │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ PTXAS │ │
│ │ PTX → SASS (GPU机器码) │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Thunk │ │
│ │ 运行时执行单元 │ │
│ │ ├─ KernelThunk: GPU kernel启动 │ │
│ │ ├─ GemmThunk: cuBLAS调用 │ │
│ │ └─ CopyThunk: 内存拷贝 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
3.2 IrEmitter代码生成
// tensorflow/compiler/xla/service/gpu/ir_emitter.cc (简化版)
class IrEmitter {
public:
// 为HLO计算生成LLVM IR
Status EmitComputation(const HloComputation* computation) {
for (const HloInstruction* inst : computation->MakeInstructionPostOrder()) {
TF_RETURN_IF_ERROR(HandleInstruction(inst));
}
return OkStatus();
}
private:
Status HandleInstruction(const HloInstruction* inst) {
switch (inst->opcode()) {
case HloOpcode::kAdd:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
return HandleElementwiseBinary(inst);
case HloOpcode::kExp:
case HloOpcode::kLog:
case HloOpcode::kSin:
return HandleElementwiseUnary(inst);
case HloOpcode::kDot:
return HandleDot(inst);
case HloOpcode::kReduce:
return HandleReduce(inst);
case HloOpcode::kFusion:
return HandleFusion(inst);
// ... 更多操作
default:
return Unimplemented("Opcode: %s", HloOpcodeString(inst->opcode()));
}
}
Status HandleElementwiseBinary(const HloInstruction* inst) {
// 生成逐元素二元操作的kernel
llvm::Function* kernel = CreateKernelFunction(inst);
llvm::IRBuilder<> builder(kernel->getEntryBlock());
// 计算线程索引
llvm::Value* tid = EmitThreadId(&builder);
llvm::Value* bid = EmitBlockId(&builder);
llvm::Value* idx = builder.CreateAdd(
builder.CreateMul(bid, GetBlockSize()),
tid);
// 边界检查
llvm::Value* in_bounds = builder.CreateICmpULT(idx, GetNumElements(inst));
llvm::BasicBlock* compute_block = CreateBlock("compute");
llvm::BasicBlock* exit_block = CreateBlock("exit");
builder.CreateCondBr(in_bounds, compute_block, exit_block);
// 计算块
builder.SetInsertPoint(compute_block);
// 加载操作数
llvm::Value* lhs = EmitLoad(&builder, inst->operand(0), idx);
llvm::Value* rhs = EmitLoad(&builder, inst->operand(1), idx);
// 执行操作
llvm::Value* result;
switch (inst->opcode()) {
case HloOpcode::kAdd:
result = builder.CreateFAdd(lhs, rhs);
break;
case HloOpcode::kMultiply:
result = builder.CreateFMul(lhs, rhs);
break;
// ...
}
// 存储结果
EmitStore(&builder, inst, result, idx);
builder.CreateBr(exit_block);
builder.SetInsertPoint(exit_block);
builder.CreateRetVoid();
return OkStatus();
}
Status HandleFusion(const HloInstruction* fusion) {
// 生成融合kernel
const HloComputation* fused_computation = fusion->fused_instructions_computation();
llvm::Function* kernel = CreateKernelFunction(fusion);
llvm::IRBuilder<> builder(kernel->getEntryBlock());
// 计算索引
llvm::Value* idx = EmitLinearIndex(&builder);
// 按拓扑顺序生成融合指令
absl::flat_hash_map<const HloInstruction*, llvm::Value*> value_map;
for (const HloInstruction* inst : fused_computation->MakeInstructionPostOrder()) {
if (inst->opcode() == HloOpcode::kParameter) {
// 参数: 从全局内存加载
value_map[inst] = EmitLoad(&builder, fusion->operand(inst->parameter_number()), idx);
} else {
// 计算: 使用之前的值
std::vector<llvm::Value*> operand_values;
for (const HloInstruction* operand : inst->operands()) {
operand_values.push_back(value_map[operand]);
}
llvm::Value* result = EmitPrimitiveOp(&builder, inst, operand_values);
value_map[inst] = result;
}
}
// 存储根指令的结果
HloInstruction* root = fused_computation->root_instruction();
EmitStore(&builder, fusion, value_map[root], idx);
return OkStatus();
}
};
3.3 Thunk执行
// tensorflow/compiler/xla/service/gpu/thunk.h (简化版)
class Thunk {
public:
enum class Kind {
kKernel,
kGemm,
kCopy,
kConvolution,
kAllReduce,
// ...
};
virtual ~Thunk() = default;
// 执行thunk
virtual Status ExecuteOnStream(const ExecuteParams& params) = 0;
Kind kind() const { return kind_; }
protected:
explicit Thunk(Kind kind) : kind_(kind) {}
private:
Kind kind_;
};
class KernelThunk : public Thunk {
public:
KernelThunk(std::vector<BufferAllocation::Slice> args,
std::string kernel_name,
LaunchDimensions launch_dims)
: Thunk(Kind::kKernel),
args_(std::move(args)),
kernel_name_(std::move(kernel_name)),
launch_dims_(launch_dims) {}
Status ExecuteOnStream(const ExecuteParams& params) override {
// 获取kernel函数
se::KernelBase* kernel = GetKernel(kernel_name_, params.stream->parent());
// 准备参数
se::KernelArgsArray</* max args */> kernel_args;
for (const auto& slice : args_) {
void* ptr = params.buffer_allocations->GetDeviceAddress(slice).opaque();
kernel_args.add_argument(ptr);
}
// 启动kernel
return params.stream->ThenLaunch(
launch_dims_.thread_counts_per_block(),
launch_dims_.block_counts(),
*kernel,
kernel_args);
}
private:
std::vector<BufferAllocation::Slice> args_;
std::string kernel_name_;
LaunchDimensions launch_dims_;
};
class GemmThunk : public Thunk {
public:
// GEMM操作: C = alpha * A @ B + beta * C
Status ExecuteOnStream(const ExecuteParams& params) override {
// 获取buffer指针
void* lhs_ptr = params.buffer_allocations->GetDeviceAddress(lhs_buffer_).opaque();
void* rhs_ptr = params.buffer_allocations->GetDeviceAddress(rhs_buffer_).opaque();
void* output_ptr = params.buffer_allocations->GetDeviceAddress(output_buffer_).opaque();
// 调用cuBLAS
se::blas::BlasSupport* blas = params.stream->parent()->AsBlas();
return blas->DoBlasGemm(
params.stream,
se::blas::Transpose::kNoTranspose,
se::blas::Transpose::kNoTranspose,
m_, n_, k_,
alpha_,
lhs_ptr, lhs_type_, lda_,
rhs_ptr, rhs_type_, ldb_,
beta_,
output_ptr, output_type_, ldc_);
}
private:
BufferAllocation::Slice lhs_buffer_, rhs_buffer_, output_buffer_;
int64_t m_, n_, k_;
float alpha_, beta_;
// ...
};
4. JAX与XLA集成
4.1 JAX编译流程
import jax
import jax.numpy as jnp
from jax import lax
# JAX函数到XLA的编译流程
@jax.jit
def attention(query, key, value):
"""
注意力计算
JAX会将此函数编译为XLA HLO
"""
# Q @ K^T
scores = jnp.einsum('bhqd,bhkd->bhqk', query, key)
# Scale
d_k = query.shape[-1]
scores = scores / jnp.sqrt(d_k)
# Softmax
weights = jax.nn.softmax(scores, axis=-1)
# Weighted sum
output = jnp.einsum('bhqk,bhkd->bhqd', weights, value)
return output
# 查看生成的Jaxpr (JAX的IR)
print("Jaxpr:")
print(jax.make_jaxpr(attention)(
jnp.zeros((2, 8, 1024, 64)),
jnp.zeros((2, 8, 1024, 64)),
jnp.zeros((2, 8, 1024, 64))
))
# 查看生成的HLO
def get_hlo(fn, *args):
"""获取XLA HLO文本"""
lowered = jax.jit(fn).lower(*args)
compiled = lowered.compile()
return compiled.as_text()
print("\nHLO:")
print(get_hlo(
attention,
jnp.zeros((2, 8, 1024, 64)),
jnp.zeros((2, 8, 1024, 64)),
jnp.zeros((2, 8, 1024, 64))
))
4.2 JAX控制流原语
import jax
import jax.numpy as jnp
from jax import lax
# JAX控制流会被编译为XLA控制流指令
# 1. lax.cond - 条件分支
def cond_example(pred, x):
return lax.cond(
pred,
lambda x: jnp.sin(x), # true分支
lambda x: jnp.cos(x), # false分支
x
)
# 2. lax.while_loop - while循环
def while_example(x):
def cond_fn(state):
i, x = state
return i < 10
def body_fn(state):
i, x = state
return (i + 1, x * 2)
init_state = (0, x)
final_state = lax.while_loop(cond_fn, body_fn, init_state)
return final_state[1]
# 3. lax.fori_loop - for循环
def fori_example(x):
def body_fn(i, x):
return x + jnp.sin(x * i)
return lax.fori_loop(0, 10, body_fn, x)
# 4. lax.scan - 序列处理
def scan_example(xs):
def step(carry, x):
new_carry = carry + x
output = carry * x
return new_carry, output
final_carry, outputs = lax.scan(step, jnp.array(0.0), xs)
return outputs
# 这些都会被编译为XLA的While/Conditional指令
for fn in [cond_example, while_example, fori_example, scan_example]:
print(f"\n{fn.__name__}:")
if fn == cond_example:
print(jax.make_jaxpr(fn)(True, jnp.array(1.0)))
elif fn == scan_example:
print(jax.make_jaxpr(fn)(jnp.arange(10.0)))
else:
print(jax.make_jaxpr(fn)(jnp.array(1.0)))
4.3 JAX自定义XLA调用
import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
import jax.numpy as jnp
# 自定义JAX原语
custom_add_p = core.Primitive("custom_add")
def custom_add(x, y):
"""自定义加法操作"""
return custom_add_p.bind(x, y)
# 抽象评估规则
@custom_add_p.def_abstract_eval
def custom_add_abstract_eval(x, y):
assert x.shape == y.shape
return core.ShapedArray(x.shape, x.dtype)
# MLIR lowering规则
def custom_add_lowering(ctx, x, y):
"""将自定义操作lowering为HLO"""
return mlir.mhlo.AddOp(x, y).results
mlir.register_lowering(custom_add_p, custom_add_lowering)
# 使用自定义操作
@jax.jit
def use_custom_add(x, y):
return custom_add(x, y)
# 测试
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
result = use_custom_add(x, y)
print(result) # [5. 7. 9.]
5. XLA性能调优
5.1 XLA配置选项
import os
import jax
# XLA配置选项
# 1. 启用XLA详细日志
os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump'
# 2. 控制优化级别
os.environ['XLA_FLAGS'] = '--xla_backend_optimization_level=3'
# 3. 启用特定优化
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_cudnn_frontend=true ' # 使用cuDNN前端
'--xla_gpu_enable_triton_gemm=true ' # 使用Triton GEMM
'--xla_gpu_autotune_level=4' # 自动调优级别
)
# 4. 内存优化
os.environ['XLA_FLAGS'] = (
'--xla_gpu_memory_limit_slop_factor=1.2 '
'--xla_gpu_enable_async_all_reduce=true'
)
# 5. 调试选项
os.environ['XLA_FLAGS'] = (
'--xla_dump_hlo_as_text=true '
'--xla_dump_hlo_as_proto=true '
'--xla_dump_hlo_pass_re=.*' # dump所有pass
)
# JAX配置
jax.config.update("jax_enable_x64", True) # 启用64位精度
jax.config.update("jax_debug_nans", True) # 检测NaN
5.2 性能分析
import jax
import jax.numpy as jnp
from jax.profiler import trace, StepTraceAnnotation
# 1. 使用JAX profiler
def profile_example():
@jax.jit
def fn(x):
for _ in range(10):
x = jnp.sin(x) + jnp.cos(x)
return x
x = jnp.ones((1000, 1000))
# Profile
with trace("/tmp/jax_trace"):
for i in range(100):
with StepTraceAnnotation("step", step_num=i):
x = fn(x)
x.block_until_ready()
# 2. 查看编译时间
def analyze_compilation():
@jax.jit
def fn(x):
return jnp.sin(x) + jnp.cos(x)
x = jnp.ones((1000, 1000))
# 第一次调用会触发编译
import time
start = time.time()
_ = fn(x).block_until_ready()
compile_time = time.time() - start
print(f"First call (with compile): {compile_time:.3f}s")
# 后续调用使用缓存
start = time.time()
for _ in range(100):
_ = fn(x).block_until_ready()
run_time = (time.time() - start) / 100
print(f"Subsequent calls: {run_time*1000:.3f}ms")
# 3. 分析HLO
def analyze_hlo():
@jax.jit
def fn(x, y):
a = jnp.sin(x)
b = jnp.cos(y)
c = a + b
d = c * x
return d
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
# 获取lowered representation
lowered = jax.jit(fn).lower(x, y)
# 查看HLO
print("Optimized HLO:")
print(lowered.compile().as_text())
# 查看cost analysis
print("\nCost analysis:")
print(lowered.compile().cost_analysis())
6. 面试高频问题
Q1: XLA的HLO IR有什么特点?
答案要点:
- 严格类型: 所有操作数都有精确的形状和类型信息
- SSA形式: 每个值只被赋值一次
- 显式控制流: While、Conditional等控制流显式表示
- 设备无关: 同一IR可以编译到不同后端
- 丰富的元数据: 包含优化提示和调试信息
Q2: XLA如何实现算子融合?
答案要点:
- 循环融合: Element-wise操作融合到同一kernel
- 输入融合: 生产者融合到消费者
- 输出融合: 多个输出共享计算
- 融合决策: 基于内存访问模式和计算密度
Q3: JAX是如何与XLA集成的?
答案要点:
- 追踪: JAX追踪Python函数生成Jaxpr
- Lowering: Jaxpr转换为StableHLO/MHLO
- 编译: XLA编译HLO生成可执行代码
- 缓存: 编译结果按输入签名缓存
Q4: XLA相比于eager执行有什么优势?
答案要点:
- 算子融合: 减少kernel启动开销和内存访问
- 内存优化: 更好的buffer复用
- 布局优化: 自动选择最优数据布局
- 全局优化: 跨操作的优化如CSE、DCE
Q5: XLA的Buffer Assignment如何优化内存使用?
答案要点:
- 活跃性分析: 确定每个值的生命周期
- Buffer复用: 不活跃的buffer可以被复用
- 布局协调: 减少不必要的数据拷贝
- 临时Buffer池: 预分配和复用临时内存
7. 学习资源
官方文档
源码阅读
- tensorflow/compiler/xla/ - XLA主代码
- tensorflow/compiler/xla/service/ - 优化Pass
- tensorflow/compiler/xla/service/gpu/ - GPU后端
- jax/_src/ - JAX实现
推荐论文
- "XLA: TensorFlow, Compiled"
- "JAX: Composable Transformations of Python+NumPy Programs"