推理优化技术
概述
大模型推理优化是一个系统工程,涉及模型层面、算子层面、系统层面的多维度优化。本章深入讲解各类推理优化技术的原理与实现,包括模型压缩、计算图优化、内存优化、系统级优化等关键技术。
模型量化
量化基础原理
量化是将高精度浮点数映射到低精度整数的过程:
package quantization
import (
"math"
)
// QuantizationType 量化类型
type QuantizationType int
const (
QuantInt8 QuantizationType = iota
QuantInt4
QuantFP8
QuantFP4
)
// QuantizationConfig 量化配置
type QuantizationConfig struct {
Type QuantizationType
Symmetric bool // 对称量化
PerChannel bool // 逐通道量化
GroupSize int // 分组量化大小
CalibrationMethod string // 校准方法: minmax, percentile, entropy
}
// Quantizer 量化器
type Quantizer struct {
config *QuantizationConfig
}
// QuantizedTensor 量化后的张量
type QuantizedTensor struct {
Data []int8 // 量化数据 (int8 示例)
Scale []float32 // 缩放因子
ZeroPoint []int8 // 零点
Shape []int
GroupSize int
}
// QuantizeSymmetric 对称量化
// x_q = round(x / scale)
// x = x_q * scale
func (q *Quantizer) QuantizeSymmetric(tensor []float32, bits int) *QuantizedTensor {
maxVal := q.findAbsMax(tensor)
qMax := float32(math.Pow(2, float64(bits-1)) - 1)
scale := maxVal / qMax
quantized := make([]int8, len(tensor))
for i, v := range tensor {
qVal := v / scale
// 截断到范围 [-qMax, qMax]
qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
quantized[i] = int8(math.Round(float64(qVal)))
}
return &QuantizedTensor{
Data: quantized,
Scale: []float32{scale},
ZeroPoint: []int8{0},
Shape: []int{len(tensor)},
}
}
// QuantizeAsymmetric 非对称量化
// x_q = round(x / scale) + zero_point
// x = (x_q - zero_point) * scale
func (q *Quantizer) QuantizeAsymmetric(tensor []float32, bits int) *QuantizedTensor {
minVal, maxVal := q.findMinMax(tensor)
qMin := float32(0)
qMax := float32(math.Pow(2, float64(bits)) - 1)
scale := (maxVal - minVal) / (qMax - qMin)
zeroPoint := int8(math.Round(float64(-minVal / scale)))
quantized := make([]int8, len(tensor))
for i, v := range tensor {
qVal := v/scale + float32(zeroPoint)
qVal = float32(math.Max(float64(qMin), math.Min(float64(qMax), float64(qVal))))
quantized[i] = int8(math.Round(float64(qVal)))
}
return &QuantizedTensor{
Data: quantized,
Scale: []float32{scale},
ZeroPoint: []int8{zeroPoint},
}
}
// QuantizePerGroup 分组量化
// 每 group_size 个元素使用独立的 scale 和 zero_point
func (q *Quantizer) QuantizePerGroup(
tensor []float32,
groupSize int,
bits int,
) *QuantizedTensor {
numGroups := (len(tensor) + groupSize - 1) / groupSize
scales := make([]float32, numGroups)
zeroPoints := make([]int8, numGroups)
quantized := make([]int8, len(tensor))
qMax := float32(math.Pow(2, float64(bits-1)) - 1)
for g := 0; g < numGroups; g++ {
start := g * groupSize
end := min(start+groupSize, len(tensor))
group := tensor[start:end]
// 计算该组的量化参数
maxVal := q.findAbsMax(group)
scale := maxVal / qMax
scales[g] = scale
// 量化该组
for i, v := range group {
qVal := v / scale
qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
quantized[start+i] = int8(math.Round(float64(qVal)))
}
}
return &QuantizedTensor{
Data: quantized,
Scale: scales,
ZeroPoint: zeroPoints,
GroupSize: groupSize,
}
}
// Dequantize 反量化
func (qt *QuantizedTensor) Dequantize() []float32 {
result := make([]float32, len(qt.Data))
if qt.GroupSize > 0 {
// 分组量化反量化
for i, qVal := range qt.Data {
groupIdx := i / qt.GroupSize
scale := qt.Scale[groupIdx]
zp := qt.ZeroPoint[groupIdx]
result[i] = float32(qVal-zp) * scale
}
} else {
// 全局量化反量化
scale := qt.Scale[0]
zp := qt.ZeroPoint[0]
for i, qVal := range qt.Data {
result[i] = float32(qVal-zp) * scale
}
}
return result
}
func (q *Quantizer) findAbsMax(tensor []float32) float32 {
maxVal := float32(0)
for _, v := range tensor {
absV := float32(math.Abs(float64(v)))
if absV > maxVal {
maxVal = absV
}
}
return maxVal
}
func (q *Quantizer) findMinMax(tensor []float32) (float32, float32) {
minVal := tensor[0]
maxVal := tensor[0]
for _, v := range tensor {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
return minVal, maxVal
}
GPTQ 量化
GPTQ (Generative Pre-trained Transformer Quantization) 是一种高效的后训练量化方法:
package quantization
import (
"math"
"sync"
)
// GPTQQuantizer GPTQ 量化器
type GPTQQuantizer struct {
bits int
groupSize int
symmetric bool
actOrder bool // 激活排序优化
dampingPct float64 // 阻尼系数
}
func NewGPTQQuantizer(config *GPTQConfig) *GPTQQuantizer {
return &GPTQQuantizer{
bits: config.Bits,
groupSize: config.GroupSize,
symmetric: config.Symmetric,
actOrder: config.ActOrder,
dampingPct: config.DampingPct,
}
}
// QuantizeLayer 量化单层权重
// 基于 OBS (Optimal Brain Surgeon) 框架
func (gq *GPTQQuantizer) QuantizeLayer(
weight [][]float32, // [out_features, in_features]
hessian [][]float32, // Hessian 矩阵 [in_features, in_features]
) (*GPTQLayerResult, error) {
outFeatures := len(weight)
inFeatures := len(weight[0])
// 复制权重(避免修改原始数据)
W := make([][]float32, outFeatures)
for i := range W {
W[i] = make([]float32, inFeatures)
copy(W[i], weight[i])
}
// 添加阻尼项到 Hessian 对角线
H := gq.addDamping(hessian)
// Cholesky 分解
L, err := gq.choleskyDecomposition(H)
if err != nil {
return nil, err
}
// 计算 Hessian 逆
Hinv := gq.invertLowerTriangular(L)
// 量化参数存储
numGroups := (inFeatures + gq.groupSize - 1) / gq.groupSize
scales := make([][]float32, outFeatures)
zeros := make([][]int32, outFeatures)
quantizedWeight := make([][]int8, outFeatures)
for i := range outFeatures {
scales[i] = make([]float32, numGroups)
zeros[i] = make([]int32, numGroups)
quantizedWeight[i] = make([]int8, inFeatures)
}
// 获取量化顺序
order := gq.getQuantOrder(H)
// 逐列量化
blockSize := 128 // 分块处理减少累积误差
for blockStart := 0; blockStart < inFeatures; blockStart += blockSize {
blockEnd := min(blockStart+blockSize, inFeatures)
for j := blockStart; j < blockEnd; j++ {
col := order[j]
groupIdx := col / gq.groupSize
for i := 0; i < outFeatures; i++ {
// 量化当前元素
w := W[i][col]
// 计算该组的量化参数(如果是组的第一个元素)
if col%gq.groupSize == 0 {
groupStart := col
groupEnd := min(col+gq.groupSize, inFeatures)
scale, zero := gq.computeGroupParams(W[i][groupStart:groupEnd])
scales[i][groupIdx] = scale
zeros[i][groupIdx] = zero
}
scale := scales[i][groupIdx]
zero := zeros[i][groupIdx]
// 量化
qVal := gq.quantize(w, scale, zero)
quantizedWeight[i][col] = qVal
// 反量化
wHat := gq.dequantize(qVal, scale, zero)
// 计算量化误差
err := w - wHat
// 误差补偿(OBS 核心)
// 将量化误差分配到后续未量化的列
if j < blockEnd-1 {
errWeight := err / float32(Hinv[col][col])
for k := j + 1; k < blockEnd; k++ {
nextCol := order[k]
W[i][nextCol] -= errWeight * float32(Hinv[nextCol][col])
}
}
}
}
}
return &GPTQLayerResult{
QuantizedWeight: quantizedWeight,
Scales: scales,
Zeros: zeros,
GroupSize: gq.groupSize,
}, nil
}
// ComputeHessian 计算 Hessian 矩阵
// H = X^T * X(对于线性层)
func (gq *GPTQQuantizer) ComputeHessian(
calibrationData [][]float32, // 校准数据 [num_samples, in_features]
) [][]float32 {
numSamples := len(calibrationData)
inFeatures := len(calibrationData[0])
H := make([][]float32, inFeatures)
for i := range H {
H[i] = make([]float32, inFeatures)
}
// 并行计算 Hessian
var wg sync.WaitGroup
numWorkers := 8
rowsPerWorker := (inFeatures + numWorkers - 1) / numWorkers
for w := 0; w < numWorkers; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
startRow := workerID * rowsPerWorker
endRow := min(startRow+rowsPerWorker, inFeatures)
for i := startRow; i < endRow; i++ {
for j := i; j < inFeatures; j++ {
var sum float32
for s := 0; s < numSamples; s++ {
sum += calibrationData[s][i] * calibrationData[s][j]
}
H[i][j] = sum / float32(numSamples)
if i != j {
H[j][i] = H[i][j] // 对称矩阵
}
}
}
}(w)
}
wg.Wait()
return H
}
func (gq *GPTQQuantizer) addDamping(H [][]float32) [][]float32 {
n := len(H)
result := make([][]float32, n)
for i := range result {
result[i] = make([]float32, n)
copy(result[i], H[i])
}
// 计算对角线平均值
var diagSum float32
for i := 0; i < n; i++ {
diagSum += result[i][i]
}
diagMean := diagSum / float32(n)
// 添加阻尼
damping := float32(gq.dampingPct) * diagMean
for i := 0; i < n; i++ {
result[i][i] += damping
}
return result
}
func (gq *GPTQQuantizer) getQuantOrder(H [][]float32) []int {
n := len(H)
order := make([]int, n)
if gq.actOrder {
// 按 Hessian 对角线元素排序(激活量大的先量化)
diagValues := make([]float32, n)
for i := 0; i < n; i++ {
diagValues[i] = H[i][i]
}
// 获取排序索引
indices := make([]int, n)
for i := range indices {
indices[i] = i
}
// 降序排序
for i := 0; i < n-1; i++ {
for j := i + 1; j < n; j++ {
if diagValues[indices[i]] < diagValues[indices[j]] {
indices[i], indices[j] = indices[j], indices[i]
}
}
}
copy(order, indices)
} else {
// 默认顺序
for i := range order {
order[i] = i
}
}
return order
}
func (gq *GPTQQuantizer) computeGroupParams(group []float32) (float32, int32) {
maxVal := float32(0)
for _, v := range group {
absV := float32(math.Abs(float64(v)))
if absV > maxVal {
maxVal = absV
}
}
qMax := float32(math.Pow(2, float64(gq.bits-1)) - 1)
scale := maxVal / qMax
if scale == 0 {
scale = 1
}
return scale, 0
}
func (gq *GPTQQuantizer) quantize(val, scale float32, zero int32) int8 {
qMax := float32(math.Pow(2, float64(gq.bits-1)) - 1)
qVal := val/scale + float32(zero)
qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
return int8(math.Round(float64(qVal)))
}
func (gq *GPTQQuantizer) dequantize(qVal int8, scale float32, zero int32) float32 {
return float32(int32(qVal)-zero) * scale
}
type GPTQConfig struct {
Bits int
GroupSize int
Symmetric bool
ActOrder bool
DampingPct float64
}
type GPTQLayerResult struct {
QuantizedWeight [][]int8
Scales [][]float32
Zeros [][]int32
GroupSize int
}
AWQ 量化
AWQ (Activation-aware Weight Quantization) 基于激活感知的权重量化:
package quantization
import (
"math"
)
// AWQQuantizer AWQ 量化器
type AWQQuantizer struct {
bits int
groupSize int
searchScale bool
}
// AWQConfig AWQ 配置
type AWQConfig struct {
Bits int
GroupSize int
SearchScale bool
NumSamples int
}
func NewAWQQuantizer(config *AWQConfig) *AWQQuantizer {
return &AWQQuantizer{
bits: config.Bits,
groupSize: config.GroupSize,
searchScale: config.SearchScale,
}
}
// QuantizeLayer AWQ 量化层
func (aq *AWQQuantizer) QuantizeLayer(
weight [][]float32, // [out_features, in_features]
activations [][]float32, // 校准激活 [num_samples, in_features]
) (*AWQLayerResult, error) {
outFeatures := len(weight)
inFeatures := len(weight[0])
// 1. 计算激活通道重要性
channelImportance := aq.computeChannelImportance(activations)
// 2. 搜索最优缩放因子
var scales []float32
if aq.searchScale {
scales = aq.searchOptimalScales(weight, channelImportance)
} else {
// 使用激活幅度作为缩放因子
scales = channelImportance
}
// 3. 对权重应用缩放
scaledWeight := aq.applyScales(weight, scales)
// 4. 量化缩放后的权重
numGroups := (inFeatures + aq.groupSize - 1) / aq.groupSize
quantizedWeight := make([][]int8, outFeatures)
qScales := make([][]float32, outFeatures)
qZeros := make([][]int32, outFeatures)
for i := 0; i < outFeatures; i++ {
quantizedWeight[i] = make([]int8, inFeatures)
qScales[i] = make([]float32, numGroups)
qZeros[i] = make([]int32, numGroups)
for g := 0; g < numGroups; g++ {
start := g * aq.groupSize
end := min(start+aq.groupSize, inFeatures)
// 计算组内量化参数
group := scaledWeight[i][start:end]
scale, zero := aq.computeQuantParams(group)
qScales[i][g] = scale
qZeros[i][g] = zero
// 量化
for j := start; j < end; j++ {
quantizedWeight[i][j] = aq.quantize(scaledWeight[i][j], scale, zero)
}
}
}
return &AWQLayerResult{
QuantizedWeight: quantizedWeight,
Scales: qScales,
Zeros: qZeros,
ChannelScales: scales,
GroupSize: aq.groupSize,
}, nil
}
// computeChannelImportance 计算通道重要性
func (aq *AWQQuantizer) computeChannelImportance(activations [][]float32) []float32 {
numSamples := len(activations)
inFeatures := len(activations[0])
importance := make([]float32, inFeatures)
for c := 0; c < inFeatures; c++ {
var sum float32
for s := 0; s < numSamples; s++ {
sum += float32(math.Abs(float64(activations[s][c])))
}
importance[c] = sum / float32(numSamples)
}
// 归一化
maxImp := importance[0]
for _, v := range importance {
if v > maxImp {
maxImp = v
}
}
if maxImp > 0 {
for i := range importance {
importance[i] /= maxImp
}
}
return importance
}
// searchOptimalScales 搜索最优缩放因子
func (aq *AWQQuantizer) searchOptimalScales(
weight [][]float32,
channelImportance []float32,
) []float32 {
inFeatures := len(weight[0])
scales := make([]float32, inFeatures)
// 网格搜索缩放因子
searchRange := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}
for c := 0; c < inFeatures; c++ {
bestScale := float32(1.0)
bestError := float32(math.MaxFloat32)
baseScale := channelImportance[c]
if baseScale < 0.01 {
baseScale = 0.01
}
for _, ratio := range searchRange {
scale := float32(ratio) * baseScale
// 计算该缩放因子下的量化误差
error := aq.computeQuantError(weight, c, scale)
if error < bestError {
bestError = error
bestScale = scale
}
}
scales[c] = bestScale
}
return scales
}
func (aq *AWQQuantizer) computeQuantError(weight [][]float32, channel int, scale float32) float32 {
var totalError float32
for i := range weight {
w := weight[i][channel]
scaledW := w * scale
// 模拟量化
qMax := float32(math.Pow(2, float64(aq.bits-1)) - 1)
qScale := float32(math.Abs(float64(scaledW))) / qMax
if qScale == 0 {
qScale = 1
}
qVal := scaledW / qScale
qVal = float32(math.Max(-float64(qMax), math.Min(float64(qMax), float64(qVal))))
qVal = float32(math.Round(float64(qVal)))
// 反量化
deqVal := qVal * qScale / scale
// 累计误差
totalError += float32(math.Abs(float64(w - deqVal)))
}
return totalError
}
func (aq *AWQQuantizer) applyScales(weight [][]float32, scales []float32) [][]float32 {
outFeatures := len(weight)
inFeatures := len(weight[0])
scaled := make([][]float32, outFeatures)
for i := range scaled {
scaled[i] = make([]float32, inFeatures)
for j := 0; j < inFeatures; j++ {
scaled[i][j] = weight[i][j] * scales[j]
}
}
return scaled
}
func (aq *AWQQuantizer) computeQuantParams(group []float32) (float32, int32) {
maxVal := float32(0)
for _, v := range group {
absV := float32(math.Abs(float64(v)))
if absV > maxVal {
maxVal = absV
}
}
qMax := float32(math.Pow(2, float64(aq.bits-1)) - 1)
scale := maxVal / qMax
if scale == 0 {
scale = 1
}
return scale, 0
}
func (aq *AWQQuantizer) quantize(val, scale float32, zero int32) int8 {
qMax := float32(math.Pow(2, float64(aq.bits-1)) - 1)
qVal := val/scale + float32(zero)
qVal = float32(math.Max(float64(-qMax), math.Min(float64(qMax), float64(qVal))))
return int8(math.Round(float64(qVal)))
}
type AWQLayerResult struct {
QuantizedWeight [][]int8
Scales [][]float32
Zeros [][]int32
ChannelScales []float32
GroupSize int
}
计算图优化
算子融合
算子融合是减少内存访问和 kernel 启动开销的关键技术:
package optimization
import (
"fmt"
)
// GraphOptimizer 计算图优化器
type GraphOptimizer struct {
graph *ComputeGraph
fusionRules []FusionRule
}
// ComputeGraph 计算图
type ComputeGraph struct {
Nodes map[string]*GraphNode
Edges []*GraphEdge
}
// GraphNode 图节点
type GraphNode struct {
ID string
OpType string
Inputs []string
Outputs []string
Attrs map[string]interface{}
}
// GraphEdge 图边
type GraphEdge struct {
Source string
Target string
TensorID string
}
// FusionRule 融合规则
type FusionRule interface {
Match(graph *ComputeGraph, nodeID string) []string
Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode
Name() string
}
// LinearGELUFusion Linear + GELU 融合
type LinearGELUFusion struct{}
func (f *LinearGELUFusion) Name() string {
return "LinearGELUFusion"
}
func (f *LinearGELUFusion) Match(graph *ComputeGraph, nodeID string) []string {
node := graph.Nodes[nodeID]
if node.OpType != "Linear" {
return nil
}
// 检查输出是否连接到 GELU
for _, edge := range graph.Edges {
if edge.Source == nodeID {
targetNode := graph.Nodes[edge.Target]
if targetNode.OpType == "GELU" {
return []string{nodeID, edge.Target}
}
}
}
return nil
}
func (f *LinearGELUFusion) Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode {
linearNode := graph.Nodes[nodeIDs[0]]
geluNode := graph.Nodes[nodeIDs[1]]
return &GraphNode{
ID: fmt.Sprintf("fused_%s_%s", linearNode.ID, geluNode.ID),
OpType: "LinearGELU",
Inputs: linearNode.Inputs,
Outputs: geluNode.Outputs,
Attrs: map[string]interface{}{
"weight": linearNode.Attrs["weight"],
"bias": linearNode.Attrs["bias"],
},
}
}
// MultiHeadAttentionFusion 多头注意力融合
type MultiHeadAttentionFusion struct{}
func (f *MultiHeadAttentionFusion) Name() string {
return "MultiHeadAttentionFusion"
}
func (f *MultiHeadAttentionFusion) Match(graph *ComputeGraph, nodeID string) []string {
// 匹配 QKV projection -> Attention -> Output projection 模式
node := graph.Nodes[nodeID]
// 检查是否是 Q 投影
if node.OpType != "Linear" || !isQProjection(node) {
return nil
}
// 找到关联的 K, V 投影和后续操作
kNode := findKProjection(graph, node)
vNode := findVProjection(graph, node)
if kNode == nil || vNode == nil {
return nil
}
// 找到 attention 和 output 投影
attnNode := findAttentionNode(graph, node, kNode, vNode)
outNode := findOutputProjection(graph, attnNode)
if attnNode == nil || outNode == nil {
return nil
}
return []string{node.ID, kNode.ID, vNode.ID, attnNode.ID, outNode.ID}
}
func (f *MultiHeadAttentionFusion) Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode {
qNode := graph.Nodes[nodeIDs[0]]
return &GraphNode{
ID: fmt.Sprintf("fused_mha_%s", nodeIDs[0]),
OpType: "FusedMultiHeadAttention",
Inputs: qNode.Inputs,
Outputs: []string{"mha_output"},
Attrs: map[string]interface{}{
"num_heads": extractNumHeads(graph.Nodes[nodeIDs[3]]),
"qkv_fused": true,
},
}
}
// RMSNormFusion RMSNorm 融合
type RMSNormFusion struct{}
func (f *RMSNormFusion) Name() string {
return "RMSNormFusion"
}
func (f *RMSNormFusion) Match(graph *ComputeGraph, nodeID string) []string {
// 匹配 x -> pow(2) -> mean -> rsqrt -> mul -> mul(weight) 模式
node := graph.Nodes[nodeID]
if node.OpType != "Pow" {
return nil
}
// 检查指数是否为 2
if exp, ok := node.Attrs["exponent"].(float64); !ok || exp != 2 {
return nil
}
// 查找后续节点链
meanNode := findNextOp(graph, nodeID, "Mean")
if meanNode == nil {
return nil
}
rsqrtNode := findNextOp(graph, meanNode.ID, "Rsqrt")
if rsqrtNode == nil {
return nil
}
mul1Node := findNextOp(graph, rsqrtNode.ID, "Mul")
if mul1Node == nil {
return nil
}
mul2Node := findNextOp(graph, mul1Node.ID, "Mul")
if mul2Node == nil {
return nil
}
return []string{node.ID, meanNode.ID, rsqrtNode.ID, mul1Node.ID, mul2Node.ID}
}
func (f *RMSNormFusion) Fuse(graph *ComputeGraph, nodeIDs []string) *GraphNode {
powNode := graph.Nodes[nodeIDs[0]]
mul2Node := graph.Nodes[nodeIDs[4]]
return &GraphNode{
ID: fmt.Sprintf("fused_rmsnorm_%s", nodeIDs[0]),
OpType: "FusedRMSNorm",
Inputs: powNode.Inputs,
Outputs: mul2Node.Outputs,
Attrs: map[string]interface{}{
"eps": 1e-6,
"weight": mul2Node.Attrs["other"],
},
}
}
// Optimize 执行图优化
func (opt *GraphOptimizer) Optimize() *ComputeGraph {
optimizedGraph := opt.graph.Clone()
changed := true
for changed {
changed = false
for nodeID := range optimizedGraph.Nodes {
for _, rule := range opt.fusionRules {
matchedNodes := rule.Match(optimizedGraph, nodeID)
if matchedNodes != nil {
fusedNode := rule.Fuse(optimizedGraph, matchedNodes)
opt.applyFusion(optimizedGraph, matchedNodes, fusedNode)
changed = true
break
}
}
if changed {
break
}
}
}
return optimizedGraph
}
func (opt *GraphOptimizer) applyFusion(
graph *ComputeGraph,
oldNodes []string,
newNode *GraphNode,
) {
// 添加新节点
graph.Nodes[newNode.ID] = newNode
// 更新边
newEdges := make([]*GraphEdge, 0)
for _, edge := range graph.Edges {
// 更新指向旧节点的边
sourceInOld := contains(oldNodes, edge.Source)
targetInOld := contains(oldNodes, edge.Target)
if sourceInOld && targetInOld {
// 内部边,删除
continue
} else if sourceInOld {
// 输出边,更新 source
edge.Source = newNode.ID
} else if targetInOld {
// 输入边,更新 target
edge.Target = newNode.ID
}
newEdges = append(newEdges, edge)
}
graph.Edges = newEdges
// 删除旧节点
for _, nodeID := range oldNodes {
delete(graph.Nodes, nodeID)
}
}
// 辅助函数
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
func findNextOp(graph *ComputeGraph, nodeID, opType string) *GraphNode {
for _, edge := range graph.Edges {
if edge.Source == nodeID {
targetNode := graph.Nodes[edge.Target]
if targetNode.OpType == opType {
return targetNode
}
}
}
return nil
}
func isQProjection(node *GraphNode) bool {
// 根据命名或属性判断
return true // 简化实现
}
func findKProjection(graph *ComputeGraph, qNode *GraphNode) *GraphNode {
return nil // 简化实现
}
func findVProjection(graph *ComputeGraph, qNode *GraphNode) *GraphNode {
return nil // 简化实现
}
func findAttentionNode(graph *ComputeGraph, q, k, v *GraphNode) *GraphNode {
return nil // 简化实现
}
func findOutputProjection(graph *ComputeGraph, attnNode *GraphNode) *GraphNode {
return nil // 简化实现
}
func extractNumHeads(node *GraphNode) int {
return 32 // 简化实现
}
func (g *ComputeGraph) Clone() *ComputeGraph {
// 深拷贝
return g // 简化实现
}
内存规划
高效的内存规划减少显存占用:
package optimization
import (
"sort"
)
// MemoryPlanner 内存规划器
type MemoryPlanner struct {
graph *ComputeGraph
tensorLifetimes map[string]*TensorLifetime
memoryPlan *MemoryPlan
}
// TensorLifetime 张量生命周期
type TensorLifetime struct {
TensorID string
Size int64
FirstUse int // 首次使用的操作索引
LastUse int // 最后使用的操作索引
MemOffset int64 // 分配的内存偏移
}
// MemoryPlan 内存计划
type MemoryPlan struct {
TotalSize int64
Allocations map[string]*Allocation
ReusableBuffers []*BufferPool
}
// Allocation 内存分配
type Allocation struct {
TensorID string
Offset int64
Size int64
Reused bool
}
// BufferPool 可重用缓冲池
type BufferPool struct {
Size int64
Tensors []string
}
// Plan 规划内存
func (mp *MemoryPlanner) Plan(executionOrder []string) *MemoryPlan {
// 1. 分析张量生命周期
mp.analyzeTensorLifetimes(executionOrder)
// 2. 按大小和生命周期排序
tensors := mp.getSortedTensors()
// 3. 使用最佳适配算法分配
plan := &MemoryPlan{
Allocations: make(map[string]*Allocation),
}
freeList := NewFreeList()
for _, lt := range tensors {
// 查找可重用的空间
offset, reused := freeList.FindBestFit(lt.Size, lt.FirstUse)
if !reused {
// 分配新空间
offset = plan.TotalSize
plan.TotalSize += lt.Size
}
plan.Allocations[lt.TensorID] = &Allocation{
TensorID: lt.TensorID,
Offset: offset,
Size: lt.Size,
Reused: reused,
}
lt.MemOffset = offset
// 添加到空闲列表(在最后使用后释放)
freeList.Add(&FreeBlock{
Offset: offset,
Size: lt.Size,
FreeAfter: lt.LastUse,
})
}
return plan
}
func (mp *MemoryPlanner) analyzeTensorLifetimes(executionOrder []string) {
mp.tensorLifetimes = make(map[string]*TensorLifetime)
for opIdx, nodeID := range executionOrder {
node := mp.graph.Nodes[nodeID]
// 处理输入张量
for _, inputID := range node.Inputs {
if lt, exists := mp.tensorLifetimes[inputID]; exists {
lt.LastUse = opIdx
}
}
// 处理输出张量
for _, outputID := range node.Outputs {
size := mp.estimateTensorSize(nodeID, outputID)
mp.tensorLifetimes[outputID] = &TensorLifetime{
TensorID: outputID,
Size: size,
FirstUse: opIdx,
LastUse: opIdx,
}
}
}
}
func (mp *MemoryPlanner) getSortedTensors() []*TensorLifetime {
tensors := make([]*TensorLifetime, 0, len(mp.tensorLifetimes))
for _, lt := range mp.tensorLifetimes {
tensors = append(tensors, lt)
}
// 优先分配大张量和生命周期长的张量
sort.Slice(tensors, func(i, j int) bool {
// 首先按大小降序
if tensors[i].Size != tensors[j].Size {
return tensors[i].Size > tensors[j].Size
}
// 然后按生命周期降序
lifeI := tensors[i].LastUse - tensors[i].FirstUse
lifeJ := tensors[j].LastUse - tensors[j].FirstUse
return lifeI > lifeJ
})
return tensors
}
func (mp *MemoryPlanner) estimateTensorSize(nodeID, tensorID string) int64 {
// 根据节点类型和属性估算张量大小
// 这里简化处理
return 1024 * 1024 // 1MB
}
// FreeList 空闲列表
type FreeList struct {
blocks []*FreeBlock
}
type FreeBlock struct {
Offset int64
Size int64
FreeAfter int
}
func NewFreeList() *FreeList {
return &FreeList{
blocks: make([]*FreeBlock, 0),
}
}
func (fl *FreeList) FindBestFit(size int64, currentOp int) (int64, bool) {
var bestBlock *FreeBlock
var bestIdx int = -1
for i, block := range fl.blocks {
// 检查是否已释放且大小足够
if block.FreeAfter < currentOp && block.Size >= size {
if bestBlock == nil || block.Size < bestBlock.Size {
bestBlock = block
bestIdx = i
}
}
}
if bestBlock != nil {
// 找到合适的块
offset := bestBlock.Offset
// 如果块更大,拆分
if bestBlock.Size > size {
newBlock := &FreeBlock{
Offset: offset + size,
Size: bestBlock.Size - size,
FreeAfter: bestBlock.FreeAfter,
}
fl.blocks[bestIdx] = newBlock
} else {
// 移除使用的块
fl.blocks = append(fl.blocks[:bestIdx], fl.blocks[bestIdx+1:]...)
}
return offset, true
}
return 0, false
}
func (fl *FreeList) Add(block *FreeBlock) {
fl.blocks = append(fl.blocks, block)
}
// InPlaceOptimizer 原地优化器
type InPlaceOptimizer struct {
graph *ComputeGraph
}
// Optimize 执行原地优化
func (ipo *InPlaceOptimizer) Optimize() {
for _, node := range ipo.graph.Nodes {
if ipo.canBeInPlace(node) {
node.Attrs["inplace"] = true
}
}
}
func (ipo *InPlaceOptimizer) canBeInPlace(node *GraphNode) bool {
// 检查操作是否可以原地执行
switch node.OpType {
case "ReLU", "GELU", "Sigmoid", "Tanh":
// 激活函数可以原地
return true
case "Add", "Mul":
// 逐元素操作在形状相同时可以原地
return ipo.hasMatchingShapes(node.Inputs)
case "Dropout":
// Dropout 可以原地
return true
default:
return false
}
}
func (ipo *InPlaceOptimizer) hasMatchingShapes(inputs []string) bool {
// 检查输入形状是否匹配
return true // 简化实现
}
KV Cache 优化
PagedAttention 实现
PagedAttention 是 vLLM 的核心技术,通过分页管理 KV Cache:
package optimization
import (
"errors"
"sync"
)
// PagedKVCache 分页 KV Cache
type PagedKVCache struct {
config *PagedKVConfig
blocks []*KVBlock
freeBlocks []int
blockTables map[int][]int // sequence_id -> block_indices
mu sync.RWMutex
}
type PagedKVConfig struct {
NumLayers int
NumHeads int
HeadDim int
BlockSize int // 每块的 token 数
NumGPUBlocks int
NumCPUBlocks int
DType string // fp16, fp32
}
// KVBlock KV Cache 块
type KVBlock struct {
KeyCache [][]float16 // [num_tokens, num_heads * head_dim]
ValueCache [][]float16 // [num_tokens, num_heads * head_dim]
NumTokens int
RefCount int
}
type float16 float32 // 简化表示
func NewPagedKVCache(config *PagedKVConfig) *PagedKVCache {
cache := &PagedKVCache{
config: config,
blocks: make([]*KVBlock, config.NumGPUBlocks),
freeBlocks: make([]int, config.NumGPUBlocks),
blockTables: make(map[int][]int),
}
// 初始化所有块
for i := 0; i < config.NumGPUBlocks; i++ {
cache.blocks[i] = &KVBlock{
KeyCache: make([][]float16, config.BlockSize),
ValueCache: make([][]float16, config.BlockSize),
}
for j := 0; j < config.BlockSize; j++ {
cache.blocks[i].KeyCache[j] = make([]float16, config.NumHeads*config.HeadDim)
cache.blocks[i].ValueCache[j] = make([]float16, config.NumHeads*config.HeadDim)
}
cache.freeBlocks[i] = i
}
return cache
}
// AllocateBlocks 为序列分配块
func (c *PagedKVCache) AllocateBlocks(seqID, numBlocks int) ([]int, error) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.freeBlocks) < numBlocks {
return nil, errors.New("not enough free blocks")
}
// 分配块
allocated := make([]int, numBlocks)
copy(allocated, c.freeBlocks[:numBlocks])
c.freeBlocks = c.freeBlocks[numBlocks:]
// 增加引用计数
for _, blockIdx := range allocated {
c.blocks[blockIdx].RefCount = 1
}
// 更新块表
c.blockTables[seqID] = allocated
return allocated, nil
}
// AppendToken 追加 token 的 KV Cache
func (c *PagedKVCache) AppendToken(
seqID int,
layer int,
key, value []float16,
) error {
c.mu.Lock()
defer c.mu.Unlock()
blockTable := c.blockTables[seqID]
if blockTable == nil {
return errors.New("sequence not found")
}
// 计算当前 token 位置
totalTokens := c.getSequenceLength(seqID)
blockIdx := totalTokens / c.config.BlockSize
slotIdx := totalTokens % c.config.BlockSize
// 检查是否需要新块
if blockIdx >= len(blockTable) {
// 需要分配新块
if len(c.freeBlocks) == 0 {
return errors.New("no free blocks")
}
newBlockIdx := c.freeBlocks[0]
c.freeBlocks = c.freeBlocks[1:]
c.blocks[newBlockIdx].RefCount = 1
blockTable = append(blockTable, newBlockIdx)
c.blockTables[seqID] = blockTable
}
// 写入 KV Cache
physicalBlock := blockTable[blockIdx]
copy(c.blocks[physicalBlock].KeyCache[slotIdx], key)
copy(c.blocks[physicalBlock].ValueCache[slotIdx], value)
c.blocks[physicalBlock].NumTokens++
return nil
}
// ForkSequence 分叉序列(用于 beam search)
func (c *PagedKVCache) ForkSequence(parentSeqID, childSeqID int) error {
c.mu.Lock()
defer c.mu.Unlock()
parentTable := c.blockTables[parentSeqID]
if parentTable == nil {
return errors.New("parent sequence not found")
}
// 创建子序列块表(共享块)
childTable := make([]int, len(parentTable))
copy(childTable, parentTable)
// 增加引用计数(Copy-on-Write)
for _, blockIdx := range childTable {
c.blocks[blockIdx].RefCount++
}
c.blockTables[childSeqID] = childTable
return nil
}
// CopyOnWrite 写时复制
func (c *PagedKVCache) CopyOnWrite(seqID, logicalBlockIdx int) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
blockTable := c.blockTables[seqID]
if blockTable == nil || logicalBlockIdx >= len(blockTable) {
return -1, errors.New("invalid sequence or block index")
}
physicalBlockIdx := blockTable[logicalBlockIdx]
// 检查是否需要复制
if c.blocks[physicalBlockIdx].RefCount == 1 {
return physicalBlockIdx, nil // 不需要复制
}
// 分配新块
if len(c.freeBlocks) == 0 {
return -1, errors.New("no free blocks for copy")
}
newBlockIdx := c.freeBlocks[0]
c.freeBlocks = c.freeBlocks[1:]
// 复制数据
oldBlock := c.blocks[physicalBlockIdx]
newBlock := c.blocks[newBlockIdx]
for i := 0; i < c.config.BlockSize; i++ {
copy(newBlock.KeyCache[i], oldBlock.KeyCache[i])
copy(newBlock.ValueCache[i], oldBlock.ValueCache[i])
}
newBlock.NumTokens = oldBlock.NumTokens
newBlock.RefCount = 1
// 减少原块引用
oldBlock.RefCount--
// 更新块表
blockTable[logicalBlockIdx] = newBlockIdx
return newBlockIdx, nil
}
// FreeSequence 释放序列
func (c *PagedKVCache) FreeSequence(seqID int) {
c.mu.Lock()
defer c.mu.Unlock()
blockTable := c.blockTables[seqID]
if blockTable == nil {
return
}
for _, blockIdx := range blockTable {
c.blocks[blockIdx].RefCount--
if c.blocks[blockIdx].RefCount == 0 {
// 块空闲
c.blocks[blockIdx].NumTokens = 0
c.freeBlocks = append(c.freeBlocks, blockIdx)
}
}
delete(c.blockTables, seqID)
}
// GetBlockTable 获取块表
func (c *PagedKVCache) GetBlockTable(seqID int) []int {
c.mu.RLock()
defer c.mu.RUnlock()
table := c.blockTables[seqID]
if table == nil {
return nil
}
result := make([]int, len(table))
copy(result, table)
return result
}
func (c *PagedKVCache) getSequenceLength(seqID int) int {
blockTable := c.blockTables[seqID]
if len(blockTable) == 0 {
return 0
}
totalTokens := 0
for _, blockIdx := range blockTable {
totalTokens += c.blocks[blockIdx].NumTokens
}
return totalTokens
}
// GetFreeBlockCount 获取空闲块数量
func (c *PagedKVCache) GetFreeBlockCount() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.freeBlocks)
}
// GetUtilization 获取内存利用率
func (c *PagedKVCache) GetUtilization() float64 {
c.mu.RLock()
defer c.mu.RUnlock()
usedBlocks := c.config.NumGPUBlocks - len(c.freeBlocks)
return float64(usedBlocks) / float64(c.config.NumGPUBlocks)
}
Prefix Caching
前缀缓存允许多个请求共享相同前缀的 KV Cache:
package optimization
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
// PrefixCache 前缀缓存
type PrefixCache struct {
cache map[string]*PrefixEntry
maxEntries int
evictionPolicy string
mu sync.RWMutex
}
// PrefixEntry 前缀条目
type PrefixEntry struct {
Hash string
Tokens []int32
BlockTable []int
NumTokens int
RefCount int
LastAccess time.Time
CreateTime time.Time
}
func NewPrefixCache(maxEntries int, evictionPolicy string) *PrefixCache {
return &PrefixCache{
cache: make(map[string]*PrefixEntry),
maxEntries: maxEntries,
evictionPolicy: evictionPolicy,
}
}
// ComputePrefixHash 计算前缀哈希
func (pc *PrefixCache) ComputePrefixHash(tokens []int32) string {
h := sha256.New()
for _, t := range tokens {
h.Write([]byte{
byte(t >> 24),
byte(t >> 16),
byte(t >> 8),
byte(t),
})
}
return hex.EncodeToString(h.Sum(nil))[:16]
}
// FindPrefix 查找最长匹配前缀
func (pc *PrefixCache) FindPrefix(tokens []int32) (*PrefixEntry, int) {
pc.mu.RLock()
defer pc.mu.RUnlock()
var bestMatch *PrefixEntry
var bestLen int
// 尝试不同长度的前缀
for prefixLen := len(tokens); prefixLen > 0; prefixLen-- {
hash := pc.ComputePrefixHash(tokens[:prefixLen])
if entry, exists := pc.cache[hash]; exists {
if prefixLen > bestLen {
bestMatch = entry
bestLen = prefixLen
}
break // 找到最长匹配
}
}
if bestMatch != nil {
bestMatch.LastAccess = time.Now()
bestMatch.RefCount++
}
return bestMatch, bestLen
}
// InsertPrefix 插入前缀
func (pc *PrefixCache) InsertPrefix(tokens []int32, blockTable []int) *PrefixEntry {
pc.mu.Lock()
defer pc.mu.Unlock()
// 检查是否需要驱逐
if len(pc.cache) >= pc.maxEntries {
pc.evict()
}
hash := pc.ComputePrefixHash(tokens)
entry := &PrefixEntry{
Hash: hash,
Tokens: append([]int32{}, tokens...),
BlockTable: append([]int{}, blockTable...),
NumTokens: len(tokens),
RefCount: 1,
LastAccess: time.Now(),
CreateTime: time.Now(),
}
pc.cache[hash] = entry
return entry
}
// evict 驱逐策略
func (pc *PrefixCache) evict() {
switch pc.evictionPolicy {
case "lru":
pc.evictLRU()
case "lfu":
pc.evictLFU()
default:
pc.evictLRU()
}
}
func (pc *PrefixCache) evictLRU() {
var oldest *PrefixEntry
var oldestHash string
for hash, entry := range pc.cache {
if entry.RefCount == 0 {
if oldest == nil || entry.LastAccess.Before(oldest.LastAccess) {
oldest = entry
oldestHash = hash
}
}
}
if oldestHash != "" {
delete(pc.cache, oldestHash)
}
}
func (pc *PrefixCache) evictLFU() {
var leastFreq *PrefixEntry
var leastHash string
minRef := int(^uint(0) >> 1)
for hash, entry := range pc.cache {
if entry.RefCount < minRef {
minRef = entry.RefCount
leastFreq = entry
leastHash = hash
}
}
if leastHash != "" {
delete(pc.cache, leastHash)
}
_ = leastFreq
}
// ReleasePrefix 释放前缀引用
func (pc *PrefixCache) ReleasePrefix(hash string) {
pc.mu.Lock()
defer pc.mu.Unlock()
if entry, exists := pc.cache[hash]; exists {
entry.RefCount--
}
}
// RadixPrefixCache 基于 Radix Tree 的前缀缓存(更高效的前缀匹配)
type RadixPrefixCache struct {
root *RadixNode
blockCache *PagedKVCache
mu sync.RWMutex
}
type RadixNode struct {
Tokens []int32
BlockTable []int
Children map[int32]*RadixNode
IsEndpoint bool
RefCount int
LastAccess time.Time
}
func NewRadixPrefixCache(blockCache *PagedKVCache) *RadixPrefixCache {
return &RadixPrefixCache{
root: &RadixNode{
Children: make(map[int32]*RadixNode),
},
blockCache: blockCache,
}
}
// Insert 插入前缀
func (rpc *RadixPrefixCache) Insert(tokens []int32, blockTable []int) {
rpc.mu.Lock()
defer rpc.mu.Unlock()
node := rpc.root
for i, token := range tokens {
if child, exists := node.Children[token]; exists {
node = child
} else {
newNode := &RadixNode{
Tokens: tokens[i:],
Children: make(map[int32]*RadixNode),
}
node.Children[token] = newNode
node = newNode
break
}
}
node.BlockTable = append([]int{}, blockTable...)
node.IsEndpoint = true
node.RefCount = 1
node.LastAccess = time.Now()
}
// FindLongestPrefix 查找最长前缀
func (rpc *RadixPrefixCache) FindLongestPrefix(tokens []int32) ([]int, int) {
rpc.mu.RLock()
defer rpc.mu.RUnlock()
var bestBlockTable []int
var bestLen int
node := rpc.root
matchedLen := 0
for _, token := range tokens {
child, exists := node.Children[token]
if !exists {
break
}
node = child
matchedLen++
if node.IsEndpoint {
bestBlockTable = node.BlockTable
bestLen = matchedLen
node.LastAccess = time.Now()
node.RefCount++
}
}
return bestBlockTable, bestLen
}
// Evict 驱逐不活跃的前缀
func (rpc *RadixPrefixCache) Evict(maxAge time.Duration) {
rpc.mu.Lock()
defer rpc.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
rpc.evictRecursive(rpc.root, cutoff)
}
func (rpc *RadixPrefixCache) evictRecursive(node *RadixNode, cutoff time.Time) {
tokensToDelete := make([]int32, 0)
for token, child := range node.Children {
if child.RefCount == 0 && child.LastAccess.Before(cutoff) {
// 释放 KV Cache 块
for _, blockIdx := range child.BlockTable {
_ = blockIdx // 释放块
}
tokensToDelete = append(tokensToDelete, token)
} else {
rpc.evictRecursive(child, cutoff)
}
}
for _, token := range tokensToDelete {
delete(node.Children, token)
}
}
系统级优化
CUDA 优化
package optimization
/*
#cgo LDFLAGS: -lcudart -lcublas
#include <cuda_runtime.h>
#include <cublas_v2.h>
// 自定义 CUDA kernel 声明
extern void fused_attention_kernel(
float* query, float* key, float* value, float* output,
int batch_size, int num_heads, int seq_len, int head_dim,
cudaStream_t stream
);
extern void fused_rope_kernel(
float* input, float* cos_cache, float* sin_cache,
int batch_size, int seq_len, int num_heads, int head_dim,
cudaStream_t stream
);
extern void quantized_matmul_kernel(
int8_t* A, int8_t* B, float* C,
float* scale_a, float* scale_b,
int M, int N, int K,
cudaStream_t stream
);
*/
import "C"
import (
"unsafe"
)
// CUDAOptimizer CUDA 优化器
type CUDAOptimizer struct {
deviceID int
stream C.cudaStream_t
cublasHandle C.cublasHandle_t
memPool *CUDAMemoryPool
}
// CUDAMemoryPool CUDA 内存池
type CUDAMemoryPool struct {
pool C.cudaMemPool_t
allocSize int64
}
func NewCUDAOptimizer(deviceID int) (*CUDAOptimizer, error) {
opt := &CUDAOptimizer{deviceID: deviceID}
// 设置设备
C.cudaSetDevice(C.int(deviceID))
// 创建 stream
C.cudaStreamCreate(&opt.stream)
// 创建 cuBLAS handle
C.cublasCreate(&opt.cublasHandle)
C.cublasSetStream(opt.cublasHandle, opt.stream)
// 创建内存池
opt.memPool = opt.createMemoryPool()
return opt, nil
}
func (opt *CUDAOptimizer) createMemoryPool() *CUDAMemoryPool {
var poolProps C.cudaMemPoolProps
poolProps.allocType = C.cudaMemAllocationTypePinned
poolProps.handleTypes = C.cudaMemHandleTypeNone
poolProps.location.type_ = C.cudaMemLocationTypeDevice
poolProps.location.id = C.int(opt.deviceID)
var pool C.cudaMemPool_t
C.cudaMemPoolCreate(&pool, &poolProps)
return &CUDAMemoryPool{pool: pool}
}
// FusedAttention 融合注意力
func (opt *CUDAOptimizer) FusedAttention(
query, key, value, output unsafe.Pointer,
batchSize, numHeads, seqLen, headDim int,
) {
C.fused_attention_kernel(
(*C.float)(query),
(*C.float)(key),
(*C.float)(value),
(*C.float)(output),
C.int(batchSize),
C.int(numHeads),
C.int(seqLen),
C.int(headDim),
opt.stream,
)
}
// FusedRoPE 融合旋转位置编码
func (opt *CUDAOptimizer) FusedRoPE(
input, cosCache, sinCache unsafe.Pointer,
batchSize, seqLen, numHeads, headDim int,
) {
C.fused_rope_kernel(
(*C.float)(input),
(*C.float)(cosCache),
(*C.float)(sinCache),
C.int(batchSize),
C.int(seqLen),
C.int(numHeads),
C.int(headDim),
opt.stream,
)
}
// QuantizedMatmul INT8 矩阵乘法
func (opt *CUDAOptimizer) QuantizedMatmul(
A, B, C, scaleA, scaleB unsafe.Pointer,
M, N, K int,
) {
C.quantized_matmul_kernel(
(*C.int8_t)(A),
(*C.int8_t)(B),
(*C.float)(C),
(*C.float)(scaleA),
(*C.float)(scaleB),
C.int(M),
C.int(N),
C.int(K),
opt.stream,
)
}
// AllocAsync 异步内存分配
func (opt *CUDAOptimizer) AllocAsync(size int64) unsafe.Pointer {
var ptr unsafe.Pointer
C.cudaMallocAsync(&ptr, C.size_t(size), opt.stream)
return ptr
}
// FreeAsync 异步内存释放
func (opt *CUDAOptimizer) FreeAsync(ptr unsafe.Pointer) {
C.cudaFreeAsync(ptr, opt.stream)
}
// Synchronize 同步
func (opt *CUDAOptimizer) Synchronize() {
C.cudaStreamSynchronize(opt.stream)
}
// Close 关闭
func (opt *CUDAOptimizer) Close() {
C.cudaStreamDestroy(opt.stream)
C.cublasDestroy(opt.cublasHandle)
}
// MultiStreamExecutor 多流执行器
type MultiStreamExecutor struct {
streams []C.cudaStream_t
numStreams int
currentIdx int
}
func NewMultiStreamExecutor(numStreams int) *MultiStreamExecutor {
exec := &MultiStreamExecutor{
streams: make([]C.cudaStream_t, numStreams),
numStreams: numStreams,
}
for i := 0; i < numStreams; i++ {
C.cudaStreamCreate(&exec.streams[i])
}
return exec
}
// GetNextStream 获取下一个 stream
func (exec *MultiStreamExecutor) GetNextStream() C.cudaStream_t {
stream := exec.streams[exec.currentIdx]
exec.currentIdx = (exec.currentIdx + 1) % exec.numStreams
return stream
}
// SyncAll 同步所有 stream
func (exec *MultiStreamExecutor) SyncAll() {
for _, stream := range exec.streams {
C.cudaStreamSynchronize(stream)
}
}
// Close 关闭
func (exec *MultiStreamExecutor) Close() {
for _, stream := range exec.streams {
C.cudaStreamDestroy(stream)
}
}
并行化策略
package optimization
import (
"context"
"sync"
)
// TensorParallelism 张量并行
type TensorParallelism struct {
worldSize int
rank int
commBackend CommunicationBackend
}
type CommunicationBackend interface {
AllReduce(tensor []float32) []float32
AllGather(tensor []float32) []float32
Broadcast(tensor []float32, srcRank int) []float32
Send(tensor []float32, dstRank int)
Recv(srcRank int) []float32
}
// ColumnParallelLinear 列并行线性层
type ColumnParallelLinear struct {
tp *TensorParallelism
weight [][]float32 // [out_features/world_size, in_features]
bias []float32 // [out_features/world_size]
gatherOutput bool
}
func (cpl *ColumnParallelLinear) Forward(input []float32) []float32 {
// 本地矩阵乘法
localOutput := matmul(input, cpl.weight)
if cpl.bias != nil {
for i := range localOutput {
localOutput[i] += cpl.bias[i%len(cpl.bias)]
}
}
if cpl.gatherOutput {
// AllGather 收集所有结果
return cpl.tp.commBackend.AllGather(localOutput)
}
return localOutput
}
// RowParallelLinear 行并行线性层
type RowParallelLinear struct {
tp *TensorParallelism
weight [][]float32 // [out_features, in_features/world_size]
bias []float32 // [out_features]
inputIsParallel bool
}
func (rpl *RowParallelLinear) Forward(input []float32) []float32 {
var localInput []float32
if !rpl.inputIsParallel {
// 分割输入
inputSize := len(input)
localSize := inputSize / rpl.tp.worldSize
start := rpl.tp.rank * localSize
localInput = input[start : start+localSize]
} else {
localInput = input
}
// 本地矩阵乘法
localOutput := matmul(localInput, rpl.weight)
// AllReduce 聚合
output := rpl.tp.commBackend.AllReduce(localOutput)
if rpl.bias != nil && rpl.tp.rank == 0 {
for i := range output {
output[i] += rpl.bias[i%len(rpl.bias)]
}
}
return output
}
// PipelineParallelism 流水线并行
type PipelineParallelism struct {
stages []PipelineStage
numStages int
microBatches int
commBackend CommunicationBackend
}
type PipelineStage interface {
Forward(input []float32) []float32
Backward(grad []float32) []float32
GetRank() int
}
// Execute1F1B 1F1B 调度
func (pp *PipelineParallelism) Execute1F1B(ctx context.Context, input []float32) []float32 {
numMicroBatches := pp.microBatches
warmupSteps := pp.numStages - 1
steadySteps := numMicroBatches - warmupSteps
// 分割成 micro-batches
microBatchSize := len(input) / numMicroBatches
microBatches := make([][]float32, numMicroBatches)
for i := 0; i < numMicroBatches; i++ {
start := i * microBatchSize
microBatches[i] = input[start : start+microBatchSize]
}
outputs := make([][]float32, numMicroBatches)
var wg sync.WaitGroup
// Warm-up 阶段:只有前向
for i := 0; i < warmupSteps && i < numMicroBatches; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
outputs[idx] = pp.forwardMicroBatch(microBatches[idx], idx)
}(i)
}
// Steady 阶段:1 前向 + 1 反向
for i := 0; i < steadySteps; i++ {
fwdIdx := warmupSteps + i
bwdIdx := i
// 前向
if fwdIdx < numMicroBatches {
wg.Add(1)
go func(idx int) {
defer wg.Done()
outputs[idx] = pp.forwardMicroBatch(microBatches[idx], idx)
}(fwdIdx)
}
// 反向
wg.Add(1)
go func(idx int) {
defer wg.Done()
pp.backwardMicroBatch(outputs[idx], idx)
}(bwdIdx)
}
// Cool-down 阶段:只有反向
for i := steadySteps; i < numMicroBatches; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
pp.backwardMicroBatch(outputs[idx], idx)
}(i)
}
wg.Wait()
// 合并输出
var result []float32
for _, out := range outputs {
result = append(result, out...)
}
return result
}
func (pp *PipelineParallelism) forwardMicroBatch(input []float32, microBatchID int) []float32 {
current := input
for _, stage := range pp.stages {
current = stage.Forward(current)
}
return current
}
func (pp *PipelineParallelism) backwardMicroBatch(grad []float32, microBatchID int) {
current := grad
for i := len(pp.stages) - 1; i >= 0; i-- {
current = pp.stages[i].Backward(current)
}
}
func matmul(a []float32, b [][]float32) []float32 {
// 简化的矩阵乘法
return nil
}
性能基准测试
基准测试框架
package benchmark
import (
"context"
"fmt"
"sync"
"time"
)
// InferenceBenchmark 推理基准测试
type InferenceBenchmark struct {
engine InferenceEngine
config *BenchmarkConfig
results *BenchmarkResults
}
type BenchmarkConfig struct {
// 测试参数
NumPrompts int
PromptLength []int // 不同长度的 prompt
OutputLength []int // 不同长度的输出
ConcurrentReqs []int // 不同并发数
// 预热
WarmupRequests int
WarmupDuration time.Duration
// 测试时间
TestDuration time.Duration
RequestTimeout time.Duration
}
type BenchmarkResults struct {
// 延迟指标
TTFT LatencyStats // Time to First Token
TPOT LatencyStats // Time per Output Token
E2ELatency LatencyStats // 端到端延迟
// 吞吐量指标
TokensPerSecond float64
RequestsPerSecond float64
// 资源指标
GPUUtilization float64
MemoryUsage int64
PeakMemory int64
// 详细结果
DetailedResults []*TestCaseResult
}
type LatencyStats struct {
Min time.Duration
Max time.Duration
Mean time.Duration
P50 time.Duration
P90 time.Duration
P95 time.Duration
P99 time.Duration
Stddev time.Duration
}
type TestCaseResult struct {
PromptLength int
OutputLength int
Concurrency int
TTFT time.Duration
TPOT time.Duration
TotalLatency time.Duration
TokensGenerated int
Success bool
Error error
}
type InferenceEngine interface {
Generate(ctx context.Context, prompt string, maxTokens int) (*GenerateResult, error)
}
type GenerateResult struct {
Text string
Tokens []int32
TTFT time.Duration
TotalLatency time.Duration
TokensPerSec float64
}
// Run 运行基准测试
func (b *InferenceBenchmark) Run(ctx context.Context) (*BenchmarkResults, error) {
b.results = &BenchmarkResults{
DetailedResults: make([]*TestCaseResult, 0),
}
// 预热
fmt.Println("Warming up...")
if err := b.warmup(ctx); err != nil {
return nil, err
}
// 运行测试矩阵
for _, promptLen := range b.config.PromptLength {
for _, outputLen := range b.config.OutputLength {
for _, concurrency := range b.config.ConcurrentReqs {
fmt.Printf("Testing: prompt=%d, output=%d, concurrency=%d\n",
promptLen, outputLen, concurrency)
results := b.runTestCase(ctx, promptLen, outputLen, concurrency)
b.results.DetailedResults = append(b.results.DetailedResults, results...)
}
}
}
// 计算聚合指标
b.calculateAggregates()
return b.results, nil
}
func (b *InferenceBenchmark) warmup(ctx context.Context) error {
prompt := generatePrompt(100)
for i := 0; i < b.config.WarmupRequests; i++ {
_, err := b.engine.Generate(ctx, prompt, 50)
if err != nil {
return err
}
}
return nil
}
func (b *InferenceBenchmark) runTestCase(
ctx context.Context,
promptLen, outputLen, concurrency int,
) []*TestCaseResult {
results := make([]*TestCaseResult, 0)
var mu sync.Mutex
var wg sync.WaitGroup
prompt := generatePrompt(promptLen)
// 计算请求数
numRequests := b.config.NumPrompts
requestsPerWorker := numRequests / concurrency
startTime := time.Now()
for w := 0; w < concurrency; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for i := 0; i < requestsPerWorker; i++ {
// 检查是否超时
if time.Since(startTime) > b.config.TestDuration {
return
}
reqCtx, cancel := context.WithTimeout(ctx, b.config.RequestTimeout)
result := b.runSingleRequest(reqCtx, prompt, outputLen)
cancel()
mu.Lock()
results = append(results, result)
mu.Unlock()
}
}(w)
}
wg.Wait()
return results
}
func (b *InferenceBenchmark) runSingleRequest(
ctx context.Context,
prompt string,
maxTokens int,
) *TestCaseResult {
result := &TestCaseResult{
PromptLength: len(prompt),
OutputLength: maxTokens,
}
startTime := time.Now()
genResult, err := b.engine.Generate(ctx, prompt, maxTokens)
if err != nil {
result.Error = err
result.Success = false
return result
}
result.TTFT = genResult.TTFT
result.TotalLatency = time.Since(startTime)
result.TokensGenerated = len(genResult.Tokens)
if result.TokensGenerated > 0 {
result.TPOT = (result.TotalLatency - result.TTFT) / time.Duration(result.TokensGenerated)
}
result.Success = true
return result
}
func (b *InferenceBenchmark) calculateAggregates() {
var ttfts, tpots, e2eLatencies []time.Duration
var totalTokens int
successCount := 0
for _, r := range b.results.DetailedResults {
if r.Success {
ttfts = append(ttfts, r.TTFT)
tpots = append(tpots, r.TPOT)
e2eLatencies = append(e2eLatencies, r.TotalLatency)
totalTokens += r.TokensGenerated
successCount++
}
}
if len(ttfts) > 0 {
b.results.TTFT = calculateLatencyStats(ttfts)
b.results.TPOT = calculateLatencyStats(tpots)
b.results.E2ELatency = calculateLatencyStats(e2eLatencies)
}
// 计算吞吐量
if len(e2eLatencies) > 0 {
totalTime := e2eLatencies[len(e2eLatencies)-1]
b.results.TokensPerSecond = float64(totalTokens) / totalTime.Seconds()
b.results.RequestsPerSecond = float64(successCount) / totalTime.Seconds()
}
}
func calculateLatencyStats(latencies []time.Duration) LatencyStats {
if len(latencies) == 0 {
return LatencyStats{}
}
// 排序
sorted := make([]time.Duration, len(latencies))
copy(sorted, latencies)
sortDurations(sorted)
// 计算统计量
var sum time.Duration
for _, l := range sorted {
sum += l
}
mean := sum / time.Duration(len(sorted))
return LatencyStats{
Min: sorted[0],
Max: sorted[len(sorted)-1],
Mean: mean,
P50: percentile(sorted, 50),
P90: percentile(sorted, 90),
P95: percentile(sorted, 95),
P99: percentile(sorted, 99),
}
}
func percentile(sorted []time.Duration, p int) time.Duration {
idx := (len(sorted) * p) / 100
if idx >= len(sorted) {
idx = len(sorted) - 1
}
return sorted[idx]
}
func sortDurations(d []time.Duration) {
for i := 0; i < len(d)-1; i++ {
for j := i + 1; j < len(d); j++ {
if d[i] > d[j] {
d[i], d[j] = d[j], d[i]
}
}
}
}
func generatePrompt(length int) string {
// 生成指定长度的测试 prompt
text := "The quick brown fox jumps over the lazy dog. "
result := ""
for len(result) < length {
result += text
}
return result[:length]
}
// GenerateReport 生成报告
func (b *InferenceBenchmark) GenerateReport() string {
r := b.results
report := fmt.Sprintf(`
=== Inference Benchmark Report ===
Latency Metrics:
TTFT (Time to First Token):
Min: %v, Max: %v, Mean: %v
P50: %v, P90: %v, P95: %v, P99: %v
TPOT (Time per Output Token):
Min: %v, Max: %v, Mean: %v
P50: %v, P90: %v, P95: %v, P99: %v
E2E Latency:
Min: %v, Max: %v, Mean: %v
P50: %v, P90: %v, P95: %v, P99: %v
Throughput Metrics:
Tokens/sec: %.2f
Requests/sec: %.2f
Resource Metrics:
GPU Utilization: %.2f%%
Memory Usage: %d MB
Peak Memory: %d MB
`,
r.TTFT.Min, r.TTFT.Max, r.TTFT.Mean,
r.TTFT.P50, r.TTFT.P90, r.TTFT.P95, r.TTFT.P99,
r.TPOT.Min, r.TPOT.Max, r.TPOT.Mean,
r.TPOT.P50, r.TPOT.P90, r.TPOT.P95, r.TPOT.P99,
r.E2ELatency.Min, r.E2ELatency.Max, r.E2ELatency.Mean,
r.E2ELatency.P50, r.E2ELatency.P90, r.E2ELatency.P95, r.E2ELatency.P99,
r.TokensPerSecond, r.RequestsPerSecond,
r.GPUUtilization*100, r.MemoryUsage/1024/1024, r.PeakMemory/1024/1024,
)
return report
}
小结
本章深入探讨了大模型推理优化的核心技术:
- 模型量化:INT8/INT4 量化、GPTQ、AWQ 等后训练量化方法
- 计算图优化:算子融合、内存规划、原地优化
- KV Cache 优化:PagedAttention、Prefix Caching、Copy-on-Write
- 系统级优化:CUDA 优化、多流执行、张量并行、流水线并行
- 性能基准测试:TTFT、TPOT、吞吐量等关键指标的测量方法
推理优化是一个需要从多个层面综合考虑的系统工程。通过合理应用这些优化技术,可以在保持模型精度的前提下,显著提升推理效率,降低服务成本。
下一章我们将探讨 多模型服务,讲解如何在同一集群中高效管理和调度多个大模型。