多模型服务
概述
在实际生产环境中,往往需要同时部署和管理多个大模型,包括不同规模的 LLM、多模态模型、专用模型等。多模型服务面临着资源调度、模型切换、负载均衡、成本优化等挑战。本章深入讲解多模型服务的架构设计与实现。
多模型架构设计
整体架构
多模型服务架构:
┌─────────────────────────────────────────────────────────────────┐
│ API Gateway │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Auth │ │ Router │ │ Rate │ │ Circuit │ │
│ │ │ │ │ │ Limiter │ │ Breaker │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Model Orchestrator │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Model │ │ Request │ │ Resource │ │
│ │ Registry │ │ Router │ │ Manager │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────────┘
│
┌─────────────────────┼─────────────────────┐
▼ ▼ ▼
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Model Pool │ │ Model Pool │ │ Model Pool │
│ (GPT-4级) │ │ (7B模型) │ │ (多模态) │
│ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌──────────┐ │
│ │Instance 1│ │ │ │Instance 1│ │ │ │Instance 1│ │
│ ├──────────┤ │ │ ├──────────┤ │ │ ├──────────┤ │
│ │Instance 2│ │ │ │Instance 2│ │ │ │Instance 2│ │
│ └──────────┘ │ │ └──────────┘ │ │ └──────────┘ │
└──────────────┘ └──────────────┘ └──────────────┘
模型注册中心
package multimodel
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"go.etcd.io/etcd/client/v3"
)
// ModelRegistry 模型注册中心
type ModelRegistry struct {
etcdClient *clientv3.Client
models map[string]*ModelInfo
instances map[string][]*ModelInstance
mu sync.RWMutex
watchCh clientv3.WatchChan
}
// ModelInfo 模型信息
type ModelInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Type ModelType `json:"type"`
Framework string `json:"framework"` // vllm, tgi, triton
ModelPath string `json:"model_path"`
Config *ModelConfig `json:"config"`
Requirements *ResourceRequirements `json:"requirements"`
Capabilities []string `json:"capabilities"` // chat, completion, embedding
Status ModelStatus `json:"status"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
type ModelType string
const (
ModelTypeLLM ModelType = "llm"
ModelTypeEmbedding ModelType = "embedding"
ModelTypeMultimodal ModelType = "multimodal"
ModelTypeSpecialized ModelType = "specialized"
)
type ModelConfig struct {
MaxSeqLen int `json:"max_seq_len"`
MaxBatchSize int `json:"max_batch_size"`
TensorParallel int `json:"tensor_parallel"`
Quantization string `json:"quantization"` // none, int8, int4
DType string `json:"dtype"` // fp16, bf16, fp32
}
type ResourceRequirements struct {
GPUMemoryGB float64 `json:"gpu_memory_gb"`
GPUCount int `json:"gpu_count"`
CPUCores int `json:"cpu_cores"`
MemoryGB float64 `json:"memory_gb"`
GPUType string `json:"gpu_type"` // A100, H100, etc.
}
type ModelStatus string
const (
ModelStatusPending ModelStatus = "pending"
ModelStatusLoading ModelStatus = "loading"
ModelStatusReady ModelStatus = "ready"
ModelStatusUnloading ModelStatus = "unloading"
ModelStatusFailed ModelStatus = "failed"
)
// ModelInstance 模型实例
type ModelInstance struct {
ID string `json:"id"`
ModelName string `json:"model_name"`
NodeID string `json:"node_id"`
Endpoint string `json:"endpoint"`
GPUs []int `json:"gpus"`
Status InstanceStatus `json:"status"`
Health *HealthStatus `json:"health"`
Metrics *InstanceMetrics `json:"metrics"`
CreateTime time.Time `json:"create_time"`
LastHeartbeat time.Time `json:"last_heartbeat"`
}
type InstanceStatus string
const (
InstanceStatusStarting InstanceStatus = "starting"
InstanceStatusRunning InstanceStatus = "running"
InstanceStatusDraining InstanceStatus = "draining"
InstanceStatusStopped InstanceStatus = "stopped"
InstanceStatusUnhealthy InstanceStatus = "unhealthy"
)
type HealthStatus struct {
Healthy bool `json:"healthy"`
LastCheck time.Time `json:"last_check"`
FailCount int `json:"fail_count"`
Message string `json:"message"`
}
type InstanceMetrics struct {
RequestsTotal int64 `json:"requests_total"`
RequestsActive int `json:"requests_active"`
TokensGenerated int64 `json:"tokens_generated"`
AverageLatencyMs float64 `json:"average_latency_ms"`
GPUUtilization float64 `json:"gpu_utilization"`
GPUMemoryUsed float64 `json:"gpu_memory_used"`
}
func NewModelRegistry(etcdEndpoints []string) (*ModelRegistry, error) {
client, err := clientv3.New(clientv3.Config{
Endpoints: etcdEndpoints,
DialTimeout: 5 * time.Second,
})
if err != nil {
return nil, err
}
registry := &ModelRegistry{
etcdClient: client,
models: make(map[string]*ModelInfo),
instances: make(map[string][]*ModelInstance),
}
// 启动监听
go registry.watchChanges()
return registry, nil
}
// RegisterModel 注册模型
func (r *ModelRegistry) RegisterModel(ctx context.Context, model *ModelInfo) error {
r.mu.Lock()
defer r.mu.Unlock()
model.CreateTime = time.Now()
model.UpdateTime = time.Now()
model.Status = ModelStatusPending
data, err := json.Marshal(model)
if err != nil {
return err
}
key := fmt.Sprintf("/models/%s", model.Name)
_, err = r.etcdClient.Put(ctx, key, string(data))
if err != nil {
return err
}
r.models[model.Name] = model
return nil
}
// RegisterInstance 注册实例
func (r *ModelRegistry) RegisterInstance(ctx context.Context, instance *ModelInstance) error {
r.mu.Lock()
defer r.mu.Unlock()
instance.CreateTime = time.Now()
instance.LastHeartbeat = time.Now()
data, err := json.Marshal(instance)
if err != nil {
return err
}
key := fmt.Sprintf("/instances/%s/%s", instance.ModelName, instance.ID)
// 使用租约,实现自动过期
lease, err := r.etcdClient.Grant(ctx, 30) // 30秒 TTL
if err != nil {
return err
}
_, err = r.etcdClient.Put(ctx, key, string(data), clientv3.WithLease(lease.ID))
if err != nil {
return err
}
// 启动续租
go r.keepAlive(ctx, lease.ID, instance)
r.instances[instance.ModelName] = append(r.instances[instance.ModelName], instance)
return nil
}
// keepAlive 保持实例活跃
func (r *ModelRegistry) keepAlive(ctx context.Context, leaseID clientv3.LeaseID, instance *ModelInstance) {
ch, err := r.etcdClient.KeepAlive(ctx, leaseID)
if err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case resp, ok := <-ch:
if !ok {
return
}
if resp != nil {
instance.LastHeartbeat = time.Now()
}
}
}
}
// GetModel 获取模型信息
func (r *ModelRegistry) GetModel(name string) *ModelInfo {
r.mu.RLock()
defer r.mu.RUnlock()
return r.models[name]
}
// GetInstances 获取模型实例
func (r *ModelRegistry) GetInstances(modelName string) []*ModelInstance {
r.mu.RLock()
defer r.mu.RUnlock()
instances := r.instances[modelName]
result := make([]*ModelInstance, 0)
for _, inst := range instances {
if inst.Status == InstanceStatusRunning {
result = append(result, inst)
}
}
return result
}
// GetHealthyInstances 获取健康实例
func (r *ModelRegistry) GetHealthyInstances(modelName string) []*ModelInstance {
instances := r.GetInstances(modelName)
result := make([]*ModelInstance, 0)
for _, inst := range instances {
if inst.Health != nil && inst.Health.Healthy {
result = append(result, inst)
}
}
return result
}
// ListModels 列出所有模型
func (r *ModelRegistry) ListModels() []*ModelInfo {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*ModelInfo, 0, len(r.models))
for _, model := range r.models {
result = append(result, model)
}
return result
}
// watchChanges 监听变更
func (r *ModelRegistry) watchChanges() {
r.watchCh = r.etcdClient.Watch(context.Background(), "/", clientv3.WithPrefix())
for resp := range r.watchCh {
for _, event := range resp.Events {
r.handleEvent(event)
}
}
}
func (r *ModelRegistry) handleEvent(event *clientv3.Event) {
r.mu.Lock()
defer r.mu.Unlock()
key := string(event.Kv.Key)
switch event.Type {
case clientv3.EventTypePut:
// 更新或新增
if len(key) > 8 && key[:8] == "/models/" {
var model ModelInfo
if err := json.Unmarshal(event.Kv.Value, &model); err == nil {
r.models[model.Name] = &model
}
} else if len(key) > 11 && key[:11] == "/instances/" {
var instance ModelInstance
if err := json.Unmarshal(event.Kv.Value, &instance); err == nil {
r.updateInstance(&instance)
}
}
case clientv3.EventTypeDelete:
// 删除
if len(key) > 11 && key[:11] == "/instances/" {
r.removeInstance(key)
}
}
}
func (r *ModelRegistry) updateInstance(instance *ModelInstance) {
instances := r.instances[instance.ModelName]
for i, inst := range instances {
if inst.ID == instance.ID {
r.instances[instance.ModelName][i] = instance
return
}
}
r.instances[instance.ModelName] = append(instances, instance)
}
func (r *ModelRegistry) removeInstance(key string) {
// 从 key 解析 modelName 和 instanceID
// /instances/{modelName}/{instanceID}
// 简化处理
}
请求路由器
package multimodel
import (
"context"
"errors"
"hash/fnv"
"sort"
"sync"
"time"
)
// RequestRouter 请求路由器
type RequestRouter struct {
registry *ModelRegistry
loadBalancers map[string]*LoadBalancer
routingRules []*RoutingRule
fallbackChain []string
mu sync.RWMutex
}
// RoutingRule 路由规则
type RoutingRule struct {
Name string `json:"name"`
Priority int `json:"priority"`
Conditions []RoutingCondition `json:"conditions"`
Target string `json:"target"` // 目标模型
Weight int `json:"weight"` // 权重(用于流量分配)
}
type RoutingCondition struct {
Field string `json:"field"` // model, capability, prompt_length, user_tier
Operator string `json:"operator"` // eq, ne, gt, lt, in, contains
Value interface{} `json:"value"`
}
// InferenceRequest 推理请求
type InferenceRequest struct {
ID string `json:"id"`
Model string `json:"model"` // 指定模型,可为空
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature"`
Stream bool `json:"stream"`
UserTier string `json:"user_tier"` // free, pro, enterprise
Capabilities []string `json:"capabilities"` // 需要的能力
Metadata map[string]interface{} `json:"metadata"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
func NewRequestRouter(registry *ModelRegistry) *RequestRouter {
return &RequestRouter{
registry: registry,
loadBalancers: make(map[string]*LoadBalancer),
routingRules: make([]*RoutingRule, 0),
fallbackChain: []string{},
}
}
// Route 路由请求
func (r *RequestRouter) Route(ctx context.Context, req *InferenceRequest) (*ModelInstance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
// 1. 如果指定了模型,直接路由
if req.Model != "" {
return r.routeToModel(ctx, req.Model, req)
}
// 2. 应用路由规则
targetModel := r.applyRoutingRules(req)
if targetModel != "" {
instance, err := r.routeToModel(ctx, targetModel, req)
if err == nil {
return instance, nil
}
// 路由失败,尝试 fallback
}
// 3. 基于能力路由
if len(req.Capabilities) > 0 {
instance, err := r.routeByCapability(ctx, req)
if err == nil {
return instance, nil
}
}
// 4. Fallback 链
for _, modelName := range r.fallbackChain {
instance, err := r.routeToModel(ctx, modelName, req)
if err == nil {
return instance, nil
}
}
return nil, errors.New("no available model instance")
}
// routeToModel 路由到指定模型
func (r *RequestRouter) routeToModel(
ctx context.Context,
modelName string,
req *InferenceRequest,
) (*ModelInstance, error) {
model := r.registry.GetModel(modelName)
if model == nil {
return nil, fmt.Errorf("model %s not found", modelName)
}
if model.Status != ModelStatusReady {
return nil, fmt.Errorf("model %s not ready", modelName)
}
instances := r.registry.GetHealthyInstances(modelName)
if len(instances) == 0 {
return nil, fmt.Errorf("no healthy instances for model %s", modelName)
}
// 获取或创建负载均衡器
lb := r.getLoadBalancer(modelName)
return lb.Select(instances, req)
}
// applyRoutingRules 应用路由规则
func (r *RequestRouter) applyRoutingRules(req *InferenceRequest) string {
// 按优先级排序
rules := make([]*RoutingRule, len(r.routingRules))
copy(rules, r.routingRules)
sort.Slice(rules, func(i, j int) bool {
return rules[i].Priority > rules[j].Priority
})
for _, rule := range rules {
if r.matchRule(rule, req) {
return rule.Target
}
}
return ""
}
// matchRule 匹配规则
func (r *RequestRouter) matchRule(rule *RoutingRule, req *InferenceRequest) bool {
for _, cond := range rule.Conditions {
if !r.matchCondition(cond, req) {
return false
}
}
return true
}
func (r *RequestRouter) matchCondition(cond RoutingCondition, req *InferenceRequest) bool {
var value interface{}
switch cond.Field {
case "model":
value = req.Model
case "user_tier":
value = req.UserTier
case "prompt_length":
value = r.calculatePromptLength(req)
case "capability":
value = req.Capabilities
default:
if req.Metadata != nil {
value = req.Metadata[cond.Field]
}
}
return r.compareValues(cond.Operator, value, cond.Value)
}
func (r *RequestRouter) compareValues(operator string, actual, expected interface{}) bool {
switch operator {
case "eq":
return actual == expected
case "ne":
return actual != expected
case "gt":
if a, ok := actual.(int); ok {
if e, ok := expected.(int); ok {
return a > e
}
}
case "lt":
if a, ok := actual.(int); ok {
if e, ok := expected.(int); ok {
return a < e
}
}
case "in":
if expected, ok := expected.([]interface{}); ok {
for _, e := range expected {
if actual == e {
return true
}
}
}
case "contains":
if actual, ok := actual.([]string); ok {
if expected, ok := expected.(string); ok {
for _, a := range actual {
if a == expected {
return true
}
}
}
}
}
return false
}
// routeByCapability 按能力路由
func (r *RequestRouter) routeByCapability(
ctx context.Context,
req *InferenceRequest,
) (*ModelInstance, error) {
models := r.registry.ListModels()
// 找到满足所有能力的模型
candidateModels := make([]*ModelInfo, 0)
for _, model := range models {
if model.Status != ModelStatusReady {
continue
}
if r.hasAllCapabilities(model.Capabilities, req.Capabilities) {
candidateModels = append(candidateModels, model)
}
}
if len(candidateModels) == 0 {
return nil, errors.New("no model with required capabilities")
}
// 选择最合适的模型(简单策略:选择资源需求最小的)
sort.Slice(candidateModels, func(i, j int) bool {
return candidateModels[i].Requirements.GPUMemoryGB <
candidateModels[j].Requirements.GPUMemoryGB
})
return r.routeToModel(ctx, candidateModels[0].Name, req)
}
func (r *RequestRouter) hasAllCapabilities(modelCaps, reqCaps []string) bool {
capSet := make(map[string]bool)
for _, c := range modelCaps {
capSet[c] = true
}
for _, c := range reqCaps {
if !capSet[c] {
return false
}
}
return true
}
func (r *RequestRouter) calculatePromptLength(req *InferenceRequest) int {
total := 0
for _, msg := range req.Messages {
total += len(msg.Content)
}
return total
}
func (r *RequestRouter) getLoadBalancer(modelName string) *LoadBalancer {
if lb, exists := r.loadBalancers[modelName]; exists {
return lb
}
lb := NewLoadBalancer(LoadBalancerLeastConnections)
r.loadBalancers[modelName] = lb
return lb
}
// AddRoutingRule 添加路由规则
func (r *RequestRouter) AddRoutingRule(rule *RoutingRule) {
r.mu.Lock()
defer r.mu.Unlock()
r.routingRules = append(r.routingRules, rule)
}
// SetFallbackChain 设置 fallback 链
func (r *RequestRouter) SetFallbackChain(models []string) {
r.mu.Lock()
defer r.mu.Unlock()
r.fallbackChain = models
}
// LoadBalancer 负载均衡器
type LoadBalancer struct {
strategy LoadBalancerStrategy
connections map[string]int
mu sync.RWMutex
}
type LoadBalancerStrategy int
const (
LoadBalancerRoundRobin LoadBalancerStrategy = iota
LoadBalancerLeastConnections
LoadBalancerWeightedRandom
LoadBalancerConsistentHash
)
func NewLoadBalancer(strategy LoadBalancerStrategy) *LoadBalancer {
return &LoadBalancer{
strategy: strategy,
connections: make(map[string]int),
}
}
// Select 选择实例
func (lb *LoadBalancer) Select(
instances []*ModelInstance,
req *InferenceRequest,
) (*ModelInstance, error) {
if len(instances) == 0 {
return nil, errors.New("no instances available")
}
switch lb.strategy {
case LoadBalancerRoundRobin:
return lb.selectRoundRobin(instances)
case LoadBalancerLeastConnections:
return lb.selectLeastConnections(instances)
case LoadBalancerWeightedRandom:
return lb.selectWeightedRandom(instances)
case LoadBalancerConsistentHash:
return lb.selectConsistentHash(instances, req)
default:
return instances[0], nil
}
}
func (lb *LoadBalancer) selectLeastConnections(instances []*ModelInstance) (*ModelInstance, error) {
lb.mu.RLock()
defer lb.mu.RUnlock()
var selected *ModelInstance
minConn := int(^uint(0) >> 1)
for _, inst := range instances {
conn := lb.connections[inst.ID]
if inst.Metrics != nil {
conn += inst.Metrics.RequestsActive
}
if conn < minConn {
minConn = conn
selected = inst
}
}
return selected, nil
}
func (lb *LoadBalancer) selectWeightedRandom(instances []*ModelInstance) (*ModelInstance, error) {
// 基于 GPU 利用率的权重
totalWeight := 0.0
weights := make([]float64, len(instances))
for i, inst := range instances {
// 利用率越低,权重越高
utilization := 0.5
if inst.Metrics != nil {
utilization = inst.Metrics.GPUUtilization
}
weights[i] = 1.0 - utilization
totalWeight += weights[i]
}
// 随机选择
r := randomFloat() * totalWeight
cumWeight := 0.0
for i, w := range weights {
cumWeight += w
if r <= cumWeight {
return instances[i], nil
}
}
return instances[len(instances)-1], nil
}
func (lb *LoadBalancer) selectConsistentHash(
instances []*ModelInstance,
req *InferenceRequest,
) (*ModelInstance, error) {
// 基于请求 ID 的一致性哈希
h := fnv.New32a()
h.Write([]byte(req.ID))
hash := h.Sum32()
idx := int(hash) % len(instances)
return instances[idx], nil
}
func (lb *LoadBalancer) selectRoundRobin(instances []*ModelInstance) (*ModelInstance, error) {
// 简化实现
return instances[0], nil
}
// IncrementConnections 增加连接数
func (lb *LoadBalancer) IncrementConnections(instanceID string) {
lb.mu.Lock()
defer lb.mu.Unlock()
lb.connections[instanceID]++
}
// DecrementConnections 减少连接数
func (lb *LoadBalancer) DecrementConnections(instanceID string) {
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.connections[instanceID] > 0 {
lb.connections[instanceID]--
}
}
func randomFloat() float64 {
return 0.5 // 简化实现
}
资源管理与调度
GPU 资源管理器
package multimodel
import (
"context"
"errors"
"sort"
"sync"
"time"
)
// ResourceManager 资源管理器
type ResourceManager struct {
nodes map[string]*NodeInfo
allocations map[string]*ResourceAllocation
scheduler *ModelScheduler
mu sync.RWMutex
}
// NodeInfo 节点信息
type NodeInfo struct {
ID string `json:"id"`
Hostname string `json:"hostname"`
GPUs []*GPUInfo `json:"gpus"`
CPUCores int `json:"cpu_cores"`
MemoryGB float64 `json:"memory_gb"`
Labels map[string]string `json:"labels"`
Status NodeStatus `json:"status"`
LastHeartbeat time.Time `json:"last_heartbeat"`
}
type NodeStatus string
const (
NodeStatusReady NodeStatus = "ready"
NodeStatusNotReady NodeStatus = "not_ready"
NodeStatusDraining NodeStatus = "draining"
)
type GPUInfo struct {
Index int `json:"index"`
UUID string `json:"uuid"`
Name string `json:"name"`
MemoryTotal float64 `json:"memory_total_gb"`
MemoryUsed float64 `json:"memory_used_gb"`
Utilization float64 `json:"utilization"`
Allocated bool `json:"allocated"`
AllocatedTo string `json:"allocated_to"` // 模型实例 ID
}
// ResourceAllocation 资源分配
type ResourceAllocation struct {
ID string `json:"id"`
ModelName string `json:"model_name"`
InstanceID string `json:"instance_id"`
NodeID string `json:"node_id"`
GPUs []int `json:"gpus"`
CPUCores int `json:"cpu_cores"`
MemoryGB float64 `json:"memory_gb"`
CreateTime time.Time `json:"create_time"`
}
func NewResourceManager() *ResourceManager {
rm := &ResourceManager{
nodes: make(map[string]*NodeInfo),
allocations: make(map[string]*ResourceAllocation),
scheduler: NewModelScheduler(),
}
return rm
}
// RegisterNode 注册节点
func (rm *ResourceManager) RegisterNode(node *NodeInfo) error {
rm.mu.Lock()
defer rm.mu.Unlock()
node.Status = NodeStatusReady
node.LastHeartbeat = time.Now()
rm.nodes[node.ID] = node
return nil
}
// UpdateNodeHeartbeat 更新节点心跳
func (rm *ResourceManager) UpdateNodeHeartbeat(nodeID string, gpuMetrics []*GPUMetrics) error {
rm.mu.Lock()
defer rm.mu.Unlock()
node, exists := rm.nodes[nodeID]
if !exists {
return errors.New("node not found")
}
node.LastHeartbeat = time.Now()
// 更新 GPU 指标
for _, metrics := range gpuMetrics {
if metrics.Index < len(node.GPUs) {
node.GPUs[metrics.Index].MemoryUsed = metrics.MemoryUsed
node.GPUs[metrics.Index].Utilization = metrics.Utilization
}
}
return nil
}
type GPUMetrics struct {
Index int `json:"index"`
MemoryUsed float64 `json:"memory_used_gb"`
Utilization float64 `json:"utilization"`
}
// AllocateResources 分配资源
func (rm *ResourceManager) AllocateResources(
ctx context.Context,
model *ModelInfo,
instanceID string,
) (*ResourceAllocation, error) {
rm.mu.Lock()
defer rm.mu.Unlock()
// 找到合适的节点
node, gpuIndices, err := rm.scheduler.Schedule(rm.nodes, model.Requirements)
if err != nil {
return nil, err
}
// 标记 GPU 为已分配
for _, idx := range gpuIndices {
node.GPUs[idx].Allocated = true
node.GPUs[idx].AllocatedTo = instanceID
}
allocation := &ResourceAllocation{
ID: fmt.Sprintf("alloc-%s-%d", instanceID, time.Now().UnixNano()),
ModelName: model.Name,
InstanceID: instanceID,
NodeID: node.ID,
GPUs: gpuIndices,
CPUCores: model.Requirements.CPUCores,
MemoryGB: model.Requirements.MemoryGB,
CreateTime: time.Now(),
}
rm.allocations[allocation.ID] = allocation
return allocation, nil
}
// ReleaseResources 释放资源
func (rm *ResourceManager) ReleaseResources(allocationID string) error {
rm.mu.Lock()
defer rm.mu.Unlock()
allocation, exists := rm.allocations[allocationID]
if !exists {
return errors.New("allocation not found")
}
node := rm.nodes[allocation.NodeID]
if node != nil {
for _, idx := range allocation.GPUs {
if idx < len(node.GPUs) {
node.GPUs[idx].Allocated = false
node.GPUs[idx].AllocatedTo = ""
}
}
}
delete(rm.allocations, allocationID)
return nil
}
// GetNodeStatus 获取节点状态
func (rm *ResourceManager) GetNodeStatus() []*NodeInfo {
rm.mu.RLock()
defer rm.mu.RUnlock()
result := make([]*NodeInfo, 0, len(rm.nodes))
for _, node := range rm.nodes {
result = append(result, node)
}
return result
}
// GetAvailableGPUs 获取可用 GPU 数量
func (rm *ResourceManager) GetAvailableGPUs() int {
rm.mu.RLock()
defer rm.mu.RUnlock()
count := 0
for _, node := range rm.nodes {
if node.Status != NodeStatusReady {
continue
}
for _, gpu := range node.GPUs {
if !gpu.Allocated {
count++
}
}
}
return count
}
// ModelScheduler 模型调度器
type ModelScheduler struct {
strategy SchedulingStrategy
}
type SchedulingStrategy int
const (
StrategyBinPacking SchedulingStrategy = iota
StrategySpread
StrategyGPUAffinity
)
func NewModelScheduler() *ModelScheduler {
return &ModelScheduler{
strategy: StrategyBinPacking,
}
}
// Schedule 调度模型到节点
func (s *ModelScheduler) Schedule(
nodes map[string]*NodeInfo,
requirements *ResourceRequirements,
) (*NodeInfo, []int, error) {
candidates := s.findCandidateNodes(nodes, requirements)
if len(candidates) == 0 {
return nil, nil, errors.New("no suitable node found")
}
switch s.strategy {
case StrategyBinPacking:
return s.binPackingSchedule(candidates, requirements)
case StrategySpread:
return s.spreadSchedule(candidates, requirements)
case StrategyGPUAffinity:
return s.gpuAffinitySchedule(candidates, requirements)
default:
return s.binPackingSchedule(candidates, requirements)
}
}
func (s *ModelScheduler) findCandidateNodes(
nodes map[string]*NodeInfo,
req *ResourceRequirements,
) []*NodeInfo {
candidates := make([]*NodeInfo, 0)
for _, node := range nodes {
if node.Status != NodeStatusReady {
continue
}
// 检查 GPU 数量
availableGPUs := s.countAvailableGPUs(node, req.GPUType)
if availableGPUs < req.GPUCount {
continue
}
// 检查连续 GPU(对于张量并行)
if req.GPUCount > 1 && !s.hasContiguousGPUs(node, req.GPUCount) {
continue
}
// 检查 CPU 和内存
if node.CPUCores < req.CPUCores || node.MemoryGB < req.MemoryGB {
continue
}
candidates = append(candidates, node)
}
return candidates
}
func (s *ModelScheduler) countAvailableGPUs(node *NodeInfo, gpuType string) int {
count := 0
for _, gpu := range node.GPUs {
if gpu.Allocated {
continue
}
if gpuType != "" && gpu.Name != gpuType {
continue
}
count++
}
return count
}
func (s *ModelScheduler) hasContiguousGPUs(node *NodeInfo, count int) bool {
consecutive := 0
for _, gpu := range node.GPUs {
if !gpu.Allocated {
consecutive++
if consecutive >= count {
return true
}
} else {
consecutive = 0
}
}
return false
}
// binPackingSchedule 装箱调度(优先填满节点)
func (s *ModelScheduler) binPackingSchedule(
candidates []*NodeInfo,
req *ResourceRequirements,
) (*NodeInfo, []int, error) {
// 按可用 GPU 数量升序排序(优先选择 GPU 少的节点)
sort.Slice(candidates, func(i, j int) bool {
availI := s.countAvailableGPUs(candidates[i], req.GPUType)
availJ := s.countAvailableGPUs(candidates[j], req.GPUType)
return availI < availJ
})
for _, node := range candidates {
gpuIndices := s.selectGPUs(node, req.GPUCount, req.GPUType)
if len(gpuIndices) == req.GPUCount {
return node, gpuIndices, nil
}
}
return nil, nil, errors.New("scheduling failed")
}
// spreadSchedule 分散调度(均匀分布)
func (s *ModelScheduler) spreadSchedule(
candidates []*NodeInfo,
req *ResourceRequirements,
) (*NodeInfo, []int, error) {
// 按可用 GPU 数量降序排序(优先选择 GPU 多的节点)
sort.Slice(candidates, func(i, j int) bool {
availI := s.countAvailableGPUs(candidates[i], req.GPUType)
availJ := s.countAvailableGPUs(candidates[j], req.GPUType)
return availI > availJ
})
for _, node := range candidates {
gpuIndices := s.selectGPUs(node, req.GPUCount, req.GPUType)
if len(gpuIndices) == req.GPUCount {
return node, gpuIndices, nil
}
}
return nil, nil, errors.New("scheduling failed")
}
// gpuAffinitySchedule GPU 亲和性调度
func (s *ModelScheduler) gpuAffinitySchedule(
candidates []*NodeInfo,
req *ResourceRequirements,
) (*NodeInfo, []int, error) {
// 优先选择相邻 GPU(NVLink 连接)
for _, node := range candidates {
gpuIndices := s.selectContiguousGPUs(node, req.GPUCount, req.GPUType)
if len(gpuIndices) == req.GPUCount {
return node, gpuIndices, nil
}
}
// 退化为普通调度
return s.binPackingSchedule(candidates, req)
}
func (s *ModelScheduler) selectGPUs(node *NodeInfo, count int, gpuType string) []int {
indices := make([]int, 0, count)
for i, gpu := range node.GPUs {
if gpu.Allocated {
continue
}
if gpuType != "" && gpu.Name != gpuType {
continue
}
indices = append(indices, i)
if len(indices) == count {
break
}
}
return indices
}
func (s *ModelScheduler) selectContiguousGPUs(node *NodeInfo, count int, gpuType string) []int {
start := -1
consecutive := 0
for i, gpu := range node.GPUs {
if gpu.Allocated || (gpuType != "" && gpu.Name != gpuType) {
start = -1
consecutive = 0
continue
}
if start == -1 {
start = i
}
consecutive++
if consecutive == count {
indices := make([]int, count)
for j := 0; j < count; j++ {
indices[j] = start + j
}
return indices
}
}
return nil
}
自动扩缩容
package multimodel
import (
"context"
"math"
"sync"
"time"
)
// AutoScaler 自动扩缩容器
type AutoScaler struct {
registry *ModelRegistry
resourceMgr *ResourceManager
metrics *MetricsCollector
scalingPolicies map[string]*ScalingPolicy
mu sync.RWMutex
stopCh chan struct{}
}
// ScalingPolicy 扩缩容策略
type ScalingPolicy struct {
ModelName string `json:"model_name"`
MinReplicas int `json:"min_replicas"`
MaxReplicas int `json:"max_replicas"`
TargetMetrics []TargetMetric `json:"target_metrics"`
ScaleUpCooldown time.Duration `json:"scale_up_cooldown"`
ScaleDownCooldown time.Duration `json:"scale_down_cooldown"`
ScaleUpStep int `json:"scale_up_step"`
ScaleDownStep int `json:"scale_down_step"`
// 内部状态
lastScaleUp time.Time
lastScaleDown time.Time
currentReplicas int
}
type TargetMetric struct {
Name string `json:"name"` // qps, latency, gpu_utilization, queue_length
Target float64 `json:"target"` // 目标值
Tolerance float64 `json:"tolerance"` // 容忍度
}
// MetricsCollector 指标收集器
type MetricsCollector struct {
registry *ModelRegistry
}
type ModelMetrics struct {
ModelName string
QPS float64
AverageLatency time.Duration
P99Latency time.Duration
GPUUtilization float64
QueueLength int
ActiveRequests int
InstanceCount int
}
func NewAutoScaler(
registry *ModelRegistry,
resourceMgr *ResourceManager,
) *AutoScaler {
return &AutoScaler{
registry: registry,
resourceMgr: resourceMgr,
metrics: &MetricsCollector{registry: registry},
scalingPolicies: make(map[string]*ScalingPolicy),
stopCh: make(chan struct{}),
}
}
// SetPolicy 设置扩缩容策略
func (as *AutoScaler) SetPolicy(policy *ScalingPolicy) {
as.mu.Lock()
defer as.mu.Unlock()
as.scalingPolicies[policy.ModelName] = policy
}
// Start 启动自动扩缩容
func (as *AutoScaler) Start(ctx context.Context) {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-as.stopCh:
return
case <-ticker.C:
as.evaluate()
}
}
}
// Stop 停止
func (as *AutoScaler) Stop() {
close(as.stopCh)
}
// evaluate 评估扩缩容
func (as *AutoScaler) evaluate() {
as.mu.RLock()
policies := make([]*ScalingPolicy, 0, len(as.scalingPolicies))
for _, p := range as.scalingPolicies {
policies = append(policies, p)
}
as.mu.RUnlock()
for _, policy := range policies {
as.evaluateModel(policy)
}
}
func (as *AutoScaler) evaluateModel(policy *ScalingPolicy) {
metrics := as.metrics.Collect(policy.ModelName)
if metrics == nil {
return
}
policy.currentReplicas = metrics.InstanceCount
// 计算期望副本数
desiredReplicas := as.calculateDesiredReplicas(policy, metrics)
// 应用边界
desiredReplicas = max(desiredReplicas, policy.MinReplicas)
desiredReplicas = min(desiredReplicas, policy.MaxReplicas)
// 检查是否需要扩缩容
if desiredReplicas > policy.currentReplicas {
as.scaleUp(policy, desiredReplicas)
} else if desiredReplicas < policy.currentReplicas {
as.scaleDown(policy, desiredReplicas)
}
}
func (as *AutoScaler) calculateDesiredReplicas(
policy *ScalingPolicy,
metrics *ModelMetrics,
) int {
maxRatio := 1.0
for _, target := range policy.TargetMetrics {
var currentValue float64
switch target.Name {
case "qps":
currentValue = metrics.QPS
case "latency":
currentValue = float64(metrics.AverageLatency.Milliseconds())
case "gpu_utilization":
currentValue = metrics.GPUUtilization * 100
case "queue_length":
currentValue = float64(metrics.QueueLength)
default:
continue
}
if target.Target > 0 {
ratio := currentValue / target.Target
if ratio > maxRatio {
maxRatio = ratio
}
}
}
// 计算期望副本数
desired := int(math.Ceil(float64(policy.currentReplicas) * maxRatio))
return desired
}
func (as *AutoScaler) scaleUp(policy *ScalingPolicy, desired int) {
// 检查冷却时间
if time.Since(policy.lastScaleUp) < policy.ScaleUpCooldown {
return
}
// 计算实际扩容数量
scaleCount := min(desired-policy.currentReplicas, policy.ScaleUpStep)
if scaleCount <= 0 {
return
}
// 执行扩容
for i := 0; i < scaleCount; i++ {
err := as.createInstance(policy.ModelName)
if err != nil {
break
}
}
policy.lastScaleUp = time.Now()
}
func (as *AutoScaler) scaleDown(policy *ScalingPolicy, desired int) {
// 检查冷却时间
if time.Since(policy.lastScaleDown) < policy.ScaleDownCooldown {
return
}
// 检查是否低于最小副本数
if policy.currentReplicas <= policy.MinReplicas {
return
}
// 计算实际缩容数量
scaleCount := min(policy.currentReplicas-desired, policy.ScaleDownStep)
scaleCount = min(scaleCount, policy.currentReplicas-policy.MinReplicas)
if scaleCount <= 0 {
return
}
// 执行缩容(优先删除负载最低的实例)
instances := as.registry.GetInstances(policy.ModelName)
sort.Slice(instances, func(i, j int) bool {
loadI := 0
loadJ := 0
if instances[i].Metrics != nil {
loadI = instances[i].Metrics.RequestsActive
}
if instances[j].Metrics != nil {
loadJ = instances[j].Metrics.RequestsActive
}
return loadI < loadJ
})
for i := 0; i < scaleCount && i < len(instances); i++ {
as.deleteInstance(instances[i])
}
policy.lastScaleDown = time.Now()
}
func (as *AutoScaler) createInstance(modelName string) error {
model := as.registry.GetModel(modelName)
if model == nil {
return errors.New("model not found")
}
instanceID := fmt.Sprintf("%s-%d", modelName, time.Now().UnixNano())
// 分配资源
allocation, err := as.resourceMgr.AllocateResources(
context.Background(), model, instanceID,
)
if err != nil {
return err
}
// 启动实例(这里简化处理,实际需要调用 Kubernetes API 或容器运行时)
instance := &ModelInstance{
ID: instanceID,
ModelName: modelName,
NodeID: allocation.NodeID,
GPUs: allocation.GPUs,
Status: InstanceStatusStarting,
}
return as.registry.RegisterInstance(context.Background(), instance)
}
func (as *AutoScaler) deleteInstance(instance *ModelInstance) error {
// 1. 设置为 draining 状态
instance.Status = InstanceStatusDraining
// 2. 等待请求处理完成
// 实际实现需要等待或超时
// 3. 释放资源
for allocID, alloc := range as.resourceMgr.allocations {
if alloc.InstanceID == instance.ID {
as.resourceMgr.ReleaseResources(allocID)
break
}
}
// 4. 从注册中心移除
// ...
return nil
}
// Collect 收集模型指标
func (mc *MetricsCollector) Collect(modelName string) *ModelMetrics {
instances := mc.registry.GetInstances(modelName)
if len(instances) == 0 {
return nil
}
metrics := &ModelMetrics{
ModelName: modelName,
InstanceCount: len(instances),
}
var totalQPS float64
var totalLatency time.Duration
var totalUtilization float64
var totalQueue int
var totalActive int
validCount := 0
for _, inst := range instances {
if inst.Metrics == nil {
continue
}
totalQPS += float64(inst.Metrics.RequestsTotal) / 60.0 // 假设 1 分钟窗口
totalLatency += time.Duration(inst.Metrics.AverageLatencyMs) * time.Millisecond
totalUtilization += inst.Metrics.GPUUtilization
totalActive += inst.Metrics.RequestsActive
validCount++
}
if validCount > 0 {
metrics.QPS = totalQPS
metrics.AverageLatency = totalLatency / time.Duration(validCount)
metrics.GPUUtilization = totalUtilization / float64(validCount)
metrics.ActiveRequests = totalActive
}
return metrics
}
模型热切换
无缝模型切换
package multimodel
import (
"context"
"sync"
"time"
)
// ModelSwitcher 模型切换器
type ModelSwitcher struct {
registry *ModelRegistry
resourceMgr *ResourceManager
router *RequestRouter
mu sync.RWMutex
}
// SwitchRequest 切换请求
type SwitchRequest struct {
OldModel string `json:"old_model"`
NewModel string `json:"new_model"`
Strategy SwitchStrategy `json:"strategy"`
TrafficSplit float64 `json:"traffic_split"` // 新模型流量比例
Timeout time.Duration `json:"timeout"`
}
type SwitchStrategy string
const (
SwitchStrategyBlueGreen SwitchStrategy = "blue_green"
SwitchStrategyCanary SwitchStrategy = "canary"
SwitchStrategyRolling SwitchStrategy = "rolling"
)
// SwitchProgress 切换进度
type SwitchProgress struct {
RequestID string `json:"request_id"`
Status SwitchStatus `json:"status"`
OldModel string `json:"old_model"`
NewModel string `json:"new_model"`
TrafficSplit float64 `json:"traffic_split"`
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time"`
Message string `json:"message"`
}
type SwitchStatus string
const (
SwitchStatusPending SwitchStatus = "pending"
SwitchStatusRunning SwitchStatus = "running"
SwitchStatusCompleted SwitchStatus = "completed"
SwitchStatusFailed SwitchStatus = "failed"
SwitchStatusRolledBack SwitchStatus = "rolled_back"
)
func NewModelSwitcher(
registry *ModelRegistry,
resourceMgr *ResourceManager,
router *RequestRouter,
) *ModelSwitcher {
return &ModelSwitcher{
registry: registry,
resourceMgr: resourceMgr,
router: router,
}
}
// Switch 执行模型切换
func (ms *ModelSwitcher) Switch(ctx context.Context, req *SwitchRequest) (*SwitchProgress, error) {
progress := &SwitchProgress{
RequestID: fmt.Sprintf("switch-%d", time.Now().UnixNano()),
Status: SwitchStatusPending,
OldModel: req.OldModel,
NewModel: req.NewModel,
TrafficSplit: 0,
StartTime: time.Now(),
}
switch req.Strategy {
case SwitchStrategyBlueGreen:
return ms.blueGreenSwitch(ctx, req, progress)
case SwitchStrategyCanary:
return ms.canarySwitch(ctx, req, progress)
case SwitchStrategyRolling:
return ms.rollingSwitch(ctx, req, progress)
default:
return nil, errors.New("unknown switch strategy")
}
}
// blueGreenSwitch 蓝绿切换
func (ms *ModelSwitcher) blueGreenSwitch(
ctx context.Context,
req *SwitchRequest,
progress *SwitchProgress,
) (*SwitchProgress, error) {
progress.Status = SwitchStatusRunning
// 1. 确保新模型已就绪
newModel := ms.registry.GetModel(req.NewModel)
if newModel == nil || newModel.Status != ModelStatusReady {
progress.Status = SwitchStatusFailed
progress.Message = "new model not ready"
return progress, errors.New("new model not ready")
}
newInstances := ms.registry.GetHealthyInstances(req.NewModel)
if len(newInstances) == 0 {
progress.Status = SwitchStatusFailed
progress.Message = "no healthy instances for new model"
return progress, errors.New("no healthy instances")
}
// 2. 更新路由规则,将流量切换到新模型
ms.router.mu.Lock()
ms.updateRoutingForSwitch(req.OldModel, req.NewModel, 1.0)
ms.router.mu.Unlock()
progress.TrafficSplit = 1.0
// 3. 等待旧模型请求处理完成
err := ms.drainOldModel(ctx, req.OldModel, req.Timeout)
if err != nil {
// 回滚
ms.router.mu.Lock()
ms.updateRoutingForSwitch(req.NewModel, req.OldModel, 1.0)
ms.router.mu.Unlock()
progress.Status = SwitchStatusRolledBack
progress.Message = err.Error()
return progress, err
}
// 4. 可选:释放旧模型资源
// ms.releaseModelResources(req.OldModel)
progress.Status = SwitchStatusCompleted
now := time.Now()
progress.EndTime = &now
return progress, nil
}
// canarySwitch 金丝雀切换
func (ms *ModelSwitcher) canarySwitch(
ctx context.Context,
req *SwitchRequest,
progress *SwitchProgress,
) (*SwitchProgress, error) {
progress.Status = SwitchStatusRunning
// 渐进式增加流量
trafficSteps := []float64{0.1, 0.25, 0.5, 0.75, 1.0}
if req.TrafficSplit > 0 {
// 使用自定义步长
trafficSteps = []float64{req.TrafficSplit}
}
for _, split := range trafficSteps {
// 检查上下文是否取消
select {
case <-ctx.Done():
progress.Status = SwitchStatusFailed
return progress, ctx.Err()
default:
}
// 更新流量分配
ms.router.mu.Lock()
ms.updateRoutingForSwitch(req.OldModel, req.NewModel, split)
ms.router.mu.Unlock()
progress.TrafficSplit = split
// 监控新模型表现
if split < 1.0 {
healthy, err := ms.monitorNewModel(ctx, req.NewModel, 30*time.Second)
if !healthy || err != nil {
// 回滚
ms.router.mu.Lock()
ms.updateRoutingForSwitch(req.NewModel, req.OldModel, 1.0)
ms.router.mu.Unlock()
progress.Status = SwitchStatusRolledBack
progress.Message = "new model unhealthy during canary"
return progress, errors.New("canary failed")
}
}
}
progress.Status = SwitchStatusCompleted
now := time.Now()
progress.EndTime = &now
return progress, nil
}
// rollingSwitch 滚动切换
func (ms *ModelSwitcher) rollingSwitch(
ctx context.Context,
req *SwitchRequest,
progress *SwitchProgress,
) (*SwitchProgress, error) {
progress.Status = SwitchStatusRunning
oldInstances := ms.registry.GetInstances(req.OldModel)
if len(oldInstances) == 0 {
progress.Status = SwitchStatusFailed
return progress, errors.New("no old instances to replace")
}
// 逐个替换实例
for i, oldInst := range oldInstances {
select {
case <-ctx.Done():
progress.Status = SwitchStatusFailed
return progress, ctx.Err()
default:
}
// 1. 启动新实例
newInstID := fmt.Sprintf("%s-instance-%d", req.NewModel, i)
err := ms.createModelInstance(ctx, req.NewModel, newInstID)
if err != nil {
continue
}
// 2. 等待新实例就绪
err = ms.waitInstanceReady(ctx, newInstID, 5*time.Minute)
if err != nil {
ms.deleteModelInstance(newInstID)
continue
}
// 3. 将旧实例设为 draining
oldInst.Status = InstanceStatusDraining
// 4. 等待旧实例请求完成
ms.waitInstanceDrained(ctx, oldInst.ID, 2*time.Minute)
// 5. 删除旧实例
ms.deleteModelInstance(oldInst.ID)
progress.TrafficSplit = float64(i+1) / float64(len(oldInstances))
}
progress.Status = SwitchStatusCompleted
now := time.Now()
progress.EndTime = &now
return progress, nil
}
func (ms *ModelSwitcher) updateRoutingForSwitch(oldModel, newModel string, newModelWeight float64) {
// 添加流量分配规则
rule := &RoutingRule{
Name: fmt.Sprintf("switch-%s-to-%s", oldModel, newModel),
Priority: 100, // 高优先级
Conditions: []RoutingCondition{
{Field: "model", Operator: "eq", Value: oldModel},
},
Target: newModel,
Weight: int(newModelWeight * 100),
}
// 移除旧规则
for i, r := range ms.router.routingRules {
if r.Name == rule.Name {
ms.router.routingRules = append(
ms.router.routingRules[:i],
ms.router.routingRules[i+1:]...,
)
break
}
}
if newModelWeight > 0 {
ms.router.routingRules = append(ms.router.routingRules, rule)
}
}
func (ms *ModelSwitcher) drainOldModel(ctx context.Context, modelName string, timeout time.Duration) error {
instances := ms.registry.GetInstances(modelName)
deadline := time.Now().Add(timeout)
for _, inst := range instances {
inst.Status = InstanceStatusDraining
}
// 等待所有活跃请求完成
for {
if time.Now().After(deadline) {
return errors.New("drain timeout")
}
allDrained := true
for _, inst := range instances {
if inst.Metrics != nil && inst.Metrics.RequestsActive > 0 {
allDrained = false
break
}
}
if allDrained {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
continue
}
}
}
func (ms *ModelSwitcher) monitorNewModel(ctx context.Context, modelName string, duration time.Duration) (bool, error) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
deadline := time.Now().Add(duration)
errorCount := 0
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return false, ctx.Err()
case <-ticker.C:
instances := ms.registry.GetHealthyInstances(modelName)
if len(instances) == 0 {
errorCount++
if errorCount >= 3 {
return false, errors.New("no healthy instances")
}
} else {
errorCount = 0
}
}
}
return true, nil
}
func (ms *ModelSwitcher) createModelInstance(ctx context.Context, modelName, instanceID string) error {
// 实际实现
return nil
}
func (ms *ModelSwitcher) deleteModelInstance(instanceID string) error {
// 实际实现
return nil
}
func (ms *ModelSwitcher) waitInstanceReady(ctx context.Context, instanceID string, timeout time.Duration) error {
// 实际实现
return nil
}
func (ms *ModelSwitcher) waitInstanceDrained(ctx context.Context, instanceID string, timeout time.Duration) {
// 实际实现
}
Kubernetes 集成
多模型 CRD
# CRD 定义
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
name: inferencemodels.ai.example.com
spec:
group: ai.example.com
versions:
- name: v1
served: true
storage: true
schema:
openAPIV3Schema:
type: object
properties:
spec:
type: object
required:
- modelName
- framework
properties:
modelName:
type: string
version:
type: string
framework:
type: string
enum: [vllm, tgi, triton, tensorrt-llm]
modelPath:
type: string
replicas:
type: integer
minimum: 0
default: 1
resources:
type: object
properties:
gpuType:
type: string
gpuCount:
type: integer
gpuMemory:
type: string
config:
type: object
properties:
maxSeqLen:
type: integer
maxBatchSize:
type: integer
tensorParallel:
type: integer
quantization:
type: string
dtype:
type: string
autoscaling:
type: object
properties:
enabled:
type: boolean
minReplicas:
type: integer
maxReplicas:
type: integer
metrics:
type: array
items:
type: object
properties:
name:
type: string
target:
type: number
status:
type: object
properties:
phase:
type: string
replicas:
type: integer
readyReplicas:
type: integer
conditions:
type: array
items:
type: object
properties:
type:
type: string
status:
type: string
message:
type: string
lastUpdateTime:
type: string
subresources:
status: {}
scale:
specReplicasPath: .spec.replicas
statusReplicasPath: .status.replicas
scope: Namespaced
names:
plural: inferencemodels
singular: inferencemodel
kind: InferenceModel
shortNames:
- im
---
# 示例资源
apiVersion: ai.example.com/v1
kind: InferenceModel
metadata:
name: llama-70b
namespace: ai-inference
spec:
modelName: meta-llama/Llama-2-70b-chat-hf
version: "1.0"
framework: vllm
modelPath: s3://models/llama-70b
replicas: 2
resources:
gpuType: nvidia.com/a100-80g
gpuCount: 4
gpuMemory: "320Gi"
config:
maxSeqLen: 4096
maxBatchSize: 64
tensorParallel: 4
quantization: "none"
dtype: "fp16"
autoscaling:
enabled: true
minReplicas: 2
maxReplicas: 8
metrics:
- name: qps
target: 100
- name: latency
target: 500
---
# 多模型服务编排
apiVersion: ai.example.com/v1
kind: ModelOrchestrator
metadata:
name: multi-model-service
namespace: ai-inference
spec:
models:
- name: llama-70b
ref: llama-70b
weight: 50
capabilities:
- chat
- completion
- name: llama-7b
ref: llama-7b
weight: 30
capabilities:
- chat
- name: codellama
ref: codellama-34b
weight: 20
capabilities:
- code
routing:
defaultModel: llama-7b
rules:
- name: premium-users
priority: 100
conditions:
- field: user_tier
operator: eq
value: premium
target: llama-70b
- name: code-requests
priority: 90
conditions:
- field: capability
operator: contains
value: code
target: codellama
- name: long-context
priority: 80
conditions:
- field: prompt_length
operator: gt
value: 2000
target: llama-70b
fallbackChain:
- llama-70b
- llama-7b
loadBalancing:
strategy: least_connections
healthCheck:
path: /health
interval: 10s
timeout: 5s
failureThreshold: 3
gateway:
replicas: 3
resources:
cpu: "2"
memory: "4Gi"
Operator 控制器
package controller
import (
"context"
"fmt"
"time"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
)
// InferenceModelReconciler 推理模型控制器
type InferenceModelReconciler struct {
client.Client
Scheme *runtime.Scheme
}
// InferenceModel CRD 结构
type InferenceModel struct {
metav1.TypeMeta `json:",inline"`
metav1.ObjectMeta `json:"metadata,omitempty"`
Spec InferenceModelSpec `json:"spec,omitempty"`
Status InferenceModelStatus `json:"status,omitempty"`
}
type InferenceModelSpec struct {
ModelName string `json:"modelName"`
Version string `json:"version,omitempty"`
Framework string `json:"framework"`
ModelPath string `json:"modelPath"`
Replicas int32 `json:"replicas"`
Resources ResourceSpec `json:"resources"`
Config ModelConfig `json:"config,omitempty"`
Autoscaling *AutoscalingSpec `json:"autoscaling,omitempty"`
}
type ResourceSpec struct {
GPUType string `json:"gpuType"`
GPUCount int `json:"gpuCount"`
GPUMemory string `json:"gpuMemory"`
}
type AutoscalingSpec struct {
Enabled bool `json:"enabled"`
MinReplicas int32 `json:"minReplicas"`
MaxReplicas int32 `json:"maxReplicas"`
Metrics []Metric `json:"metrics"`
}
type Metric struct {
Name string `json:"name"`
Target float64 `json:"target"`
}
type InferenceModelStatus struct {
Phase string `json:"phase"`
Replicas int32 `json:"replicas"`
ReadyReplicas int32 `json:"readyReplicas"`
Conditions []Condition `json:"conditions,omitempty"`
}
type Condition struct {
Type string `json:"type"`
Status string `json:"status"`
Message string `json:"message"`
LastUpdateTime string `json:"lastUpdateTime"`
}
// Reconcile 协调函数
func (r *InferenceModelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
logger := log.FromContext(ctx)
// 获取 InferenceModel 资源
var model InferenceModel
if err := r.Get(ctx, req.NamespacedName, &model); err != nil {
if errors.IsNotFound(err) {
return ctrl.Result{}, nil
}
return ctrl.Result{}, err
}
logger.Info("Reconciling InferenceModel", "name", model.Name)
// 确保 Deployment 存在
if err := r.ensureDeployment(ctx, &model); err != nil {
logger.Error(err, "Failed to ensure Deployment")
return ctrl.Result{RequeueAfter: time.Minute}, err
}
// 确保 Service 存在
if err := r.ensureService(ctx, &model); err != nil {
logger.Error(err, "Failed to ensure Service")
return ctrl.Result{RequeueAfter: time.Minute}, err
}
// 确保 HPA 存在(如果启用了自动扩缩容)
if model.Spec.Autoscaling != nil && model.Spec.Autoscaling.Enabled {
if err := r.ensureHPA(ctx, &model); err != nil {
logger.Error(err, "Failed to ensure HPA")
return ctrl.Result{RequeueAfter: time.Minute}, err
}
}
// 更新状态
if err := r.updateStatus(ctx, &model); err != nil {
logger.Error(err, "Failed to update status")
return ctrl.Result{RequeueAfter: time.Minute}, err
}
return ctrl.Result{RequeueAfter: 30 * time.Second}, nil
}
func (r *InferenceModelReconciler) ensureDeployment(ctx context.Context, model *InferenceModel) error {
deployment := r.buildDeployment(model)
var existing appsv1.Deployment
err := r.Get(ctx, client.ObjectKeyFromObject(deployment), &existing)
if errors.IsNotFound(err) {
return r.Create(ctx, deployment)
}
if err != nil {
return err
}
// 更新 Deployment
existing.Spec = deployment.Spec
return r.Update(ctx, &existing)
}
func (r *InferenceModelReconciler) buildDeployment(model *InferenceModel) *appsv1.Deployment {
labels := map[string]string{
"app": model.Name,
"model": model.Spec.ModelName,
}
// 构建容器
container := corev1.Container{
Name: "inference",
Image: r.getFrameworkImage(model.Spec.Framework),
Ports: []corev1.ContainerPort{
{ContainerPort: 8000, Name: "http"},
},
Env: r.buildEnvVars(model),
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
"nvidia.com/gpu": resource.MustParse(fmt.Sprintf("%d", model.Spec.Resources.GPUCount)),
},
},
VolumeMounts: []corev1.VolumeMount{
{Name: "model-cache", MountPath: "/models"},
{Name: "shm", MountPath: "/dev/shm"},
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromInt(8000),
},
},
InitialDelaySeconds: 30,
PeriodSeconds: 10,
},
}
return &appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Name: model.Name,
Namespace: model.Namespace,
Labels: labels,
},
Spec: appsv1.DeploymentSpec{
Replicas: &model.Spec.Replicas,
Selector: &metav1.LabelSelector{
MatchLabels: labels,
},
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: labels,
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{container},
Volumes: []corev1.Volume{
{
Name: "model-cache",
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: model.Name + "-models",
},
},
},
{
Name: "shm",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: corev1.StorageMediumMemory,
SizeLimit: resource.NewQuantity(32*1024*1024*1024, resource.BinarySI),
},
},
},
},
NodeSelector: map[string]string{
"nvidia.com/gpu.product": model.Spec.Resources.GPUType,
},
Tolerations: []corev1.Toleration{
{
Key: "nvidia.com/gpu",
Operator: corev1.TolerationOpExists,
Effect: corev1.TaintEffectNoSchedule,
},
},
},
},
},
}
}
func (r *InferenceModelReconciler) buildEnvVars(model *InferenceModel) []corev1.EnvVar {
envs := []corev1.EnvVar{
{Name: "MODEL_NAME", Value: model.Spec.ModelName},
{Name: "MODEL_PATH", Value: model.Spec.ModelPath},
{Name: "MAX_SEQ_LEN", Value: fmt.Sprintf("%d", model.Spec.Config.MaxSeqLen)},
{Name: "MAX_BATCH_SIZE", Value: fmt.Sprintf("%d", model.Spec.Config.MaxBatchSize)},
{Name: "TENSOR_PARALLEL_SIZE", Value: fmt.Sprintf("%d", model.Spec.Config.TensorParallel)},
}
if model.Spec.Config.Quantization != "" {
envs = append(envs, corev1.EnvVar{
Name: "QUANTIZATION", Value: model.Spec.Config.Quantization,
})
}
return envs
}
func (r *InferenceModelReconciler) getFrameworkImage(framework string) string {
images := map[string]string{
"vllm": "vllm/vllm-openai:latest",
"tgi": "ghcr.io/huggingface/text-generation-inference:latest",
"triton": "nvcr.io/nvidia/tritonserver:23.12-trtllm-python-py3",
"tensorrt-llm": "nvcr.io/nvidia/tensorrt-llm:24.01",
}
return images[framework]
}
func (r *InferenceModelReconciler) ensureService(ctx context.Context, model *InferenceModel) error {
service := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: model.Name,
Namespace: model.Namespace,
},
Spec: corev1.ServiceSpec{
Selector: map[string]string{
"app": model.Name,
},
Ports: []corev1.ServicePort{
{
Name: "http",
Port: 80,
TargetPort: intstr.FromInt(8000),
},
},
},
}
var existing corev1.Service
err := r.Get(ctx, client.ObjectKeyFromObject(service), &existing)
if errors.IsNotFound(err) {
return r.Create(ctx, service)
}
return err
}
func (r *InferenceModelReconciler) ensureHPA(ctx context.Context, model *InferenceModel) error {
// 构建 HPA
hpa := &autoscalingv2.HorizontalPodAutoscaler{
ObjectMeta: metav1.ObjectMeta{
Name: model.Name,
Namespace: model.Namespace,
},
Spec: autoscalingv2.HorizontalPodAutoscalerSpec{
ScaleTargetRef: autoscalingv2.CrossVersionObjectReference{
APIVersion: "apps/v1",
Kind: "Deployment",
Name: model.Name,
},
MinReplicas: &model.Spec.Autoscaling.MinReplicas,
MaxReplicas: model.Spec.Autoscaling.MaxReplicas,
Metrics: r.buildHPAMetrics(model.Spec.Autoscaling.Metrics),
},
}
var existing autoscalingv2.HorizontalPodAutoscaler
err := r.Get(ctx, client.ObjectKeyFromObject(hpa), &existing)
if errors.IsNotFound(err) {
return r.Create(ctx, hpa)
}
if err != nil {
return err
}
existing.Spec = hpa.Spec
return r.Update(ctx, &existing)
}
func (r *InferenceModelReconciler) buildHPAMetrics(metrics []Metric) []autoscalingv2.MetricSpec {
result := make([]autoscalingv2.MetricSpec, 0, len(metrics))
for _, m := range metrics {
switch m.Name {
case "qps":
result = append(result, autoscalingv2.MetricSpec{
Type: autoscalingv2.PodsMetricSourceType,
Pods: &autoscalingv2.PodsMetricSource{
Metric: autoscalingv2.MetricIdentifier{
Name: "requests_per_second",
},
Target: autoscalingv2.MetricTarget{
Type: autoscalingv2.AverageValueMetricType,
AverageValue: resource.NewQuantity(int64(m.Target), resource.DecimalSI),
},
},
})
case "gpu_utilization":
result = append(result, autoscalingv2.MetricSpec{
Type: autoscalingv2.PodsMetricSourceType,
Pods: &autoscalingv2.PodsMetricSource{
Metric: autoscalingv2.MetricIdentifier{
Name: "gpu_utilization_percent",
},
Target: autoscalingv2.MetricTarget{
Type: autoscalingv2.AverageValueMetricType,
AverageValue: resource.NewQuantity(int64(m.Target), resource.DecimalSI),
},
},
})
}
}
return result
}
func (r *InferenceModelReconciler) updateStatus(ctx context.Context, model *InferenceModel) error {
// 获取 Deployment 状态
var deployment appsv1.Deployment
if err := r.Get(ctx, client.ObjectKey{Name: model.Name, Namespace: model.Namespace}, &deployment); err != nil {
return err
}
model.Status.Replicas = deployment.Status.Replicas
model.Status.ReadyReplicas = deployment.Status.ReadyReplicas
if deployment.Status.ReadyReplicas == *deployment.Spec.Replicas {
model.Status.Phase = "Ready"
} else if deployment.Status.ReadyReplicas > 0 {
model.Status.Phase = "PartiallyReady"
} else {
model.Status.Phase = "Pending"
}
return r.Status().Update(ctx, model)
}
// SetupWithManager 设置控制器
func (r *InferenceModelReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&InferenceModel{}).
Owns(&appsv1.Deployment{}).
Owns(&corev1.Service{}).
Complete(r)
}
小结
本章深入探讨了多模型服务的核心技术:
- 架构设计:模型注册中心、请求路由器、资源管理的整体架构
- 智能路由:基于规则、能力、负载的多维度路由策略
- 资源管理:GPU 调度、装箱算法、自动扩缩容
- 模型切换:蓝绿部署、金丝雀发布、滚动更新策略
- Kubernetes 集成:CRD 设计、Operator 控制器实现
多模型服务是大规模 AI 平台的基础设施,通过合理的架构设计和调度策略,可以实现资源的高效利用和服务的稳定可靠。
至此,推理服务章节全部完成。下一章我们将进入 异构计算,探讨 GPU、TPU、FPGA 等不同计算设备的调度与管理。