HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

动态批处理

概述

动态批处理(Dynamic Batching)是提升大模型推理吞吐量的核心技术。与传统的静态批处理不同,动态批处理能够在运行时根据请求特性自适应地组织批次,充分利用 GPU 并行计算能力。本章深入讲解动态批处理的原理、实现策略以及在生产环境中的优化实践。

批处理基础

为什么需要批处理

GPU 的设计特点决定了批处理的必要性:

单请求处理 vs 批处理对比:

单请求:
┌─────────────────────────────────────────────────────────────┐
│ GPU 利用率: 10-20%                                           │
│ ┌─────┐                                                     │
│ │Req 1│────────────────────────────────────────────────────►│
│ └─────┘                                                     │
│ 大量 GPU 核心空闲                                            │
└─────────────────────────────────────────────────────────────┘

批处理 (batch_size=8):
┌─────────────────────────────────────────────────────────────┐
│ GPU 利用率: 70-90%                                           │
│ ┌─────┐                                                     │
│ │Req 1│──┐                                                  │
│ ├─────┤  │                                                  │
│ │Req 2│  ├──────────────────────────────────────────────►   │
│ ├─────┤  │                                                  │
│ │ ... │  │                                                  │
│ ├─────┤  │                                                  │
│ │Req 8│──┘                                                  │
│ └─────┘                                                     │
│ GPU 核心充分利用                                              │
└─────────────────────────────────────────────────────────────┘

静态批处理的局限

package batching

import (
    "context"
    "sync"
    "time"
)

// StaticBatcher 静态批处理器
// 问题:等待固定数量请求,延迟不可控
type StaticBatcher struct {
    batchSize    int
    requests     []*InferenceRequest
    mu           sync.Mutex
    batchReady   chan []*InferenceRequest
}

type InferenceRequest struct {
    ID           string
    Prompt       string
    InputTokens  []int32
    MaxNewTokens int
    Temperature  float32
    TopP         float32
    ResultChan   chan *InferenceResult
    ArrivalTime  time.Time
}

type InferenceResult struct {
    RequestID    string
    OutputTokens []int32
    Text         string
    FinishReason string
    Latency      time.Duration
    Error        error
}

func NewStaticBatcher(batchSize int) *StaticBatcher {
    return &StaticBatcher{
        batchSize:  batchSize,
        requests:   make([]*InferenceRequest, 0, batchSize),
        batchReady: make(chan []*InferenceRequest, 10),
    }
}

func (b *StaticBatcher) Submit(req *InferenceRequest) {
    b.mu.Lock()
    defer b.mu.Unlock()

    req.ArrivalTime = time.Now()
    b.requests = append(b.requests, req)

    // 问题1: 必须等待足够请求
    if len(b.requests) >= b.batchSize {
        batch := make([]*InferenceRequest, len(b.requests))
        copy(batch, b.requests)
        b.requests = b.requests[:0]
        b.batchReady <- batch
    }
    // 问题2: 低流量时请求可能永远等待
}

// 静态批处理的问题:
// 1. 序列长度不一致:短序列需要等待长序列完成
// 2. 延迟不可预测:取决于批次中最长序列
// 3. 资源利用不均:padding 浪费计算资源
// 4. 无法动态调整:批次大小固定

动态批处理架构

连续批处理 (Continuous Batching)

连续批处理是现代推理引擎的核心技术,允许请求在任意 token 位置加入或离开批次:

package batching

import (
    "container/heap"
    "context"
    "sync"
    "sync/atomic"
    "time"
)

// ContinuousBatcher 连续批处理器
type ContinuousBatcher struct {
    config          *BatcherConfig
    waitingQueue    *PriorityQueue      // 等待队列
    runningBatch    *RunningBatch       // 运行中的批次
    scheduler       *BatchScheduler     // 调度器
    engine          InferenceEngine     // 推理引擎

    mu              sync.RWMutex
    stats           *BatcherStats
    stopCh          chan struct{}
}

type BatcherConfig struct {
    MaxBatchSize       int           // 最大批次大小
    MaxWaitingRequests int           // 最大等待请求数
    MaxTokensPerBatch  int           // 批次最大 token 数
    MaxWaitTime        time.Duration // 最大等待时间
    PrefillChunkSize   int           // Prefill 分块大小
    SchedulingPolicy   string        // 调度策略: fcfs, sjf, priority
}

type RunningBatch struct {
    requests    map[string]*ActiveRequest
    totalTokens int64
    mu          sync.RWMutex
}

type ActiveRequest struct {
    *InferenceRequest
    State           RequestState
    GeneratedTokens []int32
    KVCacheBlocks   []int          // KV Cache 块索引
    CurrentStep     int
    StartTime       time.Time
    PrefillDone     bool
}

type RequestState int

const (
    StateWaiting RequestState = iota
    StatePrefill
    StateDecoding
    StateFinished
)

func NewContinuousBatcher(config *BatcherConfig, engine InferenceEngine) *ContinuousBatcher {
    cb := &ContinuousBatcher{
        config:       config,
        waitingQueue: NewPriorityQueue(),
        runningBatch: &RunningBatch{
            requests: make(map[string]*ActiveRequest),
        },
        scheduler: NewBatchScheduler(config),
        engine:    engine,
        stats:     &BatcherStats{},
        stopCh:    make(chan struct{}),
    }
    return cb
}

// Submit 提交推理请求
func (cb *ContinuousBatcher) Submit(ctx context.Context, req *InferenceRequest) (*InferenceResult, error) {
    req.ArrivalTime = time.Now()
    req.ResultChan = make(chan *InferenceResult, 1)

    // 检查队列容量
    cb.mu.Lock()
    if cb.waitingQueue.Len() >= cb.config.MaxWaitingRequests {
        cb.mu.Unlock()
        return nil, ErrQueueFull
    }

    // 加入等待队列
    heap.Push(cb.waitingQueue, &QueueItem{
        Request:  req,
        Priority: cb.calculatePriority(req),
    })
    cb.mu.Unlock()

    atomic.AddInt64(&cb.stats.TotalRequests, 1)

    // 等待结果
    select {
    case result := <-req.ResultChan:
        return result, nil
    case <-ctx.Done():
        cb.cancelRequest(req.ID)
        return nil, ctx.Err()
    }
}

// Run 运行批处理循环
func (cb *ContinuousBatcher) Run(ctx context.Context) error {
    ticker := time.NewTicker(time.Millisecond) // 1ms 调度间隔
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return ctx.Err()
        case <-cb.stopCh:
            return nil
        case <-ticker.C:
            cb.step()
        }
    }
}

// step 执行一个调度步骤
func (cb *ContinuousBatcher) step() {
    cb.mu.Lock()
    defer cb.mu.Unlock()

    // 1. 移除已完成的请求
    cb.evictFinishedRequests()

    // 2. 调度新请求加入批次
    cb.scheduleNewRequests()

    // 3. 如果批次非空,执行一步推理
    if cb.runningBatch.Size() > 0 {
        cb.executeStep()
    }
}

// evictFinishedRequests 移除已完成的请求
func (cb *ContinuousBatcher) evictFinishedRequests() {
    cb.runningBatch.mu.Lock()
    defer cb.runningBatch.mu.Unlock()

    for id, req := range cb.runningBatch.requests {
        if req.State == StateFinished {
            // 释放 KV Cache
            cb.engine.ReleaseKVCache(req.KVCacheBlocks)

            // 发送结果
            result := &InferenceResult{
                RequestID:    id,
                OutputTokens: req.GeneratedTokens,
                Text:         cb.engine.Decode(req.GeneratedTokens),
                FinishReason: cb.getFinishReason(req),
                Latency:      time.Since(req.StartTime),
            }
            req.ResultChan <- result

            delete(cb.runningBatch.requests, id)
            atomic.AddInt64(&cb.stats.CompletedRequests, 1)
        }
    }
}

// scheduleNewRequests 调度新请求
func (cb *ContinuousBatcher) scheduleNewRequests() {
    for cb.waitingQueue.Len() > 0 {
        // 检查批次容量
        if cb.runningBatch.Size() >= cb.config.MaxBatchSize {
            break
        }

        // 检查 token 容量
        item := cb.waitingQueue.Peek()
        requiredTokens := len(item.Request.InputTokens) + item.Request.MaxNewTokens
        if cb.runningBatch.totalTokens+int64(requiredTokens) > int64(cb.config.MaxTokensPerBatch) {
            break
        }

        // 分配 KV Cache
        blocks, err := cb.engine.AllocateKVCache(requiredTokens)
        if err != nil {
            break // KV Cache 不足
        }

        // 从等待队列取出
        heap.Pop(cb.waitingQueue)

        // 创建活跃请求
        activeReq := &ActiveRequest{
            InferenceRequest: item.Request,
            State:           StatePrefill,
            GeneratedTokens: make([]int32, 0, item.Request.MaxNewTokens),
            KVCacheBlocks:   blocks,
            StartTime:       time.Now(),
        }

        cb.runningBatch.mu.Lock()
        cb.runningBatch.requests[item.Request.ID] = activeReq
        cb.runningBatch.totalTokens += int64(requiredTokens)
        cb.runningBatch.mu.Unlock()

        atomic.AddInt64(&cb.stats.ScheduledRequests, 1)
    }
}

// executeStep 执行一步推理
func (cb *ContinuousBatcher) executeStep() {
    cb.runningBatch.mu.Lock()
    defer cb.runningBatch.mu.Unlock()

    // 分离 prefill 和 decode 请求
    prefillReqs := make([]*ActiveRequest, 0)
    decodeReqs := make([]*ActiveRequest, 0)

    for _, req := range cb.runningBatch.requests {
        if req.State == StatePrefill {
            prefillReqs = append(prefillReqs, req)
        } else if req.State == StateDecoding {
            decodeReqs = append(decodeReqs, req)
        }
    }

    // 执行 Prefill(分块处理)
    for _, req := range prefillReqs {
        cb.executePrefill(req)
    }

    // 执行 Decode
    if len(decodeReqs) > 0 {
        cb.executeDecode(decodeReqs)
    }
}

// executePrefill 执行 Prefill 阶段
func (cb *ContinuousBatcher) executePrefill(req *ActiveRequest) {
    inputLen := len(req.InputTokens)
    chunkSize := cb.config.PrefillChunkSize

    // 分块 Prefill,避免长序列阻塞
    start := req.CurrentStep
    end := min(start+chunkSize, inputLen)

    chunk := req.InputTokens[start:end]
    cb.engine.Prefill(req.KVCacheBlocks, chunk, start)

    req.CurrentStep = end

    if req.CurrentStep >= inputLen {
        req.PrefillDone = true
        req.State = StateDecoding
    }
}

// executeDecode 执行 Decode 阶段
func (cb *ContinuousBatcher) executeDecode(reqs []*ActiveRequest) {
    // 构建批处理输入
    batchInput := &BatchDecodeInput{
        RequestIDs:    make([]string, len(reqs)),
        LastTokens:    make([]int32, len(reqs)),
        KVCacheBlocks: make([][]int, len(reqs)),
        Positions:     make([]int, len(reqs)),
    }

    for i, req := range reqs {
        batchInput.RequestIDs[i] = req.ID
        if len(req.GeneratedTokens) == 0 {
            batchInput.LastTokens[i] = req.InputTokens[len(req.InputTokens)-1]
        } else {
            batchInput.LastTokens[i] = req.GeneratedTokens[len(req.GeneratedTokens)-1]
        }
        batchInput.KVCacheBlocks[i] = req.KVCacheBlocks
        batchInput.Positions[i] = len(req.InputTokens) + len(req.GeneratedTokens)
    }

    // 批量解码
    outputs := cb.engine.BatchDecode(batchInput)

    // 处理输出
    for i, req := range reqs {
        token := outputs.NextTokens[i]
        req.GeneratedTokens = append(req.GeneratedTokens, token)

        // 检查是否完成
        if cb.isFinished(req, token) {
            req.State = StateFinished
        }
    }
}

func (cb *ContinuousBatcher) isFinished(req *ActiveRequest, token int32) bool {
    // EOS token
    if token == cb.engine.GetEOSToken() {
        return true
    }
    // 达到最大长度
    if len(req.GeneratedTokens) >= req.MaxNewTokens {
        return true
    }
    return false
}

优先级队列实现

package batching

import "container/heap"

// QueueItem 队列项
type QueueItem struct {
    Request  *InferenceRequest
    Priority float64
    Index    int
}

// PriorityQueue 优先级队列
type PriorityQueue []*QueueItem

func NewPriorityQueue() *PriorityQueue {
    pq := make(PriorityQueue, 0)
    heap.Init(&pq)
    return &pq
}

func (pq PriorityQueue) Len() int { return len(pq) }

func (pq PriorityQueue) Less(i, j int) bool {
    // 高优先级在前
    return pq[i].Priority > pq[j].Priority
}

func (pq PriorityQueue) Swap(i, j int) {
    pq[i], pq[j] = pq[j], pq[i]
    pq[i].Index = i
    pq[j].Index = j
}

func (pq *PriorityQueue) Push(x interface{}) {
    n := len(*pq)
    item := x.(*QueueItem)
    item.Index = n
    *pq = append(*pq, item)
}

func (pq *PriorityQueue) Pop() interface{} {
    old := *pq
    n := len(old)
    item := old[n-1]
    old[n-1] = nil
    item.Index = -1
    *pq = old[0 : n-1]
    return item
}

func (pq *PriorityQueue) Peek() *QueueItem {
    if len(*pq) == 0 {
        return nil
    }
    return (*pq)[0]
}

// calculatePriority 计算请求优先级
func (cb *ContinuousBatcher) calculatePriority(req *InferenceRequest) float64 {
    switch cb.config.SchedulingPolicy {
    case "fcfs":
        // 先来先服务:按到达时间排序
        return -float64(req.ArrivalTime.UnixNano())

    case "sjf":
        // 短作业优先:输入+输出 token 数少的优先
        totalTokens := len(req.InputTokens) + req.MaxNewTokens
        return -float64(totalTokens)

    case "priority":
        // 优先级调度:结合多个因素
        return cb.calculateCompositePriority(req)

    default:
        return -float64(req.ArrivalTime.UnixNano())
    }
}

func (cb *ContinuousBatcher) calculateCompositePriority(req *InferenceRequest) float64 {
    // 基于多因素的优先级计算
    // 1. 等待时间因子
    waitTime := time.Since(req.ArrivalTime).Seconds()
    waitFactor := waitTime * 0.1 // 每秒增加 0.1 优先级

    // 2. 序列长度因子(短序列优先)
    totalTokens := float64(len(req.InputTokens) + req.MaxNewTokens)
    lengthFactor := 1000.0 / totalTokens

    // 3. 用户优先级(如有)
    userPriority := 1.0 // 默认优先级

    return waitFactor + lengthFactor + userPriority
}

迭代级调度 (Iteration-Level Scheduling)

细粒度调度策略

迭代级调度允许在每个 decode 迭代中动态调整批次组成:

package batching

import (
    "sort"
    "sync"
)

// IterationScheduler 迭代级调度器
type IterationScheduler struct {
    config        *SchedulerConfig
    preemptable   bool              // 是否支持抢占
    swapManager   *SwapManager      // 交换管理器
}

type SchedulerConfig struct {
    MaxBatchTokens      int     // 每批次最大 token
    MaxPrefillTokens    int     // Prefill 最大 token
    MaxNumSeqs          int     // 最大并发序列
    PreemptionMode      string  // 抢占模式: recompute, swap
    SwapSpaceGB         float64 // 交换空间大小
}

// SchedulerOutput 调度输出
type SchedulerOutput struct {
    ScheduledSeqs    []*SequenceGroup  // 调度的序列组
    PreemptedSeqs    []*SequenceGroup  // 被抢占的序列组
    BlocksToSwapIn   map[int]int       // 需要换入的块
    BlocksToSwapOut  map[int]int       // 需要换出的块
    BlocksToCopy     map[int]int       // 需要复制的块
}

// SequenceGroup 序列组(支持 beam search)
type SequenceGroup struct {
    RequestID     string
    Sequences     []*Sequence
    SamplingParams *SamplingParams
    ArrivalTime   time.Time
    State         SequenceGroupState
}

type SequenceGroupState int

const (
    SGStateWaiting SequenceGroupState = iota
    SGStateRunning
    SGStateSwapped
    SGStateFinished
)

type Sequence struct {
    SeqID         int
    TokenIDs      []int32
    LogicalBlocks []int      // 逻辑块索引
    OutputLen     int
    Status        SequenceStatus
}

type SequenceStatus int

const (
    SeqWaiting SequenceStatus = iota
    SeqRunning
    SeqFinishedStopped
    SeqFinishedLengthCapped
    SeqFinishedAborted
)

// Schedule 执行一次调度
func (s *IterationScheduler) Schedule(
    waitingQueue []*SequenceGroup,
    runningQueue []*SequenceGroup,
    swappedQueue []*SequenceGroup,
    blockManager *BlockManager,
) *SchedulerOutput {

    output := &SchedulerOutput{
        ScheduledSeqs:   make([]*SequenceGroup, 0),
        PreemptedSeqs:   make([]*SequenceGroup, 0),
        BlocksToSwapIn:  make(map[int]int),
        BlocksToSwapOut: make(map[int]int),
        BlocksToCopy:    make(map[int]int),
    }

    // 阶段1: 处理正在运行的序列
    s.scheduleRunning(runningQueue, blockManager, output)

    // 阶段2: 处理被交换的序列
    s.scheduleSwapped(swappedQueue, blockManager, output)

    // 阶段3: 处理等待队列
    s.scheduleWaiting(waitingQueue, blockManager, output)

    return output
}

// scheduleRunning 调度运行中的序列
func (s *IterationScheduler) scheduleRunning(
    runningQueue []*SequenceGroup,
    blockManager *BlockManager,
    output *SchedulerOutput,
) {
    // 按到达时间排序(FCFS)
    sort.Slice(runningQueue, func(i, j int) bool {
        return runningQueue[i].ArrivalTime.Before(runningQueue[j].ArrivalTime)
    })

    budgetTokens := s.config.MaxBatchTokens
    budgetSeqs := s.config.MaxNumSeqs

    for _, seqGroup := range runningQueue {
        if seqGroup.State == SGStateFinished {
            continue
        }

        // 计算需要的资源
        numTokens := s.getNumTokens(seqGroup)
        numSeqs := len(seqGroup.Sequences)

        // 检查资源是否足够
        if numTokens > budgetTokens || numSeqs > budgetSeqs {
            // 需要抢占
            if s.preemptable {
                s.preempt(seqGroup, blockManager, output)
            }
            continue
        }

        // 尝试为新 token 分配块
        if !s.allocateNewBlocks(seqGroup, blockManager) {
            // 分配失败,需要抢占其他序列
            if s.preemptable {
                s.preemptLowerPriority(seqGroup, runningQueue, blockManager, output)
            }
            continue
        }

        // 调度成功
        output.ScheduledSeqs = append(output.ScheduledSeqs, seqGroup)
        budgetTokens -= numTokens
        budgetSeqs -= numSeqs
    }
}

// scheduleSwapped 调度被交换的序列
func (s *IterationScheduler) scheduleSwapped(
    swappedQueue []*SequenceGroup,
    blockManager *BlockManager,
    output *SchedulerOutput,
) {
    // 尝试换入被交换的序列
    for _, seqGroup := range swappedQueue {
        // 计算需要的 GPU 块
        requiredBlocks := s.getRequiredBlocks(seqGroup)

        // 检查是否有足够的空闲块
        if !blockManager.CanAllocate(requiredBlocks) {
            continue
        }

        // 执行换入
        swapIn := s.swapIn(seqGroup, blockManager)
        for src, dst := range swapIn {
            output.BlocksToSwapIn[src] = dst
        }

        seqGroup.State = SGStateRunning
        output.ScheduledSeqs = append(output.ScheduledSeqs, seqGroup)
    }
}

// scheduleWaiting 调度等待队列
func (s *IterationScheduler) scheduleWaiting(
    waitingQueue []*SequenceGroup,
    blockManager *BlockManager,
    output *SchedulerOutput,
) {
    budgetTokens := s.config.MaxPrefillTokens

    for _, seqGroup := range waitingQueue {
        // 计算 prefill 需要的 token 数
        prefillTokens := s.getPrefillTokens(seqGroup)

        if prefillTokens > budgetTokens {
            // 超出 prefill 预算
            break
        }

        // 尝试分配初始块
        requiredBlocks := (prefillTokens + s.config.BlockSize - 1) / s.config.BlockSize
        if !blockManager.CanAllocate(requiredBlocks) {
            // 块不足,尝试抢占
            if s.preemptable && s.preemptForPrefill(seqGroup, blockManager, output) {
                // 抢占成功,重试分配
                if !blockManager.CanAllocate(requiredBlocks) {
                    continue
                }
            } else {
                continue
            }
        }

        // 分配块
        blocks := blockManager.Allocate(requiredBlocks)
        s.assignBlocks(seqGroup, blocks)

        seqGroup.State = SGStateRunning
        output.ScheduledSeqs = append(output.ScheduledSeqs, seqGroup)
        budgetTokens -= prefillTokens
    }
}

// preempt 抢占序列
func (s *IterationScheduler) preempt(
    seqGroup *SequenceGroup,
    blockManager *BlockManager,
    output *SchedulerOutput,
) {
    switch s.config.PreemptionMode {
    case "recompute":
        // 重计算模式:释放所有块,稍后重新 prefill
        for _, seq := range seqGroup.Sequences {
            blockManager.Free(seq.LogicalBlocks)
            seq.LogicalBlocks = nil
        }
        seqGroup.State = SGStateWaiting

    case "swap":
        // 交换模式:将块换出到 CPU
        swapOut := s.swapOut(seqGroup, blockManager)
        for src, dst := range swapOut {
            output.BlocksToSwapOut[src] = dst
        }
        seqGroup.State = SGStateSwapped
    }

    output.PreemptedSeqs = append(output.PreemptedSeqs, seqGroup)
}

块管理器

package batching

import (
    "errors"
    "sync"
)

// BlockManager GPU 内存块管理器
type BlockManager struct {
    blockSize     int           // 每块的 token 数
    numGPUBlocks  int           // GPU 块总数
    numCPUBlocks  int           // CPU 块总数

    gpuAllocator  *BlockAllocator
    cpuAllocator  *BlockAllocator

    // 块映射:逻辑块 -> 物理块
    blockTables   map[int]map[int]int // seqID -> logicalBlock -> physicalBlock
    refCounts     map[int]int         // physicalBlock -> refCount

    mu sync.RWMutex
}

type BlockAllocator struct {
    freeBlocks  []int
    usedBlocks  map[int]bool
    totalBlocks int
    mu          sync.Mutex
}

func NewBlockManager(config *BlockManagerConfig) *BlockManager {
    gpuBlocks := config.GPUMemoryGB * 1024 * 1024 * 1024 /
                 (config.BlockSize * config.NumLayers * config.HeadDim * 2 * 2) // KV * fp16
    cpuBlocks := config.CPUSwapSpaceGB * 1024 * 1024 * 1024 /
                 (config.BlockSize * config.NumLayers * config.HeadDim * 2 * 2)

    return &BlockManager{
        blockSize:    config.BlockSize,
        numGPUBlocks: int(gpuBlocks),
        numCPUBlocks: int(cpuBlocks),
        gpuAllocator: NewBlockAllocator(int(gpuBlocks)),
        cpuAllocator: NewBlockAllocator(int(cpuBlocks)),
        blockTables:  make(map[int]map[int]int),
        refCounts:    make(map[int]int),
    }
}

func NewBlockAllocator(numBlocks int) *BlockAllocator {
    freeBlocks := make([]int, numBlocks)
    for i := 0; i < numBlocks; i++ {
        freeBlocks[i] = i
    }
    return &BlockAllocator{
        freeBlocks:  freeBlocks,
        usedBlocks:  make(map[int]bool),
        totalBlocks: numBlocks,
    }
}

// CanAllocate 检查是否可以分配
func (bm *BlockManager) CanAllocate(numBlocks int) bool {
    bm.mu.RLock()
    defer bm.mu.RUnlock()
    return len(bm.gpuAllocator.freeBlocks) >= numBlocks
}

// Allocate 分配块
func (bm *BlockManager) Allocate(numBlocks int) []int {
    bm.mu.Lock()
    defer bm.mu.Unlock()

    if len(bm.gpuAllocator.freeBlocks) < numBlocks {
        return nil
    }

    blocks := make([]int, numBlocks)
    for i := 0; i < numBlocks; i++ {
        block := bm.gpuAllocator.freeBlocks[0]
        bm.gpuAllocator.freeBlocks = bm.gpuAllocator.freeBlocks[1:]
        bm.gpuAllocator.usedBlocks[block] = true
        bm.refCounts[block] = 1
        blocks[i] = block
    }
    return blocks
}

// Free 释放块
func (bm *BlockManager) Free(blocks []int) {
    bm.mu.Lock()
    defer bm.mu.Unlock()

    for _, block := range blocks {
        bm.refCounts[block]--
        if bm.refCounts[block] <= 0 {
            delete(bm.gpuAllocator.usedBlocks, block)
            bm.gpuAllocator.freeBlocks = append(bm.gpuAllocator.freeBlocks, block)
            delete(bm.refCounts, block)
        }
    }
}

// AppendSlot 追加一个 slot(用于 decode)
func (bm *BlockManager) AppendSlot(seqID int) (int, int, error) {
    bm.mu.Lock()
    defer bm.mu.Unlock()

    blockTable, ok := bm.blockTables[seqID]
    if !ok {
        return -1, -1, errors.New("sequence not found")
    }

    // 计算当前位置
    numLogicalBlocks := len(blockTable)
    lastBlockIdx := numLogicalBlocks - 1
    lastPhysicalBlock := blockTable[lastBlockIdx]

    // 检查最后一个块是否有空间
    // 这里简化处理,实际需要跟踪每个块的使用情况

    // 如果需要新块
    if len(bm.gpuAllocator.freeBlocks) == 0 {
        return -1, -1, errors.New("no free blocks")
    }

    newBlock := bm.gpuAllocator.freeBlocks[0]
    bm.gpuAllocator.freeBlocks = bm.gpuAllocator.freeBlocks[1:]
    bm.gpuAllocator.usedBlocks[newBlock] = true
    bm.refCounts[newBlock] = 1

    blockTable[numLogicalBlocks] = newBlock

    return newBlock, 0, nil // 返回块索引和块内偏移
}

// ForkSequence 分叉序列(用于 beam search)
func (bm *BlockManager) ForkSequence(parentSeqID, childSeqID int) error {
    bm.mu.Lock()
    defer bm.mu.Unlock()

    parentTable, ok := bm.blockTables[parentSeqID]
    if !ok {
        return errors.New("parent sequence not found")
    }

    // 创建子序列的块表,使用 copy-on-write
    childTable := make(map[int]int)
    for logical, physical := range parentTable {
        childTable[logical] = physical
        bm.refCounts[physical]++ // 增加引用计数
    }

    bm.blockTables[childSeqID] = childTable
    return nil
}

// CopyOnWrite 写时复制
func (bm *BlockManager) CopyOnWrite(seqID, logicalBlock int) (int, error) {
    bm.mu.Lock()
    defer bm.mu.Unlock()

    blockTable := bm.blockTables[seqID]
    physicalBlock := blockTable[logicalBlock]

    // 如果只有一个引用,不需要复制
    if bm.refCounts[physicalBlock] == 1 {
        return physicalBlock, nil
    }

    // 分配新块
    if len(bm.gpuAllocator.freeBlocks) == 0 {
        return -1, errors.New("no free blocks for copy")
    }

    newBlock := bm.gpuAllocator.freeBlocks[0]
    bm.gpuAllocator.freeBlocks = bm.gpuAllocator.freeBlocks[1:]
    bm.gpuAllocator.usedBlocks[newBlock] = true
    bm.refCounts[newBlock] = 1

    // 减少原块引用
    bm.refCounts[physicalBlock]--

    // 更新块表
    blockTable[logicalBlock] = newBlock

    return newBlock, nil
}

// GetBlockTable 获取序列的块表
func (bm *BlockManager) GetBlockTable(seqID int) []int {
    bm.mu.RLock()
    defer bm.mu.RUnlock()

    table := bm.blockTables[seqID]
    result := make([]int, len(table))
    for logical, physical := range table {
        result[logical] = physical
    }
    return result
}

// GetFreeBlockCount 获取空闲块数量
func (bm *BlockManager) GetFreeBlockCount() int {
    bm.mu.RLock()
    defer bm.mu.RUnlock()
    return len(bm.gpuAllocator.freeBlocks)
}

Chunked Prefill

分块预填充实现

Chunked Prefill 将长序列的 prefill 阶段分成多个小块,避免阻塞 decode:

package batching

import (
    "context"
    "sync"
)

// ChunkedPrefillScheduler 分块 Prefill 调度器
type ChunkedPrefillScheduler struct {
    config        *ChunkedPrefillConfig
    prefillQueue  []*PrefillChunk
    decodeQueue   []*DecodeRequest
    mu            sync.Mutex
}

type ChunkedPrefillConfig struct {
    MaxPrefillTokens   int     // 每次调度最大 prefill token
    MaxDecodeTokens    int     // 每次调度最大 decode token
    ChunkSize          int     // Prefill 块大小
    PrefillDecodeRatio float64 // Prefill:Decode token 比例
}

type PrefillChunk struct {
    SeqGroup    *SequenceGroup
    StartPos    int
    EndPos      int
    IsLast      bool
}

type DecodeRequest struct {
    SeqGroup    *SequenceGroup
    Position    int
}

// ScheduleMixed 混合调度 Prefill 和 Decode
func (s *ChunkedPrefillScheduler) ScheduleMixed(
    waitingQueue []*SequenceGroup,
    runningQueue []*SequenceGroup,
    blockManager *BlockManager,
) *MixedScheduleOutput {

    s.mu.Lock()
    defer s.mu.Unlock()

    output := &MixedScheduleOutput{
        PrefillChunks: make([]*PrefillChunk, 0),
        DecodeSeqs:    make([]*SequenceGroup, 0),
    }

    prefillBudget := s.config.MaxPrefillTokens
    decodeBudget := s.config.MaxDecodeTokens

    // 优先处理已在运行的 decode 请求
    for _, seqGroup := range runningQueue {
        if seqGroup.State != SGStateRunning {
            continue
        }

        // 检查是否还在 prefill
        if !s.isPrefillComplete(seqGroup) {
            // 分配下一个 prefill chunk
            chunk := s.getNextPrefillChunk(seqGroup)
            chunkTokens := chunk.EndPos - chunk.StartPos

            if chunkTokens <= prefillBudget {
                output.PrefillChunks = append(output.PrefillChunks, chunk)
                prefillBudget -= chunkTokens
            }
            continue
        }

        // Decode 请求
        if decodeBudget > 0 {
            output.DecodeSeqs = append(output.DecodeSeqs, seqGroup)
            decodeBudget--
        }
    }

    // 处理等待队列中的新请求
    for _, seqGroup := range waitingQueue {
        // 计算首个 prefill chunk
        promptLen := s.getPromptLength(seqGroup)
        chunkSize := min(promptLen, s.config.ChunkSize)

        if chunkSize > prefillBudget {
            break
        }

        // 尝试分配初始块
        requiredBlocks := (chunkSize + blockManager.blockSize - 1) / blockManager.blockSize
        if !blockManager.CanAllocate(requiredBlocks) {
            continue
        }

        blocks := blockManager.Allocate(requiredBlocks)
        s.assignBlocks(seqGroup, blocks)

        chunk := &PrefillChunk{
            SeqGroup: seqGroup,
            StartPos: 0,
            EndPos:   chunkSize,
            IsLast:   chunkSize >= promptLen,
        }

        output.PrefillChunks = append(output.PrefillChunks, chunk)
        prefillBudget -= chunkSize
        seqGroup.State = SGStateRunning
    }

    return output
}

type MixedScheduleOutput struct {
    PrefillChunks []*PrefillChunk
    DecodeSeqs    []*SequenceGroup
}

// ExecuteMixed 执行混合批次
func (s *ChunkedPrefillScheduler) ExecuteMixed(
    ctx context.Context,
    output *MixedScheduleOutput,
    engine InferenceEngine,
) error {
    // 构建混合批次输入
    input := &MixedBatchInput{
        PrefillInputs: make([]*PrefillInput, 0),
        DecodeInputs:  make([]*DecodeInput, 0),
    }

    // 添加 Prefill 输入
    for _, chunk := range output.PrefillChunks {
        seq := chunk.SeqGroup.Sequences[0]
        prefillInput := &PrefillInput{
            SeqID:       seq.SeqID,
            TokenIDs:    seq.TokenIDs[chunk.StartPos:chunk.EndPos],
            StartPos:    chunk.StartPos,
            BlockTable:  s.getBlockTable(chunk.SeqGroup),
        }
        input.PrefillInputs = append(input.PrefillInputs, prefillInput)
    }

    // 添加 Decode 输入
    for _, seqGroup := range output.DecodeSeqs {
        for _, seq := range seqGroup.Sequences {
            decodeInput := &DecodeInput{
                SeqID:      seq.SeqID,
                LastToken:  seq.TokenIDs[len(seq.TokenIDs)-1],
                Position:   len(seq.TokenIDs),
                BlockTable: s.getBlockTable(seqGroup),
            }
            input.DecodeInputs = append(input.DecodeInputs, decodeInput)
        }
    }

    // 执行混合前向传播
    result, err := engine.ForwardMixed(ctx, input)
    if err != nil {
        return err
    }

    // 处理 Prefill 结果
    for i, chunk := range output.PrefillChunks {
        if chunk.IsLast {
            // Prefill 完成,获取首个生成 token
            token := result.PrefillOutputs[i].FirstToken
            seq := chunk.SeqGroup.Sequences[0]
            seq.TokenIDs = append(seq.TokenIDs, token)
        }
        // 更新进度
        s.updatePrefillProgress(chunk.SeqGroup, chunk.EndPos)
    }

    // 处理 Decode 结果
    decodeIdx := 0
    for _, seqGroup := range output.DecodeSeqs {
        for _, seq := range seqGroup.Sequences {
            token := result.DecodeOutputs[decodeIdx].NextToken
            seq.TokenIDs = append(seq.TokenIDs, token)
            decodeIdx++

            // 检查是否完成
            if s.isFinished(seq, token) {
                seq.Status = SeqFinishedStopped
            }
        }

        // 更新序列组状态
        if s.allSequencesFinished(seqGroup) {
            seqGroup.State = SGStateFinished
        }
    }

    return nil
}

自适应分块策略

package batching

import (
    "math"
    "time"
)

// AdaptiveChunker 自适应分块器
type AdaptiveChunker struct {
    config           *AdaptiveConfig
    latencyHistory   *LatencyHistory
    throughputTarget float64
}

type AdaptiveConfig struct {
    MinChunkSize     int
    MaxChunkSize     int
    TargetLatencyMs  float64
    LatencyWindowSize int
}

type LatencyHistory struct {
    prefillLatencies []float64  // ms/token
    decodeLatencies  []float64  // ms/token
    windowSize       int
}

// OptimalChunkSize 计算最优分块大小
func (ac *AdaptiveChunker) OptimalChunkSize(
    promptLength int,
    currentDecodeLoad int,
) int {
    // 获取历史延迟统计
    avgPrefillLatency := ac.latencyHistory.AveragePrefillLatency()
    avgDecodeLatency := ac.latencyHistory.AverageDecodeLatency()

    // 目标:保证 decode 请求延迟不超过阈值
    // decode 延迟 = prefill 时间 + decode 时间
    // prefill 时间 = chunk_size * prefill_latency
    // decode 等待时间 = chunk_size * prefill_latency

    // 计算 decode 可接受的等待时间
    targetWaitTime := ac.config.TargetLatencyMs - avgDecodeLatency
    if targetWaitTime <= 0 {
        return ac.config.MinChunkSize
    }

    // 计算最大可接受的 chunk size
    maxChunk := int(targetWaitTime / avgPrefillLatency)

    // 考虑当前 decode 负载
    // 高负载时减小 chunk size 以减少 decode 等待
    loadFactor := 1.0 - float64(currentDecodeLoad)/float64(ac.config.MaxDecodeLoad)
    adjustedMaxChunk := int(float64(maxChunk) * math.Max(0.2, loadFactor))

    // 边界约束
    optimalSize := min(
        max(adjustedMaxChunk, ac.config.MinChunkSize),
        ac.config.MaxChunkSize,
    )

    // 不超过 prompt 长度
    return min(optimalSize, promptLength)
}

// UpdateLatency 更新延迟统计
func (lh *LatencyHistory) UpdateLatency(isPrefill bool, tokens int, durationMs float64) {
    latencyPerToken := durationMs / float64(tokens)

    if isPrefill {
        lh.prefillLatencies = append(lh.prefillLatencies, latencyPerToken)
        if len(lh.prefillLatencies) > lh.windowSize {
            lh.prefillLatencies = lh.prefillLatencies[1:]
        }
    } else {
        lh.decodeLatencies = append(lh.decodeLatencies, latencyPerToken)
        if len(lh.decodeLatencies) > lh.windowSize {
            lh.decodeLatencies = lh.decodeLatencies[1:]
        }
    }
}

func (lh *LatencyHistory) AveragePrefillLatency() float64 {
    if len(lh.prefillLatencies) == 0 {
        return 1.0 // 默认值
    }
    sum := 0.0
    for _, l := range lh.prefillLatencies {
        sum += l
    }
    return sum / float64(len(lh.prefillLatencies))
}

func (lh *LatencyHistory) AverageDecodeLatency() float64 {
    if len(lh.decodeLatencies) == 0 {
        return 1.0 // 默认值
    }
    sum := 0.0
    for _, l := range lh.decodeLatencies {
        sum += l
    }
    return sum / float64(len(lh.decodeLatencies))
}

Speculative Decoding

推测解码实现

Speculative Decoding 使用小模型预测多个 token,然后由大模型并行验证:

package batching

import (
    "context"
    "sync"
)

// SpeculativeDecoder 推测解码器
type SpeculativeDecoder struct {
    targetModel  LargeModel     // 目标大模型
    draftModel   SmallModel     // 草稿小模型
    config       *SpecConfig
    acceptStats  *AcceptanceStats
}

type SpecConfig struct {
    NumSpecTokens    int     // 每次推测的 token 数
    AcceptThreshold  float64 // 接受阈值
    DraftTemperature float64 // 草稿模型温度
    TargetTemperature float64 // 目标模型温度
}

type AcceptanceStats struct {
    TotalTokens    int64
    AcceptedTokens int64
    mu             sync.Mutex
}

// SpeculativeStep 执行一步推测解码
func (sd *SpeculativeDecoder) SpeculativeStep(
    ctx context.Context,
    seqGroup *SequenceGroup,
) (*SpeculativeResult, error) {

    seq := seqGroup.Sequences[0]
    currentTokens := seq.TokenIDs

    // 阶段1: 草稿模型生成多个候选 token
    draftTokens, draftProbs, err := sd.generateDraft(ctx, currentTokens)
    if err != nil {
        return nil, err
    }

    // 阶段2: 目标模型并行验证所有候选
    targetProbs, err := sd.verifyWithTarget(ctx, currentTokens, draftTokens)
    if err != nil {
        return nil, err
    }

    // 阶段3: 使用 rejection sampling 确定接受的 token
    acceptedTokens := sd.rejectionSampling(draftTokens, draftProbs, targetProbs)

    // 更新统计
    sd.acceptStats.mu.Lock()
    sd.acceptStats.TotalTokens += int64(len(draftTokens))
    sd.acceptStats.AcceptedTokens += int64(len(acceptedTokens))
    sd.acceptStats.mu.Unlock()

    return &SpeculativeResult{
        AcceptedTokens: acceptedTokens,
        AcceptRate:     float64(len(acceptedTokens)) / float64(len(draftTokens)),
    }, nil
}

// generateDraft 使用草稿模型生成候选 token
func (sd *SpeculativeDecoder) generateDraft(
    ctx context.Context,
    inputTokens []int32,
) ([]int32, [][]float32, error) {

    draftTokens := make([]int32, 0, sd.config.NumSpecTokens)
    draftProbs := make([][]float32, 0, sd.config.NumSpecTokens)

    currentInput := inputTokens

    for i := 0; i < sd.config.NumSpecTokens; i++ {
        // 草稿模型前向传播
        logits, err := sd.draftModel.Forward(ctx, currentInput)
        if err != nil {
            return nil, nil, err
        }

        // 采样
        probs := softmax(logits, sd.config.DraftTemperature)
        token := sample(probs)

        draftTokens = append(draftTokens, token)
        draftProbs = append(draftProbs, probs)

        // 更新输入
        currentInput = append(currentInput, token)
    }

    return draftTokens, draftProbs, nil
}

// verifyWithTarget 目标模型验证
func (sd *SpeculativeDecoder) verifyWithTarget(
    ctx context.Context,
    inputTokens []int32,
    draftTokens []int32,
) ([][]float32, error) {

    // 构建完整序列
    fullSequence := make([]int32, len(inputTokens)+len(draftTokens))
    copy(fullSequence, inputTokens)
    copy(fullSequence[len(inputTokens):], draftTokens)

    // 目标模型并行计算所有位置的概率
    // 这是推测解码的关键优势:一次前向传播验证多个 token
    allLogits, err := sd.targetModel.Forward(ctx, fullSequence)
    if err != nil {
        return nil, err
    }

    // 提取验证位置的概率
    targetProbs := make([][]float32, len(draftTokens))
    for i := 0; i < len(draftTokens); i++ {
        pos := len(inputTokens) + i
        targetProbs[i] = softmax(allLogits[pos], sd.config.TargetTemperature)
    }

    return targetProbs, nil
}

// rejectionSampling 拒绝采样确定接受的 token
func (sd *SpeculativeDecoder) rejectionSampling(
    draftTokens []int32,
    draftProbs [][]float32,
    targetProbs [][]float32,
) []int32 {

    acceptedTokens := make([]int32, 0, len(draftTokens))

    for i, token := range draftTokens {
        // 获取该 token 在两个分布下的概率
        pDraft := draftProbs[i][token]
        pTarget := targetProbs[i][token]

        // 接受概率 = min(1, p_target / p_draft)
        acceptProb := min(1.0, float64(pTarget)/float64(pDraft))

        // 随机决定是否接受
        if randomFloat() < acceptProb {
            acceptedTokens = append(acceptedTokens, token)
        } else {
            // 拒绝后,从修正分布中采样
            correctedProbs := sd.computeCorrectedDistribution(
                targetProbs[i], draftProbs[i],
            )
            newToken := sample(correctedProbs)
            acceptedTokens = append(acceptedTokens, newToken)
            break // 后续 token 作废
        }
    }

    return acceptedTokens
}

// computeCorrectedDistribution 计算修正分布
func (sd *SpeculativeDecoder) computeCorrectedDistribution(
    targetProbs, draftProbs []float32,
) []float32 {

    corrected := make([]float32, len(targetProbs))
    var sum float32 = 0

    for i := range corrected {
        // p_corrected = max(0, p_target - p_draft)
        diff := targetProbs[i] - draftProbs[i]
        if diff > 0 {
            corrected[i] = diff
            sum += diff
        }
    }

    // 归一化
    if sum > 0 {
        for i := range corrected {
            corrected[i] /= sum
        }
    } else {
        // 如果修正分布为零,使用目标分布
        copy(corrected, targetProbs)
    }

    return corrected
}

type SpeculativeResult struct {
    AcceptedTokens []int32
    AcceptRate     float64
}

// 辅助函数
func softmax(logits []float32, temperature float64) []float32 {
    probs := make([]float32, len(logits))
    var maxLogit float32 = logits[0]
    for _, l := range logits {
        if l > maxLogit {
            maxLogit = l
        }
    }

    var sum float32 = 0
    for i, l := range logits {
        probs[i] = float32(math.Exp(float64((l - maxLogit) / float32(temperature))))
        sum += probs[i]
    }

    for i := range probs {
        probs[i] /= sum
    }
    return probs
}

func sample(probs []float32) int32 {
    r := randomFloat()
    var cumSum float32 = 0
    for i, p := range probs {
        cumSum += p
        if float32(r) < cumSum {
            return int32(i)
        }
    }
    return int32(len(probs) - 1)
}

自适应推测长度

package batching

import (
    "math"
)

// AdaptiveSpeculator 自适应推测器
type AdaptiveSpeculator struct {
    baseSpecLen    int
    minSpecLen     int
    maxSpecLen     int
    acceptHistory  *RingBuffer
    windowSize     int
}

// NewAdaptiveSpeculator 创建自适应推测器
func NewAdaptiveSpeculator(config *AdaptiveSpecConfig) *AdaptiveSpeculator {
    return &AdaptiveSpeculator{
        baseSpecLen:   config.BaseSpecLen,
        minSpecLen:    config.MinSpecLen,
        maxSpecLen:    config.MaxSpecLen,
        acceptHistory: NewRingBuffer(config.WindowSize),
        windowSize:    config.WindowSize,
    }
}

// OptimalSpecLength 计算最优推测长度
func (as *AdaptiveSpeculator) OptimalSpecLength() int {
    avgAcceptRate := as.acceptHistory.Average()

    if avgAcceptRate < 0.001 {
        return as.baseSpecLen
    }

    // 基于接受率计算期望加速比
    // 期望生成 token 数 = sum(accept_rate^i) for i in 0..spec_len
    // 加速比 = 期望生成数 / (1 + draft_cost)

    bestLen := as.minSpecLen
    bestSpeedup := 0.0

    draftCostRatio := 0.1 // 草稿模型成本约为目标模型的 10%

    for specLen := as.minSpecLen; specLen <= as.maxSpecLen; specLen++ {
        expectedTokens := 0.0
        for i := 0; i <= specLen; i++ {
            expectedTokens += math.Pow(avgAcceptRate, float64(i))
        }

        // 考虑验证成本(一次前向传播)和草稿成本
        totalCost := 1.0 + float64(specLen)*draftCostRatio
        speedup := expectedTokens / totalCost

        if speedup > bestSpeedup {
            bestSpeedup = speedup
            bestLen = specLen
        }
    }

    return bestLen
}

// UpdateAcceptRate 更新接受率
func (as *AdaptiveSpeculator) UpdateAcceptRate(accepted, total int) {
    if total > 0 {
        rate := float64(accepted) / float64(total)
        as.acceptHistory.Push(rate)
    }
}

// RingBuffer 环形缓冲区
type RingBuffer struct {
    data  []float64
    head  int
    size  int
    count int
}

func NewRingBuffer(size int) *RingBuffer {
    return &RingBuffer{
        data: make([]float64, size),
        size: size,
    }
}

func (rb *RingBuffer) Push(value float64) {
    rb.data[rb.head] = value
    rb.head = (rb.head + 1) % rb.size
    if rb.count < rb.size {
        rb.count++
    }
}

func (rb *RingBuffer) Average() float64 {
    if rb.count == 0 {
        return 0
    }
    sum := 0.0
    for i := 0; i < rb.count; i++ {
        sum += rb.data[i]
    }
    return sum / float64(rb.count)
}

批处理性能优化

内存优化策略

package batching

import (
    "runtime"
    "sync"
    "unsafe"
)

// BatchMemoryManager 批处理内存管理器
type BatchMemoryManager struct {
    tokenPool     *sync.Pool
    logitPool     *sync.Pool
    attentionPool *sync.Pool
    config        *MemoryConfig
}

type MemoryConfig struct {
    MaxBatchSize    int
    MaxSeqLen       int
    VocabSize       int
    HiddenSize      int
    NumHeads        int
    PreallocateMB   int
}

func NewBatchMemoryManager(config *MemoryConfig) *BatchMemoryManager {
    bmm := &BatchMemoryManager{config: config}

    // Token ID 池
    bmm.tokenPool = &sync.Pool{
        New: func() interface{} {
            return make([]int32, 0, config.MaxBatchSize*config.MaxSeqLen)
        },
    }

    // Logits 池
    bmm.logitPool = &sync.Pool{
        New: func() interface{} {
            return make([]float32, 0, config.MaxBatchSize*config.VocabSize)
        },
    }

    // Attention 中间结果池
    bmm.attentionPool = &sync.Pool{
        New: func() interface{} {
            size := config.MaxBatchSize * config.NumHeads * config.MaxSeqLen * config.MaxSeqLen
            return make([]float32, 0, size)
        },
    }

    // 预热池
    bmm.warmupPools()

    return bmm
}

func (bmm *BatchMemoryManager) warmupPools() {
    // 预分配对象减少运行时分配
    for i := 0; i < 10; i++ {
        tokens := bmm.tokenPool.Get().([]int32)
        bmm.tokenPool.Put(tokens[:0])

        logits := bmm.logitPool.Get().([]float32)
        bmm.logitPool.Put(logits[:0])

        attn := bmm.attentionPool.Get().([]float32)
        bmm.attentionPool.Put(attn[:0])
    }
}

// AllocateTokenBuffer 分配 token 缓冲区
func (bmm *BatchMemoryManager) AllocateTokenBuffer(size int) []int32 {
    buf := bmm.tokenPool.Get().([]int32)
    if cap(buf) < size {
        buf = make([]int32, size)
    } else {
        buf = buf[:size]
    }
    return buf
}

// ReleaseTokenBuffer 释放 token 缓冲区
func (bmm *BatchMemoryManager) ReleaseTokenBuffer(buf []int32) {
    bmm.tokenPool.Put(buf[:0])
}

// BatchBuffer 批处理缓冲区
type BatchBuffer struct {
    InputIDs      []int32
    AttentionMask []int32
    PositionIDs   []int32
    BlockTables   [][]int
    SeqLens       []int

    // 输出缓冲
    Logits        []float32
    NextTokens    []int32

    // 元数据
    BatchSize     int
    MaxSeqLen     int

    pool          *BatchMemoryManager
}

func (bmm *BatchMemoryManager) NewBatchBuffer(batchSize, maxSeqLen int) *BatchBuffer {
    totalTokens := batchSize * maxSeqLen

    return &BatchBuffer{
        InputIDs:      bmm.AllocateTokenBuffer(totalTokens),
        AttentionMask: bmm.AllocateTokenBuffer(totalTokens),
        PositionIDs:   bmm.AllocateTokenBuffer(totalTokens),
        BlockTables:   make([][]int, batchSize),
        SeqLens:       make([]int, batchSize),
        Logits:        bmm.logitPool.Get().([]float32)[:batchSize*bmm.config.VocabSize],
        NextTokens:    make([]int32, batchSize),
        BatchSize:     batchSize,
        MaxSeqLen:     maxSeqLen,
        pool:          bmm,
    }
}

func (bb *BatchBuffer) Release() {
    bb.pool.ReleaseTokenBuffer(bb.InputIDs)
    bb.pool.ReleaseTokenBuffer(bb.AttentionMask)
    bb.pool.ReleaseTokenBuffer(bb.PositionIDs)
    bb.pool.logitPool.Put(bb.Logits[:0])
}

// ZeroCopyBatch 零拷贝批处理
type ZeroCopyBatch struct {
    requests []*ActiveRequest
    views    []*TensorView
}

type TensorView struct {
    Data   unsafe.Pointer
    Shape  []int
    Stride []int
    Offset int
}

// CreateBatchView 创建批次视图(避免数据拷贝)
func CreateBatchView(requests []*ActiveRequest) *ZeroCopyBatch {
    batch := &ZeroCopyBatch{
        requests: requests,
        views:    make([]*TensorView, len(requests)),
    }

    for i, req := range requests {
        // 直接引用请求的数据,不复制
        batch.views[i] = &TensorView{
            Data:   unsafe.Pointer(&req.InputTokens[0]),
            Shape:  []int{len(req.InputTokens)},
            Stride: []int{1},
            Offset: 0,
        }
    }

    return batch
}

并发优化

package batching

import (
    "context"
    "runtime"
    "sync"
)

// ParallelBatchProcessor 并行批处理器
type ParallelBatchProcessor struct {
    numWorkers    int
    workerPool    chan struct{}
    preprocessCh  chan *PreprocessTask
    postprocessCh chan *PostprocessTask
    wg            sync.WaitGroup
}

type PreprocessTask struct {
    Request  *InferenceRequest
    ResultCh chan *PreprocessedInput
}

type PostprocessTask struct {
    Output   *RawOutput
    Request  *InferenceRequest
    ResultCh chan *InferenceResult
}

type PreprocessedInput struct {
    TokenIDs     []int32
    AttentionMask []int32
}

type RawOutput struct {
    Logits []float32
}

func NewParallelBatchProcessor(numWorkers int) *ParallelBatchProcessor {
    if numWorkers <= 0 {
        numWorkers = runtime.NumCPU()
    }

    pbp := &ParallelBatchProcessor{
        numWorkers:    numWorkers,
        workerPool:    make(chan struct{}, numWorkers),
        preprocessCh:  make(chan *PreprocessTask, 1000),
        postprocessCh: make(chan *PostprocessTask, 1000),
    }

    // 启动预处理 worker
    for i := 0; i < numWorkers; i++ {
        go pbp.preprocessWorker()
    }

    // 启动后处理 worker
    for i := 0; i < numWorkers; i++ {
        go pbp.postprocessWorker()
    }

    return pbp
}

func (pbp *ParallelBatchProcessor) preprocessWorker() {
    for task := range pbp.preprocessCh {
        result := pbp.doPreprocess(task.Request)
        task.ResultCh <- result
    }
}

func (pbp *ParallelBatchProcessor) postprocessWorker() {
    for task := range pbp.postprocessCh {
        result := pbp.doPostprocess(task.Output, task.Request)
        task.ResultCh <- result
    }
}

func (pbp *ParallelBatchProcessor) doPreprocess(req *InferenceRequest) *PreprocessedInput {
    // Tokenization
    tokens := tokenize(req.Prompt)

    // 创建 attention mask
    mask := make([]int32, len(tokens))
    for i := range mask {
        mask[i] = 1
    }

    return &PreprocessedInput{
        TokenIDs:      tokens,
        AttentionMask: mask,
    }
}

func (pbp *ParallelBatchProcessor) doPostprocess(
    output *RawOutput,
    req *InferenceRequest,
) *InferenceResult {
    // 采样
    token := sampleFromLogits(output.Logits, req.Temperature, req.TopP)

    // 解码
    text := decode([]int32{token})

    return &InferenceResult{
        RequestID:    req.ID,
        OutputTokens: []int32{token},
        Text:         text,
    }
}

// PreprocessBatch 并行预处理批次
func (pbp *ParallelBatchProcessor) PreprocessBatch(
    ctx context.Context,
    requests []*InferenceRequest,
) []*PreprocessedInput {

    results := make([]*PreprocessedInput, len(requests))
    resultChs := make([]chan *PreprocessedInput, len(requests))

    // 提交所有任务
    for i, req := range requests {
        resultChs[i] = make(chan *PreprocessedInput, 1)
        pbp.preprocessCh <- &PreprocessTask{
            Request:  req,
            ResultCh: resultChs[i],
        }
    }

    // 收集结果
    for i, ch := range resultChs {
        select {
        case result := <-ch:
            results[i] = result
        case <-ctx.Done():
            return nil
        }
    }

    return results
}

// PostprocessBatch 并行后处理批次
func (pbp *ParallelBatchProcessor) PostprocessBatch(
    ctx context.Context,
    outputs []*RawOutput,
    requests []*InferenceRequest,
) []*InferenceResult {

    results := make([]*InferenceResult, len(outputs))
    resultChs := make([]chan *InferenceResult, len(outputs))

    // 提交所有任务
    for i := range outputs {
        resultChs[i] = make(chan *InferenceResult, 1)
        pbp.postprocessCh <- &PostprocessTask{
            Output:   outputs[i],
            Request:  requests[i],
            ResultCh: resultChs[i],
        }
    }

    // 收集结果
    for i, ch := range resultChs {
        select {
        case result := <-ch:
            results[i] = result
        case <-ctx.Done():
            return nil
        }
    }

    return results
}

监控与调优

批处理指标

package batching

import (
    "sync/atomic"
    "time"

    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

// BatchMetrics 批处理指标收集器
type BatchMetrics struct {
    // 请求指标
    requestsTotal      prometheus.Counter
    requestsQueued     prometheus.Gauge
    requestsRunning    prometheus.Gauge
    requestLatency     prometheus.Histogram

    // 批次指标
    batchSize          prometheus.Histogram
    batchTokens        prometheus.Histogram
    batchUtilization   prometheus.Gauge
    iterationLatency   prometheus.Histogram

    // 调度指标
    schedulerOverhead  prometheus.Histogram
    preemptionCount    prometheus.Counter
    swapInCount        prometheus.Counter
    swapOutCount       prometheus.Counter

    // 吞吐量指标
    tokensGenerated    prometheus.Counter
    prefillTokens      prometheus.Counter
    decodeTokens       prometheus.Counter

    // 推测解码指标
    specAcceptRate     prometheus.Gauge
    specSpeedup        prometheus.Gauge
}

func NewBatchMetrics(namespace string) *BatchMetrics {
    return &BatchMetrics{
        requestsTotal: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "requests_total",
            Help:      "Total number of inference requests",
        }),
        requestsQueued: promauto.NewGauge(prometheus.GaugeOpts{
            Namespace: namespace,
            Name:      "requests_queued",
            Help:      "Number of requests in queue",
        }),
        requestsRunning: promauto.NewGauge(prometheus.GaugeOpts{
            Namespace: namespace,
            Name:      "requests_running",
            Help:      "Number of requests currently running",
        }),
        requestLatency: promauto.NewHistogram(prometheus.HistogramOpts{
            Namespace: namespace,
            Name:      "request_latency_seconds",
            Help:      "Request end-to-end latency",
            Buckets:   []float64{0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30},
        }),
        batchSize: promauto.NewHistogram(prometheus.HistogramOpts{
            Namespace: namespace,
            Name:      "batch_size",
            Help:      "Number of sequences per batch",
            Buckets:   []float64{1, 2, 4, 8, 16, 32, 64, 128, 256},
        }),
        batchTokens: promauto.NewHistogram(prometheus.HistogramOpts{
            Namespace: namespace,
            Name:      "batch_tokens",
            Help:      "Total tokens per batch",
            Buckets:   []float64{100, 500, 1000, 2000, 4000, 8000, 16000},
        }),
        batchUtilization: promauto.NewGauge(prometheus.GaugeOpts{
            Namespace: namespace,
            Name:      "batch_utilization",
            Help:      "Batch slot utilization ratio",
        }),
        iterationLatency: promauto.NewHistogram(prometheus.HistogramOpts{
            Namespace: namespace,
            Name:      "iteration_latency_ms",
            Help:      "Single decode iteration latency",
            Buckets:   []float64{1, 2, 5, 10, 20, 50, 100, 200},
        }),
        schedulerOverhead: promauto.NewHistogram(prometheus.HistogramOpts{
            Namespace: namespace,
            Name:      "scheduler_overhead_us",
            Help:      "Scheduler overhead per iteration",
            Buckets:   []float64{10, 50, 100, 500, 1000, 5000},
        }),
        preemptionCount: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "preemption_total",
            Help:      "Total number of preemptions",
        }),
        swapInCount: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "swap_in_total",
            Help:      "Total number of swap-in operations",
        }),
        swapOutCount: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "swap_out_total",
            Help:      "Total number of swap-out operations",
        }),
        tokensGenerated: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "tokens_generated_total",
            Help:      "Total tokens generated",
        }),
        prefillTokens: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "prefill_tokens_total",
            Help:      "Total prefill tokens processed",
        }),
        decodeTokens: promauto.NewCounter(prometheus.CounterOpts{
            Namespace: namespace,
            Name:      "decode_tokens_total",
            Help:      "Total decode tokens generated",
        }),
        specAcceptRate: promauto.NewGauge(prometheus.GaugeOpts{
            Namespace: namespace,
            Name:      "speculative_accept_rate",
            Help:      "Speculative decoding acceptance rate",
        }),
        specSpeedup: promauto.NewGauge(prometheus.GaugeOpts{
            Namespace: namespace,
            Name:      "speculative_speedup",
            Help:      "Speculative decoding speedup ratio",
        }),
    }
}

// BatchProfiler 批处理性能分析器
type BatchProfiler struct {
    metrics      *BatchMetrics
    startTimes   map[string]time.Time
    mu           sync.RWMutex

    // 统计数据
    totalBatches     int64
    totalIterations  int64
    totalLatencyUs   int64
}

func NewBatchProfiler(metrics *BatchMetrics) *BatchProfiler {
    return &BatchProfiler{
        metrics:    metrics,
        startTimes: make(map[string]time.Time),
    }
}

// RecordBatchStart 记录批次开始
func (bp *BatchProfiler) RecordBatchStart(batchID string, batchSize, totalTokens int) {
    bp.mu.Lock()
    bp.startTimes[batchID] = time.Now()
    bp.mu.Unlock()

    bp.metrics.batchSize.Observe(float64(batchSize))
    bp.metrics.batchTokens.Observe(float64(totalTokens))
}

// RecordBatchEnd 记录批次结束
func (bp *BatchProfiler) RecordBatchEnd(batchID string) {
    bp.mu.Lock()
    startTime, ok := bp.startTimes[batchID]
    if ok {
        delete(bp.startTimes, batchID)
    }
    bp.mu.Unlock()

    if ok {
        latency := time.Since(startTime).Milliseconds()
        bp.metrics.iterationLatency.Observe(float64(latency))
        atomic.AddInt64(&bp.totalLatencyUs, latency*1000)
        atomic.AddInt64(&bp.totalIterations, 1)
    }
}

// RecordSchedulerOverhead 记录调度器开销
func (bp *BatchProfiler) RecordSchedulerOverhead(duration time.Duration) {
    bp.metrics.schedulerOverhead.Observe(float64(duration.Microseconds()))
}

// GetAverageLatency 获取平均延迟
func (bp *BatchProfiler) GetAverageLatency() float64 {
    iterations := atomic.LoadInt64(&bp.totalIterations)
    if iterations == 0 {
        return 0
    }
    totalUs := atomic.LoadInt64(&bp.totalLatencyUs)
    return float64(totalUs) / float64(iterations) / 1000.0 // 返回毫秒
}

// ThroughputCalculator 吞吐量计算器
type ThroughputCalculator struct {
    windowSize    time.Duration
    tokenCounts   []TokenCount
    mu            sync.RWMutex
}

type TokenCount struct {
    Timestamp time.Time
    Count     int64
}

func NewThroughputCalculator(windowSize time.Duration) *ThroughputCalculator {
    return &ThroughputCalculator{
        windowSize:  windowSize,
        tokenCounts: make([]TokenCount, 0),
    }
}

func (tc *ThroughputCalculator) RecordTokens(count int64) {
    tc.mu.Lock()
    defer tc.mu.Unlock()

    now := time.Now()
    tc.tokenCounts = append(tc.tokenCounts, TokenCount{
        Timestamp: now,
        Count:     count,
    })

    // 清理过期数据
    cutoff := now.Add(-tc.windowSize)
    validIdx := 0
    for i, tc := range tc.tokenCounts {
        if tc.Timestamp.After(cutoff) {
            validIdx = i
            break
        }
    }
    tc.tokenCounts = tc.tokenCounts[validIdx:]
}

func (tc *ThroughputCalculator) GetThroughput() float64 {
    tc.mu.RLock()
    defer tc.mu.RUnlock()

    if len(tc.tokenCounts) < 2 {
        return 0
    }

    var totalTokens int64
    for _, tc := range tc.tokenCounts {
        totalTokens += tc.Count
    }

    duration := tc.tokenCounts[len(tc.tokenCounts)-1].Timestamp.Sub(
        tc.tokenCounts[0].Timestamp,
    ).Seconds()

    if duration == 0 {
        return 0
    }

    return float64(totalTokens) / duration
}

性能调优建议

# 批处理调优配置示例
batching:
  # 基础配置
  max_batch_size: 256        # 最大批次大小
  max_tokens_per_batch: 8192 # 批次最大 token 数
  max_waiting_requests: 1000 # 最大等待请求数

  # 调度配置
  scheduling_policy: "priority"  # fcfs, sjf, priority
  max_wait_time_ms: 100         # 最大等待时间
  scheduler_interval_ms: 1      # 调度间隔

  # Prefill 配置
  prefill_chunk_size: 512       # Prefill 分块大小
  max_prefill_tokens: 2048      # 每次调度最大 prefill token

  # 抢占配置
  enable_preemption: true
  preemption_mode: "swap"       # recompute, swap
  swap_space_gb: 16

  # 推测解码配置
  speculative_decoding:
    enabled: true
    num_spec_tokens: 5
    draft_model: "distill-llama-68m"
    accept_threshold: 0.7
    adaptive_length: true

# 调优建议
tuning_guide:
  high_throughput:
    description: "优化吞吐量场景"
    settings:
      max_batch_size: 512
      prefill_chunk_size: 1024
      scheduling_policy: "sjf"
      enable_preemption: true

  low_latency:
    description: "优化延迟场景"
    settings:
      max_batch_size: 32
      prefill_chunk_size: 256
      scheduling_policy: "fcfs"
      max_wait_time_ms: 10

  memory_constrained:
    description: "内存受限场景"
    settings:
      max_batch_size: 64
      max_tokens_per_batch: 4096
      preemption_mode: "swap"
      swap_space_gb: 32

小结

本章深入探讨了动态批处理的核心技术:

  1. 连续批处理:允许请求在任意 token 位置加入或离开批次,充分利用 GPU 资源
  2. 迭代级调度:细粒度控制每个 decode 迭代的批次组成,支持抢占和交换
  3. Chunked Prefill:分块处理长序列的 prefill 阶段,避免阻塞短序列的 decode
  4. 推测解码:使用小模型预测 + 大模型验证,提升生成速度
  5. 内存优化:对象池、零拷贝、块管理等技术减少内存开销
  6. 监控调优:完善的指标体系和性能分析工具

动态批处理是大模型推理系统的核心组件,直接决定了系统的吞吐量和延迟。合理的批处理策略能够在保证服务质量的前提下,最大化 GPU 利用率。

下一章我们将探讨 推理优化技术,讲解如何通过模型优化和系统优化进一步提升推理性能。

Prev
模型服务框架
Next
推理优化技术