HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

推理引擎原理

概述

推理引擎是将训练好的模型部署到生产环境的核心组件。对于大语言模型而言,推理效率直接决定了服务的延迟、吞吐量和成本。本章深入剖析推理引擎的工作原理,包括模型加载、计算优化、内存管理等关键技术。

推理流程解析

端到端推理流程

┌─────────────────────────────────────────────────────────────────┐
│                    LLM 推理完整流程                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  输入: "What is the capital of France?"                         │
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │ 1. Tokenization (分词)                                   │    │
│  │    Text → [1024, 318, 262, 3139, 286, 4881, 30]         │    │
│  └──────────────────────┬──────────────────────────────────┘    │
│                         │                                        │
│  ┌──────────────────────▼──────────────────────────────────┐    │
│  │ 2. Embedding (嵌入)                                      │    │
│  │    Token IDs → [batch, seq_len, hidden_dim]             │    │
│  └──────────────────────┬──────────────────────────────────┘    │
│                         │                                        │
│  ┌──────────────────────▼──────────────────────────────────┐    │
│  │ 3. Prefill Phase (预填充阶段)                            │    │
│  │    • 并行处理所有输入 token                              │    │
│  │    • 计算并缓存所有 KV Cache                             │    │
│  │    • 计算量大,但只执行一次                              │    │
│  └──────────────────────┬──────────────────────────────────┘    │
│                         │                                        │
│  ┌──────────────────────▼──────────────────────────────────┐    │
│  │ 4. Decode Phase (解码阶段)                               │    │
│  │    Loop:                                                 │    │
│  │    ├─ 输入: 上一个生成的 token                           │    │
│  │    ├─ 使用 KV Cache 进行 attention                       │    │
│  │    ├─ 计算 logits                                        │    │
│  │    ├─ Sampling (采样)                                    │    │
│  │    └─ 输出: 新 token                                     │    │
│  │    Until: EOS 或达到 max_length                          │    │
│  └──────────────────────┬──────────────────────────────────┘    │
│                         │                                        │
│  ┌──────────────────────▼──────────────────────────────────┐    │
│  │ 5. Detokenization (反分词)                               │    │
│  │    Token IDs → "Paris"                                   │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
│  输出: "Paris"                                                  │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Prefill vs Decode 特性对比

┌─────────────────────────────────────────────────────────────────┐
│              Prefill 与 Decode 阶段对比                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  特性              │   Prefill          │   Decode              │
│  ──────────────────┼────────────────────┼────────────────────── │
│  处理 Token 数     │   N (prompt长度)   │   1 (每次一个)        │
│  计算特点          │   计算密集         │   内存带宽密集        │
│  GPU 利用率        │   高               │   低                  │
│  KV Cache          │   生成全部         │   增量更新            │
│  Attention 类型    │   Full Attention   │   Incremental         │
│  批处理效率        │   高               │   低(需动态批处理)  │
│                                                                  │
│  计算量对比 (以 LLaMA-7B, seq_len=2048 为例):                   │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ Prefill: ~26 TFLOPs (一次性)                              │   │
│  │ Decode:  ~13 GFLOPs × 生成长度 (每个 token)               │   │
│  │                                                           │   │
│  │ 单 token 解码的计算量 ≈ Prefill 的 1/2000                 │   │
│  │ 但解码阶段受限于内存带宽,不是计算                        │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

推理引擎架构

核心组件

// inference_engine.go
package inference

import (
    "context"
    "sync"
)

// InferenceEngine 推理引擎
type InferenceEngine struct {
    // 模型
    model          *Model
    modelConfig    *ModelConfig

    // KV Cache 管理
    kvCacheManager *KVCacheManager

    // 调度器
    scheduler      *RequestScheduler

    // 采样器
    sampler        *Sampler

    // 执行器
    executor       Executor

    // 配置
    config         EngineConfig
}

// EngineConfig 引擎配置
type EngineConfig struct {
    // 模型配置
    ModelPath     string
    ModelFormat   string // pytorch, safetensors, onnx
    Dtype         string // float16, bfloat16, int8, int4

    // 硬件配置
    DeviceType    string // cuda, cpu
    DeviceIDs     []int
    TensorParallel int

    // 内存配置
    MaxBatchSize    int
    MaxSeqLen       int
    KVCacheSize     int64 // bytes
    BlockSize       int   // PagedAttention block size

    // 优化配置
    UseFlashAttention bool
    UseContinuousBatching bool
    UseSpeculativeDecoding bool
}

// Model 模型结构
type Model struct {
    // 模型层
    EmbedTokens   *Embedding
    Layers        []*TransformerLayer
    Norm          *RMSNorm
    LMHead        *Linear

    // 配置
    Config        *ModelConfig
}

// ModelConfig 模型配置
type ModelConfig struct {
    VocabSize     int     `json:"vocab_size"`
    HiddenSize    int     `json:"hidden_size"`
    IntermediateSize int  `json:"intermediate_size"`
    NumLayers     int     `json:"num_hidden_layers"`
    NumHeads      int     `json:"num_attention_heads"`
    NumKVHeads    int     `json:"num_key_value_heads"` // GQA
    HeadDim       int     `json:"head_dim"`
    MaxSeqLen     int     `json:"max_position_embeddings"`
    RopeTheta     float64 `json:"rope_theta"`
    RMSNormEps    float64 `json:"rms_norm_eps"`
}

// TransformerLayer Transformer 层
type TransformerLayer struct {
    InputNorm  *RMSNorm
    Attention  *Attention
    PostAttnNorm *RMSNorm
    MLP        *MLP
}

// Attention 注意力层
type Attention struct {
    QProj      *Linear
    KProj      *Linear
    VProj      *Linear
    OProj      *Linear

    NumHeads   int
    NumKVHeads int
    HeadDim    int
}

// MLP 前馈层
type MLP struct {
    GateProj   *Linear
    UpProj     *Linear
    DownProj   *Linear
}

// NewInferenceEngine 创建推理引擎
func NewInferenceEngine(config EngineConfig) (*InferenceEngine, error) {
    engine := &InferenceEngine{
        config: config,
    }

    // 加载模型
    model, modelConfig, err := engine.loadModel(config.ModelPath, config.ModelFormat)
    if err != nil {
        return nil, err
    }
    engine.model = model
    engine.modelConfig = modelConfig

    // 初始化 KV Cache 管理器
    engine.kvCacheManager = NewKVCacheManager(
        modelConfig,
        config.KVCacheSize,
        config.BlockSize,
        config.DeviceIDs,
    )

    // 初始化调度器
    engine.scheduler = NewRequestScheduler(
        config.MaxBatchSize,
        config.UseContinuousBatching,
    )

    // 初始化采样器
    engine.sampler = NewSampler()

    // 初始化执行器
    engine.executor = NewCUDAExecutor(config.DeviceIDs, config.TensorParallel)

    return engine, nil
}

// Generate 生成文本
func (e *InferenceEngine) Generate(ctx context.Context, request *GenerateRequest) (*GenerateResponse, error) {
    // 创建序列
    seq := &Sequence{
        ID:          generateSeqID(),
        InputIDs:    request.InputIDs,
        MaxLen:      request.MaxNewTokens + len(request.InputIDs),
        SamplingParams: request.SamplingParams,
    }

    // 添加到调度队列
    e.scheduler.Add(seq)

    // 等待完成
    return e.waitForCompletion(ctx, seq)
}

// Step 执行一步推理
func (e *InferenceEngine) Step() error {
    // 获取调度批次
    batch := e.scheduler.Schedule()
    if batch == nil {
        return nil
    }

    // 区分 Prefill 和 Decode
    prefillSeqs, decodeSeqs := batch.Split()

    // 执行 Prefill
    if len(prefillSeqs) > 0 {
        if err := e.executePrefill(prefillSeqs); err != nil {
            return err
        }
    }

    // 执行 Decode
    if len(decodeSeqs) > 0 {
        if err := e.executeDecode(decodeSeqs); err != nil {
            return err
        }
    }

    // 更新序列状态
    e.updateSequences(batch)

    return nil
}

// executePrefill 执行预填充
func (e *InferenceEngine) executePrefill(seqs []*Sequence) error {
    // 准备输入
    inputIDs := make([][]int, len(seqs))
    positions := make([][]int, len(seqs))

    for i, seq := range seqs {
        inputIDs[i] = seq.InputIDs
        positions[i] = makePositions(len(seq.InputIDs))
    }

    // 分配 KV Cache
    for _, seq := range seqs {
        blocks, err := e.kvCacheManager.Allocate(seq.ID, len(seq.InputIDs))
        if err != nil {
            return err
        }
        seq.BlockTable = blocks
    }

    // 执行前向传播
    logits, kvCache, err := e.forward(inputIDs, positions, nil, true)
    if err != nil {
        return err
    }

    // 存储 KV Cache
    for i, seq := range seqs {
        e.kvCacheManager.Store(seq.ID, kvCache[i])
    }

    // 采样下一个 token
    for i, seq := range seqs {
        nextToken := e.sampler.Sample(logits[i], seq.SamplingParams)
        seq.OutputIDs = append(seq.OutputIDs, nextToken)
        seq.Stage = StageDecoding
    }

    return nil
}

// executeDecode 执行解码
func (e *InferenceEngine) executeDecode(seqs []*Sequence) error {
    // 准备输入(每个序列只有最新的一个 token)
    inputIDs := make([][]int, len(seqs))
    positions := make([][]int, len(seqs))
    blockTables := make([][]int, len(seqs))

    for i, seq := range seqs {
        lastToken := seq.OutputIDs[len(seq.OutputIDs)-1]
        inputIDs[i] = []int{lastToken}
        positions[i] = []int{len(seq.InputIDs) + len(seq.OutputIDs) - 1}
        blockTables[i] = seq.BlockTable
    }

    // 扩展 KV Cache(如果需要新的 block)
    for _, seq := range seqs {
        curLen := len(seq.InputIDs) + len(seq.OutputIDs)
        if curLen%e.config.BlockSize == 0 {
            newBlock, err := e.kvCacheManager.AllocateBlock()
            if err != nil {
                // 处理内存不足:可能需要抢占
                return err
            }
            seq.BlockTable = append(seq.BlockTable, newBlock)
        }
    }

    // 执行前向传播
    logits, newKV, err := e.forward(inputIDs, positions, blockTables, false)
    if err != nil {
        return err
    }

    // 更新 KV Cache
    for i, seq := range seqs {
        e.kvCacheManager.Append(seq.ID, newKV[i])
    }

    // 采样下一个 token
    for i, seq := range seqs {
        nextToken := e.sampler.Sample(logits[i], seq.SamplingParams)
        seq.OutputIDs = append(seq.OutputIDs, nextToken)

        // 检查是否结束
        if nextToken == e.modelConfig.EosTokenID || len(seq.OutputIDs) >= seq.MaxLen {
            seq.Stage = StageFinished
        }
    }

    return nil
}

// forward 前向传播
func (e *InferenceEngine) forward(
    inputIDs [][]int,
    positions [][]int,
    blockTables [][]int,
    isPrefill bool,
) ([][]float32, []KVCache, error) {
    // 实际实现会调用 CUDA kernel
    return nil, nil, nil
}

// Sequence 序列
type Sequence struct {
    ID             string
    InputIDs       []int
    OutputIDs      []int
    MaxLen         int
    SamplingParams *SamplingParams
    BlockTable     []int // KV Cache block indices
    Stage          SequenceStage
}

type SequenceStage int

const (
    StagePrefill SequenceStage = iota
    StageDecoding
    StageFinished
)

// GenerateRequest 生成请求
type GenerateRequest struct {
    InputIDs       []int
    MaxNewTokens   int
    SamplingParams *SamplingParams
}

// SamplingParams 采样参数
type SamplingParams struct {
    Temperature    float32
    TopP           float32
    TopK           int
    RepetitionPenalty float32
    PresencePenalty  float32
    FrequencyPenalty float32
    StopTokenIDs   []int
}

// GenerateResponse 生成响应
type GenerateResponse struct {
    OutputIDs    []int
    OutputText   string
    FinishReason string // length, stop, error
    Usage        TokenUsage
}

// TokenUsage Token 使用统计
type TokenUsage struct {
    PromptTokens     int
    CompletionTokens int
    TotalTokens      int
}

func generateSeqID() string {
    return ""
}

func makePositions(length int) []int {
    positions := make([]int, length)
    for i := range positions {
        positions[i] = i
    }
    return positions
}

KV Cache 管理

PagedAttention 原理

┌─────────────────────────────────────────────────────────────────┐
│                    PagedAttention 原理                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  传统 KV Cache 问题:                                            │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 预分配固定大小内存(max_seq_len × batch_size)          │   │
│  │ • 实际使用率低(多数序列短于 max_seq_len)                 │   │
│  │ • 内存碎片化严重                                          │   │
│  │ • 难以支持动态批处理                                      │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  PagedAttention 解决方案:                                       │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │                                                           │   │
│  │  物理内存(GPU 显存)             逻辑 KV Cache            │   │
│  │  ┌─────────────────────┐        ┌─────────────────────┐  │   │
│  │  │ Block 0 │ Block 1 │ │        │ Seq 1: [0,2,5]     │  │   │
│  │  │─────────│─────────│ │        │ Seq 2: [1,3]       │  │   │
│  │  │ Block 2 │ Block 3 │ │        │ Seq 3: [4,6,7,8]   │  │   │
│  │  │─────────│─────────│ │   ←→   │                     │  │   │
│  │  │ Block 4 │ Block 5 │ │        │ Block Table:        │  │   │
│  │  │─────────│─────────│ │        │ Seq1: 0→2→5         │  │   │
│  │  │ Block 6 │ Block 7 │ │        │ Seq2: 1→3           │  │   │
│  │  │─────────│─────────│ │        │ Seq3: 4→6→7→8       │  │   │
│  │  │ Block 8 │ Free    │ │        │                     │  │   │
│  │  └─────────────────────┘        └─────────────────────┘  │   │
│  │                                                           │   │
│  │  • 内存按固定大小 Block 分配                              │   │
│  │  • 每个 Block 存储 block_size 个 token 的 KV             │   │
│  │  • 使用 Block Table 建立逻辑→物理映射                    │   │
│  │  • 按需分配,减少内存浪费                                │   │
│  │  • 支持高效的序列复制(共享 Block)                      │   │
│  │                                                           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

KV Cache 管理器实现

// kv_cache_manager.go
package inference

import (
    "errors"
    "sync"
)

// KVCacheManager KV Cache 管理器
type KVCacheManager struct {
    // 配置
    numLayers     int
    numKVHeads    int
    headDim       int
    blockSize     int     // 每个 block 存储的 token 数
    totalBlocks   int

    // Block 池
    freeBlocks    []int
    usedBlocks    map[int]bool

    // 序列到 Block 的映射
    seqBlockTables map[string][]int

    // 物理内存(GPU 显存)
    // shape: [num_blocks, 2, num_layers, block_size, num_kv_heads, head_dim]
    // 2 表示 K 和 V
    kvCache       []float16

    mu            sync.RWMutex
}

// KVCacheConfig KV Cache 配置
type KVCacheConfig struct {
    NumLayers     int
    NumKVHeads    int
    HeadDim       int
    BlockSize     int
    TotalMemory   int64 // 总显存限制
    Dtype         string
}

func NewKVCacheManager(modelConfig *ModelConfig, totalMemory int64, blockSize int, deviceIDs []int) *KVCacheManager {
    // 计算每个 block 的大小
    // block_size = block_tokens × num_layers × 2 × num_kv_heads × head_dim × dtype_size
    bytesPerToken := 2 * modelConfig.NumLayers * modelConfig.NumKVHeads * modelConfig.HeadDim * 2 // FP16
    bytesPerBlock := blockSize * bytesPerToken

    // 计算可分配的 block 数量
    numBlocks := int(totalMemory / int64(bytesPerBlock))

    // 初始化 free blocks
    freeBlocks := make([]int, numBlocks)
    for i := range freeBlocks {
        freeBlocks[i] = i
    }

    return &KVCacheManager{
        numLayers:      modelConfig.NumLayers,
        numKVHeads:     modelConfig.NumKVHeads,
        headDim:        modelConfig.HeadDim,
        blockSize:      blockSize,
        totalBlocks:    numBlocks,
        freeBlocks:     freeBlocks,
        usedBlocks:     make(map[int]bool),
        seqBlockTables: make(map[string][]int),
    }
}

// Allocate 为序列分配 KV Cache blocks
func (m *KVCacheManager) Allocate(seqID string, numTokens int) ([]int, error) {
    m.mu.Lock()
    defer m.mu.Unlock()

    // 计算需要的 block 数
    numBlocks := (numTokens + m.blockSize - 1) / m.blockSize

    // 检查是否有足够的 free blocks
    if len(m.freeBlocks) < numBlocks {
        return nil, errors.New("not enough free blocks")
    }

    // 分配 blocks
    blocks := make([]int, numBlocks)
    for i := 0; i < numBlocks; i++ {
        blockIdx := m.freeBlocks[len(m.freeBlocks)-1]
        m.freeBlocks = m.freeBlocks[:len(m.freeBlocks)-1]
        blocks[i] = blockIdx
        m.usedBlocks[blockIdx] = true
    }

    m.seqBlockTables[seqID] = blocks
    return blocks, nil
}

// AllocateBlock 分配单个 block
func (m *KVCacheManager) AllocateBlock() (int, error) {
    m.mu.Lock()
    defer m.mu.Unlock()

    if len(m.freeBlocks) == 0 {
        return -1, errors.New("no free blocks available")
    }

    blockIdx := m.freeBlocks[len(m.freeBlocks)-1]
    m.freeBlocks = m.freeBlocks[:len(m.freeBlocks)-1]
    m.usedBlocks[blockIdx] = true

    return blockIdx, nil
}

// Free 释放序列的所有 blocks
func (m *KVCacheManager) Free(seqID string) {
    m.mu.Lock()
    defer m.mu.Unlock()

    blocks, ok := m.seqBlockTables[seqID]
    if !ok {
        return
    }

    for _, blockIdx := range blocks {
        delete(m.usedBlocks, blockIdx)
        m.freeBlocks = append(m.freeBlocks, blockIdx)
    }

    delete(m.seqBlockTables, seqID)
}

// Store 存储 KV Cache
func (m *KVCacheManager) Store(seqID string, kvCache KVCache) {
    // 实际实现会写入 GPU 显存
}

// Append 追加 KV Cache(decode 阶段)
func (m *KVCacheManager) Append(seqID string, newKV KVCache) {
    // 追加新的 KV 到最后一个 block
}

// GetBlockTable 获取 block table
func (m *KVCacheManager) GetBlockTable(seqID string) []int {
    m.mu.RLock()
    defer m.mu.RUnlock()
    return m.seqBlockTables[seqID]
}

// Fork 复制序列的 KV Cache(用于 beam search 或投机解码)
func (m *KVCacheManager) Fork(srcSeqID, dstSeqID string) error {
    m.mu.Lock()
    defer m.mu.Unlock()

    srcBlocks, ok := m.seqBlockTables[srcSeqID]
    if !ok {
        return errors.New("source sequence not found")
    }

    // 使用 copy-on-write:共享现有 blocks
    dstBlocks := make([]int, len(srcBlocks))
    copy(dstBlocks, srcBlocks)

    // 增加 block 引用计数(实际实现需要)
    m.seqBlockTables[dstSeqID] = dstBlocks

    return nil
}

// GetStats 获取统计信息
func (m *KVCacheManager) GetStats() KVCacheStats {
    m.mu.RLock()
    defer m.mu.RUnlock()

    return KVCacheStats{
        TotalBlocks: m.totalBlocks,
        UsedBlocks:  len(m.usedBlocks),
        FreeBlocks:  len(m.freeBlocks),
        Utilization: float64(len(m.usedBlocks)) / float64(m.totalBlocks),
    }
}

type KVCacheStats struct {
    TotalBlocks int
    UsedBlocks  int
    FreeBlocks  int
    Utilization float64
}

// KVCache KV Cache 数据
type KVCache struct {
    K []float16 // [num_layers, seq_len, num_kv_heads, head_dim]
    V []float16
}

type float16 = uint16 // 简化表示

计算优化技术

Flash Attention

┌─────────────────────────────────────────────────────────────────┐
│                    Flash Attention 原理                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  标准 Attention 问题:                                           │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ 1. 计算 QK^T:O(N²d) 计算,O(N²) 内存                     │   │
│  │ 2. Softmax:需要完整的 N×N 矩阵                           │   │
│  │ 3. 计算 Score×V:O(N²d) 计算                              │   │
│  │                                                           │   │
│  │ 内存瓶颈:N² 随序列长度平方增长                           │   │
│  │ 例:N=8192, FP16 → 128MB 仅存储 attention scores          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  Flash Attention 解决方案:                                      │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │                                                           │   │
│  │  核心思想:Tiling(分块计算)+ 重计算                     │   │
│  │                                                           │   │
│  │  ┌─────────────────────────────────────────────────┐     │   │
│  │  │     Q              K^T             V            │     │   │
│  │  │  ┌──┬──┐       ┌──┬──┬──┐     ┌──┬──┬──┐       │     │   │
│  │  │  │Q1│  │   ×   │K1│K2│K3│  ×  │V1│V2│V3│       │     │   │
│  │  │  ├──┼──┤       └──┴──┴──┘     └──┴──┴──┘       │     │   │
│  │  │  │Q2│  │                                        │     │   │
│  │  │  └──┴──┘                                        │     │   │
│  │  │                                                 │     │   │
│  │  │  分块计算流程:                                 │     │   │
│  │  │  for each Q block:                              │     │   │
│  │  │    for each K,V block:                          │     │   │
│  │  │      1. 加载 Q_block, K_block, V_block 到 SRAM  │     │   │
│  │  │      2. 计算局部 attention                      │     │   │
│  │  │      3. 使用 online softmax 累加结果            │     │   │
│  │  │    输出 block 写回 HBM                          │     │   │
│  │  └─────────────────────────────────────────────────┘     │   │
│  │                                                           │   │
│  │  Online Softmax 技巧:                                   │   │
│  │  • 维护 running max 和 running sum                       │   │
│  │  • 无需存储完整的 attention matrix                       │   │
│  │  • 内存复杂度从 O(N²) 降到 O(N)                          │   │
│  │                                                           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  性能对比(A100, seq_len=2048):                               │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ 标准 Attention:  ~8ms,  内存 ~128MB                       │   │
│  │ Flash Attention: ~1.5ms, 内存 ~4MB                        │   │
│  │ 加速比: 5.3x,   内存节省: 32x                             │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

算子融合

// fused_ops.go
package inference

// FusedRMSNorm 融合的 RMS Normalization
// 将多个小 kernel 合并成一个大 kernel,减少内存访问和 kernel launch 开销
type FusedRMSNorm struct {
    weight   []float16
    eps      float32
    hiddenSize int
}

// Forward 前向传播
// 融合操作:
// 1. 计算均方根
// 2. 归一化
// 3. 缩放
func (n *FusedRMSNorm) Forward(input []float16) []float16 {
    // CUDA kernel 伪代码:
    // __global__ void fused_rmsnorm_kernel(
    //     float16* input,
    //     float16* weight,
    //     float16* output,
    //     float eps,
    //     int hidden_size
    // ) {
    //     // 每个 warp 处理一行
    //     int row = blockIdx.x;
    //     float sum_sq = 0.0f;
    //
    //     // 计算平方和(使用 warp reduce)
    //     for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
    //         float val = __half2float(input[row * hidden_size + i]);
    //         sum_sq += val * val;
    //     }
    //     sum_sq = warp_reduce_sum(sum_sq);
    //
    //     // 计算 RMS
    //     float rms = rsqrtf(sum_sq / hidden_size + eps);
    //
    //     // 归一化并缩放
    //     for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
    //         float val = __half2float(input[row * hidden_size + i]);
    //         float w = __half2float(weight[i]);
    //         output[row * hidden_size + i] = __float2half(val * rms * w);
    //     }
    // }

    return nil
}

// FusedSiLUMul 融合的 SiLU 激活和乘法
// 用于 LLaMA MLP 中的 gate 和 up projection
// gate_out = SiLU(x @ gate_weight) * (x @ up_weight)
type FusedSiLUMul struct{}

func (f *FusedSiLUMul) Forward(gate, up []float16) []float16 {
    // CUDA kernel 伪代码:
    // __global__ void fused_silu_mul_kernel(
    //     float16* gate,
    //     float16* up,
    //     float16* output,
    //     int size
    // ) {
    //     int idx = blockIdx.x * blockDim.x + threadIdx.x;
    //     if (idx < size) {
    //         float g = __half2float(gate[idx]);
    //         float u = __half2float(up[idx]);
    //         // SiLU(x) = x * sigmoid(x)
    //         float silu = g / (1.0f + expf(-g));
    //         output[idx] = __float2half(silu * u);
    //     }
    // }

    return nil
}

// FusedRotaryEmbedding 融合的 RoPE 位置编码
type FusedRotaryEmbedding struct {
    dim        int
    maxSeqLen  int
    theta      float64
    freqsCis   []complex64 // 预计算的频率
}

func NewFusedRotaryEmbedding(dim, maxSeqLen int, theta float64) *FusedRotaryEmbedding {
    rope := &FusedRotaryEmbedding{
        dim:       dim,
        maxSeqLen: maxSeqLen,
        theta:     theta,
    }
    rope.precomputeFreqs()
    return rope
}

func (r *FusedRotaryEmbedding) precomputeFreqs() {
    // 预计算频率
    // freq_i = 1 / (theta^(2i/dim))
    // freqs_cis[pos, i] = exp(i * pos * freq_i) = cos(pos*freq_i) + i*sin(pos*freq_i)
}

func (r *FusedRotaryEmbedding) Apply(q, k []float16, positions []int) ([]float16, []float16) {
    // 应用 RoPE
    // 融合到 attention kernel 中
    return nil, nil
}

// FusedAttention 融合的 Attention(包含 Flash Attention)
type FusedAttention struct {
    numHeads   int
    numKVHeads int
    headDim    int
    scale      float32
}

func (a *FusedAttention) Forward(
    q, k, v []float16,
    kvCache *KVCache,
    blockTable []int,
    positions []int,
    isPrefill bool,
) []float16 {
    if isPrefill {
        // Prefill:使用 Flash Attention
        return a.flashAttention(q, k, v)
    } else {
        // Decode:使用 Paged Attention
        return a.pagedAttention(q, kvCache, blockTable, positions)
    }
}

func (a *FusedAttention) flashAttention(q, k, v []float16) []float16 {
    // Flash Attention 实现
    // 分块计算,使用 online softmax
    return nil
}

func (a *FusedAttention) pagedAttention(q []float16, kvCache *KVCache, blockTable []int, positions []int) []float16 {
    // Paged Attention 实现
    // 从不连续的 block 中读取 KV Cache
    return nil
}

量化推理

// quantization.go
package inference

import (
    "math"
)

// Quantizer 量化器
type Quantizer struct {
    qtype QuantType
}

type QuantType string

const (
    QuantTypeINT8  QuantType = "int8"
    QuantTypeINT4  QuantType = "int4"
    QuantTypeNF4   QuantType = "nf4"   // Normal Float 4-bit
    QuantTypeGPTQ  QuantType = "gptq"
    QuantTypeAWQ   QuantType = "awq"
)

// INT8QuantizedTensor INT8 量化张量
type INT8QuantizedTensor struct {
    Data    []int8
    Scale   []float32 // per-channel 或 per-tensor
    ZeroPoint []int8
}

// Quantize INT8 量化
func (q *Quantizer) QuantizeINT8(tensor []float32, perChannel bool) *INT8QuantizedTensor {
    if perChannel {
        return q.quantizeINT8PerChannel(tensor)
    }
    return q.quantizeINT8PerTensor(tensor)
}

func (q *Quantizer) quantizeINT8PerTensor(tensor []float32) *INT8QuantizedTensor {
    // 找到 min/max
    minVal, maxVal := tensor[0], tensor[0]
    for _, v := range tensor[1:] {
        if v < minVal {
            minVal = v
        }
        if v > maxVal {
            maxVal = v
        }
    }

    // 计算 scale 和 zero point
    // 对称量化:scale = max(|min|, |max|) / 127
    absMax := math.Max(math.Abs(float64(minVal)), math.Abs(float64(maxVal)))
    scale := float32(absMax / 127.0)

    // 量化
    data := make([]int8, len(tensor))
    for i, v := range tensor {
        quantized := int(math.Round(float64(v / scale)))
        if quantized > 127 {
            quantized = 127
        } else if quantized < -128 {
            quantized = -128
        }
        data[i] = int8(quantized)
    }

    return &INT8QuantizedTensor{
        Data:  data,
        Scale: []float32{scale},
    }
}

func (q *Quantizer) quantizeINT8PerChannel(tensor []float32) *INT8QuantizedTensor {
    // 按通道量化,每个输出通道有独立的 scale
    return nil
}

// INT4QuantizedTensor INT4 量化张量
type INT4QuantizedTensor struct {
    Data      []uint8   // 2 个 INT4 打包成 1 个 uint8
    Scale     []float16
    ZeroPoint []int8
    GroupSize int       // 量化组大小
}

// QuantizeINT4 INT4 量化(Group-wise)
func (q *Quantizer) QuantizeINT4(tensor []float32, groupSize int) *INT4QuantizedTensor {
    numGroups := (len(tensor) + groupSize - 1) / groupSize

    scales := make([]float16, numGroups)
    data := make([]uint8, (len(tensor)+1)/2)

    for g := 0; g < numGroups; g++ {
        start := g * groupSize
        end := start + groupSize
        if end > len(tensor) {
            end = len(tensor)
        }

        // 找到这个 group 的 max abs
        maxAbs := float32(0)
        for i := start; i < end; i++ {
            abs := float32(math.Abs(float64(tensor[i])))
            if abs > maxAbs {
                maxAbs = abs
            }
        }

        // 计算 scale
        scale := maxAbs / 7.0 // INT4: -8 ~ 7
        scales[g] = float32ToFloat16(scale)

        // 量化这个 group
        for i := start; i < end; i++ {
            var quantized int
            if scale > 0 {
                quantized = int(math.Round(float64(tensor[i] / scale)))
            }
            if quantized > 7 {
                quantized = 7
            } else if quantized < -8 {
                quantized = -8
            }

            // 打包:2 个 INT4 到 1 个 uint8
            byteIdx := i / 2
            if i%2 == 0 {
                data[byteIdx] = uint8(quantized & 0x0F)
            } else {
                data[byteIdx] |= uint8((quantized & 0x0F) << 4)
            }
        }
    }

    return &INT4QuantizedTensor{
        Data:      data,
        Scale:     scales,
        GroupSize: groupSize,
    }
}

// GPTQ 量化
type GPTQQuantizer struct {
    bits      int
    groupSize int
    actOrder  bool // activation reordering
}

func NewGPTQQuantizer(bits, groupSize int, actOrder bool) *GPTQQuantizer {
    return &GPTQQuantizer{
        bits:      bits,
        groupSize: groupSize,
        actOrder:  actOrder,
    }
}

// QuantizeLayer 量化单层权重
func (g *GPTQQuantizer) QuantizeLayer(
    weight [][]float32,    // [out_features, in_features]
    hessian [][]float32,   // 使用校准数据计算的 Hessian 对角线
) (*GPTQQuantizedWeight, error) {
    // GPTQ 量化算法:
    // 1. 计算 Hessian 矩阵 H = 2 * X^T * X
    // 2. 按 Hessian 对角线排序列(如果 actOrder=true)
    // 3. 逐列量化,使用 OBS (Optimal Brain Surgeon) 更新剩余权重

    // 简化实现
    return nil, nil
}

type GPTQQuantizedWeight struct {
    QWeight   []int32     // 打包的量化权重
    QZeros    []int32     // 打包的零点
    Scales    []float16   // 缩放因子
    GPerm     []int32     // 列重排序(actOrder=true 时使用)
    Bits      int
    GroupSize int
}

// AWQ 量化
type AWQQuantizer struct {
    bits      int
    groupSize int
}

func NewAWQQuantizer(bits, groupSize int) *AWQQuantizer {
    return &AWQQuantizer{
        bits:      bits,
        groupSize: groupSize,
    }
}

// QuantizeWithActivation AWQ: Activation-aware Weight Quantization
func (a *AWQQuantizer) QuantizeWithActivation(
    weight [][]float32,
    activations [][]float32, // 校准数据的激活值
) (*AWQQuantizedWeight, error) {
    // AWQ 算法:
    // 1. 分析激活值分布,找到重要的通道
    // 2. 对重要通道使用更小的量化误差
    // 3. 使用 per-channel scaling 补偿量化误差

    return nil, nil
}

type AWQQuantizedWeight struct {
    QWeight   []int32
    Scales    []float16
    Zeros     []float16
    Bits      int
    GroupSize int
}

// QuantizedLinear 量化线性层
type QuantizedLinear struct {
    weight    interface{} // INT8/INT4/GPTQ/AWQ
    qtype     QuantType
    inFeatures  int
    outFeatures int
}

// Forward 量化推理前向
func (l *QuantizedLinear) Forward(input []float16) []float16 {
    switch l.qtype {
    case QuantTypeINT8:
        return l.forwardINT8(input)
    case QuantTypeINT4:
        return l.forwardINT4(input)
    case QuantTypeGPTQ:
        return l.forwardGPTQ(input)
    case QuantTypeAWQ:
        return l.forwardAWQ(input)
    }
    return nil
}

func (l *QuantizedLinear) forwardINT8(input []float16) []float16 {
    // INT8 GEMM
    // 使用 cuBLAS INT8 或 TensorRT
    return nil
}

func (l *QuantizedLinear) forwardINT4(input []float16) []float16 {
    // INT4 需要自定义 kernel
    // 解包 INT4 → FP16 → GEMM
    return nil
}

func (l *QuantizedLinear) forwardGPTQ(input []float16) []float16 {
    // GPTQ kernel(如 exllama)
    return nil
}

func (l *QuantizedLinear) forwardAWQ(input []float16) []float16 {
    // AWQ kernel
    return nil
}

func float32ToFloat16(f float32) float16 {
    // 简化实现
    return float16(f)
}

采样策略

采样器实现

// sampler.go
package inference

import (
    "math"
    "math/rand"
    "sort"
)

// Sampler 采样器
type Sampler struct {
    rng *rand.Rand
}

func NewSampler() *Sampler {
    return &Sampler{
        rng: rand.New(rand.NewSource(42)),
    }
}

// Sample 从 logits 采样下一个 token
func (s *Sampler) Sample(logits []float32, params *SamplingParams) int {
    // 1. 应用 repetition penalty
    if params.RepetitionPenalty != 1.0 {
        logits = s.applyRepetitionPenalty(logits, params)
    }

    // 2. 应用 temperature
    if params.Temperature != 1.0 {
        logits = s.applyTemperature(logits, params.Temperature)
    }

    // 3. 转换为概率
    probs := s.softmax(logits)

    // 4. 应用 top-k
    if params.TopK > 0 {
        probs = s.applyTopK(probs, params.TopK)
    }

    // 5. 应用 top-p (nucleus sampling)
    if params.TopP < 1.0 {
        probs = s.applyTopP(probs, params.TopP)
    }

    // 6. 采样
    return s.multinomialSample(probs)
}

// applyTemperature 应用温度
func (s *Sampler) applyTemperature(logits []float32, temperature float32) []float32 {
    result := make([]float32, len(logits))
    for i, l := range logits {
        result[i] = l / temperature
    }
    return result
}

// applyRepetitionPenalty 应用重复惩罚
func (s *Sampler) applyRepetitionPenalty(logits []float32, params *SamplingParams) []float32 {
    // 对已生成的 token 应用惩罚
    // 这里简化处理,实际需要传入已生成的 token
    return logits
}

// softmax 计算 softmax
func (s *Sampler) softmax(logits []float32) []float32 {
    // 数值稳定的 softmax
    maxLogit := logits[0]
    for _, l := range logits[1:] {
        if l > maxLogit {
            maxLogit = l
        }
    }

    probs := make([]float32, len(logits))
    var sum float32
    for i, l := range logits {
        probs[i] = float32(math.Exp(float64(l - maxLogit)))
        sum += probs[i]
    }

    for i := range probs {
        probs[i] /= sum
    }

    return probs
}

// applyTopK Top-K 采样
func (s *Sampler) applyTopK(probs []float32, k int) []float32 {
    if k >= len(probs) {
        return probs
    }

    // 找到 top-k 的阈值
    type indexedProb struct {
        idx  int
        prob float32
    }

    indexed := make([]indexedProb, len(probs))
    for i, p := range probs {
        indexed[i] = indexedProb{i, p}
    }

    sort.Slice(indexed, func(i, j int) bool {
        return indexed[i].prob > indexed[j].prob
    })

    threshold := indexed[k-1].prob

    // 将低于阈值的概率置零
    result := make([]float32, len(probs))
    var sum float32
    for i, p := range probs {
        if p >= threshold {
            result[i] = p
            sum += p
        }
    }

    // 重新归一化
    for i := range result {
        result[i] /= sum
    }

    return result
}

// applyTopP Top-P (Nucleus) 采样
func (s *Sampler) applyTopP(probs []float32, p float32) []float32 {
    // 按概率排序
    type indexedProb struct {
        idx  int
        prob float32
    }

    indexed := make([]indexedProb, len(probs))
    for i, prob := range probs {
        indexed[i] = indexedProb{i, prob}
    }

    sort.Slice(indexed, func(i, j int) bool {
        return indexed[i].prob > indexed[j].prob
    })

    // 找到累积概率 >= p 的截断点
    var cumSum float32
    cutoff := 0
    for i, ip := range indexed {
        cumSum += ip.prob
        if cumSum >= p {
            cutoff = i + 1
            break
        }
    }

    // 将截断点之外的概率置零
    result := make([]float32, len(probs))
    var sum float32
    for i := 0; i < cutoff; i++ {
        result[indexed[i].idx] = indexed[i].prob
        sum += indexed[i].prob
    }

    // 重新归一化
    for i := range result {
        result[i] /= sum
    }

    return result
}

// multinomialSample 多项式采样
func (s *Sampler) multinomialSample(probs []float32) int {
    r := s.rng.Float32()
    var cumSum float32
    for i, p := range probs {
        cumSum += p
        if r < cumSum {
            return i
        }
    }
    return len(probs) - 1
}

// GreedySample 贪婪采样(选择最高概率)
func (s *Sampler) GreedySample(logits []float32) int {
    maxIdx := 0
    maxVal := logits[0]
    for i, l := range logits[1:] {
        if l > maxVal {
            maxVal = l
            maxIdx = i + 1
        }
    }
    return maxIdx
}

// BeamSearch Beam Search
type BeamSearch struct {
    beamWidth int
    sampler   *Sampler
}

func NewBeamSearch(beamWidth int) *BeamSearch {
    return &BeamSearch{
        beamWidth: beamWidth,
        sampler:   NewSampler(),
    }
}

// Beam 单个 beam
type Beam struct {
    TokenIDs []int
    Score    float32
    Finished bool
}

// Step 执行一步 beam search
func (b *BeamSearch) Step(beams []*Beam, logits [][]float32) []*Beam {
    // 收集所有候选
    type candidate struct {
        beamIdx int
        tokenID int
        score   float32
    }

    var candidates []candidate

    for beamIdx, beam := range beams {
        if beam.Finished {
            candidates = append(candidates, candidate{
                beamIdx: beamIdx,
                tokenID: -1,
                score:   beam.Score,
            })
            continue
        }

        // 对每个 beam 的 logits 进行 log_softmax
        logProbs := b.sampler.softmax(logits[beamIdx])
        for tokenID, prob := range logProbs {
            if prob > 0 {
                score := beam.Score + float32(math.Log(float64(prob)))
                candidates = append(candidates, candidate{
                    beamIdx: beamIdx,
                    tokenID: tokenID,
                    score:   score,
                })
            }
        }
    }

    // 选择 top-k 个候选
    sort.Slice(candidates, func(i, j int) bool {
        return candidates[i].score > candidates[j].score
    })

    if len(candidates) > b.beamWidth {
        candidates = candidates[:b.beamWidth]
    }

    // 构建新的 beams
    newBeams := make([]*Beam, len(candidates))
    for i, c := range candidates {
        if c.tokenID == -1 {
            // 已完成的 beam
            newBeams[i] = beams[c.beamIdx]
        } else {
            newBeams[i] = &Beam{
                TokenIDs: append(append([]int{}, beams[c.beamIdx].TokenIDs...), c.tokenID),
                Score:    c.score,
                Finished: false, // 检查 EOS
            }
        }
    }

    return newBeams
}

性能指标

关键指标定义

┌─────────────────────────────────────────────────────────────────┐
│                    推理性能关键指标                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  1. 延迟指标 (Latency)                                          │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • TTFT (Time To First Token):首 token 延迟               │   │
│  │   - 主要受 Prefill 阶段影响                               │   │
│  │   - 影响用户感知的响应速度                                 │   │
│  │                                                           │   │
│  │ • TPOT (Time Per Output Token):每 token 延迟             │   │
│  │   - Decode 阶段的单 token 生成时间                        │   │
│  │   - 影响流式输出的流畅度                                  │   │
│  │                                                           │   │
│  │ • E2E Latency:端到端延迟                                 │   │
│  │   = TTFT + TPOT × (output_len - 1)                        │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  2. 吞吐量指标 (Throughput)                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • Tokens/second:每秒生成 token 数                        │   │
│  │   = total_output_tokens / total_time                      │   │
│  │                                                           │   │
│  │ • Requests/second:每秒处理请求数                         │   │
│  │   = num_requests / total_time                             │   │
│  │                                                           │   │
│  │ • GPU Utilization:GPU 利用率                             │   │
│  │   - Compute utilization                                   │   │
│  │   - Memory bandwidth utilization                          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  3. 效率指标 (Efficiency)                                       │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • Memory Efficiency:内存效率                             │   │
│  │   = actual_KV_cache_used / total_KV_cache_allocated      │   │
│  │                                                           │   │
│  │ • Batch Efficiency:批处理效率                            │   │
│  │   = actual_batch_size / max_batch_size                   │   │
│  │                                                           │   │
│  │ • Model FLOPS Utilization (MFU)                          │   │
│  │   = achieved_FLOPS / peak_FLOPS                          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  典型值参考 (LLaMA-7B, A100 80GB):                              │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ TTFT (512 prompt):      ~50ms                             │   │
│  │ TPOT:                   ~10-20ms                          │   │
│  │ Throughput (batch=32):  ~2000 tokens/s                   │   │
│  │ Max batch size:         ~256 (取决于 seq_len)             │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

性能监控

// metrics.go
package inference

import (
    "sync"
    "time"

    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

var (
    // 延迟指标
    ttftHistogram = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "llm_ttft_seconds",
            Help:    "Time to first token in seconds",
            Buckets: []float64{0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0},
        },
        []string{"model"},
    )

    tpotHistogram = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "llm_tpot_seconds",
            Help:    "Time per output token in seconds",
            Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25},
        },
        []string{"model"},
    )

    e2eLatencyHistogram = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "llm_e2e_latency_seconds",
            Help:    "End-to-end latency in seconds",
            Buckets: []float64{0.1, 0.5, 1, 2.5, 5, 10, 30, 60},
        },
        []string{"model"},
    )

    // 吞吐量指标
    tokensGeneratedTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "llm_tokens_generated_total",
            Help: "Total number of tokens generated",
        },
        []string{"model", "type"}, // type: prompt, completion
    )

    requestsTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "llm_requests_total",
            Help: "Total number of requests",
        },
        []string{"model", "status"}, // status: success, error, timeout
    )

    // 批处理指标
    batchSizeHistogram = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "llm_batch_size",
            Help:    "Batch size distribution",
            Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256},
        },
        []string{"model", "phase"}, // phase: prefill, decode
    )

    // KV Cache 指标
    kvCacheUtilization = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "llm_kv_cache_utilization",
            Help: "KV cache utilization ratio",
        },
        []string{"model"},
    )

    kvCacheBlocks = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "llm_kv_cache_blocks",
            Help: "Number of KV cache blocks",
        },
        []string{"model", "state"}, // state: used, free
    )

    // GPU 指标
    gpuMemoryUsed = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "llm_gpu_memory_bytes",
            Help: "GPU memory usage in bytes",
        },
        []string{"gpu_id", "type"}, // type: model, kv_cache, activation
    )

    gpuUtilization = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "llm_gpu_utilization",
            Help: "GPU compute utilization",
        },
        []string{"gpu_id"},
    )
)

// MetricsCollector 指标收集器
type MetricsCollector struct {
    model     string
    startTime time.Time

    mu              sync.Mutex
    requestCount    int64
    tokenCount      int64
    lastReportTime  time.Time
}

func NewMetricsCollector(model string) *MetricsCollector {
    return &MetricsCollector{
        model:          model,
        startTime:      time.Now(),
        lastReportTime: time.Now(),
    }
}

// RecordRequest 记录请求
func (m *MetricsCollector) RecordRequest(
    promptTokens int,
    completionTokens int,
    ttft time.Duration,
    e2eLatency time.Duration,
    status string,
) {
    // 记录延迟
    ttftHistogram.WithLabelValues(m.model).Observe(ttft.Seconds())
    e2eLatencyHistogram.WithLabelValues(m.model).Observe(e2eLatency.Seconds())

    // 计算 TPOT
    if completionTokens > 1 {
        tpot := (e2eLatency - ttft) / time.Duration(completionTokens-1)
        tpotHistogram.WithLabelValues(m.model).Observe(tpot.Seconds())
    }

    // 记录 token 数
    tokensGeneratedTotal.WithLabelValues(m.model, "prompt").Add(float64(promptTokens))
    tokensGeneratedTotal.WithLabelValues(m.model, "completion").Add(float64(completionTokens))

    // 记录请求
    requestsTotal.WithLabelValues(m.model, status).Inc()
}

// RecordBatch 记录批处理
func (m *MetricsCollector) RecordBatch(size int, phase string) {
    batchSizeHistogram.WithLabelValues(m.model, phase).Observe(float64(size))
}

// UpdateKVCacheStats 更新 KV Cache 统计
func (m *MetricsCollector) UpdateKVCacheStats(stats KVCacheStats) {
    kvCacheUtilization.WithLabelValues(m.model).Set(stats.Utilization)
    kvCacheBlocks.WithLabelValues(m.model, "used").Set(float64(stats.UsedBlocks))
    kvCacheBlocks.WithLabelValues(m.model, "free").Set(float64(stats.FreeBlocks))
}

// UpdateGPUStats 更新 GPU 统计
func (m *MetricsCollector) UpdateGPUStats(gpuID string, memUsed int64, utilization float64) {
    gpuMemoryUsed.WithLabelValues(gpuID, "total").Set(float64(memUsed))
    gpuUtilization.WithLabelValues(gpuID).Set(utilization)
}

// GetThroughput 获取吞吐量
func (m *MetricsCollector) GetThroughput() (tokensPerSec, requestsPerSec float64) {
    m.mu.Lock()
    defer m.mu.Unlock()

    elapsed := time.Since(m.startTime).Seconds()
    if elapsed > 0 {
        tokensPerSec = float64(m.tokenCount) / elapsed
        requestsPerSec = float64(m.requestCount) / elapsed
    }
    return
}

小结

本章详细介绍了推理引擎的核心原理:

  1. 推理流程:Prefill 和 Decode 两个阶段的特性和优化方向
  2. 引擎架构:模型加载、调度器、执行器的设计
  3. KV Cache 管理:PagedAttention 原理和实现
  4. 计算优化:Flash Attention、算子融合、量化推理
  5. 采样策略:Temperature、Top-K、Top-P、Beam Search
  6. 性能指标:TTFT、TPOT、吞吐量、效率指标

下一章我们将探讨 模型服务框架,讲解如何使用 vLLM、TGI 等框架构建高性能推理服务。

Next
模型服务框架