HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

推理优化技术

概述

大模型推理优化是一个系统工程,涉及模型层面、算子层面、系统层面的多维度优化。本章深入讲解各类推理优化技术的原理与实现,包括模型压缩、计算图优化、内存优化、系统级优化等关键技术。

模型量化

量化基础原理

量化是将高精度浮点数映射到低精度整数的过程:

package quantization

import (
    "math"
)

// QuantizationType 量化类型
type QuantizationType int

const (
    QuantInt8 QuantizationType = iota
    QuantInt4
    QuantFP8
    QuantFP4
)

// QuantizationConfig 量化配置
type QuantizationConfig struct {
    Type          QuantizationType
    Symmetric     bool              // 对称量化
    PerChannel    bool              // 逐通道量化
    GroupSize     int               // 分组量化大小
    CalibrationMethod string        // 校准方法: minmax, percentile, entropy
}

// Quantizer 量化器
type Quantizer struct {
    config *QuantizationConfig
}

// QuantizedTensor 量化后的张量
type QuantizedTensor struct {
    Data      []int8           // 量化数据 (int8 示例)
    Scale     []float32        // 缩放因子
    ZeroPoint []int8           // 零点
    Shape     []int
    GroupSize int
}

// QuantizeSymmetric 对称量化
// x_q = round(x / scale)
// x = x_q * scale
func (q *Quantizer) QuantizeSymmetric(tensor []float32, bits int) *QuantizedTensor {
    maxVal := q.findAbsMax(tensor)
    qMax := float32(math.Pow(2, float64(bits-1)) - 1)

    scale := maxVal / qMax

    quantized := make([]int8, len(tensor))
    for i, v := range tensor {
        qVal := v / scale
        // 截断到范围 [-qMax, qMax]
        qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
        quantized[i] = int8(math.Round(float64(qVal)))
    }

    return &QuantizedTensor{
        Data:      quantized,
        Scale:     []float32{scale},
        ZeroPoint: []int8{0},
        Shape:     []int{len(tensor)},
    }
}

// QuantizeAsymmetric 非对称量化
// x_q = round(x / scale) + zero_point
// x = (x_q - zero_point) * scale
func (q *Quantizer) QuantizeAsymmetric(tensor []float32, bits int) *QuantizedTensor {
    minVal, maxVal := q.findMinMax(tensor)
    qMin := float32(0)
    qMax := float32(math.Pow(2, float64(bits)) - 1)

    scale := (maxVal - minVal) / (qMax - qMin)
    zeroPoint := int8(math.Round(float64(-minVal / scale)))

    quantized := make([]int8, len(tensor))
    for i, v := range tensor {
        qVal := v/scale + float32(zeroPoint)
        qVal = float32(math.Max(float64(qMin), math.Min(float64(qMax), float64(qVal))))
        quantized[i] = int8(math.Round(float64(qVal)))
    }

    return &QuantizedTensor{
        Data:      quantized,
        Scale:     []float32{scale},
        ZeroPoint: []int8{zeroPoint},
    }
}

// QuantizePerGroup 分组量化
// 每 group_size 个元素使用独立的 scale 和 zero_point
func (q *Quantizer) QuantizePerGroup(
    tensor []float32,
    groupSize int,
    bits int,
) *QuantizedTensor {

    numGroups := (len(tensor) + groupSize - 1) / groupSize
    scales := make([]float32, numGroups)
    zeroPoints := make([]int8, numGroups)
    quantized := make([]int8, len(tensor))

    qMax := float32(math.Pow(2, float64(bits-1)) - 1)

    for g := 0; g < numGroups; g++ {
        start := g * groupSize
        end := min(start+groupSize, len(tensor))
        group := tensor[start:end]

        // 计算该组的量化参数
        maxVal := q.findAbsMax(group)
        scale := maxVal / qMax
        scales[g] = scale

        // 量化该组
        for i, v := range group {
            qVal := v / scale
            qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
            quantized[start+i] = int8(math.Round(float64(qVal)))
        }
    }

    return &QuantizedTensor{
        Data:      quantized,
        Scale:     scales,
        ZeroPoint: zeroPoints,
        GroupSize: groupSize,
    }
}

// Dequantize 反量化
func (qt *QuantizedTensor) Dequantize() []float32 {
    result := make([]float32, len(qt.Data))

    if qt.GroupSize > 0 {
        // 分组量化反量化
        for i, qVal := range qt.Data {
            groupIdx := i / qt.GroupSize
            scale := qt.Scale[groupIdx]
            zp := qt.ZeroPoint[groupIdx]
            result[i] = float32(qVal-zp) * scale
        }
    } else {
        // 全局量化反量化
        scale := qt.Scale[0]
        zp := qt.ZeroPoint[0]
        for i, qVal := range qt.Data {
            result[i] = float32(qVal-zp) * scale
        }
    }

    return result
}

func (q *Quantizer) findAbsMax(tensor []float32) float32 {
    maxVal := float32(0)
    for _, v := range tensor {
        absV := float32(math.Abs(float64(v)))
        if absV > maxVal {
            maxVal = absV
        }
    }
    return maxVal
}

func (q *Quantizer) findMinMax(tensor []float32) (float32, float32) {
    minVal := tensor[0]
    maxVal := tensor[0]
    for _, v := range tensor {
        if v < minVal {
            minVal = v
        }
        if v > maxVal {
            maxVal = v
        }
    }
    return minVal, maxVal
}

GPTQ 量化

GPTQ (Generative Pre-trained Transformer Quantization) 是一种高效的后训练量化方法:

package quantization

import (
    "math"
    "sync"
)

// GPTQQuantizer GPTQ 量化器
type GPTQQuantizer struct {
    bits       int
    groupSize  int
    symmetric  bool
    actOrder   bool     // 激活排序优化
    dampingPct float64  // 阻尼系数
}

func NewGPTQQuantizer(config *GPTQConfig) *GPTQQuantizer {
    return &GPTQQuantizer{
        bits:       config.Bits,
        groupSize:  config.GroupSize,
        symmetric:  config.Symmetric,
        actOrder:   config.ActOrder,
        dampingPct: config.DampingPct,
    }
}

// QuantizeLayer 量化单层权重
// 基于 OBS (Optimal Brain Surgeon) 框架
func (gq *GPTQQuantizer) QuantizeLayer(
    weight [][]float32,  // [out_features, in_features]
    hessian [][]float32, // Hessian 矩阵 [in_features, in_features]
) (*GPTQLayerResult, error) {

    outFeatures := len(weight)
    inFeatures := len(weight[0])

    // 复制权重(避免修改原始数据)
    W := make([][]float32, outFeatures)
    for i := range W {
        W[i] = make([]float32, inFeatures)
        copy(W[i], weight[i])
    }

    // 添加阻尼项到 Hessian 对角线
    H := gq.addDamping(hessian)

    // Cholesky 分解
    L, err := gq.choleskyDecomposition(H)
    if err != nil {
        return nil, err
    }

    // 计算 Hessian 逆
    Hinv := gq.invertLowerTriangular(L)

    // 量化参数存储
    numGroups := (inFeatures + gq.groupSize - 1) / gq.groupSize
    scales := make([][]float32, outFeatures)
    zeros := make([][]int32, outFeatures)
    quantizedWeight := make([][]int8, outFeatures)

    for i := range outFeatures {
        scales[i] = make([]float32, numGroups)
        zeros[i] = make([]int32, numGroups)
        quantizedWeight[i] = make([]int8, inFeatures)
    }

    // 获取量化顺序
    order := gq.getQuantOrder(H)

    // 逐列量化
    blockSize := 128 // 分块处理减少累积误差
    for blockStart := 0; blockStart < inFeatures; blockStart += blockSize {
        blockEnd := min(blockStart+blockSize, inFeatures)

        for j := blockStart; j < blockEnd; j++ {
            col := order[j]
            groupIdx := col / gq.groupSize

            for i := 0; i < outFeatures; i++ {
                // 量化当前元素
                w := W[i][col]

                // 计算该组的量化参数(如果是组的第一个元素)
                if col%gq.groupSize == 0 {
                    groupStart := col
                    groupEnd := min(col+gq.groupSize, inFeatures)
                    scale, zero := gq.computeGroupParams(W[i][groupStart:groupEnd])
                    scales[i][groupIdx] = scale
                    zeros[i][groupIdx] = zero
                }

                scale := scales[i][groupIdx]
                zero := zeros[i][groupIdx]

                // 量化
                qVal := gq.quantize(w, scale, zero)
                quantizedWeight[i][col] = qVal

                // 反量化
                wHat := gq.dequantize(qVal, scale, zero)

                // 计算量化误差
                err := w - wHat

                // 误差补偿(OBS 核心)
                // 将量化误差分配到后续未量化的列
                if j < blockEnd-1 {
                    errWeight := err / float32(Hinv[col][col])
                    for k := j + 1; k < blockEnd; k++ {
                        nextCol := order[k]
                        W[i][nextCol] -= errWeight * float32(Hinv[nextCol][col])
                    }
                }
            }
        }
    }

    return &GPTQLayerResult{
        QuantizedWeight: quantizedWeight,
        Scales:          scales,
        Zeros:           zeros,
        GroupSize:       gq.groupSize,
    }, nil
}

// ComputeHessian 计算 Hessian 矩阵
// H = X^T * X(对于线性层)
func (gq *GPTQQuantizer) ComputeHessian(
    calibrationData [][]float32, // 校准数据 [num_samples, in_features]
) [][]float32 {
    numSamples := len(calibrationData)
    inFeatures := len(calibrationData[0])

    H := make([][]float32, inFeatures)
    for i := range H {
        H[i] = make([]float32, inFeatures)
    }

    // 并行计算 Hessian
    var wg sync.WaitGroup
    numWorkers := 8
    rowsPerWorker := (inFeatures + numWorkers - 1) / numWorkers

    for w := 0; w < numWorkers; w++ {
        wg.Add(1)
        go func(workerID int) {
            defer wg.Done()
            startRow := workerID * rowsPerWorker
            endRow := min(startRow+rowsPerWorker, inFeatures)

            for i := startRow; i < endRow; i++ {
                for j := i; j < inFeatures; j++ {
                    var sum float32
                    for s := 0; s < numSamples; s++ {
                        sum += calibrationData[s][i] * calibrationData[s][j]
                    }
                    H[i][j] = sum / float32(numSamples)
                    if i != j {
                        H[j][i] = H[i][j] // 对称矩阵
                    }
                }
            }
        }(w)
    }
    wg.Wait()

    return H
}

func (gq *GPTQQuantizer) addDamping(H [][]float32) [][]float32 {
    n := len(H)
    result := make([][]float32, n)
    for i := range result {
        result[i] = make([]float32, n)
        copy(result[i], H[i])
    }

    // 计算对角线平均值
    var diagSum float32
    for i := 0; i < n; i++ {
        diagSum += result[i][i]
    }
    diagMean := diagSum / float32(n)

    // 添加阻尼
    damping := float32(gq.dampingPct) * diagMean
    for i := 0; i < n; i++ {
        result[i][i] += damping
    }

    return result
}

func (gq *GPTQQuantizer) getQuantOrder(H [][]float32) []int {
    n := len(H)
    order := make([]int, n)

    if gq.actOrder {
        // 按 Hessian 对角线元素排序(激活量大的先量化)
        diagValues := make([]float32, n)
        for i := 0; i < n; i++ {
            diagValues[i] = H[i][i]
        }

        // 获取排序索引
        indices := make([]int, n)
        for i := range indices {
            indices[i] = i
        }

        // 降序排序
        for i := 0; i < n-1; i++ {
            for j := i + 1; j < n; j++ {
                if diagValues[indices[i]] < diagValues[indices[j]] {
                    indices[i], indices[j] = indices[j], indices[i]
                }
            }
        }
        copy(order, indices)
    } else {
        // 默认顺序
        for i := range order {
            order[i] = i
        }
    }

    return order
}

func (gq *GPTQQuantizer) computeGroupParams(group []float32) (float32, int32) {
    maxVal := float32(0)
    for _, v := range group {
        absV := float32(math.Abs(float64(v)))
        if absV > maxVal {
            maxVal = absV
        }
    }

    qMax := float32(math.Pow(2, float64(gq.bits-1)) - 1)
    scale := maxVal / qMax
    if scale == 0 {
        scale = 1
    }

    return scale, 0
}

func (gq *GPTQQuantizer) quantize(val, scale float32, zero int32) int8 {
    qMax := float32(math.Pow(2, float64(gq.bits-1)) - 1)
    qVal := val/scale + float32(zero)
    qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
    return int8(math.Round(float64(qVal)))
}

func (gq *GPTQQuantizer) dequantize(qVal int8, scale float32, zero int32) float32 {
    return float32(int32(qVal)-zero) * scale
}

type GPTQConfig struct {
    Bits       int
    GroupSize  int
    Symmetric  bool
    ActOrder   bool
    DampingPct float64
}

type GPTQLayerResult struct {
    QuantizedWeight [][]int8
    Scales          [][]float32
    Zeros           [][]int32
    GroupSize       int
}

AWQ 量化

AWQ (Activation-aware Weight Quantization) 基于激活感知的权重量化:

package quantization

import (
    "math"
)

// AWQQuantizer AWQ 量化器
type AWQQuantizer struct {
    bits       int
    groupSize  int
    searchScale bool
}

// AWQConfig AWQ 配置
type AWQConfig struct {
    Bits        int
    GroupSize   int
    SearchScale bool
    NumSamples  int
}

func NewAWQQuantizer(config *AWQConfig) *AWQQuantizer {
    return &AWQQuantizer{
        bits:        config.Bits,
        groupSize:   config.GroupSize,
        searchScale: config.SearchScale,
    }
}

// QuantizeLayer AWQ 量化层
func (aq *AWQQuantizer) QuantizeLayer(
    weight [][]float32,      // [out_features, in_features]
    activations [][]float32, // 校准激活 [num_samples, in_features]
) (*AWQLayerResult, error) {

    outFeatures := len(weight)
    inFeatures := len(weight[0])

    // 1. 计算激活通道重要性
    channelImportance := aq.computeChannelImportance(activations)

    // 2. 搜索最优缩放因子
    var scales []float32
    if aq.searchScale {
        scales = aq.searchOptimalScales(weight, channelImportance)
    } else {
        // 使用激活幅度作为缩放因子
        scales = channelImportance
    }

    // 3. 对权重应用缩放
    scaledWeight := aq.applyScales(weight, scales)

    // 4. 量化缩放后的权重
    numGroups := (inFeatures + aq.groupSize - 1) / aq.groupSize
    quantizedWeight := make([][]int8, outFeatures)
    qScales := make([][]float32, outFeatures)
    qZeros := make([][]int32, outFeatures)

    for i := 0; i < outFeatures; i++ {
        quantizedWeight[i] = make([]int8, inFeatures)
        qScales[i] = make([]float32, numGroups)
        qZeros[i] = make([]int32, numGroups)

        for g := 0; g < numGroups; g++ {
            start := g * aq.groupSize
            end := min(start+aq.groupSize, inFeatures)

            // 计算组内量化参数
            group := scaledWeight[i][start:end]
            scale, zero := aq.computeQuantParams(group)
            qScales[i][g] = scale
            qZeros[i][g] = zero

            // 量化
            for j := start; j < end; j++ {
                quantizedWeight[i][j] = aq.quantize(scaledWeight[i][j], scale, zero)
            }
        }
    }

    return &AWQLayerResult{
        QuantizedWeight:  quantizedWeight,
        Scales:           qScales,
        Zeros:            qZeros,
        ChannelScales:    scales,
        GroupSize:        aq.groupSize,
    }, nil
}

// computeChannelImportance 计算通道重要性
func (aq *AWQQuantizer) computeChannelImportance(activations [][]float32) []float32 {
    numSamples := len(activations)
    inFeatures := len(activations[0])

    importance := make([]float32, inFeatures)

    for c := 0; c < inFeatures; c++ {
        var sum float32
        for s := 0; s < numSamples; s++ {
            sum += float32(math.Abs(float64(activations[s][c])))
        }
        importance[c] = sum / float32(numSamples)
    }

    // 归一化
    maxImp := importance[0]
    for _, v := range importance {
        if v > maxImp {
            maxImp = v
        }
    }
    if maxImp > 0 {
        for i := range importance {
            importance[i] /= maxImp
        }
    }

    return importance
}

// searchOptimalScales 搜索最优缩放因子
func (aq *AWQQuantizer) searchOptimalScales(
    weight [][]float32,
    channelImportance []float32,
) []float32 {

    inFeatures := len(weight[0])
    scales := make([]float32, inFeatures)

    // 网格搜索缩放因子
    searchRange := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}

    for c := 0; c < inFeatures; c++ {
        bestScale := float32(1.0)
        bestError := float32(math.MaxFloat32)

        baseScale := channelImportance[c]
        if baseScale < 0.01 {
            baseScale = 0.01
        }

        for _, ratio := range searchRange {
            scale := float32(ratio) * baseScale

            // 计算该缩放因子下的量化误差
            error := aq.computeQuantError(weight, c, scale)
            if error < bestError {
                bestError = error
                bestScale = scale
            }
        }

        scales[c] = bestScale
    }

    return scales
}

func (aq *AWQQuantizer) computeQuantError(weight [][]float32, channel int, scale float32) float32 {
    var totalError float32

    for i := range weight {
        w := weight[i][channel]
        scaledW := w * scale

        // 模拟量化
        qMax := float32(math.Pow(2, float64(aq.bits-1)) - 1)
        qScale := float32(math.Abs(float64(scaledW))) / qMax
        if qScale == 0 {
            qScale = 1
        }

        qVal := scaledW / qScale
        qVal = float32(math.Max(-float64(qMax), math.Min(float64(qMax), float64(qVal))))
        qVal = float32(math.Round(float64(qVal)))

        // 反量化
        deqVal := qVal * qScale / scale

        // 累计误差
        totalError += float32(math.Abs(float64(w - deqVal)))
    }

    return totalError
}

func (aq *AWQQuantizer) applyScales(weight [][]float32, scales []float32) [][]float32 {
    outFeatures := len(weight)
    inFeatures := len(weight[0])

    scaled := make([][]float32, outFeatures)
    for i := range scaled {
        scaled[i] = make([]float32, inFeatures)
        for j := 0; j < inFeatures; j++ {
            scaled[i][j] = weight[i][j] * scales[j]
        }
    }

    return scaled
}

func (aq *AWQQuantizer) computeQuantParams(group []float32) (float32, int32) {
    maxVal := float32(0)
    for _, v := range group {
        absV := float32(math.Abs(float64(v)))
        if absV > maxVal {
            maxVal = absV
        }
    }

    qMax := float32(math.Pow(2, float64(aq.bits-1)) - 1)
    scale := maxVal / qMax
    if scale == 0 {
        scale = 1
    }

    return scale, 0
}

func (aq *AWQQuantizer) quantize(val, scale float32, zero int32) int8 {
    qMax := float32(math.Pow(2, float64(aq.bits-1)) - 1)
    qVal := val/scale + float32(zero)
    qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
    return int8(math.Round(float64(qVal)))
}

type AWQLayerResult struct {
    QuantizedWeight [][]int8
    Scales          [][]float32
    Zeros           [][]int32
    ChannelScales   []float32
    GroupSize       int
}

计算图优化

算子融合

算子融合是减少内存访问和 kernel 启动开销的关键技术:

package optimization

import (
    "fmt"
)

// GraphOptimizer 计算图优化器
type GraphOptimizer struct {
    graph       *ComputeGraph
    fusionRules []FusionRule
}

// ComputeGraph 计算图
type ComputeGraph struct {
    Nodes map[string]*GraphNode
    Edges []*GraphEdge
}

// GraphNode 图节点
type GraphNode struct {
    ID       string
    OpType   string
    Inputs   []string
    Outputs  []string
    Attrs    map[string]interface{}
}

// GraphEdge 图边
type GraphEdge struct {
    Source   string
    Target   string
    TensorID string
}

// FusionRule 融合规则
type FusionRule interface {
    Match(graph *ComputeGraph, nodeID string) []string
    Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode
    Name() string
}

// LinearGELUFusion Linear + GELU 融合
type LinearGELUFusion struct{}

func (f *LinearGELUFusion) Name() string {
    return "LinearGELUFusion"
}

func (f *LinearGELUFusion) Match(graph *ComputeGraph, nodeID string) []string {
    node := graph.Nodes[nodeID]
    if node.OpType != "Linear" {
        return nil
    }

    // 检查输出是否连接到 GELU
    for _, edge := range graph.Edges {
        if edge.Source == nodeID {
            targetNode := graph.Nodes[edge.Target]
            if targetNode.OpType == "GELU" {
                return []string{nodeID, edge.Target}
            }
        }
    }

    return nil
}

func (f *LinearGELUFusion) Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode {
    linearNode := graph.Nodes[nodeIDs[0]]
    geluNode := graph.Nodes[nodeIDs[1]]

    return &GraphNode{
        ID:      fmt.Sprintf("fused_%s_%s", linearNode.ID, geluNode.ID),
        OpType:  "LinearGELU",
        Inputs:  linearNode.Inputs,
        Outputs: geluNode.Outputs,
        Attrs: map[string]interface{}{
            "weight": linearNode.Attrs["weight"],
            "bias":   linearNode.Attrs["bias"],
        },
    }
}

// MultiHeadAttentionFusion 多头注意力融合
type MultiHeadAttentionFusion struct{}

func (f *MultiHeadAttentionFusion) Name() string {
    return "MultiHeadAttentionFusion"
}

func (f *MultiHeadAttentionFusion) Match(graph *ComputeGraph, nodeID string) []string {
    // 匹配 QKV projection -> Attention -> Output projection 模式
    node := graph.Nodes[nodeID]

    // 检查是否是 Q 投影
    if node.OpType != "Linear" || !isQProjection(node) {
        return nil
    }

    // 找到关联的 K, V 投影和后续操作
    kNode := findKProjection(graph, node)
    vNode := findVProjection(graph, node)

    if kNode == nil || vNode == nil {
        return nil
    }

    // 找到 attention 和 output 投影
    attnNode := findAttentionNode(graph, node, kNode, vNode)
    outNode := findOutputProjection(graph, attnNode)

    if attnNode == nil || outNode == nil {
        return nil
    }

    return []string{node.ID, kNode.ID, vNode.ID, attnNode.ID, outNode.ID}
}

func (f *MultiHeadAttentionFusion) Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode {
    qNode := graph.Nodes[nodeIDs[0]]

    return &GraphNode{
        ID:     fmt.Sprintf("fused_mha_%s", nodeIDs[0]),
        OpType: "FusedMultiHeadAttention",
        Inputs: qNode.Inputs,
        Outputs: []string{"mha_output"},
        Attrs: map[string]interface{}{
            "num_heads": extractNumHeads(graph.Nodes[nodeIDs[3]]),
            "qkv_fused": true,
        },
    }
}

// RMSNormFusion RMSNorm 融合
type RMSNormFusion struct{}

func (f *RMSNormFusion) Name() string {
    return "RMSNormFusion"
}

func (f *RMSNormFusion) Match(graph *ComputeGraph, nodeID string) []string {
    // 匹配 x -> pow(2) -> mean -> rsqrt -> mul -> mul(weight) 模式
    node := graph.Nodes[nodeID]

    if node.OpType != "Pow" {
        return nil
    }

    // 检查指数是否为 2
    if exp, ok := node.Attrs["exponent"].(float64); !ok || exp != 2 {
        return nil
    }

    // 查找后续节点链
    meanNode := findNextOp(graph, nodeID, "Mean")
    if meanNode == nil {
        return nil
    }

    rsqrtNode := findNextOp(graph, meanNode.ID, "Rsqrt")
    if rsqrtNode == nil {
        return nil
    }

    mul1Node := findNextOp(graph, rsqrtNode.ID, "Mul")
    if mul1Node == nil {
        return nil
    }

    mul2Node := findNextOp(graph, mul1Node.ID, "Mul")
    if mul2Node == nil {
        return nil
    }

    return []string{node.ID, meanNode.ID, rsqrtNode.ID, mul1Node.ID, mul2Node.ID}
}

func (f *RMSNormFusion) Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode {
    powNode := graph.Nodes[nodeIDs[0]]
    mul2Node := graph.Nodes[nodeIDs[4]]

    return &GraphNode{
        ID:      fmt.Sprintf("fused_rmsnorm_%s", nodeIDs[0]),
        OpType:  "FusedRMSNorm",
        Inputs:  powNode.Inputs,
        Outputs: mul2Node.Outputs,
        Attrs: map[string]interface{}{
            "eps":    1e-6,
            "weight": mul2Node.Attrs["other"],
        },
    }
}

// Optimize 执行图优化
func (opt *GraphOptimizer) Optimize() *ComputeGraph {
    optimizedGraph := opt.graph.Clone()

    changed := true
    for changed {
        changed = false
        for nodeID := range optimizedGraph.Nodes {
            for _, rule := range opt.fusionRules {
                matchedNodes := rule.Match(optimizedGraph, nodeID)
                if matchedNodes != nil {
                    fusedNode := rule.Fuse(optimizedGraph, matchedNodes)
                    opt.applyFusion(optimizedGraph, matchedNodes, fusedNode)
                    changed = true
                    break
                }
            }
            if changed {
                break
            }
        }
    }

    return optimizedGraph
}

func (opt *GraphOptimizer) applyFusion(
    graph *ComputeGraph,
    oldNodes []string,
    newNode *GraphNode,
) {
    // 添加新节点
    graph.Nodes[newNode.ID] = newNode

    // 更新边
    newEdges := make([]*GraphEdge, 0)
    for _, edge := range graph.Edges {
        // 更新指向旧节点的边
        sourceInOld := contains(oldNodes, edge.Source)
        targetInOld := contains(oldNodes, edge.Target)

        if sourceInOld && targetInOld {
            // 内部边,删除
            continue
        } else if sourceInOld {
            // 输出边,更新 source
            edge.Source = newNode.ID
        } else if targetInOld {
            // 输入边,更新 target
            edge.Target = newNode.ID
        }
        newEdges = append(newEdges, edge)
    }
    graph.Edges = newEdges

    // 删除旧节点
    for _, nodeID := range oldNodes {
        delete(graph.Nodes, nodeID)
    }
}

// 辅助函数
func contains(slice []string, item string) bool {
    for _, s := range slice {
        if s == item {
            return true
        }
    }
    return false
}

func findNextOp(graph *ComputeGraph, nodeID, opType string) *GraphNode {
    for _, edge := range graph.Edges {
        if edge.Source == nodeID {
            targetNode := graph.Nodes[edge.Target]
            if targetNode.OpType == opType {
                return targetNode
            }
        }
    }
    return nil
}

func isQProjection(node *GraphNode) bool {
    // 根据命名或属性判断
    return true // 简化实现
}

func findKProjection(graph *ComputeGraph, qNode *GraphNode) *GraphNode {
    return nil // 简化实现
}

func findVProjection(graph *ComputeGraph, qNode *GraphNode) *GraphNode {
    return nil // 简化实现
}

func findAttentionNode(graph *ComputeGraph, q, k, v *GraphNode) *GraphNode {
    return nil // 简化实现
}

func findOutputProjection(graph *ComputeGraph, attnNode *GraphNode) *GraphNode {
    return nil // 简化实现
}

func extractNumHeads(node *GraphNode) int {
    return 32 // 简化实现
}

func (g *ComputeGraph) Clone() *ComputeGraph {
    // 深拷贝
    return g // 简化实现
}

内存规划

高效的内存规划减少显存占用:

package optimization

import (
    "sort"
)

// MemoryPlanner 内存规划器
type MemoryPlanner struct {
    graph        *ComputeGraph
    tensorLifetimes map[string]*TensorLifetime
    memoryPlan   *MemoryPlan
}

// TensorLifetime 张量生命周期
type TensorLifetime struct {
    TensorID   string
    Size       int64
    FirstUse   int // 首次使用的操作索引
    LastUse    int // 最后使用的操作索引
    MemOffset  int64 // 分配的内存偏移
}

// MemoryPlan 内存计划
type MemoryPlan struct {
    TotalSize       int64
    Allocations     map[string]*Allocation
    ReusableBuffers []*BufferPool
}

// Allocation 内存分配
type Allocation struct {
    TensorID string
    Offset   int64
    Size     int64
    Reused   bool
}

// BufferPool 可重用缓冲池
type BufferPool struct {
    Size    int64
    Tensors []string
}

// Plan 规划内存
func (mp *MemoryPlanner) Plan(executionOrder []string) *MemoryPlan {
    // 1. 分析张量生命周期
    mp.analyzeTensorLifetimes(executionOrder)

    // 2. 按大小和生命周期排序
    tensors := mp.getSortedTensors()

    // 3. 使用最佳适配算法分配
    plan := &MemoryPlan{
        Allocations: make(map[string]*Allocation),
    }

    freeList := NewFreeList()

    for _, lt := range tensors {
        // 查找可重用的空间
        offset, reused := freeList.FindBestFit(lt.Size, lt.FirstUse)

        if !reused {
            // 分配新空间
            offset = plan.TotalSize
            plan.TotalSize += lt.Size
        }

        plan.Allocations[lt.TensorID] = &Allocation{
            TensorID: lt.TensorID,
            Offset:   offset,
            Size:     lt.Size,
            Reused:   reused,
        }

        lt.MemOffset = offset

        // 添加到空闲列表(在最后使用后释放)
        freeList.Add(&FreeBlock{
            Offset:    offset,
            Size:      lt.Size,
            FreeAfter: lt.LastUse,
        })
    }

    return plan
}

func (mp *MemoryPlanner) analyzeTensorLifetimes(executionOrder []string) {
    mp.tensorLifetimes = make(map[string]*TensorLifetime)

    for opIdx, nodeID := range executionOrder {
        node := mp.graph.Nodes[nodeID]

        // 处理输入张量
        for _, inputID := range node.Inputs {
            if lt, exists := mp.tensorLifetimes[inputID]; exists {
                lt.LastUse = opIdx
            }
        }

        // 处理输出张量
        for _, outputID := range node.Outputs {
            size := mp.estimateTensorSize(nodeID, outputID)
            mp.tensorLifetimes[outputID] = &TensorLifetime{
                TensorID: outputID,
                Size:     size,
                FirstUse: opIdx,
                LastUse:  opIdx,
            }
        }
    }
}

func (mp *MemoryPlanner) getSortedTensors() []*TensorLifetime {
    tensors := make([]*TensorLifetime, 0, len(mp.tensorLifetimes))
    for _, lt := range mp.tensorLifetimes {
        tensors = append(tensors, lt)
    }

    // 优先分配大张量和生命周期长的张量
    sort.Slice(tensors, func(i, j int) bool {
        // 首先按大小降序
        if tensors[i].Size != tensors[j].Size {
            return tensors[i].Size > tensors[j].Size
        }
        // 然后按生命周期降序
        lifeI := tensors[i].LastUse - tensors[i].FirstUse
        lifeJ := tensors[j].LastUse - tensors[j].FirstUse
        return lifeI > lifeJ
    })

    return tensors
}

func (mp *MemoryPlanner) estimateTensorSize(nodeID, tensorID string) int64 {
    // 根据节点类型和属性估算张量大小
    // 这里简化处理
    return 1024 * 1024 // 1MB
}

// FreeList 空闲列表
type FreeList struct {
    blocks []*FreeBlock
}

type FreeBlock struct {
    Offset    int64
    Size      int64
    FreeAfter int
}

func NewFreeList() *FreeList {
    return &FreeList{
        blocks: make([]*FreeBlock, 0),
    }
}

func (fl *FreeList) FindBestFit(size int64, currentOp int) (int64, bool) {
    var bestBlock *FreeBlock
    var bestIdx int = -1

    for i, block := range fl.blocks {
        // 检查是否已释放且大小足够
        if block.FreeAfter < currentOp && block.Size >= size {
            if bestBlock == nil || block.Size < bestBlock.Size {
                bestBlock = block
                bestIdx = i
            }
        }
    }

    if bestBlock != nil {
        // 找到合适的块
        offset := bestBlock.Offset

        // 如果块更大,拆分
        if bestBlock.Size > size {
            newBlock := &FreeBlock{
                Offset:    offset + size,
                Size:      bestBlock.Size - size,
                FreeAfter: bestBlock.FreeAfter,
            }
            fl.blocks[bestIdx] = newBlock
        } else {
            // 移除使用的块
            fl.blocks = append(fl.blocks[:bestIdx], fl.blocks[bestIdx+1:]...)
        }

        return offset, true
    }

    return 0, false
}

func (fl *FreeList) Add(block *FreeBlock) {
    fl.blocks = append(fl.blocks, block)
}

// InPlaceOptimizer 原地优化器
type InPlaceOptimizer struct {
    graph *ComputeGraph
}

// Optimize 执行原地优化
func (ipo *InPlaceOptimizer) Optimize() {
    for _, node := range ipo.graph.Nodes {
        if ipo.canBeInPlace(node) {
            node.Attrs["inplace"] = true
        }
    }
}

func (ipo *InPlaceOptimizer) canBeInPlace(node *GraphNode) bool {
    // 检查操作是否可以原地执行
    switch node.OpType {
    case "ReLU", "GELU", "Sigmoid", "Tanh":
        // 激活函数可以原地
        return true
    case "Add", "Mul":
        // 逐元素操作在形状相同时可以原地
        return ipo.hasMatchingShapes(node.Inputs)
    case "Dropout":
        // Dropout 可以原地
        return true
    default:
        return false
    }
}

func (ipo *InPlaceOptimizer) hasMatchingShapes(inputs []string) bool {
    // 检查输入形状是否匹配
    return true // 简化实现
}

KV Cache 优化

PagedAttention 实现

PagedAttention 是 vLLM 的核心技术,通过分页管理 KV Cache:

package optimization

import (
    "errors"
    "sync"
)

// PagedKVCache 分页 KV Cache
type PagedKVCache struct {
    config      *PagedKVConfig
    blocks      []*KVBlock
    freeBlocks  []int
    blockTables map[int][]int // sequence_id -> block_indices
    mu          sync.RWMutex
}

type PagedKVConfig struct {
    NumLayers    int
    NumHeads     int
    HeadDim      int
    BlockSize    int   // 每块的 token 数
    NumGPUBlocks int
    NumCPUBlocks int
    DType        string // fp16, fp32
}

// KVBlock KV Cache 块
type KVBlock struct {
    KeyCache   [][]float16   // [num_tokens, num_heads * head_dim]
    ValueCache [][]float16   // [num_tokens, num_heads * head_dim]
    NumTokens  int
    RefCount   int
}

type float16 float32 // 简化表示

func NewPagedKVCache(config *PagedKVConfig) *PagedKVCache {
    cache := &PagedKVCache{
        config:      config,
        blocks:      make([]*KVBlock, config.NumGPUBlocks),
        freeBlocks:  make([]int, config.NumGPUBlocks),
        blockTables: make(map[int][]int),
    }

    // 初始化所有块
    for i := 0; i < config.NumGPUBlocks; i++ {
        cache.blocks[i] = &KVBlock{
            KeyCache:   make([][]float16, config.BlockSize),
            ValueCache: make([][]float16, config.BlockSize),
        }
        for j := 0; j < config.BlockSize; j++ {
            cache.blocks[i].KeyCache[j] = make([]float16, config.NumHeads*config.HeadDim)
            cache.blocks[i].ValueCache[j] = make([]float16, config.NumHeads*config.HeadDim)
        }
        cache.freeBlocks[i] = i
    }

    return cache
}

// AllocateBlocks 为序列分配块
func (c *PagedKVCache) AllocateBlocks(seqID, numBlocks int) ([]int, error) {
    c.mu.Lock()
    defer c.mu.Unlock()

    if len(c.freeBlocks) < numBlocks {
        return nil, errors.New("not enough free blocks")
    }

    // 分配块
    allocated := make([]int, numBlocks)
    copy(allocated, c.freeBlocks[:numBlocks])
    c.freeBlocks = c.freeBlocks[numBlocks:]

    // 增加引用计数
    for _, blockIdx := range allocated {
        c.blocks[blockIdx].RefCount = 1
    }

    // 更新块表
    c.blockTables[seqID] = allocated

    return allocated, nil
}

// AppendToken 追加 token 的 KV Cache
func (c *PagedKVCache) AppendToken(
    seqID int,
    layer int,
    key, value []float16,
) error {
    c.mu.Lock()
    defer c.mu.Unlock()

    blockTable := c.blockTables[seqID]
    if blockTable == nil {
        return errors.New("sequence not found")
    }

    // 计算当前 token 位置
    totalTokens := c.getSequenceLength(seqID)
    blockIdx := totalTokens / c.config.BlockSize
    slotIdx := totalTokens % c.config.BlockSize

    // 检查是否需要新块
    if blockIdx >= len(blockTable) {
        // 需要分配新块
        if len(c.freeBlocks) == 0 {
            return errors.New("no free blocks")
        }
        newBlockIdx := c.freeBlocks[0]
        c.freeBlocks = c.freeBlocks[1:]
        c.blocks[newBlockIdx].RefCount = 1
        blockTable = append(blockTable, newBlockIdx)
        c.blockTables[seqID] = blockTable
    }

    // 写入 KV Cache
    physicalBlock := blockTable[blockIdx]
    copy(c.blocks[physicalBlock].KeyCache[slotIdx], key)
    copy(c.blocks[physicalBlock].ValueCache[slotIdx], value)
    c.blocks[physicalBlock].NumTokens++

    return nil
}

// ForkSequence 分叉序列(用于 beam search)
func (c *PagedKVCache) ForkSequence(parentSeqID, childSeqID int) error {
    c.mu.Lock()
    defer c.mu.Unlock()

    parentTable := c.blockTables[parentSeqID]
    if parentTable == nil {
        return errors.New("parent sequence not found")
    }

    // 创建子序列块表(共享块)
    childTable := make([]int, len(parentTable))
    copy(childTable, parentTable)

    // 增加引用计数(Copy-on-Write)
    for _, blockIdx := range childTable {
        c.blocks[blockIdx].RefCount++
    }

    c.blockTables[childSeqID] = childTable
    return nil
}

// CopyOnWrite 写时复制
func (c *PagedKVCache) CopyOnWrite(seqID, logicalBlockIdx int) (int, error) {
    c.mu.Lock()
    defer c.mu.Unlock()

    blockTable := c.blockTables[seqID]
    if blockTable == nil || logicalBlockIdx >= len(blockTable) {
        return -1, errors.New("invalid sequence or block index")
    }

    physicalBlockIdx := blockTable[logicalBlockIdx]

    // 检查是否需要复制
    if c.blocks[physicalBlockIdx].RefCount == 1 {
        return physicalBlockIdx, nil // 不需要复制
    }

    // 分配新块
    if len(c.freeBlocks) == 0 {
        return -1, errors.New("no free blocks for copy")
    }
    newBlockIdx := c.freeBlocks[0]
    c.freeBlocks = c.freeBlocks[1:]

    // 复制数据
    oldBlock := c.blocks[physicalBlockIdx]
    newBlock := c.blocks[newBlockIdx]

    for i := 0; i < c.config.BlockSize; i++ {
        copy(newBlock.KeyCache[i], oldBlock.KeyCache[i])
        copy(newBlock.ValueCache[i], oldBlock.ValueCache[i])
    }
    newBlock.NumTokens = oldBlock.NumTokens
    newBlock.RefCount = 1

    // 减少原块引用
    oldBlock.RefCount--

    // 更新块表
    blockTable[logicalBlockIdx] = newBlockIdx

    return newBlockIdx, nil
}

// FreeSequence 释放序列
func (c *PagedKVCache) FreeSequence(seqID int) {
    c.mu.Lock()
    defer c.mu.Unlock()

    blockTable := c.blockTables[seqID]
    if blockTable == nil {
        return
    }

    for _, blockIdx := range blockTable {
        c.blocks[blockIdx].RefCount--
        if c.blocks[blockIdx].RefCount == 0 {
            // 块空闲
            c.blocks[blockIdx].NumTokens = 0
            c.freeBlocks = append(c.freeBlocks, blockIdx)
        }
    }

    delete(c.blockTables, seqID)
}

// GetBlockTable 获取块表
func (c *PagedKVCache) GetBlockTable(seqID int) []int {
    c.mu.RLock()
    defer c.mu.RUnlock()

    table := c.blockTables[seqID]
    if table == nil {
        return nil
    }

    result := make([]int, len(table))
    copy(result, table)
    return result
}

func (c *PagedKVCache) getSequenceLength(seqID int) int {
    blockTable := c.blockTables[seqID]
    if len(blockTable) == 0 {
        return 0
    }

    totalTokens := 0
    for _, blockIdx := range blockTable {
        totalTokens += c.blocks[blockIdx].NumTokens
    }
    return totalTokens
}

// GetFreeBlockCount 获取空闲块数量
func (c *PagedKVCache) GetFreeBlockCount() int {
    c.mu.RLock()
    defer c.mu.RUnlock()
    return len(c.freeBlocks)
}

// GetUtilization 获取内存利用率
func (c *PagedKVCache) GetUtilization() float64 {
    c.mu.RLock()
    defer c.mu.RUnlock()

    usedBlocks := c.config.NumGPUBlocks - len(c.freeBlocks)
    return float64(usedBlocks) / float64(c.config.NumGPUBlocks)
}

Prefix Caching

前缀缓存允许多个请求共享相同前缀的 KV Cache:

package optimization

import (
    "crypto/sha256"
    "encoding/hex"
    "sync"
    "time"
)

// PrefixCache 前缀缓存
type PrefixCache struct {
    cache       map[string]*PrefixEntry
    maxEntries  int
    evictionPolicy string
    mu          sync.RWMutex
}

// PrefixEntry 前缀条目
type PrefixEntry struct {
    Hash        string
    Tokens      []int32
    BlockTable  []int
    NumTokens   int
    RefCount    int
    LastAccess  time.Time
    CreateTime  time.Time
}

func NewPrefixCache(maxEntries int, evictionPolicy string) *PrefixCache {
    return &PrefixCache{
        cache:          make(map[string]*PrefixEntry),
        maxEntries:     maxEntries,
        evictionPolicy: evictionPolicy,
    }
}

// ComputePrefixHash 计算前缀哈希
func (pc *PrefixCache) ComputePrefixHash(tokens []int32) string {
    h := sha256.New()
    for _, t := range tokens {
        h.Write([]byte{
            byte(t >> 24),
            byte(t >> 16),
            byte(t >> 8),
            byte(t),
        })
    }
    return hex.EncodeToString(h.Sum(nil))[:16]
}

// FindPrefix 查找最长匹配前缀
func (pc *PrefixCache) FindPrefix(tokens []int32) (*PrefixEntry, int) {
    pc.mu.RLock()
    defer pc.mu.RUnlock()

    var bestMatch *PrefixEntry
    var bestLen int

    // 尝试不同长度的前缀
    for prefixLen := len(tokens); prefixLen > 0; prefixLen-- {
        hash := pc.ComputePrefixHash(tokens[:prefixLen])

        if entry, exists := pc.cache[hash]; exists {
            if prefixLen > bestLen {
                bestMatch = entry
                bestLen = prefixLen
            }
            break // 找到最长匹配
        }
    }

    if bestMatch != nil {
        bestMatch.LastAccess = time.Now()
        bestMatch.RefCount++
    }

    return bestMatch, bestLen
}

// InsertPrefix 插入前缀
func (pc *PrefixCache) InsertPrefix(tokens []int32, blockTable []int) *PrefixEntry {
    pc.mu.Lock()
    defer pc.mu.Unlock()

    // 检查是否需要驱逐
    if len(pc.cache) >= pc.maxEntries {
        pc.evict()
    }

    hash := pc.ComputePrefixHash(tokens)

    entry := &PrefixEntry{
        Hash:       hash,
        Tokens:     append([]int32{}, tokens...),
        BlockTable: append([]int{}, blockTable...),
        NumTokens:  len(tokens),
        RefCount:   1,
        LastAccess: time.Now(),
        CreateTime: time.Now(),
    }

    pc.cache[hash] = entry
    return entry
}

// evict 驱逐策略
func (pc *PrefixCache) evict() {
    switch pc.evictionPolicy {
    case "lru":
        pc.evictLRU()
    case "lfu":
        pc.evictLFU()
    default:
        pc.evictLRU()
    }
}

func (pc *PrefixCache) evictLRU() {
    var oldest *PrefixEntry
    var oldestHash string

    for hash, entry := range pc.cache {
        if entry.RefCount == 0 {
            if oldest == nil || entry.LastAccess.Before(oldest.LastAccess) {
                oldest = entry
                oldestHash = hash
            }
        }
    }

    if oldestHash != "" {
        delete(pc.cache, oldestHash)
    }
}

func (pc *PrefixCache) evictLFU() {
    var leastFreq *PrefixEntry
    var leastHash string
    minRef := int(^uint(0) >> 1)

    for hash, entry := range pc.cache {
        if entry.RefCount < minRef {
            minRef = entry.RefCount
            leastFreq = entry
            leastHash = hash
        }
    }

    if leastHash != "" {
        delete(pc.cache, leastHash)
    }
    _ = leastFreq
}

// ReleasePrefix 释放前缀引用
func (pc *PrefixCache) ReleasePrefix(hash string) {
    pc.mu.Lock()
    defer pc.mu.Unlock()

    if entry, exists := pc.cache[hash]; exists {
        entry.RefCount--
    }
}

// RadixPrefixCache 基于 Radix Tree 的前缀缓存(更高效的前缀匹配)
type RadixPrefixCache struct {
    root       *RadixNode
    blockCache *PagedKVCache
    mu         sync.RWMutex
}

type RadixNode struct {
    Tokens      []int32
    BlockTable  []int
    Children    map[int32]*RadixNode
    IsEndpoint  bool
    RefCount    int
    LastAccess  time.Time
}

func NewRadixPrefixCache(blockCache *PagedKVCache) *RadixPrefixCache {
    return &RadixPrefixCache{
        root: &RadixNode{
            Children: make(map[int32]*RadixNode),
        },
        blockCache: blockCache,
    }
}

// Insert 插入前缀
func (rpc *RadixPrefixCache) Insert(tokens []int32, blockTable []int) {
    rpc.mu.Lock()
    defer rpc.mu.Unlock()

    node := rpc.root
    for i, token := range tokens {
        if child, exists := node.Children[token]; exists {
            node = child
        } else {
            newNode := &RadixNode{
                Tokens:   tokens[i:],
                Children: make(map[int32]*RadixNode),
            }
            node.Children[token] = newNode
            node = newNode
            break
        }
    }

    node.BlockTable = append([]int{}, blockTable...)
    node.IsEndpoint = true
    node.RefCount = 1
    node.LastAccess = time.Now()
}

// FindLongestPrefix 查找最长前缀
func (rpc *RadixPrefixCache) FindLongestPrefix(tokens []int32) ([]int, int) {
    rpc.mu.RLock()
    defer rpc.mu.RUnlock()

    var bestBlockTable []int
    var bestLen int

    node := rpc.root
    matchedLen := 0

    for _, token := range tokens {
        child, exists := node.Children[token]
        if !exists {
            break
        }
        node = child
        matchedLen++

        if node.IsEndpoint {
            bestBlockTable = node.BlockTable
            bestLen = matchedLen
            node.LastAccess = time.Now()
            node.RefCount++
        }
    }

    return bestBlockTable, bestLen
}

// Evict 驱逐不活跃的前缀
func (rpc *RadixPrefixCache) Evict(maxAge time.Duration) {
    rpc.mu.Lock()
    defer rpc.mu.Unlock()

    cutoff := time.Now().Add(-maxAge)
    rpc.evictRecursive(rpc.root, cutoff)
}

func (rpc *RadixPrefixCache) evictRecursive(node *RadixNode, cutoff time.Time) {
    tokensToDelete := make([]int32, 0)

    for token, child := range node.Children {
        if child.RefCount == 0 && child.LastAccess.Before(cutoff) {
            // 释放 KV Cache 块
            for _, blockIdx := range child.BlockTable {
                _ = blockIdx // 释放块
            }
            tokensToDelete = append(tokensToDelete, token)
        } else {
            rpc.evictRecursive(child, cutoff)
        }
    }

    for _, token := range tokensToDelete {
        delete(node.Children, token)
    }
}

系统级优化

CUDA 优化

package optimization

/*
#cgo LDFLAGS: -lcudart -lcublas

#include <cuda_runtime.h>
#include <cublas_v2.h>

// 自定义 CUDA kernel 声明
extern void fused_attention_kernel(
    float* query, float* key, float* value, float* output,
    int batch_size, int num_heads, int seq_len, int head_dim,
    cudaStream_t stream
);

extern void fused_rope_kernel(
    float* input, float* cos_cache, float* sin_cache,
    int batch_size, int seq_len, int num_heads, int head_dim,
    cudaStream_t stream
);

extern void quantized_matmul_kernel(
    int8_t* A, int8_t* B, float* C,
    float* scale_a, float* scale_b,
    int M, int N, int K,
    cudaStream_t stream
);
*/
import "C"

import (
    "unsafe"
)

// CUDAOptimizer CUDA 优化器
type CUDAOptimizer struct {
    deviceID     int
    stream       C.cudaStream_t
    cublasHandle C.cublasHandle_t
    memPool      *CUDAMemoryPool
}

// CUDAMemoryPool CUDA 内存池
type CUDAMemoryPool struct {
    pool      C.cudaMemPool_t
    allocSize int64
}

func NewCUDAOptimizer(deviceID int) (*CUDAOptimizer, error) {
    opt := &CUDAOptimizer{deviceID: deviceID}

    // 设置设备
    C.cudaSetDevice(C.int(deviceID))

    // 创建 stream
    C.cudaStreamCreate(&opt.stream)

    // 创建 cuBLAS handle
    C.cublasCreate(&opt.cublasHandle)
    C.cublasSetStream(opt.cublasHandle, opt.stream)

    // 创建内存池
    opt.memPool = opt.createMemoryPool()

    return opt, nil
}

func (opt *CUDAOptimizer) createMemoryPool() *CUDAMemoryPool {
    var poolProps C.cudaMemPoolProps
    poolProps.allocType = C.cudaMemAllocationTypePinned
    poolProps.handleTypes = C.cudaMemHandleTypeNone
    poolProps.location.type_ = C.cudaMemLocationTypeDevice
    poolProps.location.id = C.int(opt.deviceID)

    var pool C.cudaMemPool_t
    C.cudaMemPoolCreate(&pool, &poolProps)

    return &CUDAMemoryPool{pool: pool}
}

// FusedAttention 融合注意力
func (opt *CUDAOptimizer) FusedAttention(
    query, key, value, output unsafe.Pointer,
    batchSize, numHeads, seqLen, headDim int,
) {
    C.fused_attention_kernel(
        (*C.float)(query),
        (*C.float)(key),
        (*C.float)(value),
        (*C.float)(output),
        C.int(batchSize),
        C.int(numHeads),
        C.int(seqLen),
        C.int(headDim),
        opt.stream,
    )
}

// FusedRoPE 融合旋转位置编码
func (opt *CUDAOptimizer) FusedRoPE(
    input, cosCache, sinCache unsafe.Pointer,
    batchSize, seqLen, numHeads, headDim int,
) {
    C.fused_rope_kernel(
        (*C.float)(input),
        (*C.float)(cosCache),
        (*C.float)(sinCache),
        C.int(batchSize),
        C.int(seqLen),
        C.int(numHeads),
        C.int(headDim),
        opt.stream,
    )
}

// QuantizedMatmul INT8 矩阵乘法
func (opt *CUDAOptimizer) QuantizedMatmul(
    A, B, C, scaleA, scaleB unsafe.Pointer,
    M, N, K int,
) {
    C.quantized_matmul_kernel(
        (*C.int8_t)(A),
        (*C.int8_t)(B),
        (*C.float)(C),
        (*C.float)(scaleA),
        (*C.float)(scaleB),
        C.int(M),
        C.int(N),
        C.int(K),
        opt.stream,
    )
}

// AllocAsync 异步内存分配
func (opt *CUDAOptimizer) AllocAsync(size int64) unsafe.Pointer {
    var ptr unsafe.Pointer
    C.cudaMallocAsync(&ptr, C.size_t(size), opt.stream)
    return ptr
}

// FreeAsync 异步内存释放
func (opt *CUDAOptimizer) FreeAsync(ptr unsafe.Pointer) {
    C.cudaFreeAsync(ptr, opt.stream)
}

// Synchronize 同步
func (opt *CUDAOptimizer) Synchronize() {
    C.cudaStreamSynchronize(opt.stream)
}

// Close 关闭
func (opt *CUDAOptimizer) Close() {
    C.cudaStreamDestroy(opt.stream)
    C.cublasDestroy(opt.cublasHandle)
}

// MultiStreamExecutor 多流执行器
type MultiStreamExecutor struct {
    streams     []C.cudaStream_t
    numStreams  int
    currentIdx  int
}

func NewMultiStreamExecutor(numStreams int) *MultiStreamExecutor {
    exec := &MultiStreamExecutor{
        streams:    make([]C.cudaStream_t, numStreams),
        numStreams: numStreams,
    }

    for i := 0; i < numStreams; i++ {
        C.cudaStreamCreate(&exec.streams[i])
    }

    return exec
}

// GetNextStream 获取下一个 stream
func (exec *MultiStreamExecutor) GetNextStream() C.cudaStream_t {
    stream := exec.streams[exec.currentIdx]
    exec.currentIdx = (exec.currentIdx + 1) % exec.numStreams
    return stream
}

// SyncAll 同步所有 stream
func (exec *MultiStreamExecutor) SyncAll() {
    for _, stream := range exec.streams {
        C.cudaStreamSynchronize(stream)
    }
}

// Close 关闭
func (exec *MultiStreamExecutor) Close() {
    for _, stream := range exec.streams {
        C.cudaStreamDestroy(stream)
    }
}

并行化策略

package optimization

import (
    "context"
    "sync"
)

// TensorParallelism 张量并行
type TensorParallelism struct {
    worldSize    int
    rank         int
    commBackend  CommunicationBackend
}

type CommunicationBackend interface {
    AllReduce(tensor []float32) []float32
    AllGather(tensor []float32) []float32
    Broadcast(tensor []float32, srcRank int) []float32
    Send(tensor []float32, dstRank int)
    Recv(srcRank int) []float32
}

// ColumnParallelLinear 列并行线性层
type ColumnParallelLinear struct {
    tp         *TensorParallelism
    weight     [][]float32 // [out_features/world_size, in_features]
    bias       []float32   // [out_features/world_size]
    gatherOutput bool
}

func (cpl *ColumnParallelLinear) Forward(input []float32) []float32 {
    // 本地矩阵乘法
    localOutput := matmul(input, cpl.weight)

    if cpl.bias != nil {
        for i := range localOutput {
            localOutput[i] += cpl.bias[i%len(cpl.bias)]
        }
    }

    if cpl.gatherOutput {
        // AllGather 收集所有结果
        return cpl.tp.commBackend.AllGather(localOutput)
    }

    return localOutput
}

// RowParallelLinear 行并行线性层
type RowParallelLinear struct {
    tp          *TensorParallelism
    weight      [][]float32 // [out_features, in_features/world_size]
    bias        []float32   // [out_features]
    inputIsParallel bool
}

func (rpl *RowParallelLinear) Forward(input []float32) []float32 {
    var localInput []float32

    if !rpl.inputIsParallel {
        // 分割输入
        inputSize := len(input)
        localSize := inputSize / rpl.tp.worldSize
        start := rpl.tp.rank * localSize
        localInput = input[start : start+localSize]
    } else {
        localInput = input
    }

    // 本地矩阵乘法
    localOutput := matmul(localInput, rpl.weight)

    // AllReduce 聚合
    output := rpl.tp.commBackend.AllReduce(localOutput)

    if rpl.bias != nil && rpl.tp.rank == 0 {
        for i := range output {
            output[i] += rpl.bias[i%len(rpl.bias)]
        }
    }

    return output
}

// PipelineParallelism 流水线并行
type PipelineParallelism struct {
    stages       []PipelineStage
    numStages    int
    microBatches int
    commBackend  CommunicationBackend
}

type PipelineStage interface {
    Forward(input []float32) []float32
    Backward(grad []float32) []float32
    GetRank() int
}

// Execute1F1B 1F1B 调度
func (pp *PipelineParallelism) Execute1F1B(ctx context.Context, input []float32) []float32 {
    numMicroBatches := pp.microBatches
    warmupSteps := pp.numStages - 1
    steadySteps := numMicroBatches - warmupSteps

    // 分割成 micro-batches
    microBatchSize := len(input) / numMicroBatches
    microBatches := make([][]float32, numMicroBatches)
    for i := 0; i < numMicroBatches; i++ {
        start := i * microBatchSize
        microBatches[i] = input[start : start+microBatchSize]
    }

    outputs := make([][]float32, numMicroBatches)
    var wg sync.WaitGroup

    // Warm-up 阶段:只有前向
    for i := 0; i < warmupSteps && i < numMicroBatches; i++ {
        wg.Add(1)
        go func(idx int) {
            defer wg.Done()
            outputs[idx] = pp.forwardMicroBatch(microBatches[idx], idx)
        }(i)
    }

    // Steady 阶段:1 前向 + 1 反向
    for i := 0; i < steadySteps; i++ {
        fwdIdx := warmupSteps + i
        bwdIdx := i

        // 前向
        if fwdIdx < numMicroBatches {
            wg.Add(1)
            go func(idx int) {
                defer wg.Done()
                outputs[idx] = pp.forwardMicroBatch(microBatches[idx], idx)
            }(fwdIdx)
        }

        // 反向
        wg.Add(1)
        go func(idx int) {
            defer wg.Done()
            pp.backwardMicroBatch(outputs[idx], idx)
        }(bwdIdx)
    }

    // Cool-down 阶段:只有反向
    for i := steadySteps; i < numMicroBatches; i++ {
        wg.Add(1)
        go func(idx int) {
            defer wg.Done()
            pp.backwardMicroBatch(outputs[idx], idx)
        }(i)
    }

    wg.Wait()

    // 合并输出
    var result []float32
    for _, out := range outputs {
        result = append(result, out...)
    }

    return result
}

func (pp *PipelineParallelism) forwardMicroBatch(input []float32, microBatchID int) []float32 {
    current := input
    for _, stage := range pp.stages {
        current = stage.Forward(current)
    }
    return current
}

func (pp *PipelineParallelism) backwardMicroBatch(grad []float32, microBatchID int) {
    current := grad
    for i := len(pp.stages) - 1; i >= 0; i-- {
        current = pp.stages[i].Backward(current)
    }
}

func matmul(a []float32, b [][]float32) []float32 {
    // 简化的矩阵乘法
    return nil
}

性能基准测试

基准测试框架

package benchmark

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// InferenceBenchmark 推理基准测试
type InferenceBenchmark struct {
    engine       InferenceEngine
    config       *BenchmarkConfig
    results      *BenchmarkResults
}

type BenchmarkConfig struct {
    // 测试参数
    NumPrompts      int
    PromptLength    []int  // 不同长度的 prompt
    OutputLength    []int  // 不同长度的输出
    ConcurrentReqs  []int  // 不同并发数

    // 预热
    WarmupRequests  int
    WarmupDuration  time.Duration

    // 测试时间
    TestDuration    time.Duration
    RequestTimeout  time.Duration
}

type BenchmarkResults struct {
    // 延迟指标
    TTFT          LatencyStats // Time to First Token
    TPOT          LatencyStats // Time per Output Token
    E2ELatency    LatencyStats // 端到端延迟

    // 吞吐量指标
    TokensPerSecond   float64
    RequestsPerSecond float64

    // 资源指标
    GPUUtilization    float64
    MemoryUsage       int64
    PeakMemory        int64

    // 详细结果
    DetailedResults   []*TestCaseResult
}

type LatencyStats struct {
    Min    time.Duration
    Max    time.Duration
    Mean   time.Duration
    P50    time.Duration
    P90    time.Duration
    P95    time.Duration
    P99    time.Duration
    Stddev time.Duration
}

type TestCaseResult struct {
    PromptLength    int
    OutputLength    int
    Concurrency     int
    TTFT            time.Duration
    TPOT            time.Duration
    TotalLatency    time.Duration
    TokensGenerated int
    Success         bool
    Error           error
}

type InferenceEngine interface {
    Generate(ctx context.Context, prompt string, maxTokens int) (*GenerateResult, error)
}

type GenerateResult struct {
    Text           string
    Tokens         []int32
    TTFT           time.Duration
    TotalLatency   time.Duration
    TokensPerSec   float64
}

// Run 运行基准测试
func (b *InferenceBenchmark) Run(ctx context.Context) (*BenchmarkResults, error) {
    b.results = &BenchmarkResults{
        DetailedResults: make([]*TestCaseResult, 0),
    }

    // 预热
    fmt.Println("Warming up...")
    if err := b.warmup(ctx); err != nil {
        return nil, err
    }

    // 运行测试矩阵
    for _, promptLen := range b.config.PromptLength {
        for _, outputLen := range b.config.OutputLength {
            for _, concurrency := range b.config.ConcurrentReqs {
                fmt.Printf("Testing: prompt=%d, output=%d, concurrency=%d\n",
                    promptLen, outputLen, concurrency)

                results := b.runTestCase(ctx, promptLen, outputLen, concurrency)
                b.results.DetailedResults = append(b.results.DetailedResults, results...)
            }
        }
    }

    // 计算聚合指标
    b.calculateAggregates()

    return b.results, nil
}

func (b *InferenceBenchmark) warmup(ctx context.Context) error {
    prompt := generatePrompt(100)

    for i := 0; i < b.config.WarmupRequests; i++ {
        _, err := b.engine.Generate(ctx, prompt, 50)
        if err != nil {
            return err
        }
    }

    return nil
}

func (b *InferenceBenchmark) runTestCase(
    ctx context.Context,
    promptLen, outputLen, concurrency int,
) []*TestCaseResult {

    results := make([]*TestCaseResult, 0)
    var mu sync.Mutex
    var wg sync.WaitGroup

    prompt := generatePrompt(promptLen)

    // 计算请求数
    numRequests := b.config.NumPrompts
    requestsPerWorker := numRequests / concurrency

    startTime := time.Now()

    for w := 0; w < concurrency; w++ {
        wg.Add(1)
        go func(workerID int) {
            defer wg.Done()

            for i := 0; i < requestsPerWorker; i++ {
                // 检查是否超时
                if time.Since(startTime) > b.config.TestDuration {
                    return
                }

                reqCtx, cancel := context.WithTimeout(ctx, b.config.RequestTimeout)
                result := b.runSingleRequest(reqCtx, prompt, outputLen)
                cancel()

                mu.Lock()
                results = append(results, result)
                mu.Unlock()
            }
        }(w)
    }

    wg.Wait()

    return results
}

func (b *InferenceBenchmark) runSingleRequest(
    ctx context.Context,
    prompt string,
    maxTokens int,
) *TestCaseResult {

    result := &TestCaseResult{
        PromptLength: len(prompt),
        OutputLength: maxTokens,
    }

    startTime := time.Now()

    genResult, err := b.engine.Generate(ctx, prompt, maxTokens)
    if err != nil {
        result.Error = err
        result.Success = false
        return result
    }

    result.TTFT = genResult.TTFT
    result.TotalLatency = time.Since(startTime)
    result.TokensGenerated = len(genResult.Tokens)

    if result.TokensGenerated > 0 {
        result.TPOT = (result.TotalLatency - result.TTFT) / time.Duration(result.TokensGenerated)
    }

    result.Success = true
    return result
}

func (b *InferenceBenchmark) calculateAggregates() {
    var ttfts, tpots, e2eLatencies []time.Duration
    var totalTokens int
    successCount := 0

    for _, r := range b.results.DetailedResults {
        if r.Success {
            ttfts = append(ttfts, r.TTFT)
            tpots = append(tpots, r.TPOT)
            e2eLatencies = append(e2eLatencies, r.TotalLatency)
            totalTokens += r.TokensGenerated
            successCount++
        }
    }

    if len(ttfts) > 0 {
        b.results.TTFT = calculateLatencyStats(ttfts)
        b.results.TPOT = calculateLatencyStats(tpots)
        b.results.E2ELatency = calculateLatencyStats(e2eLatencies)
    }

    // 计算吞吐量
    if len(e2eLatencies) > 0 {
        totalTime := e2eLatencies[len(e2eLatencies)-1]
        b.results.TokensPerSecond = float64(totalTokens) / totalTime.Seconds()
        b.results.RequestsPerSecond = float64(successCount) / totalTime.Seconds()
    }
}

func calculateLatencyStats(latencies []time.Duration) LatencyStats {
    if len(latencies) == 0 {
        return LatencyStats{}
    }

    // 排序
    sorted := make([]time.Duration, len(latencies))
    copy(sorted, latencies)
    sortDurations(sorted)

    // 计算统计量
    var sum time.Duration
    for _, l := range sorted {
        sum += l
    }

    mean := sum / time.Duration(len(sorted))

    return LatencyStats{
        Min:  sorted[0],
        Max:  sorted[len(sorted)-1],
        Mean: mean,
        P50:  percentile(sorted, 50),
        P90:  percentile(sorted, 90),
        P95:  percentile(sorted, 95),
        P99:  percentile(sorted, 99),
    }
}

func percentile(sorted []time.Duration, p int) time.Duration {
    idx := (len(sorted) * p) / 100
    if idx >= len(sorted) {
        idx = len(sorted) - 1
    }
    return sorted[idx]
}

func sortDurations(d []time.Duration) {
    for i := 0; i < len(d)-1; i++ {
        for j := i + 1; j < len(d); j++ {
            if d[i] > d[j] {
                d[i], d[j] = d[j], d[i]
            }
        }
    }
}

func generatePrompt(length int) string {
    // 生成指定长度的测试 prompt
    text := "The quick brown fox jumps over the lazy dog. "
    result := ""
    for len(result) < length {
        result += text
    }
    return result[:length]
}

// GenerateReport 生成报告
func (b *InferenceBenchmark) GenerateReport() string {
    r := b.results

    report := fmt.Sprintf(`
=== Inference Benchmark Report ===

Latency Metrics:
  TTFT (Time to First Token):
    Min: %v, Max: %v, Mean: %v
    P50: %v, P90: %v, P95: %v, P99: %v

  TPOT (Time per Output Token):
    Min: %v, Max: %v, Mean: %v
    P50: %v, P90: %v, P95: %v, P99: %v

  E2E Latency:
    Min: %v, Max: %v, Mean: %v
    P50: %v, P90: %v, P95: %v, P99: %v

Throughput Metrics:
  Tokens/sec: %.2f
  Requests/sec: %.2f

Resource Metrics:
  GPU Utilization: %.2f%%
  Memory Usage: %d MB
  Peak Memory: %d MB
`,
        r.TTFT.Min, r.TTFT.Max, r.TTFT.Mean,
        r.TTFT.P50, r.TTFT.P90, r.TTFT.P95, r.TTFT.P99,
        r.TPOT.Min, r.TPOT.Max, r.TPOT.Mean,
        r.TPOT.P50, r.TPOT.P90, r.TPOT.P95, r.TPOT.P99,
        r.E2ELatency.Min, r.E2ELatency.Max, r.E2ELatency.Mean,
        r.E2ELatency.P50, r.E2ELatency.P90, r.E2ELatency.P95, r.E2ELatency.P99,
        r.TokensPerSecond, r.RequestsPerSecond,
        r.GPUUtilization*100, r.MemoryUsage/1024/1024, r.PeakMemory/1024/1024,
    )

    return report
}

小结

本章深入探讨了大模型推理优化的核心技术:

  1. 模型量化:INT8/INT4 量化、GPTQ、AWQ 等后训练量化方法
  2. 计算图优化:算子融合、内存规划、原地优化
  3. KV Cache 优化:PagedAttention、Prefix Caching、Copy-on-Write
  4. 系统级优化:CUDA 优化、多流执行、张量并行、流水线并行
  5. 性能基准测试:TTFT、TPOT、吞吐量等关键指标的测量方法

推理优化是一个需要从多个层面综合考虑的系统工程。通过合理应用这些优化技术,可以在保持模型精度的前提下,显著提升推理效率,降低服务成本。

下一章我们将探讨 多模型服务,讲解如何在同一集群中高效管理和调度多个大模型。

Prev
动态批处理
Next
多模型服务