动态批处理
概述
动态批处理(Dynamic Batching)是提升大模型推理吞吐量的核心技术。与传统的静态批处理不同,动态批处理能够在运行时根据请求特性自适应地组织批次,充分利用 GPU 并行计算能力。本章深入讲解动态批处理的原理、实现策略以及在生产环境中的优化实践。
批处理基础
为什么需要批处理
GPU 的设计特点决定了批处理的必要性:
单请求处理 vs 批处理对比:
单请求:
┌─────────────────────────────────────────────────────────────┐
│ GPU 利用率: 10-20% │
│ ┌─────┐ │
│ │Req 1│────────────────────────────────────────────────────►│
│ └─────┘ │
│ 大量 GPU 核心空闲 │
└─────────────────────────────────────────────────────────────┘
批处理 (batch_size=8):
┌─────────────────────────────────────────────────────────────┐
│ GPU 利用率: 70-90% │
│ ┌─────┐ │
│ │Req 1│──┐ │
│ ├─────┤ │ │
│ │Req 2│ ├──────────────────────────────────────────────► │
│ ├─────┤ │ │
│ │ ... │ │ │
│ ├─────┤ │ │
│ │Req 8│──┘ │
│ └─────┘ │
│ GPU 核心充分利用 │
└─────────────────────────────────────────────────────────────┘
静态批处理的局限
package batching
import (
"context"
"sync"
"time"
)
// StaticBatcher 静态批处理器
// 问题:等待固定数量请求,延迟不可控
type StaticBatcher struct {
batchSize int
requests []*InferenceRequest
mu sync.Mutex
batchReady chan []*InferenceRequest
}
type InferenceRequest struct {
ID string
Prompt string
InputTokens []int32
MaxNewTokens int
Temperature float32
TopP float32
ResultChan chan *InferenceResult
ArrivalTime time.Time
}
type InferenceResult struct {
RequestID string
OutputTokens []int32
Text string
FinishReason string
Latency time.Duration
Error error
}
func NewStaticBatcher(batchSize int) *StaticBatcher {
return &StaticBatcher{
batchSize: batchSize,
requests: make([]*InferenceRequest, 0, batchSize),
batchReady: make(chan []*InferenceRequest, 10),
}
}
func (b *StaticBatcher) Submit(req *InferenceRequest) {
b.mu.Lock()
defer b.mu.Unlock()
req.ArrivalTime = time.Now()
b.requests = append(b.requests, req)
// 问题1: 必须等待足够请求
if len(b.requests) >= b.batchSize {
batch := make([]*InferenceRequest, len(b.requests))
copy(batch, b.requests)
b.requests = b.requests[:0]
b.batchReady <- batch
}
// 问题2: 低流量时请求可能永远等待
}
// 静态批处理的问题:
// 1. 序列长度不一致:短序列需要等待长序列完成
// 2. 延迟不可预测:取决于批次中最长序列
// 3. 资源利用不均:padding 浪费计算资源
// 4. 无法动态调整:批次大小固定
动态批处理架构
连续批处理 (Continuous Batching)
连续批处理是现代推理引擎的核心技术,允许请求在任意 token 位置加入或离开批次:
package batching
import (
"container/heap"
"context"
"sync"
"sync/atomic"
"time"
)
// ContinuousBatcher 连续批处理器
type ContinuousBatcher struct {
config *BatcherConfig
waitingQueue *PriorityQueue // 等待队列
runningBatch *RunningBatch // 运行中的批次
scheduler *BatchScheduler // 调度器
engine InferenceEngine // 推理引擎
mu sync.RWMutex
stats *BatcherStats
stopCh chan struct{}
}
type BatcherConfig struct {
MaxBatchSize int // 最大批次大小
MaxWaitingRequests int // 最大等待请求数
MaxTokensPerBatch int // 批次最大 token 数
MaxWaitTime time.Duration // 最大等待时间
PrefillChunkSize int // Prefill 分块大小
SchedulingPolicy string // 调度策略: fcfs, sjf, priority
}
type RunningBatch struct {
requests map[string]*ActiveRequest
totalTokens int64
mu sync.RWMutex
}
type ActiveRequest struct {
*InferenceRequest
State RequestState
GeneratedTokens []int32
KVCacheBlocks []int // KV Cache 块索引
CurrentStep int
StartTime time.Time
PrefillDone bool
}
type RequestState int
const (
StateWaiting RequestState = iota
StatePrefill
StateDecoding
StateFinished
)
func NewContinuousBatcher(config *BatcherConfig, engine InferenceEngine) *ContinuousBatcher {
cb := &ContinuousBatcher{
config: config,
waitingQueue: NewPriorityQueue(),
runningBatch: &RunningBatch{
requests: make(map[string]*ActiveRequest),
},
scheduler: NewBatchScheduler(config),
engine: engine,
stats: &BatcherStats{},
stopCh: make(chan struct{}),
}
return cb
}
// Submit 提交推理请求
func (cb *ContinuousBatcher) Submit(ctx context.Context, req *InferenceRequest) (*InferenceResult, error) {
req.ArrivalTime = time.Now()
req.ResultChan = make(chan *InferenceResult, 1)
// 检查队列容量
cb.mu.Lock()
if cb.waitingQueue.Len() >= cb.config.MaxWaitingRequests {
cb.mu.Unlock()
return nil, ErrQueueFull
}
// 加入等待队列
heap.Push(cb.waitingQueue, &QueueItem{
Request: req,
Priority: cb.calculatePriority(req),
})
cb.mu.Unlock()
atomic.AddInt64(&cb.stats.TotalRequests, 1)
// 等待结果
select {
case result := <-req.ResultChan:
return result, nil
case <-ctx.Done():
cb.cancelRequest(req.ID)
return nil, ctx.Err()
}
}
// Run 运行批处理循环
func (cb *ContinuousBatcher) Run(ctx context.Context) error {
ticker := time.NewTicker(time.Millisecond) // 1ms 调度间隔
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-cb.stopCh:
return nil
case <-ticker.C:
cb.step()
}
}
}
// step 执行一个调度步骤
func (cb *ContinuousBatcher) step() {
cb.mu.Lock()
defer cb.mu.Unlock()
// 1. 移除已完成的请求
cb.evictFinishedRequests()
// 2. 调度新请求加入批次
cb.scheduleNewRequests()
// 3. 如果批次非空,执行一步推理
if cb.runningBatch.Size() > 0 {
cb.executeStep()
}
}
// evictFinishedRequests 移除已完成的请求
func (cb *ContinuousBatcher) evictFinishedRequests() {
cb.runningBatch.mu.Lock()
defer cb.runningBatch.mu.Unlock()
for id, req := range cb.runningBatch.requests {
if req.State == StateFinished {
// 释放 KV Cache
cb.engine.ReleaseKVCache(req.KVCacheBlocks)
// 发送结果
result := &InferenceResult{
RequestID: id,
OutputTokens: req.GeneratedTokens,
Text: cb.engine.Decode(req.GeneratedTokens),
FinishReason: cb.getFinishReason(req),
Latency: time.Since(req.StartTime),
}
req.ResultChan <- result
delete(cb.runningBatch.requests, id)
atomic.AddInt64(&cb.stats.CompletedRequests, 1)
}
}
}
// scheduleNewRequests 调度新请求
func (cb *ContinuousBatcher) scheduleNewRequests() {
for cb.waitingQueue.Len() > 0 {
// 检查批次容量
if cb.runningBatch.Size() >= cb.config.MaxBatchSize {
break
}
// 检查 token 容量
item := cb.waitingQueue.Peek()
requiredTokens := len(item.Request.InputTokens) + item.Request.MaxNewTokens
if cb.runningBatch.totalTokens+int64(requiredTokens) > int64(cb.config.MaxTokensPerBatch) {
break
}
// 分配 KV Cache
blocks, err := cb.engine.AllocateKVCache(requiredTokens)
if err != nil {
break // KV Cache 不足
}
// 从等待队列取出
heap.Pop(cb.waitingQueue)
// 创建活跃请求
activeReq := &ActiveRequest{
InferenceRequest: item.Request,
State: StatePrefill,
GeneratedTokens: make([]int32, 0, item.Request.MaxNewTokens),
KVCacheBlocks: blocks,
StartTime: time.Now(),
}
cb.runningBatch.mu.Lock()
cb.runningBatch.requests[item.Request.ID] = activeReq
cb.runningBatch.totalTokens += int64(requiredTokens)
cb.runningBatch.mu.Unlock()
atomic.AddInt64(&cb.stats.ScheduledRequests, 1)
}
}
// executeStep 执行一步推理
func (cb *ContinuousBatcher) executeStep() {
cb.runningBatch.mu.Lock()
defer cb.runningBatch.mu.Unlock()
// 分离 prefill 和 decode 请求
prefillReqs := make([]*ActiveRequest, 0)
decodeReqs := make([]*ActiveRequest, 0)
for _, req := range cb.runningBatch.requests {
if req.State == StatePrefill {
prefillReqs = append(prefillReqs, req)
} else if req.State == StateDecoding {
decodeReqs = append(decodeReqs, req)
}
}
// 执行 Prefill(分块处理)
for _, req := range prefillReqs {
cb.executePrefill(req)
}
// 执行 Decode
if len(decodeReqs) > 0 {
cb.executeDecode(decodeReqs)
}
}
// executePrefill 执行 Prefill 阶段
func (cb *ContinuousBatcher) executePrefill(req *ActiveRequest) {
inputLen := len(req.InputTokens)
chunkSize := cb.config.PrefillChunkSize
// 分块 Prefill,避免长序列阻塞
start := req.CurrentStep
end := min(start+chunkSize, inputLen)
chunk := req.InputTokens[start:end]
cb.engine.Prefill(req.KVCacheBlocks, chunk, start)
req.CurrentStep = end
if req.CurrentStep >= inputLen {
req.PrefillDone = true
req.State = StateDecoding
}
}
// executeDecode 执行 Decode 阶段
func (cb *ContinuousBatcher) executeDecode(reqs []*ActiveRequest) {
// 构建批处理输入
batchInput := &BatchDecodeInput{
RequestIDs: make([]string, len(reqs)),
LastTokens: make([]int32, len(reqs)),
KVCacheBlocks: make([][]int, len(reqs)),
Positions: make([]int, len(reqs)),
}
for i, req := range reqs {
batchInput.RequestIDs[i] = req.ID
if len(req.GeneratedTokens) == 0 {
batchInput.LastTokens[i] = req.InputTokens[len(req.InputTokens)-1]
} else {
batchInput.LastTokens[i] = req.GeneratedTokens[len(req.GeneratedTokens)-1]
}
batchInput.KVCacheBlocks[i] = req.KVCacheBlocks
batchInput.Positions[i] = len(req.InputTokens) + len(req.GeneratedTokens)
}
// 批量解码
outputs := cb.engine.BatchDecode(batchInput)
// 处理输出
for i, req := range reqs {
token := outputs.NextTokens[i]
req.GeneratedTokens = append(req.GeneratedTokens, token)
// 检查是否完成
if cb.isFinished(req, token) {
req.State = StateFinished
}
}
}
func (cb *ContinuousBatcher) isFinished(req *ActiveRequest, token int32) bool {
// EOS token
if token == cb.engine.GetEOSToken() {
return true
}
// 达到最大长度
if len(req.GeneratedTokens) >= req.MaxNewTokens {
return true
}
return false
}
优先级队列实现
package batching
import "container/heap"
// QueueItem 队列项
type QueueItem struct {
Request *InferenceRequest
Priority float64
Index int
}
// PriorityQueue 优先级队列
type PriorityQueue []*QueueItem
func NewPriorityQueue() *PriorityQueue {
pq := make(PriorityQueue, 0)
heap.Init(&pq)
return &pq
}
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
// 高优先级在前
return pq[i].Priority > pq[j].Priority
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].Index = i
pq[j].Index = j
}
func (pq *PriorityQueue) Push(x interface{}) {
n := len(*pq)
item := x.(*QueueItem)
item.Index = n
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
old[n-1] = nil
item.Index = -1
*pq = old[0 : n-1]
return item
}
func (pq *PriorityQueue) Peek() *QueueItem {
if len(*pq) == 0 {
return nil
}
return (*pq)[0]
}
// calculatePriority 计算请求优先级
func (cb *ContinuousBatcher) calculatePriority(req *InferenceRequest) float64 {
switch cb.config.SchedulingPolicy {
case "fcfs":
// 先来先服务:按到达时间排序
return -float64(req.ArrivalTime.UnixNano())
case "sjf":
// 短作业优先:输入+输出 token 数少的优先
totalTokens := len(req.InputTokens) + req.MaxNewTokens
return -float64(totalTokens)
case "priority":
// 优先级调度:结合多个因素
return cb.calculateCompositePriority(req)
default:
return -float64(req.ArrivalTime.UnixNano())
}
}
func (cb *ContinuousBatcher) calculateCompositePriority(req *InferenceRequest) float64 {
// 基于多因素的优先级计算
// 1. 等待时间因子
waitTime := time.Since(req.ArrivalTime).Seconds()
waitFactor := waitTime * 0.1 // 每秒增加 0.1 优先级
// 2. 序列长度因子(短序列优先)
totalTokens := float64(len(req.InputTokens) + req.MaxNewTokens)
lengthFactor := 1000.0 / totalTokens
// 3. 用户优先级(如有)
userPriority := 1.0 // 默认优先级
return waitFactor + lengthFactor + userPriority
}
迭代级调度 (Iteration-Level Scheduling)
细粒度调度策略
迭代级调度允许在每个 decode 迭代中动态调整批次组成:
package batching
import (
"sort"
"sync"
)
// IterationScheduler 迭代级调度器
type IterationScheduler struct {
config *SchedulerConfig
preemptable bool // 是否支持抢占
swapManager *SwapManager // 交换管理器
}
type SchedulerConfig struct {
MaxBatchTokens int // 每批次最大 token
MaxPrefillTokens int // Prefill 最大 token
MaxNumSeqs int // 最大并发序列
PreemptionMode string // 抢占模式: recompute, swap
SwapSpaceGB float64 // 交换空间大小
}
// SchedulerOutput 调度输出
type SchedulerOutput struct {
ScheduledSeqs []*SequenceGroup // 调度的序列组
PreemptedSeqs []*SequenceGroup // 被抢占的序列组
BlocksToSwapIn map[int]int // 需要换入的块
BlocksToSwapOut map[int]int // 需要换出的块
BlocksToCopy map[int]int // 需要复制的块
}
// SequenceGroup 序列组(支持 beam search)
type SequenceGroup struct {
RequestID string
Sequences []*Sequence
SamplingParams *SamplingParams
ArrivalTime time.Time
State SequenceGroupState
}
type SequenceGroupState int
const (
SGStateWaiting SequenceGroupState = iota
SGStateRunning
SGStateSwapped
SGStateFinished
)
type Sequence struct {
SeqID int
TokenIDs []int32
LogicalBlocks []int // 逻辑块索引
OutputLen int
Status SequenceStatus
}
type SequenceStatus int
const (
SeqWaiting SequenceStatus = iota
SeqRunning
SeqFinishedStopped
SeqFinishedLengthCapped
SeqFinishedAborted
)
// Schedule 执行一次调度
func (s *IterationScheduler) Schedule(
waitingQueue []*SequenceGroup,
runningQueue []*SequenceGroup,
swappedQueue []*SequenceGroup,
blockManager *BlockManager,
) *SchedulerOutput {
output := &SchedulerOutput{
ScheduledSeqs: make([]*SequenceGroup, 0),
PreemptedSeqs: make([]*SequenceGroup, 0),
BlocksToSwapIn: make(map[int]int),
BlocksToSwapOut: make(map[int]int),
BlocksToCopy: make(map[int]int),
}
// 阶段1: 处理正在运行的序列
s.scheduleRunning(runningQueue, blockManager, output)
// 阶段2: 处理被交换的序列
s.scheduleSwapped(swappedQueue, blockManager, output)
// 阶段3: 处理等待队列
s.scheduleWaiting(waitingQueue, blockManager, output)
return output
}
// scheduleRunning 调度运行中的序列
func (s *IterationScheduler) scheduleRunning(
runningQueue []*SequenceGroup,
blockManager *BlockManager,
output *SchedulerOutput,
) {
// 按到达时间排序(FCFS)
sort.Slice(runningQueue, func(i, j int) bool {
return runningQueue[i].ArrivalTime.Before(runningQueue[j].ArrivalTime)
})
budgetTokens := s.config.MaxBatchTokens
budgetSeqs := s.config.MaxNumSeqs
for _, seqGroup := range runningQueue {
if seqGroup.State == SGStateFinished {
continue
}
// 计算需要的资源
numTokens := s.getNumTokens(seqGroup)
numSeqs := len(seqGroup.Sequences)
// 检查资源是否足够
if numTokens > budgetTokens || numSeqs > budgetSeqs {
// 需要抢占
if s.preemptable {
s.preempt(seqGroup, blockManager, output)
}
continue
}
// 尝试为新 token 分配块
if !s.allocateNewBlocks(seqGroup, blockManager) {
// 分配失败,需要抢占其他序列
if s.preemptable {
s.preemptLowerPriority(seqGroup, runningQueue, blockManager, output)
}
continue
}
// 调度成功
output.ScheduledSeqs = append(output.ScheduledSeqs, seqGroup)
budgetTokens -= numTokens
budgetSeqs -= numSeqs
}
}
// scheduleSwapped 调度被交换的序列
func (s *IterationScheduler) scheduleSwapped(
swappedQueue []*SequenceGroup,
blockManager *BlockManager,
output *SchedulerOutput,
) {
// 尝试换入被交换的序列
for _, seqGroup := range swappedQueue {
// 计算需要的 GPU 块
requiredBlocks := s.getRequiredBlocks(seqGroup)
// 检查是否有足够的空闲块
if !blockManager.CanAllocate(requiredBlocks) {
continue
}
// 执行换入
swapIn := s.swapIn(seqGroup, blockManager)
for src, dst := range swapIn {
output.BlocksToSwapIn[src] = dst
}
seqGroup.State = SGStateRunning
output.ScheduledSeqs = append(output.ScheduledSeqs, seqGroup)
}
}
// scheduleWaiting 调度等待队列
func (s *IterationScheduler) scheduleWaiting(
waitingQueue []*SequenceGroup,
blockManager *BlockManager,
output *SchedulerOutput,
) {
budgetTokens := s.config.MaxPrefillTokens
for _, seqGroup := range waitingQueue {
// 计算 prefill 需要的 token 数
prefillTokens := s.getPrefillTokens(seqGroup)
if prefillTokens > budgetTokens {
// 超出 prefill 预算
break
}
// 尝试分配初始块
requiredBlocks := (prefillTokens + s.config.BlockSize - 1) / s.config.BlockSize
if !blockManager.CanAllocate(requiredBlocks) {
// 块不足,尝试抢占
if s.preemptable && s.preemptForPrefill(seqGroup, blockManager, output) {
// 抢占成功,重试分配
if !blockManager.CanAllocate(requiredBlocks) {
continue
}
} else {
continue
}
}
// 分配块
blocks := blockManager.Allocate(requiredBlocks)
s.assignBlocks(seqGroup, blocks)
seqGroup.State = SGStateRunning
output.ScheduledSeqs = append(output.ScheduledSeqs, seqGroup)
budgetTokens -= prefillTokens
}
}
// preempt 抢占序列
func (s *IterationScheduler) preempt(
seqGroup *SequenceGroup,
blockManager *BlockManager,
output *SchedulerOutput,
) {
switch s.config.PreemptionMode {
case "recompute":
// 重计算模式:释放所有块,稍后重新 prefill
for _, seq := range seqGroup.Sequences {
blockManager.Free(seq.LogicalBlocks)
seq.LogicalBlocks = nil
}
seqGroup.State = SGStateWaiting
case "swap":
// 交换模式:将块换出到 CPU
swapOut := s.swapOut(seqGroup, blockManager)
for src, dst := range swapOut {
output.BlocksToSwapOut[src] = dst
}
seqGroup.State = SGStateSwapped
}
output.PreemptedSeqs = append(output.PreemptedSeqs, seqGroup)
}
块管理器
package batching
import (
"errors"
"sync"
)
// BlockManager GPU 内存块管理器
type BlockManager struct {
blockSize int // 每块的 token 数
numGPUBlocks int // GPU 块总数
numCPUBlocks int // CPU 块总数
gpuAllocator *BlockAllocator
cpuAllocator *BlockAllocator
// 块映射:逻辑块 -> 物理块
blockTables map[int]map[int]int // seqID -> logicalBlock -> physicalBlock
refCounts map[int]int // physicalBlock -> refCount
mu sync.RWMutex
}
type BlockAllocator struct {
freeBlocks []int
usedBlocks map[int]bool
totalBlocks int
mu sync.Mutex
}
func NewBlockManager(config *BlockManagerConfig) *BlockManager {
gpuBlocks := config.GPUMemoryGB * 1024 * 1024 * 1024 /
(config.BlockSize * config.NumLayers * config.HeadDim * 2 * 2) // KV * fp16
cpuBlocks := config.CPUSwapSpaceGB * 1024 * 1024 * 1024 /
(config.BlockSize * config.NumLayers * config.HeadDim * 2 * 2)
return &BlockManager{
blockSize: config.BlockSize,
numGPUBlocks: int(gpuBlocks),
numCPUBlocks: int(cpuBlocks),
gpuAllocator: NewBlockAllocator(int(gpuBlocks)),
cpuAllocator: NewBlockAllocator(int(cpuBlocks)),
blockTables: make(map[int]map[int]int),
refCounts: make(map[int]int),
}
}
func NewBlockAllocator(numBlocks int) *BlockAllocator {
freeBlocks := make([]int, numBlocks)
for i := 0; i < numBlocks; i++ {
freeBlocks[i] = i
}
return &BlockAllocator{
freeBlocks: freeBlocks,
usedBlocks: make(map[int]bool),
totalBlocks: numBlocks,
}
}
// CanAllocate 检查是否可以分配
func (bm *BlockManager) CanAllocate(numBlocks int) bool {
bm.mu.RLock()
defer bm.mu.RUnlock()
return len(bm.gpuAllocator.freeBlocks) >= numBlocks
}
// Allocate 分配块
func (bm *BlockManager) Allocate(numBlocks int) []int {
bm.mu.Lock()
defer bm.mu.Unlock()
if len(bm.gpuAllocator.freeBlocks) < numBlocks {
return nil
}
blocks := make([]int, numBlocks)
for i := 0; i < numBlocks; i++ {
block := bm.gpuAllocator.freeBlocks[0]
bm.gpuAllocator.freeBlocks = bm.gpuAllocator.freeBlocks[1:]
bm.gpuAllocator.usedBlocks[block] = true
bm.refCounts[block] = 1
blocks[i] = block
}
return blocks
}
// Free 释放块
func (bm *BlockManager) Free(blocks []int) {
bm.mu.Lock()
defer bm.mu.Unlock()
for _, block := range blocks {
bm.refCounts[block]--
if bm.refCounts[block] <= 0 {
delete(bm.gpuAllocator.usedBlocks, block)
bm.gpuAllocator.freeBlocks = append(bm.gpuAllocator.freeBlocks, block)
delete(bm.refCounts, block)
}
}
}
// AppendSlot 追加一个 slot(用于 decode)
func (bm *BlockManager) AppendSlot(seqID int) (int, int, error) {
bm.mu.Lock()
defer bm.mu.Unlock()
blockTable, ok := bm.blockTables[seqID]
if !ok {
return -1, -1, errors.New("sequence not found")
}
// 计算当前位置
numLogicalBlocks := len(blockTable)
lastBlockIdx := numLogicalBlocks - 1
lastPhysicalBlock := blockTable[lastBlockIdx]
// 检查最后一个块是否有空间
// 这里简化处理,实际需要跟踪每个块的使用情况
// 如果需要新块
if len(bm.gpuAllocator.freeBlocks) == 0 {
return -1, -1, errors.New("no free blocks")
}
newBlock := bm.gpuAllocator.freeBlocks[0]
bm.gpuAllocator.freeBlocks = bm.gpuAllocator.freeBlocks[1:]
bm.gpuAllocator.usedBlocks[newBlock] = true
bm.refCounts[newBlock] = 1
blockTable[numLogicalBlocks] = newBlock
return newBlock, 0, nil // 返回块索引和块内偏移
}
// ForkSequence 分叉序列(用于 beam search)
func (bm *BlockManager) ForkSequence(parentSeqID, childSeqID int) error {
bm.mu.Lock()
defer bm.mu.Unlock()
parentTable, ok := bm.blockTables[parentSeqID]
if !ok {
return errors.New("parent sequence not found")
}
// 创建子序列的块表,使用 copy-on-write
childTable := make(map[int]int)
for logical, physical := range parentTable {
childTable[logical] = physical
bm.refCounts[physical]++ // 增加引用计数
}
bm.blockTables[childSeqID] = childTable
return nil
}
// CopyOnWrite 写时复制
func (bm *BlockManager) CopyOnWrite(seqID, logicalBlock int) (int, error) {
bm.mu.Lock()
defer bm.mu.Unlock()
blockTable := bm.blockTables[seqID]
physicalBlock := blockTable[logicalBlock]
// 如果只有一个引用,不需要复制
if bm.refCounts[physicalBlock] == 1 {
return physicalBlock, nil
}
// 分配新块
if len(bm.gpuAllocator.freeBlocks) == 0 {
return -1, errors.New("no free blocks for copy")
}
newBlock := bm.gpuAllocator.freeBlocks[0]
bm.gpuAllocator.freeBlocks = bm.gpuAllocator.freeBlocks[1:]
bm.gpuAllocator.usedBlocks[newBlock] = true
bm.refCounts[newBlock] = 1
// 减少原块引用
bm.refCounts[physicalBlock]--
// 更新块表
blockTable[logicalBlock] = newBlock
return newBlock, nil
}
// GetBlockTable 获取序列的块表
func (bm *BlockManager) GetBlockTable(seqID int) []int {
bm.mu.RLock()
defer bm.mu.RUnlock()
table := bm.blockTables[seqID]
result := make([]int, len(table))
for logical, physical := range table {
result[logical] = physical
}
return result
}
// GetFreeBlockCount 获取空闲块数量
func (bm *BlockManager) GetFreeBlockCount() int {
bm.mu.RLock()
defer bm.mu.RUnlock()
return len(bm.gpuAllocator.freeBlocks)
}
Chunked Prefill
分块预填充实现
Chunked Prefill 将长序列的 prefill 阶段分成多个小块,避免阻塞 decode:
package batching
import (
"context"
"sync"
)
// ChunkedPrefillScheduler 分块 Prefill 调度器
type ChunkedPrefillScheduler struct {
config *ChunkedPrefillConfig
prefillQueue []*PrefillChunk
decodeQueue []*DecodeRequest
mu sync.Mutex
}
type ChunkedPrefillConfig struct {
MaxPrefillTokens int // 每次调度最大 prefill token
MaxDecodeTokens int // 每次调度最大 decode token
ChunkSize int // Prefill 块大小
PrefillDecodeRatio float64 // Prefill:Decode token 比例
}
type PrefillChunk struct {
SeqGroup *SequenceGroup
StartPos int
EndPos int
IsLast bool
}
type DecodeRequest struct {
SeqGroup *SequenceGroup
Position int
}
// ScheduleMixed 混合调度 Prefill 和 Decode
func (s *ChunkedPrefillScheduler) ScheduleMixed(
waitingQueue []*SequenceGroup,
runningQueue []*SequenceGroup,
blockManager *BlockManager,
) *MixedScheduleOutput {
s.mu.Lock()
defer s.mu.Unlock()
output := &MixedScheduleOutput{
PrefillChunks: make([]*PrefillChunk, 0),
DecodeSeqs: make([]*SequenceGroup, 0),
}
prefillBudget := s.config.MaxPrefillTokens
decodeBudget := s.config.MaxDecodeTokens
// 优先处理已在运行的 decode 请求
for _, seqGroup := range runningQueue {
if seqGroup.State != SGStateRunning {
continue
}
// 检查是否还在 prefill
if !s.isPrefillComplete(seqGroup) {
// 分配下一个 prefill chunk
chunk := s.getNextPrefillChunk(seqGroup)
chunkTokens := chunk.EndPos - chunk.StartPos
if chunkTokens <= prefillBudget {
output.PrefillChunks = append(output.PrefillChunks, chunk)
prefillBudget -= chunkTokens
}
continue
}
// Decode 请求
if decodeBudget > 0 {
output.DecodeSeqs = append(output.DecodeSeqs, seqGroup)
decodeBudget--
}
}
// 处理等待队列中的新请求
for _, seqGroup := range waitingQueue {
// 计算首个 prefill chunk
promptLen := s.getPromptLength(seqGroup)
chunkSize := min(promptLen, s.config.ChunkSize)
if chunkSize > prefillBudget {
break
}
// 尝试分配初始块
requiredBlocks := (chunkSize + blockManager.blockSize - 1) / blockManager.blockSize
if !blockManager.CanAllocate(requiredBlocks) {
continue
}
blocks := blockManager.Allocate(requiredBlocks)
s.assignBlocks(seqGroup, blocks)
chunk := &PrefillChunk{
SeqGroup: seqGroup,
StartPos: 0,
EndPos: chunkSize,
IsLast: chunkSize >= promptLen,
}
output.PrefillChunks = append(output.PrefillChunks, chunk)
prefillBudget -= chunkSize
seqGroup.State = SGStateRunning
}
return output
}
type MixedScheduleOutput struct {
PrefillChunks []*PrefillChunk
DecodeSeqs []*SequenceGroup
}
// ExecuteMixed 执行混合批次
func (s *ChunkedPrefillScheduler) ExecuteMixed(
ctx context.Context,
output *MixedScheduleOutput,
engine InferenceEngine,
) error {
// 构建混合批次输入
input := &MixedBatchInput{
PrefillInputs: make([]*PrefillInput, 0),
DecodeInputs: make([]*DecodeInput, 0),
}
// 添加 Prefill 输入
for _, chunk := range output.PrefillChunks {
seq := chunk.SeqGroup.Sequences[0]
prefillInput := &PrefillInput{
SeqID: seq.SeqID,
TokenIDs: seq.TokenIDs[chunk.StartPos:chunk.EndPos],
StartPos: chunk.StartPos,
BlockTable: s.getBlockTable(chunk.SeqGroup),
}
input.PrefillInputs = append(input.PrefillInputs, prefillInput)
}
// 添加 Decode 输入
for _, seqGroup := range output.DecodeSeqs {
for _, seq := range seqGroup.Sequences {
decodeInput := &DecodeInput{
SeqID: seq.SeqID,
LastToken: seq.TokenIDs[len(seq.TokenIDs)-1],
Position: len(seq.TokenIDs),
BlockTable: s.getBlockTable(seqGroup),
}
input.DecodeInputs = append(input.DecodeInputs, decodeInput)
}
}
// 执行混合前向传播
result, err := engine.ForwardMixed(ctx, input)
if err != nil {
return err
}
// 处理 Prefill 结果
for i, chunk := range output.PrefillChunks {
if chunk.IsLast {
// Prefill 完成,获取首个生成 token
token := result.PrefillOutputs[i].FirstToken
seq := chunk.SeqGroup.Sequences[0]
seq.TokenIDs = append(seq.TokenIDs, token)
}
// 更新进度
s.updatePrefillProgress(chunk.SeqGroup, chunk.EndPos)
}
// 处理 Decode 结果
decodeIdx := 0
for _, seqGroup := range output.DecodeSeqs {
for _, seq := range seqGroup.Sequences {
token := result.DecodeOutputs[decodeIdx].NextToken
seq.TokenIDs = append(seq.TokenIDs, token)
decodeIdx++
// 检查是否完成
if s.isFinished(seq, token) {
seq.Status = SeqFinishedStopped
}
}
// 更新序列组状态
if s.allSequencesFinished(seqGroup) {
seqGroup.State = SGStateFinished
}
}
return nil
}
自适应分块策略
package batching
import (
"math"
"time"
)
// AdaptiveChunker 自适应分块器
type AdaptiveChunker struct {
config *AdaptiveConfig
latencyHistory *LatencyHistory
throughputTarget float64
}
type AdaptiveConfig struct {
MinChunkSize int
MaxChunkSize int
TargetLatencyMs float64
LatencyWindowSize int
}
type LatencyHistory struct {
prefillLatencies []float64 // ms/token
decodeLatencies []float64 // ms/token
windowSize int
}
// OptimalChunkSize 计算最优分块大小
func (ac *AdaptiveChunker) OptimalChunkSize(
promptLength int,
currentDecodeLoad int,
) int {
// 获取历史延迟统计
avgPrefillLatency := ac.latencyHistory.AveragePrefillLatency()
avgDecodeLatency := ac.latencyHistory.AverageDecodeLatency()
// 目标:保证 decode 请求延迟不超过阈值
// decode 延迟 = prefill 时间 + decode 时间
// prefill 时间 = chunk_size * prefill_latency
// decode 等待时间 = chunk_size * prefill_latency
// 计算 decode 可接受的等待时间
targetWaitTime := ac.config.TargetLatencyMs - avgDecodeLatency
if targetWaitTime <= 0 {
return ac.config.MinChunkSize
}
// 计算最大可接受的 chunk size
maxChunk := int(targetWaitTime / avgPrefillLatency)
// 考虑当前 decode 负载
// 高负载时减小 chunk size 以减少 decode 等待
loadFactor := 1.0 - float64(currentDecodeLoad)/float64(ac.config.MaxDecodeLoad)
adjustedMaxChunk := int(float64(maxChunk) * math.Max(0.2, loadFactor))
// 边界约束
optimalSize := min(
max(adjustedMaxChunk, ac.config.MinChunkSize),
ac.config.MaxChunkSize,
)
// 不超过 prompt 长度
return min(optimalSize, promptLength)
}
// UpdateLatency 更新延迟统计
func (lh *LatencyHistory) UpdateLatency(isPrefill bool, tokens int, durationMs float64) {
latencyPerToken := durationMs / float64(tokens)
if isPrefill {
lh.prefillLatencies = append(lh.prefillLatencies, latencyPerToken)
if len(lh.prefillLatencies) > lh.windowSize {
lh.prefillLatencies = lh.prefillLatencies[1:]
}
} else {
lh.decodeLatencies = append(lh.decodeLatencies, latencyPerToken)
if len(lh.decodeLatencies) > lh.windowSize {
lh.decodeLatencies = lh.decodeLatencies[1:]
}
}
}
func (lh *LatencyHistory) AveragePrefillLatency() float64 {
if len(lh.prefillLatencies) == 0 {
return 1.0 // 默认值
}
sum := 0.0
for _, l := range lh.prefillLatencies {
sum += l
}
return sum / float64(len(lh.prefillLatencies))
}
func (lh *LatencyHistory) AverageDecodeLatency() float64 {
if len(lh.decodeLatencies) == 0 {
return 1.0 // 默认值
}
sum := 0.0
for _, l := range lh.decodeLatencies {
sum += l
}
return sum / float64(len(lh.decodeLatencies))
}
Speculative Decoding
推测解码实现
Speculative Decoding 使用小模型预测多个 token,然后由大模型并行验证:
package batching
import (
"context"
"sync"
)
// SpeculativeDecoder 推测解码器
type SpeculativeDecoder struct {
targetModel LargeModel // 目标大模型
draftModel SmallModel // 草稿小模型
config *SpecConfig
acceptStats *AcceptanceStats
}
type SpecConfig struct {
NumSpecTokens int // 每次推测的 token 数
AcceptThreshold float64 // 接受阈值
DraftTemperature float64 // 草稿模型温度
TargetTemperature float64 // 目标模型温度
}
type AcceptanceStats struct {
TotalTokens int64
AcceptedTokens int64
mu sync.Mutex
}
// SpeculativeStep 执行一步推测解码
func (sd *SpeculativeDecoder) SpeculativeStep(
ctx context.Context,
seqGroup *SequenceGroup,
) (*SpeculativeResult, error) {
seq := seqGroup.Sequences[0]
currentTokens := seq.TokenIDs
// 阶段1: 草稿模型生成多个候选 token
draftTokens, draftProbs, err := sd.generateDraft(ctx, currentTokens)
if err != nil {
return nil, err
}
// 阶段2: 目标模型并行验证所有候选
targetProbs, err := sd.verifyWithTarget(ctx, currentTokens, draftTokens)
if err != nil {
return nil, err
}
// 阶段3: 使用 rejection sampling 确定接受的 token
acceptedTokens := sd.rejectionSampling(draftTokens, draftProbs, targetProbs)
// 更新统计
sd.acceptStats.mu.Lock()
sd.acceptStats.TotalTokens += int64(len(draftTokens))
sd.acceptStats.AcceptedTokens += int64(len(acceptedTokens))
sd.acceptStats.mu.Unlock()
return &SpeculativeResult{
AcceptedTokens: acceptedTokens,
AcceptRate: float64(len(acceptedTokens)) / float64(len(draftTokens)),
}, nil
}
// generateDraft 使用草稿模型生成候选 token
func (sd *SpeculativeDecoder) generateDraft(
ctx context.Context,
inputTokens []int32,
) ([]int32, [][]float32, error) {
draftTokens := make([]int32, 0, sd.config.NumSpecTokens)
draftProbs := make([][]float32, 0, sd.config.NumSpecTokens)
currentInput := inputTokens
for i := 0; i < sd.config.NumSpecTokens; i++ {
// 草稿模型前向传播
logits, err := sd.draftModel.Forward(ctx, currentInput)
if err != nil {
return nil, nil, err
}
// 采样
probs := softmax(logits, sd.config.DraftTemperature)
token := sample(probs)
draftTokens = append(draftTokens, token)
draftProbs = append(draftProbs, probs)
// 更新输入
currentInput = append(currentInput, token)
}
return draftTokens, draftProbs, nil
}
// verifyWithTarget 目标模型验证
func (sd *SpeculativeDecoder) verifyWithTarget(
ctx context.Context,
inputTokens []int32,
draftTokens []int32,
) ([][]float32, error) {
// 构建完整序列
fullSequence := make([]int32, len(inputTokens)+len(draftTokens))
copy(fullSequence, inputTokens)
copy(fullSequence[len(inputTokens):], draftTokens)
// 目标模型并行计算所有位置的概率
// 这是推测解码的关键优势:一次前向传播验证多个 token
allLogits, err := sd.targetModel.Forward(ctx, fullSequence)
if err != nil {
return nil, err
}
// 提取验证位置的概率
targetProbs := make([][]float32, len(draftTokens))
for i := 0; i < len(draftTokens); i++ {
pos := len(inputTokens) + i
targetProbs[i] = softmax(allLogits[pos], sd.config.TargetTemperature)
}
return targetProbs, nil
}
// rejectionSampling 拒绝采样确定接受的 token
func (sd *SpeculativeDecoder) rejectionSampling(
draftTokens []int32,
draftProbs [][]float32,
targetProbs [][]float32,
) []int32 {
acceptedTokens := make([]int32, 0, len(draftTokens))
for i, token := range draftTokens {
// 获取该 token 在两个分布下的概率
pDraft := draftProbs[i][token]
pTarget := targetProbs[i][token]
// 接受概率 = min(1, p_target / p_draft)
acceptProb := min(1.0, float64(pTarget)/float64(pDraft))
// 随机决定是否接受
if randomFloat() < acceptProb {
acceptedTokens = append(acceptedTokens, token)
} else {
// 拒绝后,从修正分布中采样
correctedProbs := sd.computeCorrectedDistribution(
targetProbs[i], draftProbs[i],
)
newToken := sample(correctedProbs)
acceptedTokens = append(acceptedTokens, newToken)
break // 后续 token 作废
}
}
return acceptedTokens
}
// computeCorrectedDistribution 计算修正分布
func (sd *SpeculativeDecoder) computeCorrectedDistribution(
targetProbs, draftProbs []float32,
) []float32 {
corrected := make([]float32, len(targetProbs))
var sum float32 = 0
for i := range corrected {
// p_corrected = max(0, p_target - p_draft)
diff := targetProbs[i] - draftProbs[i]
if diff > 0 {
corrected[i] = diff
sum += diff
}
}
// 归一化
if sum > 0 {
for i := range corrected {
corrected[i] /= sum
}
} else {
// 如果修正分布为零,使用目标分布
copy(corrected, targetProbs)
}
return corrected
}
type SpeculativeResult struct {
AcceptedTokens []int32
AcceptRate float64
}
// 辅助函数
func softmax(logits []float32, temperature float64) []float32 {
probs := make([]float32, len(logits))
var maxLogit float32 = logits[0]
for _, l := range logits {
if l > maxLogit {
maxLogit = l
}
}
var sum float32 = 0
for i, l := range logits {
probs[i] = float32(math.Exp(float64((l - maxLogit) / float32(temperature))))
sum += probs[i]
}
for i := range probs {
probs[i] /= sum
}
return probs
}
func sample(probs []float32) int32 {
r := randomFloat()
var cumSum float32 = 0
for i, p := range probs {
cumSum += p
if float32(r) < cumSum {
return int32(i)
}
}
return int32(len(probs) - 1)
}
自适应推测长度
package batching
import (
"math"
)
// AdaptiveSpeculator 自适应推测器
type AdaptiveSpeculator struct {
baseSpecLen int
minSpecLen int
maxSpecLen int
acceptHistory *RingBuffer
windowSize int
}
// NewAdaptiveSpeculator 创建自适应推测器
func NewAdaptiveSpeculator(config *AdaptiveSpecConfig) *AdaptiveSpeculator {
return &AdaptiveSpeculator{
baseSpecLen: config.BaseSpecLen,
minSpecLen: config.MinSpecLen,
maxSpecLen: config.MaxSpecLen,
acceptHistory: NewRingBuffer(config.WindowSize),
windowSize: config.WindowSize,
}
}
// OptimalSpecLength 计算最优推测长度
func (as *AdaptiveSpeculator) OptimalSpecLength() int {
avgAcceptRate := as.acceptHistory.Average()
if avgAcceptRate < 0.001 {
return as.baseSpecLen
}
// 基于接受率计算期望加速比
// 期望生成 token 数 = sum(accept_rate^i) for i in 0..spec_len
// 加速比 = 期望生成数 / (1 + draft_cost)
bestLen := as.minSpecLen
bestSpeedup := 0.0
draftCostRatio := 0.1 // 草稿模型成本约为目标模型的 10%
for specLen := as.minSpecLen; specLen <= as.maxSpecLen; specLen++ {
expectedTokens := 0.0
for i := 0; i <= specLen; i++ {
expectedTokens += math.Pow(avgAcceptRate, float64(i))
}
// 考虑验证成本(一次前向传播)和草稿成本
totalCost := 1.0 + float64(specLen)*draftCostRatio
speedup := expectedTokens / totalCost
if speedup > bestSpeedup {
bestSpeedup = speedup
bestLen = specLen
}
}
return bestLen
}
// UpdateAcceptRate 更新接受率
func (as *AdaptiveSpeculator) UpdateAcceptRate(accepted, total int) {
if total > 0 {
rate := float64(accepted) / float64(total)
as.acceptHistory.Push(rate)
}
}
// RingBuffer 环形缓冲区
type RingBuffer struct {
data []float64
head int
size int
count int
}
func NewRingBuffer(size int) *RingBuffer {
return &RingBuffer{
data: make([]float64, size),
size: size,
}
}
func (rb *RingBuffer) Push(value float64) {
rb.data[rb.head] = value
rb.head = (rb.head + 1) % rb.size
if rb.count < rb.size {
rb.count++
}
}
func (rb *RingBuffer) Average() float64 {
if rb.count == 0 {
return 0
}
sum := 0.0
for i := 0; i < rb.count; i++ {
sum += rb.data[i]
}
return sum / float64(rb.count)
}
批处理性能优化
内存优化策略
package batching
import (
"runtime"
"sync"
"unsafe"
)
// BatchMemoryManager 批处理内存管理器
type BatchMemoryManager struct {
tokenPool *sync.Pool
logitPool *sync.Pool
attentionPool *sync.Pool
config *MemoryConfig
}
type MemoryConfig struct {
MaxBatchSize int
MaxSeqLen int
VocabSize int
HiddenSize int
NumHeads int
PreallocateMB int
}
func NewBatchMemoryManager(config *MemoryConfig) *BatchMemoryManager {
bmm := &BatchMemoryManager{config: config}
// Token ID 池
bmm.tokenPool = &sync.Pool{
New: func() interface{} {
return make([]int32, 0, config.MaxBatchSize*config.MaxSeqLen)
},
}
// Logits 池
bmm.logitPool = &sync.Pool{
New: func() interface{} {
return make([]float32, 0, config.MaxBatchSize*config.VocabSize)
},
}
// Attention 中间结果池
bmm.attentionPool = &sync.Pool{
New: func() interface{} {
size := config.MaxBatchSize * config.NumHeads * config.MaxSeqLen * config.MaxSeqLen
return make([]float32, 0, size)
},
}
// 预热池
bmm.warmupPools()
return bmm
}
func (bmm *BatchMemoryManager) warmupPools() {
// 预分配对象减少运行时分配
for i := 0; i < 10; i++ {
tokens := bmm.tokenPool.Get().([]int32)
bmm.tokenPool.Put(tokens[:0])
logits := bmm.logitPool.Get().([]float32)
bmm.logitPool.Put(logits[:0])
attn := bmm.attentionPool.Get().([]float32)
bmm.attentionPool.Put(attn[:0])
}
}
// AllocateTokenBuffer 分配 token 缓冲区
func (bmm *BatchMemoryManager) AllocateTokenBuffer(size int) []int32 {
buf := bmm.tokenPool.Get().([]int32)
if cap(buf) < size {
buf = make([]int32, size)
} else {
buf = buf[:size]
}
return buf
}
// ReleaseTokenBuffer 释放 token 缓冲区
func (bmm *BatchMemoryManager) ReleaseTokenBuffer(buf []int32) {
bmm.tokenPool.Put(buf[:0])
}
// BatchBuffer 批处理缓冲区
type BatchBuffer struct {
InputIDs []int32
AttentionMask []int32
PositionIDs []int32
BlockTables [][]int
SeqLens []int
// 输出缓冲
Logits []float32
NextTokens []int32
// 元数据
BatchSize int
MaxSeqLen int
pool *BatchMemoryManager
}
func (bmm *BatchMemoryManager) NewBatchBuffer(batchSize, maxSeqLen int) *BatchBuffer {
totalTokens := batchSize * maxSeqLen
return &BatchBuffer{
InputIDs: bmm.AllocateTokenBuffer(totalTokens),
AttentionMask: bmm.AllocateTokenBuffer(totalTokens),
PositionIDs: bmm.AllocateTokenBuffer(totalTokens),
BlockTables: make([][]int, batchSize),
SeqLens: make([]int, batchSize),
Logits: bmm.logitPool.Get().([]float32)[:batchSize*bmm.config.VocabSize],
NextTokens: make([]int32, batchSize),
BatchSize: batchSize,
MaxSeqLen: maxSeqLen,
pool: bmm,
}
}
func (bb *BatchBuffer) Release() {
bb.pool.ReleaseTokenBuffer(bb.InputIDs)
bb.pool.ReleaseTokenBuffer(bb.AttentionMask)
bb.pool.ReleaseTokenBuffer(bb.PositionIDs)
bb.pool.logitPool.Put(bb.Logits[:0])
}
// ZeroCopyBatch 零拷贝批处理
type ZeroCopyBatch struct {
requests []*ActiveRequest
views []*TensorView
}
type TensorView struct {
Data unsafe.Pointer
Shape []int
Stride []int
Offset int
}
// CreateBatchView 创建批次视图(避免数据拷贝)
func CreateBatchView(requests []*ActiveRequest) *ZeroCopyBatch {
batch := &ZeroCopyBatch{
requests: requests,
views: make([]*TensorView, len(requests)),
}
for i, req := range requests {
// 直接引用请求的数据,不复制
batch.views[i] = &TensorView{
Data: unsafe.Pointer(&req.InputTokens[0]),
Shape: []int{len(req.InputTokens)},
Stride: []int{1},
Offset: 0,
}
}
return batch
}
并发优化
package batching
import (
"context"
"runtime"
"sync"
)
// ParallelBatchProcessor 并行批处理器
type ParallelBatchProcessor struct {
numWorkers int
workerPool chan struct{}
preprocessCh chan *PreprocessTask
postprocessCh chan *PostprocessTask
wg sync.WaitGroup
}
type PreprocessTask struct {
Request *InferenceRequest
ResultCh chan *PreprocessedInput
}
type PostprocessTask struct {
Output *RawOutput
Request *InferenceRequest
ResultCh chan *InferenceResult
}
type PreprocessedInput struct {
TokenIDs []int32
AttentionMask []int32
}
type RawOutput struct {
Logits []float32
}
func NewParallelBatchProcessor(numWorkers int) *ParallelBatchProcessor {
if numWorkers <= 0 {
numWorkers = runtime.NumCPU()
}
pbp := &ParallelBatchProcessor{
numWorkers: numWorkers,
workerPool: make(chan struct{}, numWorkers),
preprocessCh: make(chan *PreprocessTask, 1000),
postprocessCh: make(chan *PostprocessTask, 1000),
}
// 启动预处理 worker
for i := 0; i < numWorkers; i++ {
go pbp.preprocessWorker()
}
// 启动后处理 worker
for i := 0; i < numWorkers; i++ {
go pbp.postprocessWorker()
}
return pbp
}
func (pbp *ParallelBatchProcessor) preprocessWorker() {
for task := range pbp.preprocessCh {
result := pbp.doPreprocess(task.Request)
task.ResultCh <- result
}
}
func (pbp *ParallelBatchProcessor) postprocessWorker() {
for task := range pbp.postprocessCh {
result := pbp.doPostprocess(task.Output, task.Request)
task.ResultCh <- result
}
}
func (pbp *ParallelBatchProcessor) doPreprocess(req *InferenceRequest) *PreprocessedInput {
// Tokenization
tokens := tokenize(req.Prompt)
// 创建 attention mask
mask := make([]int32, len(tokens))
for i := range mask {
mask[i] = 1
}
return &PreprocessedInput{
TokenIDs: tokens,
AttentionMask: mask,
}
}
func (pbp *ParallelBatchProcessor) doPostprocess(
output *RawOutput,
req *InferenceRequest,
) *InferenceResult {
// 采样
token := sampleFromLogits(output.Logits, req.Temperature, req.TopP)
// 解码
text := decode([]int32{token})
return &InferenceResult{
RequestID: req.ID,
OutputTokens: []int32{token},
Text: text,
}
}
// PreprocessBatch 并行预处理批次
func (pbp *ParallelBatchProcessor) PreprocessBatch(
ctx context.Context,
requests []*InferenceRequest,
) []*PreprocessedInput {
results := make([]*PreprocessedInput, len(requests))
resultChs := make([]chan *PreprocessedInput, len(requests))
// 提交所有任务
for i, req := range requests {
resultChs[i] = make(chan *PreprocessedInput, 1)
pbp.preprocessCh <- &PreprocessTask{
Request: req,
ResultCh: resultChs[i],
}
}
// 收集结果
for i, ch := range resultChs {
select {
case result := <-ch:
results[i] = result
case <-ctx.Done():
return nil
}
}
return results
}
// PostprocessBatch 并行后处理批次
func (pbp *ParallelBatchProcessor) PostprocessBatch(
ctx context.Context,
outputs []*RawOutput,
requests []*InferenceRequest,
) []*InferenceResult {
results := make([]*InferenceResult, len(outputs))
resultChs := make([]chan *InferenceResult, len(outputs))
// 提交所有任务
for i := range outputs {
resultChs[i] = make(chan *InferenceResult, 1)
pbp.postprocessCh <- &PostprocessTask{
Output: outputs[i],
Request: requests[i],
ResultCh: resultChs[i],
}
}
// 收集结果
for i, ch := range resultChs {
select {
case result := <-ch:
results[i] = result
case <-ctx.Done():
return nil
}
}
return results
}
监控与调优
批处理指标
package batching
import (
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// BatchMetrics 批处理指标收集器
type BatchMetrics struct {
// 请求指标
requestsTotal prometheus.Counter
requestsQueued prometheus.Gauge
requestsRunning prometheus.Gauge
requestLatency prometheus.Histogram
// 批次指标
batchSize prometheus.Histogram
batchTokens prometheus.Histogram
batchUtilization prometheus.Gauge
iterationLatency prometheus.Histogram
// 调度指标
schedulerOverhead prometheus.Histogram
preemptionCount prometheus.Counter
swapInCount prometheus.Counter
swapOutCount prometheus.Counter
// 吞吐量指标
tokensGenerated prometheus.Counter
prefillTokens prometheus.Counter
decodeTokens prometheus.Counter
// 推测解码指标
specAcceptRate prometheus.Gauge
specSpeedup prometheus.Gauge
}
func NewBatchMetrics(namespace string) *BatchMetrics {
return &BatchMetrics{
requestsTotal: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "requests_total",
Help: "Total number of inference requests",
}),
requestsQueued: promauto.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "requests_queued",
Help: "Number of requests in queue",
}),
requestsRunning: promauto.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "requests_running",
Help: "Number of requests currently running",
}),
requestLatency: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: namespace,
Name: "request_latency_seconds",
Help: "Request end-to-end latency",
Buckets: []float64{0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30},
}),
batchSize: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: namespace,
Name: "batch_size",
Help: "Number of sequences per batch",
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256},
}),
batchTokens: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: namespace,
Name: "batch_tokens",
Help: "Total tokens per batch",
Buckets: []float64{100, 500, 1000, 2000, 4000, 8000, 16000},
}),
batchUtilization: promauto.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "batch_utilization",
Help: "Batch slot utilization ratio",
}),
iterationLatency: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: namespace,
Name: "iteration_latency_ms",
Help: "Single decode iteration latency",
Buckets: []float64{1, 2, 5, 10, 20, 50, 100, 200},
}),
schedulerOverhead: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: namespace,
Name: "scheduler_overhead_us",
Help: "Scheduler overhead per iteration",
Buckets: []float64{10, 50, 100, 500, 1000, 5000},
}),
preemptionCount: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "preemption_total",
Help: "Total number of preemptions",
}),
swapInCount: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "swap_in_total",
Help: "Total number of swap-in operations",
}),
swapOutCount: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "swap_out_total",
Help: "Total number of swap-out operations",
}),
tokensGenerated: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "tokens_generated_total",
Help: "Total tokens generated",
}),
prefillTokens: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "prefill_tokens_total",
Help: "Total prefill tokens processed",
}),
decodeTokens: promauto.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Name: "decode_tokens_total",
Help: "Total decode tokens generated",
}),
specAcceptRate: promauto.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "speculative_accept_rate",
Help: "Speculative decoding acceptance rate",
}),
specSpeedup: promauto.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Name: "speculative_speedup",
Help: "Speculative decoding speedup ratio",
}),
}
}
// BatchProfiler 批处理性能分析器
type BatchProfiler struct {
metrics *BatchMetrics
startTimes map[string]time.Time
mu sync.RWMutex
// 统计数据
totalBatches int64
totalIterations int64
totalLatencyUs int64
}
func NewBatchProfiler(metrics *BatchMetrics) *BatchProfiler {
return &BatchProfiler{
metrics: metrics,
startTimes: make(map[string]time.Time),
}
}
// RecordBatchStart 记录批次开始
func (bp *BatchProfiler) RecordBatchStart(batchID string, batchSize, totalTokens int) {
bp.mu.Lock()
bp.startTimes[batchID] = time.Now()
bp.mu.Unlock()
bp.metrics.batchSize.Observe(float64(batchSize))
bp.metrics.batchTokens.Observe(float64(totalTokens))
}
// RecordBatchEnd 记录批次结束
func (bp *BatchProfiler) RecordBatchEnd(batchID string) {
bp.mu.Lock()
startTime, ok := bp.startTimes[batchID]
if ok {
delete(bp.startTimes, batchID)
}
bp.mu.Unlock()
if ok {
latency := time.Since(startTime).Milliseconds()
bp.metrics.iterationLatency.Observe(float64(latency))
atomic.AddInt64(&bp.totalLatencyUs, latency*1000)
atomic.AddInt64(&bp.totalIterations, 1)
}
}
// RecordSchedulerOverhead 记录调度器开销
func (bp *BatchProfiler) RecordSchedulerOverhead(duration time.Duration) {
bp.metrics.schedulerOverhead.Observe(float64(duration.Microseconds()))
}
// GetAverageLatency 获取平均延迟
func (bp *BatchProfiler) GetAverageLatency() float64 {
iterations := atomic.LoadInt64(&bp.totalIterations)
if iterations == 0 {
return 0
}
totalUs := atomic.LoadInt64(&bp.totalLatencyUs)
return float64(totalUs) / float64(iterations) / 1000.0 // 返回毫秒
}
// ThroughputCalculator 吞吐量计算器
type ThroughputCalculator struct {
windowSize time.Duration
tokenCounts []TokenCount
mu sync.RWMutex
}
type TokenCount struct {
Timestamp time.Time
Count int64
}
func NewThroughputCalculator(windowSize time.Duration) *ThroughputCalculator {
return &ThroughputCalculator{
windowSize: windowSize,
tokenCounts: make([]TokenCount, 0),
}
}
func (tc *ThroughputCalculator) RecordTokens(count int64) {
tc.mu.Lock()
defer tc.mu.Unlock()
now := time.Now()
tc.tokenCounts = append(tc.tokenCounts, TokenCount{
Timestamp: now,
Count: count,
})
// 清理过期数据
cutoff := now.Add(-tc.windowSize)
validIdx := 0
for i, tc := range tc.tokenCounts {
if tc.Timestamp.After(cutoff) {
validIdx = i
break
}
}
tc.tokenCounts = tc.tokenCounts[validIdx:]
}
func (tc *ThroughputCalculator) GetThroughput() float64 {
tc.mu.RLock()
defer tc.mu.RUnlock()
if len(tc.tokenCounts) < 2 {
return 0
}
var totalTokens int64
for _, tc := range tc.tokenCounts {
totalTokens += tc.Count
}
duration := tc.tokenCounts[len(tc.tokenCounts)-1].Timestamp.Sub(
tc.tokenCounts[0].Timestamp,
).Seconds()
if duration == 0 {
return 0
}
return float64(totalTokens) / duration
}
性能调优建议
# 批处理调优配置示例
batching:
# 基础配置
max_batch_size: 256 # 最大批次大小
max_tokens_per_batch: 8192 # 批次最大 token 数
max_waiting_requests: 1000 # 最大等待请求数
# 调度配置
scheduling_policy: "priority" # fcfs, sjf, priority
max_wait_time_ms: 100 # 最大等待时间
scheduler_interval_ms: 1 # 调度间隔
# Prefill 配置
prefill_chunk_size: 512 # Prefill 分块大小
max_prefill_tokens: 2048 # 每次调度最大 prefill token
# 抢占配置
enable_preemption: true
preemption_mode: "swap" # recompute, swap
swap_space_gb: 16
# 推测解码配置
speculative_decoding:
enabled: true
num_spec_tokens: 5
draft_model: "distill-llama-68m"
accept_threshold: 0.7
adaptive_length: true
# 调优建议
tuning_guide:
high_throughput:
description: "优化吞吐量场景"
settings:
max_batch_size: 512
prefill_chunk_size: 1024
scheduling_policy: "sjf"
enable_preemption: true
low_latency:
description: "优化延迟场景"
settings:
max_batch_size: 32
prefill_chunk_size: 256
scheduling_policy: "fcfs"
max_wait_time_ms: 10
memory_constrained:
description: "内存受限场景"
settings:
max_batch_size: 64
max_tokens_per_batch: 4096
preemption_mode: "swap"
swap_space_gb: 32
小结
本章深入探讨了动态批处理的核心技术:
- 连续批处理:允许请求在任意 token 位置加入或离开批次,充分利用 GPU 资源
- 迭代级调度:细粒度控制每个 decode 迭代的批次组成,支持抢占和交换
- Chunked Prefill:分块处理长序列的 prefill 阶段,避免阻塞短序列的 decode
- 推测解码:使用小模型预测 + 大模型验证,提升生成速度
- 内存优化:对象池、零拷贝、块管理等技术减少内存开销
- 监控调优:完善的指标体系和性能分析工具
动态批处理是大模型推理系统的核心组件,直接决定了系统的吞吐量和延迟。合理的批处理策略能够在保证服务质量的前提下,最大化 GPU 利用率。
下一章我们将探讨 推理优化技术,讲解如何通过模型优化和系统优化进一步提升推理性能。