HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

多模型服务

概述

在实际生产环境中,往往需要同时部署和管理多个大模型,包括不同规模的 LLM、多模态模型、专用模型等。多模型服务面临着资源调度、模型切换、负载均衡、成本优化等挑战。本章深入讲解多模型服务的架构设计与实现。

多模型架构设计

整体架构

多模型服务架构:

┌─────────────────────────────────────────────────────────────────┐
│                        API Gateway                               │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐  ┌─────────┐            │
│  │  Auth   │  │ Router  │  │  Rate   │  │ Circuit │            │
│  │         │  │         │  │ Limiter │  │ Breaker │            │
│  └─────────┘  └─────────┘  └─────────┘  └─────────┘            │
└─────────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────────┐
│                     Model Orchestrator                           │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐              │
│  │   Model     │  │   Request   │  │   Resource  │              │
│  │  Registry   │  │   Router    │  │   Manager   │              │
│  └─────────────┘  └─────────────┘  └─────────────┘              │
└─────────────────────────────────────────────────────────────────┘
                              │
        ┌─────────────────────┼─────────────────────┐
        ▼                     ▼                     ▼
┌──────────────┐     ┌──────────────┐     ┌──────────────┐
│  Model Pool  │     │  Model Pool  │     │  Model Pool  │
│  (GPT-4级)   │     │  (7B模型)    │     │  (多模态)    │
│ ┌──────────┐ │     │ ┌──────────┐ │     │ ┌──────────┐ │
│ │Instance 1│ │     │ │Instance 1│ │     │ │Instance 1│ │
│ ├──────────┤ │     │ ├──────────┤ │     │ ├──────────┤ │
│ │Instance 2│ │     │ │Instance 2│ │     │ │Instance 2│ │
│ └──────────┘ │     │ └──────────┘ │     │ └──────────┘ │
└──────────────┘     └──────────────┘     └──────────────┘

模型注册中心

package multimodel

import (
    "context"
    "encoding/json"
    "fmt"
    "sync"
    "time"

    "go.etcd.io/etcd/client/v3"
)

// ModelRegistry 模型注册中心
type ModelRegistry struct {
    etcdClient  *clientv3.Client
    models      map[string]*ModelInfo
    instances   map[string][]*ModelInstance
    mu          sync.RWMutex
    watchCh     clientv3.WatchChan
}

// ModelInfo 模型信息
type ModelInfo struct {
    Name           string            `json:"name"`
    Version        string            `json:"version"`
    Type           ModelType         `json:"type"`
    Framework      string            `json:"framework"`      // vllm, tgi, triton
    ModelPath      string            `json:"model_path"`
    Config         *ModelConfig      `json:"config"`
    Requirements   *ResourceRequirements `json:"requirements"`
    Capabilities   []string          `json:"capabilities"`   // chat, completion, embedding
    Status         ModelStatus       `json:"status"`
    CreateTime     time.Time         `json:"create_time"`
    UpdateTime     time.Time         `json:"update_time"`
}

type ModelType string

const (
    ModelTypeLLM        ModelType = "llm"
    ModelTypeEmbedding  ModelType = "embedding"
    ModelTypeMultimodal ModelType = "multimodal"
    ModelTypeSpecialized ModelType = "specialized"
)

type ModelConfig struct {
    MaxSeqLen       int     `json:"max_seq_len"`
    MaxBatchSize    int     `json:"max_batch_size"`
    TensorParallel  int     `json:"tensor_parallel"`
    Quantization    string  `json:"quantization"`     // none, int8, int4
    DType           string  `json:"dtype"`            // fp16, bf16, fp32
}

type ResourceRequirements struct {
    GPUMemoryGB    float64 `json:"gpu_memory_gb"`
    GPUCount       int     `json:"gpu_count"`
    CPUCores       int     `json:"cpu_cores"`
    MemoryGB       float64 `json:"memory_gb"`
    GPUType        string  `json:"gpu_type"`         // A100, H100, etc.
}

type ModelStatus string

const (
    ModelStatusPending   ModelStatus = "pending"
    ModelStatusLoading   ModelStatus = "loading"
    ModelStatusReady     ModelStatus = "ready"
    ModelStatusUnloading ModelStatus = "unloading"
    ModelStatusFailed    ModelStatus = "failed"
)

// ModelInstance 模型实例
type ModelInstance struct {
    ID           string          `json:"id"`
    ModelName    string          `json:"model_name"`
    NodeID       string          `json:"node_id"`
    Endpoint     string          `json:"endpoint"`
    GPUs         []int           `json:"gpus"`
    Status       InstanceStatus  `json:"status"`
    Health       *HealthStatus   `json:"health"`
    Metrics      *InstanceMetrics `json:"metrics"`
    CreateTime   time.Time       `json:"create_time"`
    LastHeartbeat time.Time      `json:"last_heartbeat"`
}

type InstanceStatus string

const (
    InstanceStatusStarting   InstanceStatus = "starting"
    InstanceStatusRunning    InstanceStatus = "running"
    InstanceStatusDraining   InstanceStatus = "draining"
    InstanceStatusStopped    InstanceStatus = "stopped"
    InstanceStatusUnhealthy  InstanceStatus = "unhealthy"
)

type HealthStatus struct {
    Healthy     bool      `json:"healthy"`
    LastCheck   time.Time `json:"last_check"`
    FailCount   int       `json:"fail_count"`
    Message     string    `json:"message"`
}

type InstanceMetrics struct {
    RequestsTotal     int64   `json:"requests_total"`
    RequestsActive    int     `json:"requests_active"`
    TokensGenerated   int64   `json:"tokens_generated"`
    AverageLatencyMs  float64 `json:"average_latency_ms"`
    GPUUtilization    float64 `json:"gpu_utilization"`
    GPUMemoryUsed     float64 `json:"gpu_memory_used"`
}

func NewModelRegistry(etcdEndpoints []string) (*ModelRegistry, error) {
    client, err := clientv3.New(clientv3.Config{
        Endpoints:   etcdEndpoints,
        DialTimeout: 5 * time.Second,
    })
    if err != nil {
        return nil, err
    }

    registry := &ModelRegistry{
        etcdClient: client,
        models:     make(map[string]*ModelInfo),
        instances:  make(map[string][]*ModelInstance),
    }

    // 启动监听
    go registry.watchChanges()

    return registry, nil
}

// RegisterModel 注册模型
func (r *ModelRegistry) RegisterModel(ctx context.Context, model *ModelInfo) error {
    r.mu.Lock()
    defer r.mu.Unlock()

    model.CreateTime = time.Now()
    model.UpdateTime = time.Now()
    model.Status = ModelStatusPending

    data, err := json.Marshal(model)
    if err != nil {
        return err
    }

    key := fmt.Sprintf("/models/%s", model.Name)
    _, err = r.etcdClient.Put(ctx, key, string(data))
    if err != nil {
        return err
    }

    r.models[model.Name] = model
    return nil
}

// RegisterInstance 注册实例
func (r *ModelRegistry) RegisterInstance(ctx context.Context, instance *ModelInstance) error {
    r.mu.Lock()
    defer r.mu.Unlock()

    instance.CreateTime = time.Now()
    instance.LastHeartbeat = time.Now()

    data, err := json.Marshal(instance)
    if err != nil {
        return err
    }

    key := fmt.Sprintf("/instances/%s/%s", instance.ModelName, instance.ID)

    // 使用租约,实现自动过期
    lease, err := r.etcdClient.Grant(ctx, 30) // 30秒 TTL
    if err != nil {
        return err
    }

    _, err = r.etcdClient.Put(ctx, key, string(data), clientv3.WithLease(lease.ID))
    if err != nil {
        return err
    }

    // 启动续租
    go r.keepAlive(ctx, lease.ID, instance)

    r.instances[instance.ModelName] = append(r.instances[instance.ModelName], instance)
    return nil
}

// keepAlive 保持实例活跃
func (r *ModelRegistry) keepAlive(ctx context.Context, leaseID clientv3.LeaseID, instance *ModelInstance) {
    ch, err := r.etcdClient.KeepAlive(ctx, leaseID)
    if err != nil {
        return
    }

    for {
        select {
        case <-ctx.Done():
            return
        case resp, ok := <-ch:
            if !ok {
                return
            }
            if resp != nil {
                instance.LastHeartbeat = time.Now()
            }
        }
    }
}

// GetModel 获取模型信息
func (r *ModelRegistry) GetModel(name string) *ModelInfo {
    r.mu.RLock()
    defer r.mu.RUnlock()
    return r.models[name]
}

// GetInstances 获取模型实例
func (r *ModelRegistry) GetInstances(modelName string) []*ModelInstance {
    r.mu.RLock()
    defer r.mu.RUnlock()

    instances := r.instances[modelName]
    result := make([]*ModelInstance, 0)

    for _, inst := range instances {
        if inst.Status == InstanceStatusRunning {
            result = append(result, inst)
        }
    }

    return result
}

// GetHealthyInstances 获取健康实例
func (r *ModelRegistry) GetHealthyInstances(modelName string) []*ModelInstance {
    instances := r.GetInstances(modelName)
    result := make([]*ModelInstance, 0)

    for _, inst := range instances {
        if inst.Health != nil && inst.Health.Healthy {
            result = append(result, inst)
        }
    }

    return result
}

// ListModels 列出所有模型
func (r *ModelRegistry) ListModels() []*ModelInfo {
    r.mu.RLock()
    defer r.mu.RUnlock()

    result := make([]*ModelInfo, 0, len(r.models))
    for _, model := range r.models {
        result = append(result, model)
    }
    return result
}

// watchChanges 监听变更
func (r *ModelRegistry) watchChanges() {
    r.watchCh = r.etcdClient.Watch(context.Background(), "/", clientv3.WithPrefix())

    for resp := range r.watchCh {
        for _, event := range resp.Events {
            r.handleEvent(event)
        }
    }
}

func (r *ModelRegistry) handleEvent(event *clientv3.Event) {
    r.mu.Lock()
    defer r.mu.Unlock()

    key := string(event.Kv.Key)

    switch event.Type {
    case clientv3.EventTypePut:
        // 更新或新增
        if len(key) > 8 && key[:8] == "/models/" {
            var model ModelInfo
            if err := json.Unmarshal(event.Kv.Value, &model); err == nil {
                r.models[model.Name] = &model
            }
        } else if len(key) > 11 && key[:11] == "/instances/" {
            var instance ModelInstance
            if err := json.Unmarshal(event.Kv.Value, &instance); err == nil {
                r.updateInstance(&instance)
            }
        }

    case clientv3.EventTypeDelete:
        // 删除
        if len(key) > 11 && key[:11] == "/instances/" {
            r.removeInstance(key)
        }
    }
}

func (r *ModelRegistry) updateInstance(instance *ModelInstance) {
    instances := r.instances[instance.ModelName]
    for i, inst := range instances {
        if inst.ID == instance.ID {
            r.instances[instance.ModelName][i] = instance
            return
        }
    }
    r.instances[instance.ModelName] = append(instances, instance)
}

func (r *ModelRegistry) removeInstance(key string) {
    // 从 key 解析 modelName 和 instanceID
    // /instances/{modelName}/{instanceID}
    // 简化处理
}

请求路由器

package multimodel

import (
    "context"
    "errors"
    "hash/fnv"
    "sort"
    "sync"
    "time"
)

// RequestRouter 请求路由器
type RequestRouter struct {
    registry        *ModelRegistry
    loadBalancers   map[string]*LoadBalancer
    routingRules    []*RoutingRule
    fallbackChain   []string
    mu              sync.RWMutex
}

// RoutingRule 路由规则
type RoutingRule struct {
    Name        string              `json:"name"`
    Priority    int                 `json:"priority"`
    Conditions  []RoutingCondition  `json:"conditions"`
    Target      string              `json:"target"`       // 目标模型
    Weight      int                 `json:"weight"`       // 权重(用于流量分配)
}

type RoutingCondition struct {
    Field    string      `json:"field"`    // model, capability, prompt_length, user_tier
    Operator string      `json:"operator"` // eq, ne, gt, lt, in, contains
    Value    interface{} `json:"value"`
}

// InferenceRequest 推理请求
type InferenceRequest struct {
    ID           string                 `json:"id"`
    Model        string                 `json:"model"`       // 指定模型,可为空
    Messages     []Message              `json:"messages"`
    MaxTokens    int                    `json:"max_tokens"`
    Temperature  float32                `json:"temperature"`
    Stream       bool                   `json:"stream"`
    UserTier     string                 `json:"user_tier"`   // free, pro, enterprise
    Capabilities []string               `json:"capabilities"` // 需要的能力
    Metadata     map[string]interface{} `json:"metadata"`
}

type Message struct {
    Role    string `json:"role"`
    Content string `json:"content"`
}

func NewRequestRouter(registry *ModelRegistry) *RequestRouter {
    return &RequestRouter{
        registry:      registry,
        loadBalancers: make(map[string]*LoadBalancer),
        routingRules:  make([]*RoutingRule, 0),
        fallbackChain: []string{},
    }
}

// Route 路由请求
func (r *RequestRouter) Route(ctx context.Context, req *InferenceRequest) (*ModelInstance, error) {
    r.mu.RLock()
    defer r.mu.RUnlock()

    // 1. 如果指定了模型,直接路由
    if req.Model != "" {
        return r.routeToModel(ctx, req.Model, req)
    }

    // 2. 应用路由规则
    targetModel := r.applyRoutingRules(req)
    if targetModel != "" {
        instance, err := r.routeToModel(ctx, targetModel, req)
        if err == nil {
            return instance, nil
        }
        // 路由失败,尝试 fallback
    }

    // 3. 基于能力路由
    if len(req.Capabilities) > 0 {
        instance, err := r.routeByCapability(ctx, req)
        if err == nil {
            return instance, nil
        }
    }

    // 4. Fallback 链
    for _, modelName := range r.fallbackChain {
        instance, err := r.routeToModel(ctx, modelName, req)
        if err == nil {
            return instance, nil
        }
    }

    return nil, errors.New("no available model instance")
}

// routeToModel 路由到指定模型
func (r *RequestRouter) routeToModel(
    ctx context.Context,
    modelName string,
    req *InferenceRequest,
) (*ModelInstance, error) {

    model := r.registry.GetModel(modelName)
    if model == nil {
        return nil, fmt.Errorf("model %s not found", modelName)
    }

    if model.Status != ModelStatusReady {
        return nil, fmt.Errorf("model %s not ready", modelName)
    }

    instances := r.registry.GetHealthyInstances(modelName)
    if len(instances) == 0 {
        return nil, fmt.Errorf("no healthy instances for model %s", modelName)
    }

    // 获取或创建负载均衡器
    lb := r.getLoadBalancer(modelName)
    return lb.Select(instances, req)
}

// applyRoutingRules 应用路由规则
func (r *RequestRouter) applyRoutingRules(req *InferenceRequest) string {
    // 按优先级排序
    rules := make([]*RoutingRule, len(r.routingRules))
    copy(rules, r.routingRules)
    sort.Slice(rules, func(i, j int) bool {
        return rules[i].Priority > rules[j].Priority
    })

    for _, rule := range rules {
        if r.matchRule(rule, req) {
            return rule.Target
        }
    }

    return ""
}

// matchRule 匹配规则
func (r *RequestRouter) matchRule(rule *RoutingRule, req *InferenceRequest) bool {
    for _, cond := range rule.Conditions {
        if !r.matchCondition(cond, req) {
            return false
        }
    }
    return true
}

func (r *RequestRouter) matchCondition(cond RoutingCondition, req *InferenceRequest) bool {
    var value interface{}

    switch cond.Field {
    case "model":
        value = req.Model
    case "user_tier":
        value = req.UserTier
    case "prompt_length":
        value = r.calculatePromptLength(req)
    case "capability":
        value = req.Capabilities
    default:
        if req.Metadata != nil {
            value = req.Metadata[cond.Field]
        }
    }

    return r.compareValues(cond.Operator, value, cond.Value)
}

func (r *RequestRouter) compareValues(operator string, actual, expected interface{}) bool {
    switch operator {
    case "eq":
        return actual == expected
    case "ne":
        return actual != expected
    case "gt":
        if a, ok := actual.(int); ok {
            if e, ok := expected.(int); ok {
                return a > e
            }
        }
    case "lt":
        if a, ok := actual.(int); ok {
            if e, ok := expected.(int); ok {
                return a < e
            }
        }
    case "in":
        if expected, ok := expected.([]interface{}); ok {
            for _, e := range expected {
                if actual == e {
                    return true
                }
            }
        }
    case "contains":
        if actual, ok := actual.([]string); ok {
            if expected, ok := expected.(string); ok {
                for _, a := range actual {
                    if a == expected {
                        return true
                    }
                }
            }
        }
    }
    return false
}

// routeByCapability 按能力路由
func (r *RequestRouter) routeByCapability(
    ctx context.Context,
    req *InferenceRequest,
) (*ModelInstance, error) {

    models := r.registry.ListModels()

    // 找到满足所有能力的模型
    candidateModels := make([]*ModelInfo, 0)
    for _, model := range models {
        if model.Status != ModelStatusReady {
            continue
        }
        if r.hasAllCapabilities(model.Capabilities, req.Capabilities) {
            candidateModels = append(candidateModels, model)
        }
    }

    if len(candidateModels) == 0 {
        return nil, errors.New("no model with required capabilities")
    }

    // 选择最合适的模型(简单策略:选择资源需求最小的)
    sort.Slice(candidateModels, func(i, j int) bool {
        return candidateModels[i].Requirements.GPUMemoryGB <
               candidateModels[j].Requirements.GPUMemoryGB
    })

    return r.routeToModel(ctx, candidateModels[0].Name, req)
}

func (r *RequestRouter) hasAllCapabilities(modelCaps, reqCaps []string) bool {
    capSet := make(map[string]bool)
    for _, c := range modelCaps {
        capSet[c] = true
    }

    for _, c := range reqCaps {
        if !capSet[c] {
            return false
        }
    }
    return true
}

func (r *RequestRouter) calculatePromptLength(req *InferenceRequest) int {
    total := 0
    for _, msg := range req.Messages {
        total += len(msg.Content)
    }
    return total
}

func (r *RequestRouter) getLoadBalancer(modelName string) *LoadBalancer {
    if lb, exists := r.loadBalancers[modelName]; exists {
        return lb
    }

    lb := NewLoadBalancer(LoadBalancerLeastConnections)
    r.loadBalancers[modelName] = lb
    return lb
}

// AddRoutingRule 添加路由规则
func (r *RequestRouter) AddRoutingRule(rule *RoutingRule) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.routingRules = append(r.routingRules, rule)
}

// SetFallbackChain 设置 fallback 链
func (r *RequestRouter) SetFallbackChain(models []string) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.fallbackChain = models
}

// LoadBalancer 负载均衡器
type LoadBalancer struct {
    strategy    LoadBalancerStrategy
    connections map[string]int
    mu          sync.RWMutex
}

type LoadBalancerStrategy int

const (
    LoadBalancerRoundRobin LoadBalancerStrategy = iota
    LoadBalancerLeastConnections
    LoadBalancerWeightedRandom
    LoadBalancerConsistentHash
)

func NewLoadBalancer(strategy LoadBalancerStrategy) *LoadBalancer {
    return &LoadBalancer{
        strategy:    strategy,
        connections: make(map[string]int),
    }
}

// Select 选择实例
func (lb *LoadBalancer) Select(
    instances []*ModelInstance,
    req *InferenceRequest,
) (*ModelInstance, error) {

    if len(instances) == 0 {
        return nil, errors.New("no instances available")
    }

    switch lb.strategy {
    case LoadBalancerRoundRobin:
        return lb.selectRoundRobin(instances)
    case LoadBalancerLeastConnections:
        return lb.selectLeastConnections(instances)
    case LoadBalancerWeightedRandom:
        return lb.selectWeightedRandom(instances)
    case LoadBalancerConsistentHash:
        return lb.selectConsistentHash(instances, req)
    default:
        return instances[0], nil
    }
}

func (lb *LoadBalancer) selectLeastConnections(instances []*ModelInstance) (*ModelInstance, error) {
    lb.mu.RLock()
    defer lb.mu.RUnlock()

    var selected *ModelInstance
    minConn := int(^uint(0) >> 1)

    for _, inst := range instances {
        conn := lb.connections[inst.ID]
        if inst.Metrics != nil {
            conn += inst.Metrics.RequestsActive
        }

        if conn < minConn {
            minConn = conn
            selected = inst
        }
    }

    return selected, nil
}

func (lb *LoadBalancer) selectWeightedRandom(instances []*ModelInstance) (*ModelInstance, error) {
    // 基于 GPU 利用率的权重
    totalWeight := 0.0
    weights := make([]float64, len(instances))

    for i, inst := range instances {
        // 利用率越低,权重越高
        utilization := 0.5
        if inst.Metrics != nil {
            utilization = inst.Metrics.GPUUtilization
        }
        weights[i] = 1.0 - utilization
        totalWeight += weights[i]
    }

    // 随机选择
    r := randomFloat() * totalWeight
    cumWeight := 0.0

    for i, w := range weights {
        cumWeight += w
        if r <= cumWeight {
            return instances[i], nil
        }
    }

    return instances[len(instances)-1], nil
}

func (lb *LoadBalancer) selectConsistentHash(
    instances []*ModelInstance,
    req *InferenceRequest,
) (*ModelInstance, error) {

    // 基于请求 ID 的一致性哈希
    h := fnv.New32a()
    h.Write([]byte(req.ID))
    hash := h.Sum32()

    idx := int(hash) % len(instances)
    return instances[idx], nil
}

func (lb *LoadBalancer) selectRoundRobin(instances []*ModelInstance) (*ModelInstance, error) {
    // 简化实现
    return instances[0], nil
}

// IncrementConnections 增加连接数
func (lb *LoadBalancer) IncrementConnections(instanceID string) {
    lb.mu.Lock()
    defer lb.mu.Unlock()
    lb.connections[instanceID]++
}

// DecrementConnections 减少连接数
func (lb *LoadBalancer) DecrementConnections(instanceID string) {
    lb.mu.Lock()
    defer lb.mu.Unlock()
    if lb.connections[instanceID] > 0 {
        lb.connections[instanceID]--
    }
}

func randomFloat() float64 {
    return 0.5 // 简化实现
}

资源管理与调度

GPU 资源管理器

package multimodel

import (
    "context"
    "errors"
    "sort"
    "sync"
    "time"
)

// ResourceManager 资源管理器
type ResourceManager struct {
    nodes       map[string]*NodeInfo
    allocations map[string]*ResourceAllocation
    scheduler   *ModelScheduler
    mu          sync.RWMutex
}

// NodeInfo 节点信息
type NodeInfo struct {
    ID           string       `json:"id"`
    Hostname     string       `json:"hostname"`
    GPUs         []*GPUInfo   `json:"gpus"`
    CPUCores     int          `json:"cpu_cores"`
    MemoryGB     float64      `json:"memory_gb"`
    Labels       map[string]string `json:"labels"`
    Status       NodeStatus   `json:"status"`
    LastHeartbeat time.Time   `json:"last_heartbeat"`
}

type NodeStatus string

const (
    NodeStatusReady      NodeStatus = "ready"
    NodeStatusNotReady   NodeStatus = "not_ready"
    NodeStatusDraining   NodeStatus = "draining"
)

type GPUInfo struct {
    Index        int     `json:"index"`
    UUID         string  `json:"uuid"`
    Name         string  `json:"name"`
    MemoryTotal  float64 `json:"memory_total_gb"`
    MemoryUsed   float64 `json:"memory_used_gb"`
    Utilization  float64 `json:"utilization"`
    Allocated    bool    `json:"allocated"`
    AllocatedTo  string  `json:"allocated_to"` // 模型实例 ID
}

// ResourceAllocation 资源分配
type ResourceAllocation struct {
    ID          string    `json:"id"`
    ModelName   string    `json:"model_name"`
    InstanceID  string    `json:"instance_id"`
    NodeID      string    `json:"node_id"`
    GPUs        []int     `json:"gpus"`
    CPUCores    int       `json:"cpu_cores"`
    MemoryGB    float64   `json:"memory_gb"`
    CreateTime  time.Time `json:"create_time"`
}

func NewResourceManager() *ResourceManager {
    rm := &ResourceManager{
        nodes:       make(map[string]*NodeInfo),
        allocations: make(map[string]*ResourceAllocation),
        scheduler:   NewModelScheduler(),
    }
    return rm
}

// RegisterNode 注册节点
func (rm *ResourceManager) RegisterNode(node *NodeInfo) error {
    rm.mu.Lock()
    defer rm.mu.Unlock()

    node.Status = NodeStatusReady
    node.LastHeartbeat = time.Now()
    rm.nodes[node.ID] = node

    return nil
}

// UpdateNodeHeartbeat 更新节点心跳
func (rm *ResourceManager) UpdateNodeHeartbeat(nodeID string, gpuMetrics []*GPUMetrics) error {
    rm.mu.Lock()
    defer rm.mu.Unlock()

    node, exists := rm.nodes[nodeID]
    if !exists {
        return errors.New("node not found")
    }

    node.LastHeartbeat = time.Now()

    // 更新 GPU 指标
    for _, metrics := range gpuMetrics {
        if metrics.Index < len(node.GPUs) {
            node.GPUs[metrics.Index].MemoryUsed = metrics.MemoryUsed
            node.GPUs[metrics.Index].Utilization = metrics.Utilization
        }
    }

    return nil
}

type GPUMetrics struct {
    Index       int     `json:"index"`
    MemoryUsed  float64 `json:"memory_used_gb"`
    Utilization float64 `json:"utilization"`
}

// AllocateResources 分配资源
func (rm *ResourceManager) AllocateResources(
    ctx context.Context,
    model *ModelInfo,
    instanceID string,
) (*ResourceAllocation, error) {

    rm.mu.Lock()
    defer rm.mu.Unlock()

    // 找到合适的节点
    node, gpuIndices, err := rm.scheduler.Schedule(rm.nodes, model.Requirements)
    if err != nil {
        return nil, err
    }

    // 标记 GPU 为已分配
    for _, idx := range gpuIndices {
        node.GPUs[idx].Allocated = true
        node.GPUs[idx].AllocatedTo = instanceID
    }

    allocation := &ResourceAllocation{
        ID:         fmt.Sprintf("alloc-%s-%d", instanceID, time.Now().UnixNano()),
        ModelName:  model.Name,
        InstanceID: instanceID,
        NodeID:     node.ID,
        GPUs:       gpuIndices,
        CPUCores:   model.Requirements.CPUCores,
        MemoryGB:   model.Requirements.MemoryGB,
        CreateTime: time.Now(),
    }

    rm.allocations[allocation.ID] = allocation
    return allocation, nil
}

// ReleaseResources 释放资源
func (rm *ResourceManager) ReleaseResources(allocationID string) error {
    rm.mu.Lock()
    defer rm.mu.Unlock()

    allocation, exists := rm.allocations[allocationID]
    if !exists {
        return errors.New("allocation not found")
    }

    node := rm.nodes[allocation.NodeID]
    if node != nil {
        for _, idx := range allocation.GPUs {
            if idx < len(node.GPUs) {
                node.GPUs[idx].Allocated = false
                node.GPUs[idx].AllocatedTo = ""
            }
        }
    }

    delete(rm.allocations, allocationID)
    return nil
}

// GetNodeStatus 获取节点状态
func (rm *ResourceManager) GetNodeStatus() []*NodeInfo {
    rm.mu.RLock()
    defer rm.mu.RUnlock()

    result := make([]*NodeInfo, 0, len(rm.nodes))
    for _, node := range rm.nodes {
        result = append(result, node)
    }
    return result
}

// GetAvailableGPUs 获取可用 GPU 数量
func (rm *ResourceManager) GetAvailableGPUs() int {
    rm.mu.RLock()
    defer rm.mu.RUnlock()

    count := 0
    for _, node := range rm.nodes {
        if node.Status != NodeStatusReady {
            continue
        }
        for _, gpu := range node.GPUs {
            if !gpu.Allocated {
                count++
            }
        }
    }
    return count
}

// ModelScheduler 模型调度器
type ModelScheduler struct {
    strategy SchedulingStrategy
}

type SchedulingStrategy int

const (
    StrategyBinPacking SchedulingStrategy = iota
    StrategySpread
    StrategyGPUAffinity
)

func NewModelScheduler() *ModelScheduler {
    return &ModelScheduler{
        strategy: StrategyBinPacking,
    }
}

// Schedule 调度模型到节点
func (s *ModelScheduler) Schedule(
    nodes map[string]*NodeInfo,
    requirements *ResourceRequirements,
) (*NodeInfo, []int, error) {

    candidates := s.findCandidateNodes(nodes, requirements)
    if len(candidates) == 0 {
        return nil, nil, errors.New("no suitable node found")
    }

    switch s.strategy {
    case StrategyBinPacking:
        return s.binPackingSchedule(candidates, requirements)
    case StrategySpread:
        return s.spreadSchedule(candidates, requirements)
    case StrategyGPUAffinity:
        return s.gpuAffinitySchedule(candidates, requirements)
    default:
        return s.binPackingSchedule(candidates, requirements)
    }
}

func (s *ModelScheduler) findCandidateNodes(
    nodes map[string]*NodeInfo,
    req *ResourceRequirements,
) []*NodeInfo {

    candidates := make([]*NodeInfo, 0)

    for _, node := range nodes {
        if node.Status != NodeStatusReady {
            continue
        }

        // 检查 GPU 数量
        availableGPUs := s.countAvailableGPUs(node, req.GPUType)
        if availableGPUs < req.GPUCount {
            continue
        }

        // 检查连续 GPU(对于张量并行)
        if req.GPUCount > 1 && !s.hasContiguousGPUs(node, req.GPUCount) {
            continue
        }

        // 检查 CPU 和内存
        if node.CPUCores < req.CPUCores || node.MemoryGB < req.MemoryGB {
            continue
        }

        candidates = append(candidates, node)
    }

    return candidates
}

func (s *ModelScheduler) countAvailableGPUs(node *NodeInfo, gpuType string) int {
    count := 0
    for _, gpu := range node.GPUs {
        if gpu.Allocated {
            continue
        }
        if gpuType != "" && gpu.Name != gpuType {
            continue
        }
        count++
    }
    return count
}

func (s *ModelScheduler) hasContiguousGPUs(node *NodeInfo, count int) bool {
    consecutive := 0
    for _, gpu := range node.GPUs {
        if !gpu.Allocated {
            consecutive++
            if consecutive >= count {
                return true
            }
        } else {
            consecutive = 0
        }
    }
    return false
}

// binPackingSchedule 装箱调度(优先填满节点)
func (s *ModelScheduler) binPackingSchedule(
    candidates []*NodeInfo,
    req *ResourceRequirements,
) (*NodeInfo, []int, error) {

    // 按可用 GPU 数量升序排序(优先选择 GPU 少的节点)
    sort.Slice(candidates, func(i, j int) bool {
        availI := s.countAvailableGPUs(candidates[i], req.GPUType)
        availJ := s.countAvailableGPUs(candidates[j], req.GPUType)
        return availI < availJ
    })

    for _, node := range candidates {
        gpuIndices := s.selectGPUs(node, req.GPUCount, req.GPUType)
        if len(gpuIndices) == req.GPUCount {
            return node, gpuIndices, nil
        }
    }

    return nil, nil, errors.New("scheduling failed")
}

// spreadSchedule 分散调度(均匀分布)
func (s *ModelScheduler) spreadSchedule(
    candidates []*NodeInfo,
    req *ResourceRequirements,
) (*NodeInfo, []int, error) {

    // 按可用 GPU 数量降序排序(优先选择 GPU 多的节点)
    sort.Slice(candidates, func(i, j int) bool {
        availI := s.countAvailableGPUs(candidates[i], req.GPUType)
        availJ := s.countAvailableGPUs(candidates[j], req.GPUType)
        return availI > availJ
    })

    for _, node := range candidates {
        gpuIndices := s.selectGPUs(node, req.GPUCount, req.GPUType)
        if len(gpuIndices) == req.GPUCount {
            return node, gpuIndices, nil
        }
    }

    return nil, nil, errors.New("scheduling failed")
}

// gpuAffinitySchedule GPU 亲和性调度
func (s *ModelScheduler) gpuAffinitySchedule(
    candidates []*NodeInfo,
    req *ResourceRequirements,
) (*NodeInfo, []int, error) {

    // 优先选择相邻 GPU(NVLink 连接)
    for _, node := range candidates {
        gpuIndices := s.selectContiguousGPUs(node, req.GPUCount, req.GPUType)
        if len(gpuIndices) == req.GPUCount {
            return node, gpuIndices, nil
        }
    }

    // 退化为普通调度
    return s.binPackingSchedule(candidates, req)
}

func (s *ModelScheduler) selectGPUs(node *NodeInfo, count int, gpuType string) []int {
    indices := make([]int, 0, count)

    for i, gpu := range node.GPUs {
        if gpu.Allocated {
            continue
        }
        if gpuType != "" && gpu.Name != gpuType {
            continue
        }
        indices = append(indices, i)
        if len(indices) == count {
            break
        }
    }

    return indices
}

func (s *ModelScheduler) selectContiguousGPUs(node *NodeInfo, count int, gpuType string) []int {
    start := -1
    consecutive := 0

    for i, gpu := range node.GPUs {
        if gpu.Allocated || (gpuType != "" && gpu.Name != gpuType) {
            start = -1
            consecutive = 0
            continue
        }

        if start == -1 {
            start = i
        }
        consecutive++

        if consecutive == count {
            indices := make([]int, count)
            for j := 0; j < count; j++ {
                indices[j] = start + j
            }
            return indices
        }
    }

    return nil
}

自动扩缩容

package multimodel

import (
    "context"
    "math"
    "sync"
    "time"
)

// AutoScaler 自动扩缩容器
type AutoScaler struct {
    registry       *ModelRegistry
    resourceMgr    *ResourceManager
    metrics        *MetricsCollector
    scalingPolicies map[string]*ScalingPolicy
    mu             sync.RWMutex
    stopCh         chan struct{}
}

// ScalingPolicy 扩缩容策略
type ScalingPolicy struct {
    ModelName        string            `json:"model_name"`
    MinReplicas      int               `json:"min_replicas"`
    MaxReplicas      int               `json:"max_replicas"`
    TargetMetrics    []TargetMetric    `json:"target_metrics"`
    ScaleUpCooldown  time.Duration     `json:"scale_up_cooldown"`
    ScaleDownCooldown time.Duration    `json:"scale_down_cooldown"`
    ScaleUpStep      int               `json:"scale_up_step"`
    ScaleDownStep    int               `json:"scale_down_step"`

    // 内部状态
    lastScaleUp      time.Time
    lastScaleDown    time.Time
    currentReplicas  int
}

type TargetMetric struct {
    Name       string  `json:"name"`   // qps, latency, gpu_utilization, queue_length
    Target     float64 `json:"target"` // 目标值
    Tolerance  float64 `json:"tolerance"` // 容忍度
}

// MetricsCollector 指标收集器
type MetricsCollector struct {
    registry *ModelRegistry
}

type ModelMetrics struct {
    ModelName       string
    QPS             float64
    AverageLatency  time.Duration
    P99Latency      time.Duration
    GPUUtilization  float64
    QueueLength     int
    ActiveRequests  int
    InstanceCount   int
}

func NewAutoScaler(
    registry *ModelRegistry,
    resourceMgr *ResourceManager,
) *AutoScaler {
    return &AutoScaler{
        registry:        registry,
        resourceMgr:     resourceMgr,
        metrics:         &MetricsCollector{registry: registry},
        scalingPolicies: make(map[string]*ScalingPolicy),
        stopCh:          make(chan struct{}),
    }
}

// SetPolicy 设置扩缩容策略
func (as *AutoScaler) SetPolicy(policy *ScalingPolicy) {
    as.mu.Lock()
    defer as.mu.Unlock()
    as.scalingPolicies[policy.ModelName] = policy
}

// Start 启动自动扩缩容
func (as *AutoScaler) Start(ctx context.Context) {
    ticker := time.NewTicker(15 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-as.stopCh:
            return
        case <-ticker.C:
            as.evaluate()
        }
    }
}

// Stop 停止
func (as *AutoScaler) Stop() {
    close(as.stopCh)
}

// evaluate 评估扩缩容
func (as *AutoScaler) evaluate() {
    as.mu.RLock()
    policies := make([]*ScalingPolicy, 0, len(as.scalingPolicies))
    for _, p := range as.scalingPolicies {
        policies = append(policies, p)
    }
    as.mu.RUnlock()

    for _, policy := range policies {
        as.evaluateModel(policy)
    }
}

func (as *AutoScaler) evaluateModel(policy *ScalingPolicy) {
    metrics := as.metrics.Collect(policy.ModelName)
    if metrics == nil {
        return
    }

    policy.currentReplicas = metrics.InstanceCount

    // 计算期望副本数
    desiredReplicas := as.calculateDesiredReplicas(policy, metrics)

    // 应用边界
    desiredReplicas = max(desiredReplicas, policy.MinReplicas)
    desiredReplicas = min(desiredReplicas, policy.MaxReplicas)

    // 检查是否需要扩缩容
    if desiredReplicas > policy.currentReplicas {
        as.scaleUp(policy, desiredReplicas)
    } else if desiredReplicas < policy.currentReplicas {
        as.scaleDown(policy, desiredReplicas)
    }
}

func (as *AutoScaler) calculateDesiredReplicas(
    policy *ScalingPolicy,
    metrics *ModelMetrics,
) int {

    maxRatio := 1.0

    for _, target := range policy.TargetMetrics {
        var currentValue float64

        switch target.Name {
        case "qps":
            currentValue = metrics.QPS
        case "latency":
            currentValue = float64(metrics.AverageLatency.Milliseconds())
        case "gpu_utilization":
            currentValue = metrics.GPUUtilization * 100
        case "queue_length":
            currentValue = float64(metrics.QueueLength)
        default:
            continue
        }

        if target.Target > 0 {
            ratio := currentValue / target.Target
            if ratio > maxRatio {
                maxRatio = ratio
            }
        }
    }

    // 计算期望副本数
    desired := int(math.Ceil(float64(policy.currentReplicas) * maxRatio))
    return desired
}

func (as *AutoScaler) scaleUp(policy *ScalingPolicy, desired int) {
    // 检查冷却时间
    if time.Since(policy.lastScaleUp) < policy.ScaleUpCooldown {
        return
    }

    // 计算实际扩容数量
    scaleCount := min(desired-policy.currentReplicas, policy.ScaleUpStep)
    if scaleCount <= 0 {
        return
    }

    // 执行扩容
    for i := 0; i < scaleCount; i++ {
        err := as.createInstance(policy.ModelName)
        if err != nil {
            break
        }
    }

    policy.lastScaleUp = time.Now()
}

func (as *AutoScaler) scaleDown(policy *ScalingPolicy, desired int) {
    // 检查冷却时间
    if time.Since(policy.lastScaleDown) < policy.ScaleDownCooldown {
        return
    }

    // 检查是否低于最小副本数
    if policy.currentReplicas <= policy.MinReplicas {
        return
    }

    // 计算实际缩容数量
    scaleCount := min(policy.currentReplicas-desired, policy.ScaleDownStep)
    scaleCount = min(scaleCount, policy.currentReplicas-policy.MinReplicas)
    if scaleCount <= 0 {
        return
    }

    // 执行缩容(优先删除负载最低的实例)
    instances := as.registry.GetInstances(policy.ModelName)
    sort.Slice(instances, func(i, j int) bool {
        loadI := 0
        loadJ := 0
        if instances[i].Metrics != nil {
            loadI = instances[i].Metrics.RequestsActive
        }
        if instances[j].Metrics != nil {
            loadJ = instances[j].Metrics.RequestsActive
        }
        return loadI < loadJ
    })

    for i := 0; i < scaleCount && i < len(instances); i++ {
        as.deleteInstance(instances[i])
    }

    policy.lastScaleDown = time.Now()
}

func (as *AutoScaler) createInstance(modelName string) error {
    model := as.registry.GetModel(modelName)
    if model == nil {
        return errors.New("model not found")
    }

    instanceID := fmt.Sprintf("%s-%d", modelName, time.Now().UnixNano())

    // 分配资源
    allocation, err := as.resourceMgr.AllocateResources(
        context.Background(), model, instanceID,
    )
    if err != nil {
        return err
    }

    // 启动实例(这里简化处理,实际需要调用 Kubernetes API 或容器运行时)
    instance := &ModelInstance{
        ID:        instanceID,
        ModelName: modelName,
        NodeID:    allocation.NodeID,
        GPUs:      allocation.GPUs,
        Status:    InstanceStatusStarting,
    }

    return as.registry.RegisterInstance(context.Background(), instance)
}

func (as *AutoScaler) deleteInstance(instance *ModelInstance) error {
    // 1. 设置为 draining 状态
    instance.Status = InstanceStatusDraining

    // 2. 等待请求处理完成
    // 实际实现需要等待或超时

    // 3. 释放资源
    for allocID, alloc := range as.resourceMgr.allocations {
        if alloc.InstanceID == instance.ID {
            as.resourceMgr.ReleaseResources(allocID)
            break
        }
    }

    // 4. 从注册中心移除
    // ...

    return nil
}

// Collect 收集模型指标
func (mc *MetricsCollector) Collect(modelName string) *ModelMetrics {
    instances := mc.registry.GetInstances(modelName)
    if len(instances) == 0 {
        return nil
    }

    metrics := &ModelMetrics{
        ModelName:     modelName,
        InstanceCount: len(instances),
    }

    var totalQPS float64
    var totalLatency time.Duration
    var totalUtilization float64
    var totalQueue int
    var totalActive int

    validCount := 0
    for _, inst := range instances {
        if inst.Metrics == nil {
            continue
        }
        totalQPS += float64(inst.Metrics.RequestsTotal) / 60.0 // 假设 1 分钟窗口
        totalLatency += time.Duration(inst.Metrics.AverageLatencyMs) * time.Millisecond
        totalUtilization += inst.Metrics.GPUUtilization
        totalActive += inst.Metrics.RequestsActive
        validCount++
    }

    if validCount > 0 {
        metrics.QPS = totalQPS
        metrics.AverageLatency = totalLatency / time.Duration(validCount)
        metrics.GPUUtilization = totalUtilization / float64(validCount)
        metrics.ActiveRequests = totalActive
    }

    return metrics
}

模型热切换

无缝模型切换

package multimodel

import (
    "context"
    "sync"
    "time"
)

// ModelSwitcher 模型切换器
type ModelSwitcher struct {
    registry     *ModelRegistry
    resourceMgr  *ResourceManager
    router       *RequestRouter
    mu           sync.RWMutex
}

// SwitchRequest 切换请求
type SwitchRequest struct {
    OldModel     string        `json:"old_model"`
    NewModel     string        `json:"new_model"`
    Strategy     SwitchStrategy `json:"strategy"`
    TrafficSplit float64       `json:"traffic_split"` // 新模型流量比例
    Timeout      time.Duration `json:"timeout"`
}

type SwitchStrategy string

const (
    SwitchStrategyBlueGreen  SwitchStrategy = "blue_green"
    SwitchStrategyCanary     SwitchStrategy = "canary"
    SwitchStrategyRolling    SwitchStrategy = "rolling"
)

// SwitchProgress 切换进度
type SwitchProgress struct {
    RequestID    string        `json:"request_id"`
    Status       SwitchStatus  `json:"status"`
    OldModel     string        `json:"old_model"`
    NewModel     string        `json:"new_model"`
    TrafficSplit float64       `json:"traffic_split"`
    StartTime    time.Time     `json:"start_time"`
    EndTime      *time.Time    `json:"end_time"`
    Message      string        `json:"message"`
}

type SwitchStatus string

const (
    SwitchStatusPending   SwitchStatus = "pending"
    SwitchStatusRunning   SwitchStatus = "running"
    SwitchStatusCompleted SwitchStatus = "completed"
    SwitchStatusFailed    SwitchStatus = "failed"
    SwitchStatusRolledBack SwitchStatus = "rolled_back"
)

func NewModelSwitcher(
    registry *ModelRegistry,
    resourceMgr *ResourceManager,
    router *RequestRouter,
) *ModelSwitcher {
    return &ModelSwitcher{
        registry:    registry,
        resourceMgr: resourceMgr,
        router:      router,
    }
}

// Switch 执行模型切换
func (ms *ModelSwitcher) Switch(ctx context.Context, req *SwitchRequest) (*SwitchProgress, error) {
    progress := &SwitchProgress{
        RequestID:    fmt.Sprintf("switch-%d", time.Now().UnixNano()),
        Status:       SwitchStatusPending,
        OldModel:     req.OldModel,
        NewModel:     req.NewModel,
        TrafficSplit: 0,
        StartTime:    time.Now(),
    }

    switch req.Strategy {
    case SwitchStrategyBlueGreen:
        return ms.blueGreenSwitch(ctx, req, progress)
    case SwitchStrategyCanary:
        return ms.canarySwitch(ctx, req, progress)
    case SwitchStrategyRolling:
        return ms.rollingSwitch(ctx, req, progress)
    default:
        return nil, errors.New("unknown switch strategy")
    }
}

// blueGreenSwitch 蓝绿切换
func (ms *ModelSwitcher) blueGreenSwitch(
    ctx context.Context,
    req *SwitchRequest,
    progress *SwitchProgress,
) (*SwitchProgress, error) {

    progress.Status = SwitchStatusRunning

    // 1. 确保新模型已就绪
    newModel := ms.registry.GetModel(req.NewModel)
    if newModel == nil || newModel.Status != ModelStatusReady {
        progress.Status = SwitchStatusFailed
        progress.Message = "new model not ready"
        return progress, errors.New("new model not ready")
    }

    newInstances := ms.registry.GetHealthyInstances(req.NewModel)
    if len(newInstances) == 0 {
        progress.Status = SwitchStatusFailed
        progress.Message = "no healthy instances for new model"
        return progress, errors.New("no healthy instances")
    }

    // 2. 更新路由规则,将流量切换到新模型
    ms.router.mu.Lock()
    ms.updateRoutingForSwitch(req.OldModel, req.NewModel, 1.0)
    ms.router.mu.Unlock()

    progress.TrafficSplit = 1.0

    // 3. 等待旧模型请求处理完成
    err := ms.drainOldModel(ctx, req.OldModel, req.Timeout)
    if err != nil {
        // 回滚
        ms.router.mu.Lock()
        ms.updateRoutingForSwitch(req.NewModel, req.OldModel, 1.0)
        ms.router.mu.Unlock()
        progress.Status = SwitchStatusRolledBack
        progress.Message = err.Error()
        return progress, err
    }

    // 4. 可选:释放旧模型资源
    // ms.releaseModelResources(req.OldModel)

    progress.Status = SwitchStatusCompleted
    now := time.Now()
    progress.EndTime = &now
    return progress, nil
}

// canarySwitch 金丝雀切换
func (ms *ModelSwitcher) canarySwitch(
    ctx context.Context,
    req *SwitchRequest,
    progress *SwitchProgress,
) (*SwitchProgress, error) {

    progress.Status = SwitchStatusRunning

    // 渐进式增加流量
    trafficSteps := []float64{0.1, 0.25, 0.5, 0.75, 1.0}
    if req.TrafficSplit > 0 {
        // 使用自定义步长
        trafficSteps = []float64{req.TrafficSplit}
    }

    for _, split := range trafficSteps {
        // 检查上下文是否取消
        select {
        case <-ctx.Done():
            progress.Status = SwitchStatusFailed
            return progress, ctx.Err()
        default:
        }

        // 更新流量分配
        ms.router.mu.Lock()
        ms.updateRoutingForSwitch(req.OldModel, req.NewModel, split)
        ms.router.mu.Unlock()
        progress.TrafficSplit = split

        // 监控新模型表现
        if split < 1.0 {
            healthy, err := ms.monitorNewModel(ctx, req.NewModel, 30*time.Second)
            if !healthy || err != nil {
                // 回滚
                ms.router.mu.Lock()
                ms.updateRoutingForSwitch(req.NewModel, req.OldModel, 1.0)
                ms.router.mu.Unlock()
                progress.Status = SwitchStatusRolledBack
                progress.Message = "new model unhealthy during canary"
                return progress, errors.New("canary failed")
            }
        }
    }

    progress.Status = SwitchStatusCompleted
    now := time.Now()
    progress.EndTime = &now
    return progress, nil
}

// rollingSwitch 滚动切换
func (ms *ModelSwitcher) rollingSwitch(
    ctx context.Context,
    req *SwitchRequest,
    progress *SwitchProgress,
) (*SwitchProgress, error) {

    progress.Status = SwitchStatusRunning

    oldInstances := ms.registry.GetInstances(req.OldModel)
    if len(oldInstances) == 0 {
        progress.Status = SwitchStatusFailed
        return progress, errors.New("no old instances to replace")
    }

    // 逐个替换实例
    for i, oldInst := range oldInstances {
        select {
        case <-ctx.Done():
            progress.Status = SwitchStatusFailed
            return progress, ctx.Err()
        default:
        }

        // 1. 启动新实例
        newInstID := fmt.Sprintf("%s-instance-%d", req.NewModel, i)
        err := ms.createModelInstance(ctx, req.NewModel, newInstID)
        if err != nil {
            continue
        }

        // 2. 等待新实例就绪
        err = ms.waitInstanceReady(ctx, newInstID, 5*time.Minute)
        if err != nil {
            ms.deleteModelInstance(newInstID)
            continue
        }

        // 3. 将旧实例设为 draining
        oldInst.Status = InstanceStatusDraining

        // 4. 等待旧实例请求完成
        ms.waitInstanceDrained(ctx, oldInst.ID, 2*time.Minute)

        // 5. 删除旧实例
        ms.deleteModelInstance(oldInst.ID)

        progress.TrafficSplit = float64(i+1) / float64(len(oldInstances))
    }

    progress.Status = SwitchStatusCompleted
    now := time.Now()
    progress.EndTime = &now
    return progress, nil
}

func (ms *ModelSwitcher) updateRoutingForSwitch(oldModel, newModel string, newModelWeight float64) {
    // 添加流量分配规则
    rule := &RoutingRule{
        Name:     fmt.Sprintf("switch-%s-to-%s", oldModel, newModel),
        Priority: 100, // 高优先级
        Conditions: []RoutingCondition{
            {Field: "model", Operator: "eq", Value: oldModel},
        },
        Target: newModel,
        Weight: int(newModelWeight * 100),
    }

    // 移除旧规则
    for i, r := range ms.router.routingRules {
        if r.Name == rule.Name {
            ms.router.routingRules = append(
                ms.router.routingRules[:i],
                ms.router.routingRules[i+1:]...,
            )
            break
        }
    }

    if newModelWeight > 0 {
        ms.router.routingRules = append(ms.router.routingRules, rule)
    }
}

func (ms *ModelSwitcher) drainOldModel(ctx context.Context, modelName string, timeout time.Duration) error {
    instances := ms.registry.GetInstances(modelName)

    deadline := time.Now().Add(timeout)
    for _, inst := range instances {
        inst.Status = InstanceStatusDraining
    }

    // 等待所有活跃请求完成
    for {
        if time.Now().After(deadline) {
            return errors.New("drain timeout")
        }

        allDrained := true
        for _, inst := range instances {
            if inst.Metrics != nil && inst.Metrics.RequestsActive > 0 {
                allDrained = false
                break
            }
        }

        if allDrained {
            return nil
        }

        select {
        case <-ctx.Done():
            return ctx.Err()
        case <-time.After(time.Second):
            continue
        }
    }
}

func (ms *ModelSwitcher) monitorNewModel(ctx context.Context, modelName string, duration time.Duration) (bool, error) {
    ticker := time.NewTicker(5 * time.Second)
    defer ticker.Stop()

    deadline := time.Now().Add(duration)
    errorCount := 0

    for time.Now().Before(deadline) {
        select {
        case <-ctx.Done():
            return false, ctx.Err()
        case <-ticker.C:
            instances := ms.registry.GetHealthyInstances(modelName)
            if len(instances) == 0 {
                errorCount++
                if errorCount >= 3 {
                    return false, errors.New("no healthy instances")
                }
            } else {
                errorCount = 0
            }
        }
    }

    return true, nil
}

func (ms *ModelSwitcher) createModelInstance(ctx context.Context, modelName, instanceID string) error {
    // 实际实现
    return nil
}

func (ms *ModelSwitcher) deleteModelInstance(instanceID string) error {
    // 实际实现
    return nil
}

func (ms *ModelSwitcher) waitInstanceReady(ctx context.Context, instanceID string, timeout time.Duration) error {
    // 实际实现
    return nil
}

func (ms *ModelSwitcher) waitInstanceDrained(ctx context.Context, instanceID string, timeout time.Duration) {
    // 实际实现
}

Kubernetes 集成

多模型 CRD

# CRD 定义
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
  name: inferencemodels.ai.example.com
spec:
  group: ai.example.com
  versions:
    - name: v1
      served: true
      storage: true
      schema:
        openAPIV3Schema:
          type: object
          properties:
            spec:
              type: object
              required:
                - modelName
                - framework
              properties:
                modelName:
                  type: string
                version:
                  type: string
                framework:
                  type: string
                  enum: [vllm, tgi, triton, tensorrt-llm]
                modelPath:
                  type: string
                replicas:
                  type: integer
                  minimum: 0
                  default: 1
                resources:
                  type: object
                  properties:
                    gpuType:
                      type: string
                    gpuCount:
                      type: integer
                    gpuMemory:
                      type: string
                config:
                  type: object
                  properties:
                    maxSeqLen:
                      type: integer
                    maxBatchSize:
                      type: integer
                    tensorParallel:
                      type: integer
                    quantization:
                      type: string
                    dtype:
                      type: string
                autoscaling:
                  type: object
                  properties:
                    enabled:
                      type: boolean
                    minReplicas:
                      type: integer
                    maxReplicas:
                      type: integer
                    metrics:
                      type: array
                      items:
                        type: object
                        properties:
                          name:
                            type: string
                          target:
                            type: number
            status:
              type: object
              properties:
                phase:
                  type: string
                replicas:
                  type: integer
                readyReplicas:
                  type: integer
                conditions:
                  type: array
                  items:
                    type: object
                    properties:
                      type:
                        type: string
                      status:
                        type: string
                      message:
                        type: string
                      lastUpdateTime:
                        type: string
      subresources:
        status: {}
        scale:
          specReplicasPath: .spec.replicas
          statusReplicasPath: .status.replicas
  scope: Namespaced
  names:
    plural: inferencemodels
    singular: inferencemodel
    kind: InferenceModel
    shortNames:
      - im

---
# 示例资源
apiVersion: ai.example.com/v1
kind: InferenceModel
metadata:
  name: llama-70b
  namespace: ai-inference
spec:
  modelName: meta-llama/Llama-2-70b-chat-hf
  version: "1.0"
  framework: vllm
  modelPath: s3://models/llama-70b
  replicas: 2
  resources:
    gpuType: nvidia.com/a100-80g
    gpuCount: 4
    gpuMemory: "320Gi"
  config:
    maxSeqLen: 4096
    maxBatchSize: 64
    tensorParallel: 4
    quantization: "none"
    dtype: "fp16"
  autoscaling:
    enabled: true
    minReplicas: 2
    maxReplicas: 8
    metrics:
      - name: qps
        target: 100
      - name: latency
        target: 500

---
# 多模型服务编排
apiVersion: ai.example.com/v1
kind: ModelOrchestrator
metadata:
  name: multi-model-service
  namespace: ai-inference
spec:
  models:
    - name: llama-70b
      ref: llama-70b
      weight: 50
      capabilities:
        - chat
        - completion
    - name: llama-7b
      ref: llama-7b
      weight: 30
      capabilities:
        - chat
    - name: codellama
      ref: codellama-34b
      weight: 20
      capabilities:
        - code

  routing:
    defaultModel: llama-7b
    rules:
      - name: premium-users
        priority: 100
        conditions:
          - field: user_tier
            operator: eq
            value: premium
        target: llama-70b
      - name: code-requests
        priority: 90
        conditions:
          - field: capability
            operator: contains
            value: code
        target: codellama
      - name: long-context
        priority: 80
        conditions:
          - field: prompt_length
            operator: gt
            value: 2000
        target: llama-70b

    fallbackChain:
      - llama-70b
      - llama-7b

  loadBalancing:
    strategy: least_connections
    healthCheck:
      path: /health
      interval: 10s
      timeout: 5s
      failureThreshold: 3

  gateway:
    replicas: 3
    resources:
      cpu: "2"
      memory: "4Gi"

Operator 控制器

package controller

import (
    "context"
    "fmt"
    "time"

    corev1 "k8s.io/api/core/v1"
    "k8s.io/apimachinery/pkg/api/errors"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    "k8s.io/apimachinery/pkg/runtime"
    ctrl "sigs.k8s.io/controller-runtime"
    "sigs.k8s.io/controller-runtime/pkg/client"
    "sigs.k8s.io/controller-runtime/pkg/log"
)

// InferenceModelReconciler 推理模型控制器
type InferenceModelReconciler struct {
    client.Client
    Scheme *runtime.Scheme
}

// InferenceModel CRD 结构
type InferenceModel struct {
    metav1.TypeMeta   `json:",inline"`
    metav1.ObjectMeta `json:"metadata,omitempty"`
    Spec              InferenceModelSpec   `json:"spec,omitempty"`
    Status            InferenceModelStatus `json:"status,omitempty"`
}

type InferenceModelSpec struct {
    ModelName    string            `json:"modelName"`
    Version      string            `json:"version,omitempty"`
    Framework    string            `json:"framework"`
    ModelPath    string            `json:"modelPath"`
    Replicas     int32             `json:"replicas"`
    Resources    ResourceSpec      `json:"resources"`
    Config       ModelConfig       `json:"config,omitempty"`
    Autoscaling  *AutoscalingSpec  `json:"autoscaling,omitempty"`
}

type ResourceSpec struct {
    GPUType   string `json:"gpuType"`
    GPUCount  int    `json:"gpuCount"`
    GPUMemory string `json:"gpuMemory"`
}

type AutoscalingSpec struct {
    Enabled     bool     `json:"enabled"`
    MinReplicas int32    `json:"minReplicas"`
    MaxReplicas int32    `json:"maxReplicas"`
    Metrics     []Metric `json:"metrics"`
}

type Metric struct {
    Name   string  `json:"name"`
    Target float64 `json:"target"`
}

type InferenceModelStatus struct {
    Phase         string      `json:"phase"`
    Replicas      int32       `json:"replicas"`
    ReadyReplicas int32       `json:"readyReplicas"`
    Conditions    []Condition `json:"conditions,omitempty"`
}

type Condition struct {
    Type           string `json:"type"`
    Status         string `json:"status"`
    Message        string `json:"message"`
    LastUpdateTime string `json:"lastUpdateTime"`
}

// Reconcile 协调函数
func (r *InferenceModelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
    logger := log.FromContext(ctx)

    // 获取 InferenceModel 资源
    var model InferenceModel
    if err := r.Get(ctx, req.NamespacedName, &model); err != nil {
        if errors.IsNotFound(err) {
            return ctrl.Result{}, nil
        }
        return ctrl.Result{}, err
    }

    logger.Info("Reconciling InferenceModel", "name", model.Name)

    // 确保 Deployment 存在
    if err := r.ensureDeployment(ctx, &model); err != nil {
        logger.Error(err, "Failed to ensure Deployment")
        return ctrl.Result{RequeueAfter: time.Minute}, err
    }

    // 确保 Service 存在
    if err := r.ensureService(ctx, &model); err != nil {
        logger.Error(err, "Failed to ensure Service")
        return ctrl.Result{RequeueAfter: time.Minute}, err
    }

    // 确保 HPA 存在(如果启用了自动扩缩容)
    if model.Spec.Autoscaling != nil && model.Spec.Autoscaling.Enabled {
        if err := r.ensureHPA(ctx, &model); err != nil {
            logger.Error(err, "Failed to ensure HPA")
            return ctrl.Result{RequeueAfter: time.Minute}, err
        }
    }

    // 更新状态
    if err := r.updateStatus(ctx, &model); err != nil {
        logger.Error(err, "Failed to update status")
        return ctrl.Result{RequeueAfter: time.Minute}, err
    }

    return ctrl.Result{RequeueAfter: 30 * time.Second}, nil
}

func (r *InferenceModelReconciler) ensureDeployment(ctx context.Context, model *InferenceModel) error {
    deployment := r.buildDeployment(model)

    var existing appsv1.Deployment
    err := r.Get(ctx, client.ObjectKeyFromObject(deployment), &existing)
    if errors.IsNotFound(err) {
        return r.Create(ctx, deployment)
    }
    if err != nil {
        return err
    }

    // 更新 Deployment
    existing.Spec = deployment.Spec
    return r.Update(ctx, &existing)
}

func (r *InferenceModelReconciler) buildDeployment(model *InferenceModel) *appsv1.Deployment {
    labels := map[string]string{
        "app":   model.Name,
        "model": model.Spec.ModelName,
    }

    // 构建容器
    container := corev1.Container{
        Name:  "inference",
        Image: r.getFrameworkImage(model.Spec.Framework),
        Ports: []corev1.ContainerPort{
            {ContainerPort: 8000, Name: "http"},
        },
        Env: r.buildEnvVars(model),
        Resources: corev1.ResourceRequirements{
            Limits: corev1.ResourceList{
                "nvidia.com/gpu": resource.MustParse(fmt.Sprintf("%d", model.Spec.Resources.GPUCount)),
            },
        },
        VolumeMounts: []corev1.VolumeMount{
            {Name: "model-cache", MountPath: "/models"},
            {Name: "shm", MountPath: "/dev/shm"},
        },
        ReadinessProbe: &corev1.Probe{
            ProbeHandler: corev1.ProbeHandler{
                HTTPGet: &corev1.HTTPGetAction{
                    Path: "/health",
                    Port: intstr.FromInt(8000),
                },
            },
            InitialDelaySeconds: 30,
            PeriodSeconds:       10,
        },
    }

    return &appsv1.Deployment{
        ObjectMeta: metav1.ObjectMeta{
            Name:      model.Name,
            Namespace: model.Namespace,
            Labels:    labels,
        },
        Spec: appsv1.DeploymentSpec{
            Replicas: &model.Spec.Replicas,
            Selector: &metav1.LabelSelector{
                MatchLabels: labels,
            },
            Template: corev1.PodTemplateSpec{
                ObjectMeta: metav1.ObjectMeta{
                    Labels: labels,
                },
                Spec: corev1.PodSpec{
                    Containers: []corev1.Container{container},
                    Volumes: []corev1.Volume{
                        {
                            Name: "model-cache",
                            VolumeSource: corev1.VolumeSource{
                                PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
                                    ClaimName: model.Name + "-models",
                                },
                            },
                        },
                        {
                            Name: "shm",
                            VolumeSource: corev1.VolumeSource{
                                EmptyDir: &corev1.EmptyDirVolumeSource{
                                    Medium:    corev1.StorageMediumMemory,
                                    SizeLimit: resource.NewQuantity(32*1024*1024*1024, resource.BinarySI),
                                },
                            },
                        },
                    },
                    NodeSelector: map[string]string{
                        "nvidia.com/gpu.product": model.Spec.Resources.GPUType,
                    },
                    Tolerations: []corev1.Toleration{
                        {
                            Key:      "nvidia.com/gpu",
                            Operator: corev1.TolerationOpExists,
                            Effect:   corev1.TaintEffectNoSchedule,
                        },
                    },
                },
            },
        },
    }
}

func (r *InferenceModelReconciler) buildEnvVars(model *InferenceModel) []corev1.EnvVar {
    envs := []corev1.EnvVar{
        {Name: "MODEL_NAME", Value: model.Spec.ModelName},
        {Name: "MODEL_PATH", Value: model.Spec.ModelPath},
        {Name: "MAX_SEQ_LEN", Value: fmt.Sprintf("%d", model.Spec.Config.MaxSeqLen)},
        {Name: "MAX_BATCH_SIZE", Value: fmt.Sprintf("%d", model.Spec.Config.MaxBatchSize)},
        {Name: "TENSOR_PARALLEL_SIZE", Value: fmt.Sprintf("%d", model.Spec.Config.TensorParallel)},
    }

    if model.Spec.Config.Quantization != "" {
        envs = append(envs, corev1.EnvVar{
            Name: "QUANTIZATION", Value: model.Spec.Config.Quantization,
        })
    }

    return envs
}

func (r *InferenceModelReconciler) getFrameworkImage(framework string) string {
    images := map[string]string{
        "vllm":          "vllm/vllm-openai:latest",
        "tgi":           "ghcr.io/huggingface/text-generation-inference:latest",
        "triton":        "nvcr.io/nvidia/tritonserver:23.12-trtllm-python-py3",
        "tensorrt-llm":  "nvcr.io/nvidia/tensorrt-llm:24.01",
    }
    return images[framework]
}

func (r *InferenceModelReconciler) ensureService(ctx context.Context, model *InferenceModel) error {
    service := &corev1.Service{
        ObjectMeta: metav1.ObjectMeta{
            Name:      model.Name,
            Namespace: model.Namespace,
        },
        Spec: corev1.ServiceSpec{
            Selector: map[string]string{
                "app": model.Name,
            },
            Ports: []corev1.ServicePort{
                {
                    Name:       "http",
                    Port:       80,
                    TargetPort: intstr.FromInt(8000),
                },
            },
        },
    }

    var existing corev1.Service
    err := r.Get(ctx, client.ObjectKeyFromObject(service), &existing)
    if errors.IsNotFound(err) {
        return r.Create(ctx, service)
    }
    return err
}

func (r *InferenceModelReconciler) ensureHPA(ctx context.Context, model *InferenceModel) error {
    // 构建 HPA
    hpa := &autoscalingv2.HorizontalPodAutoscaler{
        ObjectMeta: metav1.ObjectMeta{
            Name:      model.Name,
            Namespace: model.Namespace,
        },
        Spec: autoscalingv2.HorizontalPodAutoscalerSpec{
            ScaleTargetRef: autoscalingv2.CrossVersionObjectReference{
                APIVersion: "apps/v1",
                Kind:       "Deployment",
                Name:       model.Name,
            },
            MinReplicas: &model.Spec.Autoscaling.MinReplicas,
            MaxReplicas: model.Spec.Autoscaling.MaxReplicas,
            Metrics:     r.buildHPAMetrics(model.Spec.Autoscaling.Metrics),
        },
    }

    var existing autoscalingv2.HorizontalPodAutoscaler
    err := r.Get(ctx, client.ObjectKeyFromObject(hpa), &existing)
    if errors.IsNotFound(err) {
        return r.Create(ctx, hpa)
    }
    if err != nil {
        return err
    }

    existing.Spec = hpa.Spec
    return r.Update(ctx, &existing)
}

func (r *InferenceModelReconciler) buildHPAMetrics(metrics []Metric) []autoscalingv2.MetricSpec {
    result := make([]autoscalingv2.MetricSpec, 0, len(metrics))

    for _, m := range metrics {
        switch m.Name {
        case "qps":
            result = append(result, autoscalingv2.MetricSpec{
                Type: autoscalingv2.PodsMetricSourceType,
                Pods: &autoscalingv2.PodsMetricSource{
                    Metric: autoscalingv2.MetricIdentifier{
                        Name: "requests_per_second",
                    },
                    Target: autoscalingv2.MetricTarget{
                        Type:         autoscalingv2.AverageValueMetricType,
                        AverageValue: resource.NewQuantity(int64(m.Target), resource.DecimalSI),
                    },
                },
            })
        case "gpu_utilization":
            result = append(result, autoscalingv2.MetricSpec{
                Type: autoscalingv2.PodsMetricSourceType,
                Pods: &autoscalingv2.PodsMetricSource{
                    Metric: autoscalingv2.MetricIdentifier{
                        Name: "gpu_utilization_percent",
                    },
                    Target: autoscalingv2.MetricTarget{
                        Type:         autoscalingv2.AverageValueMetricType,
                        AverageValue: resource.NewQuantity(int64(m.Target), resource.DecimalSI),
                    },
                },
            })
        }
    }

    return result
}

func (r *InferenceModelReconciler) updateStatus(ctx context.Context, model *InferenceModel) error {
    // 获取 Deployment 状态
    var deployment appsv1.Deployment
    if err := r.Get(ctx, client.ObjectKey{Name: model.Name, Namespace: model.Namespace}, &deployment); err != nil {
        return err
    }

    model.Status.Replicas = deployment.Status.Replicas
    model.Status.ReadyReplicas = deployment.Status.ReadyReplicas

    if deployment.Status.ReadyReplicas == *deployment.Spec.Replicas {
        model.Status.Phase = "Ready"
    } else if deployment.Status.ReadyReplicas > 0 {
        model.Status.Phase = "PartiallyReady"
    } else {
        model.Status.Phase = "Pending"
    }

    return r.Status().Update(ctx, model)
}

// SetupWithManager 设置控制器
func (r *InferenceModelReconciler) SetupWithManager(mgr ctrl.Manager) error {
    return ctrl.NewControllerManagedBy(mgr).
        For(&InferenceModel{}).
        Owns(&appsv1.Deployment{}).
        Owns(&corev1.Service{}).
        Complete(r)
}

小结

本章深入探讨了多模型服务的核心技术:

  1. 架构设计:模型注册中心、请求路由器、资源管理的整体架构
  2. 智能路由:基于规则、能力、负载的多维度路由策略
  3. 资源管理:GPU 调度、装箱算法、自动扩缩容
  4. 模型切换:蓝绿部署、金丝雀发布、滚动更新策略
  5. Kubernetes 集成:CRD 设计、Operator 控制器实现

多模型服务是大规模 AI 平台的基础设施,通过合理的架构设计和调度策略,可以实现资源的高效利用和服务的稳定可靠。

至此,推理服务章节全部完成。下一章我们将进入 异构计算,探讨 GPU、TPU、FPGA 等不同计算设备的调度与管理。

Prev
推理优化技术