推理引擎原理
概述
推理引擎是将训练好的模型部署到生产环境的核心组件。对于大语言模型而言,推理效率直接决定了服务的延迟、吞吐量和成本。本章深入剖析推理引擎的工作原理,包括模型加载、计算优化、内存管理等关键技术。
推理流程解析
端到端推理流程
┌─────────────────────────────────────────────────────────────────┐
│ LLM 推理完整流程 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 输入: "What is the capital of France?" │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 1. Tokenization (分词) │ │
│ │ Text → [1024, 318, 262, 3139, 286, 4881, 30] │ │
│ └──────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ┌──────────────────────▼──────────────────────────────────┐ │
│ │ 2. Embedding (嵌入) │ │
│ │ Token IDs → [batch, seq_len, hidden_dim] │ │
│ └──────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ┌──────────────────────▼──────────────────────────────────┐ │
│ │ 3. Prefill Phase (预填充阶段) │ │
│ │ • 并行处理所有输入 token │ │
│ │ • 计算并缓存所有 KV Cache │ │
│ │ • 计算量大,但只执行一次 │ │
│ └──────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ┌──────────────────────▼──────────────────────────────────┐ │
│ │ 4. Decode Phase (解码阶段) │ │
│ │ Loop: │ │
│ │ ├─ 输入: 上一个生成的 token │ │
│ │ ├─ 使用 KV Cache 进行 attention │ │
│ │ ├─ 计算 logits │ │
│ │ ├─ Sampling (采样) │ │
│ │ └─ 输出: 新 token │ │
│ │ Until: EOS 或达到 max_length │ │
│ └──────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ┌──────────────────────▼──────────────────────────────────┐ │
│ │ 5. Detokenization (反分词) │ │
│ │ Token IDs → "Paris" │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 输出: "Paris" │
│ │
└─────────────────────────────────────────────────────────────────┘
Prefill vs Decode 特性对比
┌─────────────────────────────────────────────────────────────────┐
│ Prefill 与 Decode 阶段对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 特性 │ Prefill │ Decode │
│ ──────────────────┼────────────────────┼────────────────────── │
│ 处理 Token 数 │ N (prompt长度) │ 1 (每次一个) │
│ 计算特点 │ 计算密集 │ 内存带宽密集 │
│ GPU 利用率 │ 高 │ 低 │
│ KV Cache │ 生成全部 │ 增量更新 │
│ Attention 类型 │ Full Attention │ Incremental │
│ 批处理效率 │ 高 │ 低(需动态批处理) │
│ │
│ 计算量对比 (以 LLaMA-7B, seq_len=2048 为例): │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Prefill: ~26 TFLOPs (一次性) │ │
│ │ Decode: ~13 GFLOPs × 生成长度 (每个 token) │ │
│ │ │ │
│ │ 单 token 解码的计算量 ≈ Prefill 的 1/2000 │ │
│ │ 但解码阶段受限于内存带宽,不是计算 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
推理引擎架构
核心组件
// inference_engine.go
package inference
import (
"context"
"sync"
)
// InferenceEngine 推理引擎
type InferenceEngine struct {
// 模型
model *Model
modelConfig *ModelConfig
// KV Cache 管理
kvCacheManager *KVCacheManager
// 调度器
scheduler *RequestScheduler
// 采样器
sampler *Sampler
// 执行器
executor Executor
// 配置
config EngineConfig
}
// EngineConfig 引擎配置
type EngineConfig struct {
// 模型配置
ModelPath string
ModelFormat string // pytorch, safetensors, onnx
Dtype string // float16, bfloat16, int8, int4
// 硬件配置
DeviceType string // cuda, cpu
DeviceIDs []int
TensorParallel int
// 内存配置
MaxBatchSize int
MaxSeqLen int
KVCacheSize int64 // bytes
BlockSize int // PagedAttention block size
// 优化配置
UseFlashAttention bool
UseContinuousBatching bool
UseSpeculativeDecoding bool
}
// Model 模型结构
type Model struct {
// 模型层
EmbedTokens *Embedding
Layers []*TransformerLayer
Norm *RMSNorm
LMHead *Linear
// 配置
Config *ModelConfig
}
// ModelConfig 模型配置
type ModelConfig struct {
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"`
IntermediateSize int `json:"intermediate_size"`
NumLayers int `json:"num_hidden_layers"`
NumHeads int `json:"num_attention_heads"`
NumKVHeads int `json:"num_key_value_heads"` // GQA
HeadDim int `json:"head_dim"`
MaxSeqLen int `json:"max_position_embeddings"`
RopeTheta float64 `json:"rope_theta"`
RMSNormEps float64 `json:"rms_norm_eps"`
}
// TransformerLayer Transformer 层
type TransformerLayer struct {
InputNorm *RMSNorm
Attention *Attention
PostAttnNorm *RMSNorm
MLP *MLP
}
// Attention 注意力层
type Attention struct {
QProj *Linear
KProj *Linear
VProj *Linear
OProj *Linear
NumHeads int
NumKVHeads int
HeadDim int
}
// MLP 前馈层
type MLP struct {
GateProj *Linear
UpProj *Linear
DownProj *Linear
}
// NewInferenceEngine 创建推理引擎
func NewInferenceEngine(config EngineConfig) (*InferenceEngine, error) {
engine := &InferenceEngine{
config: config,
}
// 加载模型
model, modelConfig, err := engine.loadModel(config.ModelPath, config.ModelFormat)
if err != nil {
return nil, err
}
engine.model = model
engine.modelConfig = modelConfig
// 初始化 KV Cache 管理器
engine.kvCacheManager = NewKVCacheManager(
modelConfig,
config.KVCacheSize,
config.BlockSize,
config.DeviceIDs,
)
// 初始化调度器
engine.scheduler = NewRequestScheduler(
config.MaxBatchSize,
config.UseContinuousBatching,
)
// 初始化采样器
engine.sampler = NewSampler()
// 初始化执行器
engine.executor = NewCUDAExecutor(config.DeviceIDs, config.TensorParallel)
return engine, nil
}
// Generate 生成文本
func (e *InferenceEngine) Generate(ctx context.Context, request *GenerateRequest) (*GenerateResponse, error) {
// 创建序列
seq := &Sequence{
ID: generateSeqID(),
InputIDs: request.InputIDs,
MaxLen: request.MaxNewTokens + len(request.InputIDs),
SamplingParams: request.SamplingParams,
}
// 添加到调度队列
e.scheduler.Add(seq)
// 等待完成
return e.waitForCompletion(ctx, seq)
}
// Step 执行一步推理
func (e *InferenceEngine) Step() error {
// 获取调度批次
batch := e.scheduler.Schedule()
if batch == nil {
return nil
}
// 区分 Prefill 和 Decode
prefillSeqs, decodeSeqs := batch.Split()
// 执行 Prefill
if len(prefillSeqs) > 0 {
if err := e.executePrefill(prefillSeqs); err != nil {
return err
}
}
// 执行 Decode
if len(decodeSeqs) > 0 {
if err := e.executeDecode(decodeSeqs); err != nil {
return err
}
}
// 更新序列状态
e.updateSequences(batch)
return nil
}
// executePrefill 执行预填充
func (e *InferenceEngine) executePrefill(seqs []*Sequence) error {
// 准备输入
inputIDs := make([][]int, len(seqs))
positions := make([][]int, len(seqs))
for i, seq := range seqs {
inputIDs[i] = seq.InputIDs
positions[i] = makePositions(len(seq.InputIDs))
}
// 分配 KV Cache
for _, seq := range seqs {
blocks, err := e.kvCacheManager.Allocate(seq.ID, len(seq.InputIDs))
if err != nil {
return err
}
seq.BlockTable = blocks
}
// 执行前向传播
logits, kvCache, err := e.forward(inputIDs, positions, nil, true)
if err != nil {
return err
}
// 存储 KV Cache
for i, seq := range seqs {
e.kvCacheManager.Store(seq.ID, kvCache[i])
}
// 采样下一个 token
for i, seq := range seqs {
nextToken := e.sampler.Sample(logits[i], seq.SamplingParams)
seq.OutputIDs = append(seq.OutputIDs, nextToken)
seq.Stage = StageDecoding
}
return nil
}
// executeDecode 执行解码
func (e *InferenceEngine) executeDecode(seqs []*Sequence) error {
// 准备输入(每个序列只有最新的一个 token)
inputIDs := make([][]int, len(seqs))
positions := make([][]int, len(seqs))
blockTables := make([][]int, len(seqs))
for i, seq := range seqs {
lastToken := seq.OutputIDs[len(seq.OutputIDs)-1]
inputIDs[i] = []int{lastToken}
positions[i] = []int{len(seq.InputIDs) + len(seq.OutputIDs) - 1}
blockTables[i] = seq.BlockTable
}
// 扩展 KV Cache(如果需要新的 block)
for _, seq := range seqs {
curLen := len(seq.InputIDs) + len(seq.OutputIDs)
if curLen%e.config.BlockSize == 0 {
newBlock, err := e.kvCacheManager.AllocateBlock()
if err != nil {
// 处理内存不足:可能需要抢占
return err
}
seq.BlockTable = append(seq.BlockTable, newBlock)
}
}
// 执行前向传播
logits, newKV, err := e.forward(inputIDs, positions, blockTables, false)
if err != nil {
return err
}
// 更新 KV Cache
for i, seq := range seqs {
e.kvCacheManager.Append(seq.ID, newKV[i])
}
// 采样下一个 token
for i, seq := range seqs {
nextToken := e.sampler.Sample(logits[i], seq.SamplingParams)
seq.OutputIDs = append(seq.OutputIDs, nextToken)
// 检查是否结束
if nextToken == e.modelConfig.EosTokenID || len(seq.OutputIDs) >= seq.MaxLen {
seq.Stage = StageFinished
}
}
return nil
}
// forward 前向传播
func (e *InferenceEngine) forward(
inputIDs [][]int,
positions [][]int,
blockTables [][]int,
isPrefill bool,
) ([][]float32, []KVCache, error) {
// 实际实现会调用 CUDA kernel
return nil, nil, nil
}
// Sequence 序列
type Sequence struct {
ID string
InputIDs []int
OutputIDs []int
MaxLen int
SamplingParams *SamplingParams
BlockTable []int // KV Cache block indices
Stage SequenceStage
}
type SequenceStage int
const (
StagePrefill SequenceStage = iota
StageDecoding
StageFinished
)
// GenerateRequest 生成请求
type GenerateRequest struct {
InputIDs []int
MaxNewTokens int
SamplingParams *SamplingParams
}
// SamplingParams 采样参数
type SamplingParams struct {
Temperature float32
TopP float32
TopK int
RepetitionPenalty float32
PresencePenalty float32
FrequencyPenalty float32
StopTokenIDs []int
}
// GenerateResponse 生成响应
type GenerateResponse struct {
OutputIDs []int
OutputText string
FinishReason string // length, stop, error
Usage TokenUsage
}
// TokenUsage Token 使用统计
type TokenUsage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
}
func generateSeqID() string {
return ""
}
func makePositions(length int) []int {
positions := make([]int, length)
for i := range positions {
positions[i] = i
}
return positions
}
KV Cache 管理
PagedAttention 原理
┌─────────────────────────────────────────────────────────────────┐
│ PagedAttention 原理 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 传统 KV Cache 问题: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 预分配固定大小内存(max_seq_len × batch_size) │ │
│ │ • 实际使用率低(多数序列短于 max_seq_len) │ │
│ │ • 内存碎片化严重 │ │
│ │ • 难以支持动态批处理 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ PagedAttention 解决方案: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 物理内存(GPU 显存) 逻辑 KV Cache │ │
│ │ ┌─────────────────────┐ ┌─────────────────────┐ │ │
│ │ │ Block 0 │ Block 1 │ │ │ Seq 1: [0,2,5] │ │ │
│ │ │─────────│─────────│ │ │ Seq 2: [1,3] │ │ │
│ │ │ Block 2 │ Block 3 │ │ │ Seq 3: [4,6,7,8] │ │ │
│ │ │─────────│─────────│ │ ←→ │ │ │ │
│ │ │ Block 4 │ Block 5 │ │ │ Block Table: │ │ │
│ │ │─────────│─────────│ │ │ Seq1: 0→2→5 │ │ │
│ │ │ Block 6 │ Block 7 │ │ │ Seq2: 1→3 │ │ │
│ │ │─────────│─────────│ │ │ Seq3: 4→6→7→8 │ │ │
│ │ │ Block 8 │ Free │ │ │ │ │ │
│ │ └─────────────────────┘ └─────────────────────┘ │ │
│ │ │ │
│ │ • 内存按固定大小 Block 分配 │ │
│ │ • 每个 Block 存储 block_size 个 token 的 KV │ │
│ │ • 使用 Block Table 建立逻辑→物理映射 │ │
│ │ • 按需分配,减少内存浪费 │ │
│ │ • 支持高效的序列复制(共享 Block) │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
KV Cache 管理器实现
// kv_cache_manager.go
package inference
import (
"errors"
"sync"
)
// KVCacheManager KV Cache 管理器
type KVCacheManager struct {
// 配置
numLayers int
numKVHeads int
headDim int
blockSize int // 每个 block 存储的 token 数
totalBlocks int
// Block 池
freeBlocks []int
usedBlocks map[int]bool
// 序列到 Block 的映射
seqBlockTables map[string][]int
// 物理内存(GPU 显存)
// shape: [num_blocks, 2, num_layers, block_size, num_kv_heads, head_dim]
// 2 表示 K 和 V
kvCache []float16
mu sync.RWMutex
}
// KVCacheConfig KV Cache 配置
type KVCacheConfig struct {
NumLayers int
NumKVHeads int
HeadDim int
BlockSize int
TotalMemory int64 // 总显存限制
Dtype string
}
func NewKVCacheManager(modelConfig *ModelConfig, totalMemory int64, blockSize int, deviceIDs []int) *KVCacheManager {
// 计算每个 block 的大小
// block_size = block_tokens × num_layers × 2 × num_kv_heads × head_dim × dtype_size
bytesPerToken := 2 * modelConfig.NumLayers * modelConfig.NumKVHeads * modelConfig.HeadDim * 2 // FP16
bytesPerBlock := blockSize * bytesPerToken
// 计算可分配的 block 数量
numBlocks := int(totalMemory / int64(bytesPerBlock))
// 初始化 free blocks
freeBlocks := make([]int, numBlocks)
for i := range freeBlocks {
freeBlocks[i] = i
}
return &KVCacheManager{
numLayers: modelConfig.NumLayers,
numKVHeads: modelConfig.NumKVHeads,
headDim: modelConfig.HeadDim,
blockSize: blockSize,
totalBlocks: numBlocks,
freeBlocks: freeBlocks,
usedBlocks: make(map[int]bool),
seqBlockTables: make(map[string][]int),
}
}
// Allocate 为序列分配 KV Cache blocks
func (m *KVCacheManager) Allocate(seqID string, numTokens int) ([]int, error) {
m.mu.Lock()
defer m.mu.Unlock()
// 计算需要的 block 数
numBlocks := (numTokens + m.blockSize - 1) / m.blockSize
// 检查是否有足够的 free blocks
if len(m.freeBlocks) < numBlocks {
return nil, errors.New("not enough free blocks")
}
// 分配 blocks
blocks := make([]int, numBlocks)
for i := 0; i < numBlocks; i++ {
blockIdx := m.freeBlocks[len(m.freeBlocks)-1]
m.freeBlocks = m.freeBlocks[:len(m.freeBlocks)-1]
blocks[i] = blockIdx
m.usedBlocks[blockIdx] = true
}
m.seqBlockTables[seqID] = blocks
return blocks, nil
}
// AllocateBlock 分配单个 block
func (m *KVCacheManager) AllocateBlock() (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.freeBlocks) == 0 {
return -1, errors.New("no free blocks available")
}
blockIdx := m.freeBlocks[len(m.freeBlocks)-1]
m.freeBlocks = m.freeBlocks[:len(m.freeBlocks)-1]
m.usedBlocks[blockIdx] = true
return blockIdx, nil
}
// Free 释放序列的所有 blocks
func (m *KVCacheManager) Free(seqID string) {
m.mu.Lock()
defer m.mu.Unlock()
blocks, ok := m.seqBlockTables[seqID]
if !ok {
return
}
for _, blockIdx := range blocks {
delete(m.usedBlocks, blockIdx)
m.freeBlocks = append(m.freeBlocks, blockIdx)
}
delete(m.seqBlockTables, seqID)
}
// Store 存储 KV Cache
func (m *KVCacheManager) Store(seqID string, kvCache KVCache) {
// 实际实现会写入 GPU 显存
}
// Append 追加 KV Cache(decode 阶段)
func (m *KVCacheManager) Append(seqID string, newKV KVCache) {
// 追加新的 KV 到最后一个 block
}
// GetBlockTable 获取 block table
func (m *KVCacheManager) GetBlockTable(seqID string) []int {
m.mu.RLock()
defer m.mu.RUnlock()
return m.seqBlockTables[seqID]
}
// Fork 复制序列的 KV Cache(用于 beam search 或投机解码)
func (m *KVCacheManager) Fork(srcSeqID, dstSeqID string) error {
m.mu.Lock()
defer m.mu.Unlock()
srcBlocks, ok := m.seqBlockTables[srcSeqID]
if !ok {
return errors.New("source sequence not found")
}
// 使用 copy-on-write:共享现有 blocks
dstBlocks := make([]int, len(srcBlocks))
copy(dstBlocks, srcBlocks)
// 增加 block 引用计数(实际实现需要)
m.seqBlockTables[dstSeqID] = dstBlocks
return nil
}
// GetStats 获取统计信息
func (m *KVCacheManager) GetStats() KVCacheStats {
m.mu.RLock()
defer m.mu.RUnlock()
return KVCacheStats{
TotalBlocks: m.totalBlocks,
UsedBlocks: len(m.usedBlocks),
FreeBlocks: len(m.freeBlocks),
Utilization: float64(len(m.usedBlocks)) / float64(m.totalBlocks),
}
}
type KVCacheStats struct {
TotalBlocks int
UsedBlocks int
FreeBlocks int
Utilization float64
}
// KVCache KV Cache 数据
type KVCache struct {
K []float16 // [num_layers, seq_len, num_kv_heads, head_dim]
V []float16
}
type float16 = uint16 // 简化表示
计算优化技术
Flash Attention
┌─────────────────────────────────────────────────────────────────┐
│ Flash Attention 原理 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 标准 Attention 问题: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 1. 计算 QK^T:O(N²d) 计算,O(N²) 内存 │ │
│ │ 2. Softmax:需要完整的 N×N 矩阵 │ │
│ │ 3. 计算 Score×V:O(N²d) 计算 │ │
│ │ │ │
│ │ 内存瓶颈:N² 随序列长度平方增长 │ │
│ │ 例:N=8192, FP16 → 128MB 仅存储 attention scores │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ Flash Attention 解决方案: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 核心思想:Tiling(分块计算)+ 重计算 │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ Q K^T V │ │ │
│ │ │ ┌──┬──┐ ┌──┬──┬──┐ ┌──┬──┬──┐ │ │ │
│ │ │ │Q1│ │ × │K1│K2│K3│ × │V1│V2│V3│ │ │ │
│ │ │ ├──┼──┤ └──┴──┴──┘ └──┴──┴──┘ │ │ │
│ │ │ │Q2│ │ │ │ │
│ │ │ └──┴──┘ │ │ │
│ │ │ │ │ │
│ │ │ 分块计算流程: │ │ │
│ │ │ for each Q block: │ │ │
│ │ │ for each K,V block: │ │ │
│ │ │ 1. 加载 Q_block, K_block, V_block 到 SRAM │ │ │
│ │ │ 2. 计算局部 attention │ │ │
│ │ │ 3. 使用 online softmax 累加结果 │ │ │
│ │ │ 输出 block 写回 HBM │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ Online Softmax 技巧: │ │
│ │ • 维护 running max 和 running sum │ │
│ │ • 无需存储完整的 attention matrix │ │
│ │ • 内存复杂度从 O(N²) 降到 O(N) │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 性能对比(A100, seq_len=2048): │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 标准 Attention: ~8ms, 内存 ~128MB │ │
│ │ Flash Attention: ~1.5ms, 内存 ~4MB │ │
│ │ 加速比: 5.3x, 内存节省: 32x │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
算子融合
// fused_ops.go
package inference
// FusedRMSNorm 融合的 RMS Normalization
// 将多个小 kernel 合并成一个大 kernel,减少内存访问和 kernel launch 开销
type FusedRMSNorm struct {
weight []float16
eps float32
hiddenSize int
}
// Forward 前向传播
// 融合操作:
// 1. 计算均方根
// 2. 归一化
// 3. 缩放
func (n *FusedRMSNorm) Forward(input []float16) []float16 {
// CUDA kernel 伪代码:
// __global__ void fused_rmsnorm_kernel(
// float16* input,
// float16* weight,
// float16* output,
// float eps,
// int hidden_size
// ) {
// // 每个 warp 处理一行
// int row = blockIdx.x;
// float sum_sq = 0.0f;
//
// // 计算平方和(使用 warp reduce)
// for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
// float val = __half2float(input[row * hidden_size + i]);
// sum_sq += val * val;
// }
// sum_sq = warp_reduce_sum(sum_sq);
//
// // 计算 RMS
// float rms = rsqrtf(sum_sq / hidden_size + eps);
//
// // 归一化并缩放
// for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
// float val = __half2float(input[row * hidden_size + i]);
// float w = __half2float(weight[i]);
// output[row * hidden_size + i] = __float2half(val * rms * w);
// }
// }
return nil
}
// FusedSiLUMul 融合的 SiLU 激活和乘法
// 用于 LLaMA MLP 中的 gate 和 up projection
// gate_out = SiLU(x @ gate_weight) * (x @ up_weight)
type FusedSiLUMul struct{}
func (f *FusedSiLUMul) Forward(gate, up []float16) []float16 {
// CUDA kernel 伪代码:
// __global__ void fused_silu_mul_kernel(
// float16* gate,
// float16* up,
// float16* output,
// int size
// ) {
// int idx = blockIdx.x * blockDim.x + threadIdx.x;
// if (idx < size) {
// float g = __half2float(gate[idx]);
// float u = __half2float(up[idx]);
// // SiLU(x) = x * sigmoid(x)
// float silu = g / (1.0f + expf(-g));
// output[idx] = __float2half(silu * u);
// }
// }
return nil
}
// FusedRotaryEmbedding 融合的 RoPE 位置编码
type FusedRotaryEmbedding struct {
dim int
maxSeqLen int
theta float64
freqsCis []complex64 // 预计算的频率
}
func NewFusedRotaryEmbedding(dim, maxSeqLen int, theta float64) *FusedRotaryEmbedding {
rope := &FusedRotaryEmbedding{
dim: dim,
maxSeqLen: maxSeqLen,
theta: theta,
}
rope.precomputeFreqs()
return rope
}
func (r *FusedRotaryEmbedding) precomputeFreqs() {
// 预计算频率
// freq_i = 1 / (theta^(2i/dim))
// freqs_cis[pos, i] = exp(i * pos * freq_i) = cos(pos*freq_i) + i*sin(pos*freq_i)
}
func (r *FusedRotaryEmbedding) Apply(q, k []float16, positions []int) ([]float16, []float16) {
// 应用 RoPE
// 融合到 attention kernel 中
return nil, nil
}
// FusedAttention 融合的 Attention(包含 Flash Attention)
type FusedAttention struct {
numHeads int
numKVHeads int
headDim int
scale float32
}
func (a *FusedAttention) Forward(
q, k, v []float16,
kvCache *KVCache,
blockTable []int,
positions []int,
isPrefill bool,
) []float16 {
if isPrefill {
// Prefill:使用 Flash Attention
return a.flashAttention(q, k, v)
} else {
// Decode:使用 Paged Attention
return a.pagedAttention(q, kvCache, blockTable, positions)
}
}
func (a *FusedAttention) flashAttention(q, k, v []float16) []float16 {
// Flash Attention 实现
// 分块计算,使用 online softmax
return nil
}
func (a *FusedAttention) pagedAttention(q []float16, kvCache *KVCache, blockTable []int, positions []int) []float16 {
// Paged Attention 实现
// 从不连续的 block 中读取 KV Cache
return nil
}
量化推理
// quantization.go
package inference
import (
"math"
)
// Quantizer 量化器
type Quantizer struct {
qtype QuantType
}
type QuantType string
const (
QuantTypeINT8 QuantType = "int8"
QuantTypeINT4 QuantType = "int4"
QuantTypeNF4 QuantType = "nf4" // Normal Float 4-bit
QuantTypeGPTQ QuantType = "gptq"
QuantTypeAWQ QuantType = "awq"
)
// INT8QuantizedTensor INT8 量化张量
type INT8QuantizedTensor struct {
Data []int8
Scale []float32 // per-channel 或 per-tensor
ZeroPoint []int8
}
// Quantize INT8 量化
func (q *Quantizer) QuantizeINT8(tensor []float32, perChannel bool) *INT8QuantizedTensor {
if perChannel {
return q.quantizeINT8PerChannel(tensor)
}
return q.quantizeINT8PerTensor(tensor)
}
func (q *Quantizer) quantizeINT8PerTensor(tensor []float32) *INT8QuantizedTensor {
// 找到 min/max
minVal, maxVal := tensor[0], tensor[0]
for _, v := range tensor[1:] {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
// 计算 scale 和 zero point
// 对称量化:scale = max(|min|, |max|) / 127
absMax := math.Max(math.Abs(float64(minVal)), math.Abs(float64(maxVal)))
scale := float32(absMax / 127.0)
// 量化
data := make([]int8, len(tensor))
for i, v := range tensor {
quantized := int(math.Round(float64(v / scale)))
if quantized > 127 {
quantized = 127
} else if quantized < -128 {
quantized = -128
}
data[i] = int8(quantized)
}
return &INT8QuantizedTensor{
Data: data,
Scale: []float32{scale},
}
}
func (q *Quantizer) quantizeINT8PerChannel(tensor []float32) *INT8QuantizedTensor {
// 按通道量化,每个输出通道有独立的 scale
return nil
}
// INT4QuantizedTensor INT4 量化张量
type INT4QuantizedTensor struct {
Data []uint8 // 2 个 INT4 打包成 1 个 uint8
Scale []float16
ZeroPoint []int8
GroupSize int // 量化组大小
}
// QuantizeINT4 INT4 量化(Group-wise)
func (q *Quantizer) QuantizeINT4(tensor []float32, groupSize int) *INT4QuantizedTensor {
numGroups := (len(tensor) + groupSize - 1) / groupSize
scales := make([]float16, numGroups)
data := make([]uint8, (len(tensor)+1)/2)
for g := 0; g < numGroups; g++ {
start := g * groupSize
end := start + groupSize
if end > len(tensor) {
end = len(tensor)
}
// 找到这个 group 的 max abs
maxAbs := float32(0)
for i := start; i < end; i++ {
abs := float32(math.Abs(float64(tensor[i])))
if abs > maxAbs {
maxAbs = abs
}
}
// 计算 scale
scale := maxAbs / 7.0 // INT4: -8 ~ 7
scales[g] = float32ToFloat16(scale)
// 量化这个 group
for i := start; i < end; i++ {
var quantized int
if scale > 0 {
quantized = int(math.Round(float64(tensor[i] / scale)))
}
if quantized > 7 {
quantized = 7
} else if quantized < -8 {
quantized = -8
}
// 打包:2 个 INT4 到 1 个 uint8
byteIdx := i / 2
if i%2 == 0 {
data[byteIdx] = uint8(quantized & 0x0F)
} else {
data[byteIdx] |= uint8((quantized & 0x0F) << 4)
}
}
}
return &INT4QuantizedTensor{
Data: data,
Scale: scales,
GroupSize: groupSize,
}
}
// GPTQ 量化
type GPTQQuantizer struct {
bits int
groupSize int
actOrder bool // activation reordering
}
func NewGPTQQuantizer(bits, groupSize int, actOrder bool) *GPTQQuantizer {
return &GPTQQuantizer{
bits: bits,
groupSize: groupSize,
actOrder: actOrder,
}
}
// QuantizeLayer 量化单层权重
func (g *GPTQQuantizer) QuantizeLayer(
weight [][]float32, // [out_features, in_features]
hessian [][]float32, // 使用校准数据计算的 Hessian 对角线
) (*GPTQQuantizedWeight, error) {
// GPTQ 量化算法:
// 1. 计算 Hessian 矩阵 H = 2 * X^T * X
// 2. 按 Hessian 对角线排序列(如果 actOrder=true)
// 3. 逐列量化,使用 OBS (Optimal Brain Surgeon) 更新剩余权重
// 简化实现
return nil, nil
}
type GPTQQuantizedWeight struct {
QWeight []int32 // 打包的量化权重
QZeros []int32 // 打包的零点
Scales []float16 // 缩放因子
GPerm []int32 // 列重排序(actOrder=true 时使用)
Bits int
GroupSize int
}
// AWQ 量化
type AWQQuantizer struct {
bits int
groupSize int
}
func NewAWQQuantizer(bits, groupSize int) *AWQQuantizer {
return &AWQQuantizer{
bits: bits,
groupSize: groupSize,
}
}
// QuantizeWithActivation AWQ: Activation-aware Weight Quantization
func (a *AWQQuantizer) QuantizeWithActivation(
weight [][]float32,
activations [][]float32, // 校准数据的激活值
) (*AWQQuantizedWeight, error) {
// AWQ 算法:
// 1. 分析激活值分布,找到重要的通道
// 2. 对重要通道使用更小的量化误差
// 3. 使用 per-channel scaling 补偿量化误差
return nil, nil
}
type AWQQuantizedWeight struct {
QWeight []int32
Scales []float16
Zeros []float16
Bits int
GroupSize int
}
// QuantizedLinear 量化线性层
type QuantizedLinear struct {
weight interface{} // INT8/INT4/GPTQ/AWQ
qtype QuantType
inFeatures int
outFeatures int
}
// Forward 量化推理前向
func (l *QuantizedLinear) Forward(input []float16) []float16 {
switch l.qtype {
case QuantTypeINT8:
return l.forwardINT8(input)
case QuantTypeINT4:
return l.forwardINT4(input)
case QuantTypeGPTQ:
return l.forwardGPTQ(input)
case QuantTypeAWQ:
return l.forwardAWQ(input)
}
return nil
}
func (l *QuantizedLinear) forwardINT8(input []float16) []float16 {
// INT8 GEMM
// 使用 cuBLAS INT8 或 TensorRT
return nil
}
func (l *QuantizedLinear) forwardINT4(input []float16) []float16 {
// INT4 需要自定义 kernel
// 解包 INT4 → FP16 → GEMM
return nil
}
func (l *QuantizedLinear) forwardGPTQ(input []float16) []float16 {
// GPTQ kernel(如 exllama)
return nil
}
func (l *QuantizedLinear) forwardAWQ(input []float16) []float16 {
// AWQ kernel
return nil
}
func float32ToFloat16(f float32) float16 {
// 简化实现
return float16(f)
}
采样策略
采样器实现
// sampler.go
package inference
import (
"math"
"math/rand"
"sort"
)
// Sampler 采样器
type Sampler struct {
rng *rand.Rand
}
func NewSampler() *Sampler {
return &Sampler{
rng: rand.New(rand.NewSource(42)),
}
}
// Sample 从 logits 采样下一个 token
func (s *Sampler) Sample(logits []float32, params *SamplingParams) int {
// 1. 应用 repetition penalty
if params.RepetitionPenalty != 1.0 {
logits = s.applyRepetitionPenalty(logits, params)
}
// 2. 应用 temperature
if params.Temperature != 1.0 {
logits = s.applyTemperature(logits, params.Temperature)
}
// 3. 转换为概率
probs := s.softmax(logits)
// 4. 应用 top-k
if params.TopK > 0 {
probs = s.applyTopK(probs, params.TopK)
}
// 5. 应用 top-p (nucleus sampling)
if params.TopP < 1.0 {
probs = s.applyTopP(probs, params.TopP)
}
// 6. 采样
return s.multinomialSample(probs)
}
// applyTemperature 应用温度
func (s *Sampler) applyTemperature(logits []float32, temperature float32) []float32 {
result := make([]float32, len(logits))
for i, l := range logits {
result[i] = l / temperature
}
return result
}
// applyRepetitionPenalty 应用重复惩罚
func (s *Sampler) applyRepetitionPenalty(logits []float32, params *SamplingParams) []float32 {
// 对已生成的 token 应用惩罚
// 这里简化处理,实际需要传入已生成的 token
return logits
}
// softmax 计算 softmax
func (s *Sampler) softmax(logits []float32) []float32 {
// 数值稳定的 softmax
maxLogit := logits[0]
for _, l := range logits[1:] {
if l > maxLogit {
maxLogit = l
}
}
probs := make([]float32, len(logits))
var sum float32
for i, l := range logits {
probs[i] = float32(math.Exp(float64(l - maxLogit)))
sum += probs[i]
}
for i := range probs {
probs[i] /= sum
}
return probs
}
// applyTopK Top-K 采样
func (s *Sampler) applyTopK(probs []float32, k int) []float32 {
if k >= len(probs) {
return probs
}
// 找到 top-k 的阈值
type indexedProb struct {
idx int
prob float32
}
indexed := make([]indexedProb, len(probs))
for i, p := range probs {
indexed[i] = indexedProb{i, p}
}
sort.Slice(indexed, func(i, j int) bool {
return indexed[i].prob > indexed[j].prob
})
threshold := indexed[k-1].prob
// 将低于阈值的概率置零
result := make([]float32, len(probs))
var sum float32
for i, p := range probs {
if p >= threshold {
result[i] = p
sum += p
}
}
// 重新归一化
for i := range result {
result[i] /= sum
}
return result
}
// applyTopP Top-P (Nucleus) 采样
func (s *Sampler) applyTopP(probs []float32, p float32) []float32 {
// 按概率排序
type indexedProb struct {
idx int
prob float32
}
indexed := make([]indexedProb, len(probs))
for i, prob := range probs {
indexed[i] = indexedProb{i, prob}
}
sort.Slice(indexed, func(i, j int) bool {
return indexed[i].prob > indexed[j].prob
})
// 找到累积概率 >= p 的截断点
var cumSum float32
cutoff := 0
for i, ip := range indexed {
cumSum += ip.prob
if cumSum >= p {
cutoff = i + 1
break
}
}
// 将截断点之外的概率置零
result := make([]float32, len(probs))
var sum float32
for i := 0; i < cutoff; i++ {
result[indexed[i].idx] = indexed[i].prob
sum += indexed[i].prob
}
// 重新归一化
for i := range result {
result[i] /= sum
}
return result
}
// multinomialSample 多项式采样
func (s *Sampler) multinomialSample(probs []float32) int {
r := s.rng.Float32()
var cumSum float32
for i, p := range probs {
cumSum += p
if r < cumSum {
return i
}
}
return len(probs) - 1
}
// GreedySample 贪婪采样(选择最高概率)
func (s *Sampler) GreedySample(logits []float32) int {
maxIdx := 0
maxVal := logits[0]
for i, l := range logits[1:] {
if l > maxVal {
maxVal = l
maxIdx = i + 1
}
}
return maxIdx
}
// BeamSearch Beam Search
type BeamSearch struct {
beamWidth int
sampler *Sampler
}
func NewBeamSearch(beamWidth int) *BeamSearch {
return &BeamSearch{
beamWidth: beamWidth,
sampler: NewSampler(),
}
}
// Beam 单个 beam
type Beam struct {
TokenIDs []int
Score float32
Finished bool
}
// Step 执行一步 beam search
func (b *BeamSearch) Step(beams []*Beam, logits [][]float32) []*Beam {
// 收集所有候选
type candidate struct {
beamIdx int
tokenID int
score float32
}
var candidates []candidate
for beamIdx, beam := range beams {
if beam.Finished {
candidates = append(candidates, candidate{
beamIdx: beamIdx,
tokenID: -1,
score: beam.Score,
})
continue
}
// 对每个 beam 的 logits 进行 log_softmax
logProbs := b.sampler.softmax(logits[beamIdx])
for tokenID, prob := range logProbs {
if prob > 0 {
score := beam.Score + float32(math.Log(float64(prob)))
candidates = append(candidates, candidate{
beamIdx: beamIdx,
tokenID: tokenID,
score: score,
})
}
}
}
// 选择 top-k 个候选
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].score > candidates[j].score
})
if len(candidates) > b.beamWidth {
candidates = candidates[:b.beamWidth]
}
// 构建新的 beams
newBeams := make([]*Beam, len(candidates))
for i, c := range candidates {
if c.tokenID == -1 {
// 已完成的 beam
newBeams[i] = beams[c.beamIdx]
} else {
newBeams[i] = &Beam{
TokenIDs: append(append([]int{}, beams[c.beamIdx].TokenIDs...), c.tokenID),
Score: c.score,
Finished: false, // 检查 EOS
}
}
}
return newBeams
}
性能指标
关键指标定义
┌─────────────────────────────────────────────────────────────────┐
│ 推理性能关键指标 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 延迟指标 (Latency) │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • TTFT (Time To First Token):首 token 延迟 │ │
│ │ - 主要受 Prefill 阶段影响 │ │
│ │ - 影响用户感知的响应速度 │ │
│ │ │ │
│ │ • TPOT (Time Per Output Token):每 token 延迟 │ │
│ │ - Decode 阶段的单 token 生成时间 │ │
│ │ - 影响流式输出的流畅度 │ │
│ │ │ │
│ │ • E2E Latency:端到端延迟 │ │
│ │ = TTFT + TPOT × (output_len - 1) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 2. 吞吐量指标 (Throughput) │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • Tokens/second:每秒生成 token 数 │ │
│ │ = total_output_tokens / total_time │ │
│ │ │ │
│ │ • Requests/second:每秒处理请求数 │ │
│ │ = num_requests / total_time │ │
│ │ │ │
│ │ • GPU Utilization:GPU 利用率 │ │
│ │ - Compute utilization │ │
│ │ - Memory bandwidth utilization │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 3. 效率指标 (Efficiency) │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • Memory Efficiency:内存效率 │ │
│ │ = actual_KV_cache_used / total_KV_cache_allocated │ │
│ │ │ │
│ │ • Batch Efficiency:批处理效率 │ │
│ │ = actual_batch_size / max_batch_size │ │
│ │ │ │
│ │ • Model FLOPS Utilization (MFU) │ │
│ │ = achieved_FLOPS / peak_FLOPS │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 典型值参考 (LLaMA-7B, A100 80GB): │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ TTFT (512 prompt): ~50ms │ │
│ │ TPOT: ~10-20ms │ │
│ │ Throughput (batch=32): ~2000 tokens/s │ │
│ │ Max batch size: ~256 (取决于 seq_len) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
性能监控
// metrics.go
package inference
import (
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// 延迟指标
ttftHistogram = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "llm_ttft_seconds",
Help: "Time to first token in seconds",
Buckets: []float64{0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0},
},
[]string{"model"},
)
tpotHistogram = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "llm_tpot_seconds",
Help: "Time per output token in seconds",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25},
},
[]string{"model"},
)
e2eLatencyHistogram = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "llm_e2e_latency_seconds",
Help: "End-to-end latency in seconds",
Buckets: []float64{0.1, 0.5, 1, 2.5, 5, 10, 30, 60},
},
[]string{"model"},
)
// 吞吐量指标
tokensGeneratedTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "llm_tokens_generated_total",
Help: "Total number of tokens generated",
},
[]string{"model", "type"}, // type: prompt, completion
)
requestsTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "llm_requests_total",
Help: "Total number of requests",
},
[]string{"model", "status"}, // status: success, error, timeout
)
// 批处理指标
batchSizeHistogram = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "llm_batch_size",
Help: "Batch size distribution",
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256},
},
[]string{"model", "phase"}, // phase: prefill, decode
)
// KV Cache 指标
kvCacheUtilization = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "llm_kv_cache_utilization",
Help: "KV cache utilization ratio",
},
[]string{"model"},
)
kvCacheBlocks = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "llm_kv_cache_blocks",
Help: "Number of KV cache blocks",
},
[]string{"model", "state"}, // state: used, free
)
// GPU 指标
gpuMemoryUsed = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "llm_gpu_memory_bytes",
Help: "GPU memory usage in bytes",
},
[]string{"gpu_id", "type"}, // type: model, kv_cache, activation
)
gpuUtilization = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "llm_gpu_utilization",
Help: "GPU compute utilization",
},
[]string{"gpu_id"},
)
)
// MetricsCollector 指标收集器
type MetricsCollector struct {
model string
startTime time.Time
mu sync.Mutex
requestCount int64
tokenCount int64
lastReportTime time.Time
}
func NewMetricsCollector(model string) *MetricsCollector {
return &MetricsCollector{
model: model,
startTime: time.Now(),
lastReportTime: time.Now(),
}
}
// RecordRequest 记录请求
func (m *MetricsCollector) RecordRequest(
promptTokens int,
completionTokens int,
ttft time.Duration,
e2eLatency time.Duration,
status string,
) {
// 记录延迟
ttftHistogram.WithLabelValues(m.model).Observe(ttft.Seconds())
e2eLatencyHistogram.WithLabelValues(m.model).Observe(e2eLatency.Seconds())
// 计算 TPOT
if completionTokens > 1 {
tpot := (e2eLatency - ttft) / time.Duration(completionTokens-1)
tpotHistogram.WithLabelValues(m.model).Observe(tpot.Seconds())
}
// 记录 token 数
tokensGeneratedTotal.WithLabelValues(m.model, "prompt").Add(float64(promptTokens))
tokensGeneratedTotal.WithLabelValues(m.model, "completion").Add(float64(completionTokens))
// 记录请求
requestsTotal.WithLabelValues(m.model, status).Inc()
}
// RecordBatch 记录批处理
func (m *MetricsCollector) RecordBatch(size int, phase string) {
batchSizeHistogram.WithLabelValues(m.model, phase).Observe(float64(size))
}
// UpdateKVCacheStats 更新 KV Cache 统计
func (m *MetricsCollector) UpdateKVCacheStats(stats KVCacheStats) {
kvCacheUtilization.WithLabelValues(m.model).Set(stats.Utilization)
kvCacheBlocks.WithLabelValues(m.model, "used").Set(float64(stats.UsedBlocks))
kvCacheBlocks.WithLabelValues(m.model, "free").Set(float64(stats.FreeBlocks))
}
// UpdateGPUStats 更新 GPU 统计
func (m *MetricsCollector) UpdateGPUStats(gpuID string, memUsed int64, utilization float64) {
gpuMemoryUsed.WithLabelValues(gpuID, "total").Set(float64(memUsed))
gpuUtilization.WithLabelValues(gpuID).Set(utilization)
}
// GetThroughput 获取吞吐量
func (m *MetricsCollector) GetThroughput() (tokensPerSec, requestsPerSec float64) {
m.mu.Lock()
defer m.mu.Unlock()
elapsed := time.Since(m.startTime).Seconds()
if elapsed > 0 {
tokensPerSec = float64(m.tokenCount) / elapsed
requestsPerSec = float64(m.requestCount) / elapsed
}
return
}
小结
本章详细介绍了推理引擎的核心原理:
- 推理流程:Prefill 和 Decode 两个阶段的特性和优化方向
- 引擎架构:模型加载、调度器、执行器的设计
- KV Cache 管理:PagedAttention 原理和实现
- 计算优化:Flash Attention、算子融合、量化推理
- 采样策略:Temperature、Top-K、Top-P、Beam Search
- 性能指标:TTFT、TPOT、吞吐量、效率指标
下一章我们将探讨 模型服务框架,讲解如何使用 vLLM、TGI 等框架构建高性能推理服务。