HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

弹性 GPU 调度

概述

在 AI 训练和推理场景中,GPU 资源的需求往往是动态变化的。弹性 GPU 调度允许工作负载根据实际需求动态调整 GPU 资源,包括自动扩缩容、优先级调度、抢占式调度等能力。本章深入讲解弹性 GPU 调度的设计与实现。

1. 弹性调度架构

1.1 整体架构

┌──────────────────────────────────────────────────────────────────────────┐
│                     Elastic GPU Scheduling Architecture                   │
├──────────────────────────────────────────────────────────────────────────┤
│                                                                           │
│  ┌─────────────────────────────────────────────────────────────────────┐ │
│  │                        Workload Layer                               │ │
│  │                                                                     │ │
│  │  ┌───────────────┐ ┌───────────────┐ ┌───────────────┐             │ │
│  │  │  Training     │ │   Inference   │ │   Notebook    │             │ │
│  │  │  Jobs         │ │   Services    │ │   Sessions    │             │ │
│  │  │  (Batch)      │ │  (Real-time)  │ │ (Interactive) │             │ │
│  │  └───────┬───────┘ └───────┬───────┘ └───────┬───────┘             │ │
│  └──────────┼─────────────────┼─────────────────┼───────────────────────┘ │
│             │                 │                 │                         │
│             ▼                 ▼                 ▼                         │
│  ┌─────────────────────────────────────────────────────────────────────┐ │
│  │                    Elastic Scheduler Layer                          │ │
│  │                                                                     │ │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐  ┌────────────┐ │ │
│  │  │  Priority   │  │  Preemption │  │   Quota     │  │   Queue    │ │ │
│  │  │  Manager    │  │   Manager   │  │  Manager    │  │  Manager   │ │ │
│  │  └─────────────┘  └─────────────┘  └─────────────┘  └────────────┘ │ │
│  │                                                                     │ │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐  ┌────────────┐ │ │
│  │  │   Elastic   │  │   Gang      │  │   Fair      │  │  Capacity  │ │ │
│  │  │   Scaling   │  │ Scheduling  │  │   Share     │  │  Planning  │ │ │
│  │  └─────────────┘  └─────────────┘  └─────────────┘  └────────────┘ │ │
│  └─────────────────────────────────────────────────────────────────────┘ │
│                                      │                                    │
│                                      ▼                                    │
│  ┌─────────────────────────────────────────────────────────────────────┐ │
│  │                      Resource Layer                                 │ │
│  │                                                                     │ │
│  │  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────────┐ │ │
│  │  │   GPU Pool      │  │  Node Manager   │  │   Cluster Autoscaler│ │ │
│  │  │                 │  │                 │  │                     │ │ │
│  │  │ ┌───┐ ┌───┐    │  │ Node 1: 8 GPU  │  │ Scale Up/Down       │ │ │
│  │  │ │GPU│ │GPU│... │  │ Node 2: 8 GPU  │  │ based on demand     │ │ │
│  │  │ └───┘ └───┘    │  │ Node 3: 4 GPU  │  │                     │ │ │
│  │  └─────────────────┘  └─────────────────┘  └─────────────────────┘ │ │
│  └─────────────────────────────────────────────────────────────────────┘ │
│                                                                           │
└──────────────────────────────────────────────────────────────────────────┘

1.2 核心概念

// pkg/elastic/types.go
package elastic

import (
    "time"

    v1 "k8s.io/api/core/v1"
)

// PriorityClass GPU 工作负载优先级
type PriorityClass string

const (
    // PriorityProduction 生产级:不可抢占
    PriorityProduction PriorityClass = "production"
    // PriorityHighPriority 高优先级:可抢占低优先级
    PriorityHighPriority PriorityClass = "high-priority"
    // PriorityNormal 普通优先级
    PriorityNormal PriorityClass = "normal"
    // PriorityBestEffort 尽力而为:可被抢占
    PriorityBestEffort PriorityClass = "best-effort"
    // PrioritySpot 竞价实例:随时可被抢占
    PrioritySpot PriorityClass = "spot"
)

// PriorityValue 优先级数值(越大越高)
var PriorityValue = map[PriorityClass]int32{
    PriorityProduction:   1000000,
    PriorityHighPriority: 100000,
    PriorityNormal:       10000,
    PriorityBestEffort:   1000,
    PrioritySpot:         100,
}

// ElasticJob 弹性作业定义
type ElasticJob struct {
    // 基本信息
    Name      string
    Namespace string

    // GPU 需求
    MinGPUs int32 // 最小 GPU 数
    MaxGPUs int32 // 最大 GPU 数
    GPUType string // GPU 型号要求

    // 优先级
    Priority PriorityClass

    // 队列
    Queue string

    // 弹性配置
    ElasticConfig *ElasticConfig

    // 调度约束
    Constraints *SchedulingConstraints

    // 状态
    Status *ElasticJobStatus
}

// ElasticConfig 弹性配置
type ElasticConfig struct {
    // 是否启用弹性伸缩
    Enabled bool
    // 扩容策略
    ScaleUpPolicy *ScalePolicy
    // 缩容策略
    ScaleDownPolicy *ScalePolicy
    // 扩缩容冷却时间
    CooldownPeriod time.Duration
    // 是否允许被抢占
    Preemptible bool
    // 抢占恢复策略
    PreemptionRecovery PreemptionRecovery
}

// ScalePolicy 扩缩容策略
type ScalePolicy struct {
    // 策略类型
    Type ScalePolicyType
    // 步长
    Step int32
    // 百分比(用于 Percent 类型)
    Percent int32
    // 稳定窗口
    StabilizationWindow time.Duration
}

// ScalePolicyType 策略类型
type ScalePolicyType string

const (
    ScalePolicyStep    ScalePolicyType = "Step"
    ScalePolicyPercent ScalePolicyType = "Percent"
)

// PreemptionRecovery 抢占恢复策略
type PreemptionRecovery string

const (
    // RecoveryRestart 重启任务
    RecoveryRestart PreemptionRecovery = "Restart"
    // RecoveryResume 从检查点恢复
    RecoveryResume PreemptionRecovery = "Resume"
    // RecoveryGraceful 优雅降级(减少 GPU 继续运行)
    RecoveryGraceful PreemptionRecovery = "Graceful"
    // RecoveryFail 标记失败
    RecoveryFail PreemptionRecovery = "Fail"
)

// SchedulingConstraints 调度约束
type SchedulingConstraints struct {
    // 拓扑约束
    TopologyConstraint string
    // 节点选择
    NodeSelector map[string]string
    // 亲和性
    Affinity *v1.Affinity
    // 容忍
    Tolerations []v1.Toleration
    // 最大等待时间
    MaxWaitTime time.Duration
}

// ElasticJobStatus 作业状态
type ElasticJobStatus struct {
    // 当前阶段
    Phase JobPhase
    // 分配的 GPU 数量
    AllocatedGPUs int32
    // 运行的 Pod 数量
    RunningPods int32
    // 等待的 Pod 数量
    PendingPods int32
    // 开始时间
    StartTime *time.Time
    // 完成时间
    CompletionTime *time.Time
    // 上次扩缩容时间
    LastScaleTime *time.Time
    // 被抢占次数
    PreemptionCount int32
    // 条件
    Conditions []JobCondition
}

// JobPhase 作业阶段
type JobPhase string

const (
    JobPhasePending   JobPhase = "Pending"
    JobPhaseRunning   JobPhase = "Running"
    JobPhaseScaling   JobPhase = "Scaling"
    JobPhasePreempted JobPhase = "Preempted"
    JobPhaseSucceeded JobPhase = "Succeeded"
    JobPhaseFailed    JobPhase = "Failed"
)

// JobCondition 作业条件
type JobCondition struct {
    Type               string
    Status             v1.ConditionStatus
    LastTransitionTime time.Time
    Reason             string
    Message            string
}

2. 优先级调度实现

2.1 优先级队列

// pkg/elastic/priority_queue.go
package elastic

import (
    "container/heap"
    "sync"
    "time"
)

// PriorityQueue 优先级队列
type PriorityQueue struct {
    mu      sync.RWMutex
    items   priorityHeap
    itemSet map[string]*QueueItem
}

// QueueItem 队列项
type QueueItem struct {
    Job            *ElasticJob
    EnqueueTime    time.Time
    Priority       int32
    Index          int // heap 索引
}

// priorityHeap 实现 heap.Interface
type priorityHeap []*QueueItem

func (h priorityHeap) Len() int { return len(h) }

func (h priorityHeap) Less(i, j int) bool {
    // 优先级高的排前面
    if h[i].Priority != h[j].Priority {
        return h[i].Priority > h[j].Priority
    }
    // 优先级相同,先入队的排前面(FIFO)
    return h[i].EnqueueTime.Before(h[j].EnqueueTime)
}

func (h priorityHeap) Swap(i, j int) {
    h[i], h[j] = h[j], h[i]
    h[i].Index = i
    h[j].Index = j
}

func (h *priorityHeap) Push(x interface{}) {
    item := x.(*QueueItem)
    item.Index = len(*h)
    *h = append(*h, item)
}

func (h *priorityHeap) Pop() interface{} {
    old := *h
    n := len(old)
    item := old[n-1]
    old[n-1] = nil
    item.Index = -1
    *h = old[0 : n-1]
    return item
}

// NewPriorityQueue 创建优先级队列
func NewPriorityQueue() *PriorityQueue {
    return &PriorityQueue{
        items:   make(priorityHeap, 0),
        itemSet: make(map[string]*QueueItem),
    }
}

// Enqueue 入队
func (q *PriorityQueue) Enqueue(job *ElasticJob) {
    q.mu.Lock()
    defer q.mu.Unlock()

    key := job.Namespace + "/" + job.Name

    // 已存在则更新
    if existing, ok := q.itemSet[key]; ok {
        existing.Job = job
        existing.Priority = PriorityValue[job.Priority]
        heap.Fix(&q.items, existing.Index)
        return
    }

    // 新增
    item := &QueueItem{
        Job:         job,
        EnqueueTime: time.Now(),
        Priority:    PriorityValue[job.Priority],
    }
    heap.Push(&q.items, item)
    q.itemSet[key] = item
}

// Dequeue 出队
func (q *PriorityQueue) Dequeue() *ElasticJob {
    q.mu.Lock()
    defer q.mu.Unlock()

    if q.items.Len() == 0 {
        return nil
    }

    item := heap.Pop(&q.items).(*QueueItem)
    delete(q.itemSet, item.Job.Namespace+"/"+item.Job.Name)
    return item.Job
}

// Peek 查看队首
func (q *PriorityQueue) Peek() *ElasticJob {
    q.mu.RLock()
    defer q.mu.RUnlock()

    if q.items.Len() == 0 {
        return nil
    }
    return q.items[0].Job
}

// Remove 移除
func (q *PriorityQueue) Remove(namespace, name string) *ElasticJob {
    q.mu.Lock()
    defer q.mu.Unlock()

    key := namespace + "/" + name
    item, ok := q.itemSet[key]
    if !ok {
        return nil
    }

    heap.Remove(&q.items, item.Index)
    delete(q.itemSet, key)
    return item.Job
}

// Len 长度
func (q *PriorityQueue) Len() int {
    q.mu.RLock()
    defer q.mu.RUnlock()
    return q.items.Len()
}

// GetJobsAbovePriority 获取高于指定优先级的作业
func (q *PriorityQueue) GetJobsAbovePriority(priority PriorityClass) []*ElasticJob {
    q.mu.RLock()
    defer q.mu.RUnlock()

    threshold := PriorityValue[priority]
    var jobs []*ElasticJob

    for _, item := range q.items {
        if item.Priority > threshold {
            jobs = append(jobs, item.Job)
        }
    }

    return jobs
}

2.2 多队列管理

// pkg/elastic/queue_manager.go
package elastic

import (
    "fmt"
    "sync"
)

// QueueConfig 队列配置
type QueueConfig struct {
    Name string
    // 权重(用于公平调度)
    Weight int32
    // 最大 GPU 配额
    MaxGPUs int32
    // 最小保证 GPU
    MinGPUs int32
    // 允许的优先级
    AllowedPriorities []PriorityClass
    // 是否允许抢占
    AllowPreemption bool
    // 是否允许借用资源
    AllowBorrowing bool
}

// QueueManager 队列管理器
type QueueManager struct {
    mu     sync.RWMutex
    queues map[string]*Queue
    // 全局配置
    defaultQueue string
}

// Queue 队列
type Queue struct {
    Config      QueueConfig
    PriorityQ   *PriorityQueue
    AllocatedGPUs int32
    UsedGPUs    int32
}

// NewQueueManager 创建队列管理器
func NewQueueManager(configs []QueueConfig) *QueueManager {
    qm := &QueueManager{
        queues:       make(map[string]*Queue),
        defaultQueue: "default",
    }

    for _, config := range configs {
        qm.queues[config.Name] = &Queue{
            Config:    config,
            PriorityQ: NewPriorityQueue(),
        }
    }

    return qm
}

// Submit 提交作业到队列
func (qm *QueueManager) Submit(job *ElasticJob) error {
    qm.mu.Lock()
    defer qm.mu.Unlock()

    queueName := job.Queue
    if queueName == "" {
        queueName = qm.defaultQueue
    }

    queue, ok := qm.queues[queueName]
    if !ok {
        return fmt.Errorf("queue %s not found", queueName)
    }

    // 检查优先级是否允许
    if !qm.isPriorityAllowed(queue, job.Priority) {
        return fmt.Errorf("priority %s not allowed in queue %s",
            job.Priority, queueName)
    }

    queue.PriorityQ.Enqueue(job)
    return nil
}

// SelectNext 选择下一个要调度的作业
func (qm *QueueManager) SelectNext(availableGPUs int32) *ElasticJob {
    qm.mu.Lock()
    defer qm.mu.Unlock()

    // 使用加权公平队列选择
    var selectedJob *ElasticJob
    var selectedQueue *Queue
    bestScore := float64(-1)

    for _, queue := range qm.queues {
        job := queue.PriorityQ.Peek()
        if job == nil {
            continue
        }

        // 检查配额
        if queue.UsedGPUs >= queue.Config.MaxGPUs {
            // 检查是否可以借用
            if !queue.Config.AllowBorrowing {
                continue
            }
        }

        // 检查资源是否足够
        if int32(job.MinGPUs) > availableGPUs {
            continue
        }

        // 计算公平调度分数
        score := qm.calculateFairScore(queue, job)
        if score > bestScore {
            bestScore = score
            selectedJob = job
            selectedQueue = queue
        }
    }

    if selectedJob != nil {
        selectedQueue.PriorityQ.Dequeue()
    }

    return selectedJob
}

// calculateFairScore 计算公平调度分数
func (qm *QueueManager) calculateFairScore(queue *Queue, job *ElasticJob) float64 {
    // 基础分数 = 权重 / (已使用 + 1)
    baseScore := float64(queue.Config.Weight) / float64(queue.UsedGPUs+1)

    // 优先级加成
    priorityBonus := float64(PriorityValue[job.Priority]) / 1000000.0

    // 等待时间加成
    // waitingBonus := time.Since(job.Status.StartTime).Minutes() / 60.0

    return baseScore + priorityBonus
}

// isPriorityAllowed 检查优先级是否允许
func (qm *QueueManager) isPriorityAllowed(queue *Queue, priority PriorityClass) bool {
    if len(queue.Config.AllowedPriorities) == 0 {
        return true
    }
    for _, p := range queue.Config.AllowedPriorities {
        if p == priority {
            return true
        }
    }
    return false
}

// GetQueueStatus 获取队列状态
func (qm *QueueManager) GetQueueStatus() map[string]*QueueStatus {
    qm.mu.RLock()
    defer qm.mu.RUnlock()

    status := make(map[string]*QueueStatus)
    for name, queue := range qm.queues {
        status[name] = &QueueStatus{
            Name:           name,
            PendingJobs:    queue.PriorityQ.Len(),
            AllocatedGPUs:  queue.AllocatedGPUs,
            UsedGPUs:       queue.UsedGPUs,
            MaxGPUs:        queue.Config.MaxGPUs,
            MinGPUs:        queue.Config.MinGPUs,
            Weight:         queue.Config.Weight,
        }
    }
    return status
}

// QueueStatus 队列状态
type QueueStatus struct {
    Name          string
    PendingJobs   int
    AllocatedGPUs int32
    UsedGPUs      int32
    MaxGPUs       int32
    MinGPUs       int32
    Weight        int32
}

3. 抢占式调度

3.1 抢占管理器

// pkg/elastic/preemption.go
package elastic

import (
    "context"
    "fmt"
    "sort"
    "time"

    v1 "k8s.io/api/core/v1"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    "k8s.io/client-go/kubernetes"
)

// PreemptionManager 抢占管理器
type PreemptionManager struct {
    kubeClient   kubernetes.Interface
    stateManager *GPUStateManager
    // 抢占策略
    strategy PreemptionStrategy
    // 优雅终止时间
    gracePeriod time.Duration
    // 最小运行时间保护
    minRuntime time.Duration
}

// PreemptionStrategy 抢占策略
type PreemptionStrategy string

const (
    // StrategyLowestPriority 抢占最低优先级
    StrategyLowestPriority PreemptionStrategy = "LowestPriority"
    // StrategyNewestFirst 抢占最新启动的
    StrategyNewestFirst PreemptionStrategy = "NewestFirst"
    // StrategyMinimumVictims 最小受害者
    StrategyMinimumVictims PreemptionStrategy = "MinimumVictims"
    // StrategyMinimumGPUWaste 最小 GPU 浪费
    StrategyMinimumGPUWaste PreemptionStrategy = "MinimumGPUWaste"
)

// PreemptionCandidate 抢占候选者
type PreemptionCandidate struct {
    Pod       *v1.Pod
    Job       *ElasticJob
    GPUs      int32
    Priority  int32
    StartTime time.Time
    NodeName  string
}

// PreemptionPlan 抢占计划
type PreemptionPlan struct {
    // 请求者
    Requester *ElasticJob
    // 受害者列表
    Victims []*PreemptionCandidate
    // 释放的 GPU 数量
    FreedGPUs int32
    // 目标节点
    TargetNodes []string
}

// NewPreemptionManager 创建抢占管理器
func NewPreemptionManager(kubeClient kubernetes.Interface,
    stateManager *GPUStateManager) *PreemptionManager {
    return &PreemptionManager{
        kubeClient:   kubeClient,
        stateManager: stateManager,
        strategy:     StrategyLowestPriority,
        gracePeriod:  30 * time.Second,
        minRuntime:   5 * time.Minute,
    }
}

// FindPreemptionCandidates 找到抢占候选者
func (pm *PreemptionManager) FindPreemptionCandidates(ctx context.Context,
    requester *ElasticJob, neededGPUs int32) (*PreemptionPlan, error) {

    requesterPriority := PriorityValue[requester.Priority]

    // 获取所有运行中的 GPU Pod
    runningPods, err := pm.getGPUPods(ctx)
    if err != nil {
        return nil, err
    }

    // 筛选可抢占的候选者
    var candidates []*PreemptionCandidate
    for _, pod := range runningPods {
        candidate := pm.evaluateCandidate(pod, requesterPriority)
        if candidate != nil {
            candidates = append(candidates, candidate)
        }
    }

    if len(candidates) == 0 {
        return nil, fmt.Errorf("no preemptible candidates found")
    }

    // 根据策略选择受害者
    victims := pm.selectVictims(candidates, neededGPUs)
    if victims == nil {
        return nil, fmt.Errorf("cannot find enough GPUs to preempt")
    }

    // 计算释放的 GPU 总数
    var freedGPUs int32
    targetNodes := make(map[string]bool)
    for _, v := range victims {
        freedGPUs += v.GPUs
        targetNodes[v.NodeName] = true
    }

    nodes := make([]string, 0, len(targetNodes))
    for node := range targetNodes {
        nodes = append(nodes, node)
    }

    return &PreemptionPlan{
        Requester:   requester,
        Victims:     victims,
        FreedGPUs:   freedGPUs,
        TargetNodes: nodes,
    }, nil
}

// evaluateCandidate 评估候选者
func (pm *PreemptionManager) evaluateCandidate(pod *v1.Pod,
    requesterPriority int32) *PreemptionCandidate {

    // 检查优先级
    podPriority := int32(0)
    if pod.Spec.Priority != nil {
        podPriority = *pod.Spec.Priority
    }

    // 只能抢占优先级更低的
    if podPriority >= requesterPriority {
        return nil
    }

    // 检查是否可抢占
    if pm.isProtected(pod) {
        return nil
    }

    // 获取 GPU 数量
    gpus := pm.getPodGPUs(pod)
    if gpus == 0 {
        return nil
    }

    // 检查最小运行时间
    startTime := pod.Status.StartTime
    if startTime != nil {
        runtime := time.Since(startTime.Time)
        if runtime < pm.minRuntime {
            return nil
        }
    }

    return &PreemptionCandidate{
        Pod:       pod,
        GPUs:      gpus,
        Priority:  podPriority,
        StartTime: startTime.Time,
        NodeName:  pod.Spec.NodeName,
    }
}

// selectVictims 选择受害者
func (pm *PreemptionManager) selectVictims(candidates []*PreemptionCandidate,
    neededGPUs int32) []*PreemptionCandidate {

    switch pm.strategy {
    case StrategyLowestPriority:
        return pm.selectLowestPriority(candidates, neededGPUs)
    case StrategyNewestFirst:
        return pm.selectNewestFirst(candidates, neededGPUs)
    case StrategyMinimumVictims:
        return pm.selectMinimumVictims(candidates, neededGPUs)
    case StrategyMinimumGPUWaste:
        return pm.selectMinimumWaste(candidates, neededGPUs)
    default:
        return pm.selectLowestPriority(candidates, neededGPUs)
    }
}

// selectLowestPriority 按优先级从低到高选择
func (pm *PreemptionManager) selectLowestPriority(candidates []*PreemptionCandidate,
    neededGPUs int32) []*PreemptionCandidate {

    // 按优先级升序排序
    sort.Slice(candidates, func(i, j int) bool {
        if candidates[i].Priority != candidates[j].Priority {
            return candidates[i].Priority < candidates[j].Priority
        }
        // 同优先级,选择启动时间较晚的
        return candidates[i].StartTime.After(candidates[j].StartTime)
    })

    var victims []*PreemptionCandidate
    var freedGPUs int32

    for _, c := range candidates {
        victims = append(victims, c)
        freedGPUs += c.GPUs
        if freedGPUs >= neededGPUs {
            return victims
        }
    }

    return nil // 无法满足需求
}

// selectNewestFirst 优先选择最新启动的
func (pm *PreemptionManager) selectNewestFirst(candidates []*PreemptionCandidate,
    neededGPUs int32) []*PreemptionCandidate {

    // 按启动时间降序排序(最新的在前)
    sort.Slice(candidates, func(i, j int) bool {
        return candidates[i].StartTime.After(candidates[j].StartTime)
    })

    var victims []*PreemptionCandidate
    var freedGPUs int32

    for _, c := range candidates {
        victims = append(victims, c)
        freedGPUs += c.GPUs
        if freedGPUs >= neededGPUs {
            return victims
        }
    }

    return nil
}

// selectMinimumVictims 选择最少数量的受害者
func (pm *PreemptionManager) selectMinimumVictims(candidates []*PreemptionCandidate,
    neededGPUs int32) []*PreemptionCandidate {

    // 按 GPU 数量降序排序
    sort.Slice(candidates, func(i, j int) bool {
        return candidates[i].GPUs > candidates[j].GPUs
    })

    var victims []*PreemptionCandidate
    var freedGPUs int32

    for _, c := range candidates {
        victims = append(victims, c)
        freedGPUs += c.GPUs
        if freedGPUs >= neededGPUs {
            return victims
        }
    }

    return nil
}

// selectMinimumWaste 选择最小浪费的组合
func (pm *PreemptionManager) selectMinimumWaste(candidates []*PreemptionCandidate,
    neededGPUs int32) []*PreemptionCandidate {

    // 使用动态规划找最接近 neededGPUs 的组合
    n := len(candidates)
    target := int(neededGPUs)

    // dp[i] 表示是否可以达到 GPU 数量 i
    dp := make([]bool, target+1)
    dp[0] = true
    parent := make([]int, target+1)
    for i := range parent {
        parent[i] = -1
    }

    for idx, c := range candidates {
        gpus := int(c.GPUs)
        for i := target; i >= gpus; i-- {
            if dp[i-gpus] && !dp[i] {
                dp[i] = true
                parent[i] = idx
            }
        }
    }

    // 找到最小的 >= neededGPUs 的值
    for i := target; i <= target+10 && i <= n*8; i++ {
        if dp[i] {
            // 回溯找出选择的候选者
            var victims []*PreemptionCandidate
            remaining := i
            for remaining > 0 && parent[remaining] >= 0 {
                idx := parent[remaining]
                victims = append(victims, candidates[idx])
                remaining -= int(candidates[idx].GPUs)
            }
            return victims
        }
    }

    // 回退到最小受害者策略
    return pm.selectMinimumVictims(candidates, neededGPUs)
}

// ExecutePreemption 执行抢占
func (pm *PreemptionManager) ExecutePreemption(ctx context.Context,
    plan *PreemptionPlan) error {

    for _, victim := range plan.Victims {
        // 发送抢占信号
        if err := pm.preemptPod(ctx, victim); err != nil {
            return fmt.Errorf("failed to preempt pod %s: %v",
                victim.Pod.Name, err)
        }
    }

    return nil
}

// preemptPod 抢占单个 Pod
func (pm *PreemptionManager) preemptPod(ctx context.Context,
    victim *PreemptionCandidate) error {

    pod := victim.Pod

    // 添加抢占注解
    if pod.Annotations == nil {
        pod.Annotations = make(map[string]string)
    }
    pod.Annotations["elastic.gpu/preempted"] = "true"
    pod.Annotations["elastic.gpu/preemption-time"] = time.Now().Format(time.RFC3339)

    // 更新 Pod
    _, err := pm.kubeClient.CoreV1().Pods(pod.Namespace).Update(ctx, pod, metav1.UpdateOptions{})
    if err != nil {
        return err
    }

    // 删除 Pod(带优雅终止)
    gracePeriod := int64(pm.gracePeriod.Seconds())
    return pm.kubeClient.CoreV1().Pods(pod.Namespace).Delete(ctx, pod.Name,
        metav1.DeleteOptions{
            GracePeriodSeconds: &gracePeriod,
        })
}

// isProtected 检查 Pod 是否受保护
func (pm *PreemptionManager) isProtected(pod *v1.Pod) bool {
    // 检查注解
    if pod.Annotations != nil {
        if v, ok := pod.Annotations["elastic.gpu/preemptible"]; ok && v == "false" {
            return true
        }
    }

    // 检查 PodDisruptionBudget
    // ...

    return false
}

// getPodGPUs 获取 Pod 的 GPU 数量
func (pm *PreemptionManager) getPodGPUs(pod *v1.Pod) int32 {
    var total int32
    for _, c := range pod.Spec.Containers {
        if c.Resources.Limits != nil {
            if gpu, ok := c.Resources.Limits["nvidia.com/gpu"]; ok {
                total += int32(gpu.Value())
            }
        }
    }
    return total
}

// getGPUPods 获取所有 GPU Pod
func (pm *PreemptionManager) getGPUPods(ctx context.Context) ([]*v1.Pod, error) {
    podList, err := pm.kubeClient.CoreV1().Pods("").List(ctx, metav1.ListOptions{
        FieldSelector: "status.phase=Running",
    })
    if err != nil {
        return nil, err
    }

    var gpuPods []*v1.Pod
    for i := range podList.Items {
        pod := &podList.Items[i]
        if pm.getPodGPUs(pod) > 0 {
            gpuPods = append(gpuPods, pod)
        }
    }

    return gpuPods, nil
}

3.2 抢占恢复

// pkg/elastic/recovery.go
package elastic

import (
    "context"
    "fmt"
    "time"

    batchv1 "k8s.io/api/batch/v1"
    v1 "k8s.io/api/core/v1"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

// RecoveryManager 恢复管理器
type RecoveryManager struct {
    kubeClient   kubernetes.Interface
    checkpointer Checkpointer
}

// Checkpointer 检查点接口
type Checkpointer interface {
    // SaveCheckpoint 保存检查点
    SaveCheckpoint(ctx context.Context, job *ElasticJob) (*Checkpoint, error)
    // LoadCheckpoint 加载检查点
    LoadCheckpoint(ctx context.Context, job *ElasticJob) (*Checkpoint, error)
    // ListCheckpoints 列出检查点
    ListCheckpoints(ctx context.Context, job *ElasticJob) ([]*Checkpoint, error)
}

// Checkpoint 检查点
type Checkpoint struct {
    ID        string
    JobName   string
    Namespace string
    // 保存时间
    Timestamp time.Time
    // 检查点路径
    Path string
    // 训练步数
    Step int64
    // 元数据
    Metadata map[string]string
}

// RecoverJob 恢复作业
func (rm *RecoveryManager) RecoverJob(ctx context.Context, job *ElasticJob) error {
    strategy := job.ElasticConfig.PreemptionRecovery

    switch strategy {
    case RecoveryRestart:
        return rm.recoverByRestart(ctx, job)
    case RecoveryResume:
        return rm.recoverByResume(ctx, job)
    case RecoveryGraceful:
        return rm.recoverGracefully(ctx, job)
    case RecoveryFail:
        return rm.markFailed(ctx, job)
    default:
        return rm.recoverByRestart(ctx, job)
    }
}

// recoverByRestart 通过重启恢复
func (rm *RecoveryManager) recoverByRestart(ctx context.Context, job *ElasticJob) error {
    // 简单重新创建 Pod
    return rm.recreateJob(ctx, job)
}

// recoverByResume 从检查点恢复
func (rm *RecoveryManager) recoverByResume(ctx context.Context, job *ElasticJob) error {
    // 获取最新检查点
    checkpoint, err := rm.checkpointer.LoadCheckpoint(ctx, job)
    if err != nil {
        // 没有检查点,回退到重启
        return rm.recoverByRestart(ctx, job)
    }

    // 使用检查点恢复
    return rm.recreateJobWithCheckpoint(ctx, job, checkpoint)
}

// recoverGracefully 优雅降级恢复
func (rm *RecoveryManager) recoverGracefully(ctx context.Context, job *ElasticJob) error {
    // 检查是否支持弹性训练
    if !job.ElasticConfig.Enabled {
        return rm.recoverByRestart(ctx, job)
    }

    // 获取当前可用资源
    availableGPUs := rm.getAvailableGPUs(ctx)

    // 计算新的 GPU 数量
    newGPUs := availableGPUs
    if newGPUs < job.MinGPUs {
        // 资源不足,等待
        return fmt.Errorf("insufficient GPUs for graceful recovery")
    }

    // 以更少的 GPU 继续运行
    return rm.recreateJobWithGPUs(ctx, job, newGPUs)
}

// markFailed 标记失败
func (rm *RecoveryManager) markFailed(ctx context.Context, job *ElasticJob) error {
    job.Status.Phase = JobPhaseFailed
    return rm.updateJobStatus(ctx, job)
}

// recreateJob 重建作业
func (rm *RecoveryManager) recreateJob(ctx context.Context, job *ElasticJob) error {
    // 创建新的 Pod/Job
    k8sJob := rm.buildK8sJob(job, nil)
    _, err := rm.kubeClient.BatchV1().Jobs(job.Namespace).Create(ctx, k8sJob, metav1.CreateOptions{})
    return err
}

// recreateJobWithCheckpoint 使用检查点重建
func (rm *RecoveryManager) recreateJobWithCheckpoint(ctx context.Context,
    job *ElasticJob, checkpoint *Checkpoint) error {

    // 构建带检查点的 Job
    k8sJob := rm.buildK8sJob(job, checkpoint)

    // 添加检查点环境变量
    for i := range k8sJob.Spec.Template.Spec.Containers {
        k8sJob.Spec.Template.Spec.Containers[i].Env = append(
            k8sJob.Spec.Template.Spec.Containers[i].Env,
            v1.EnvVar{Name: "CHECKPOINT_PATH", Value: checkpoint.Path},
            v1.EnvVar{Name: "RESUME_STEP", Value: fmt.Sprintf("%d", checkpoint.Step)},
        )
    }

    _, err := rm.kubeClient.BatchV1().Jobs(job.Namespace).Create(ctx, k8sJob, metav1.CreateOptions{})
    return err
}

// recreateJobWithGPUs 使用指定 GPU 数量重建
func (rm *RecoveryManager) recreateJobWithGPUs(ctx context.Context,
    job *ElasticJob, gpus int32) error {

    // 更新 GPU 数量
    originalGPUs := job.Status.AllocatedGPUs
    job.Status.AllocatedGPUs = gpus

    // 构建 Job
    k8sJob := rm.buildK8sJob(job, nil)

    // 添加弹性训练环境变量
    for i := range k8sJob.Spec.Template.Spec.Containers {
        k8sJob.Spec.Template.Spec.Containers[i].Env = append(
            k8sJob.Spec.Template.Spec.Containers[i].Env,
            v1.EnvVar{Name: "ELASTIC_GPUS", Value: fmt.Sprintf("%d", gpus)},
            v1.EnvVar{Name: "ORIGINAL_GPUS", Value: fmt.Sprintf("%d", originalGPUs)},
            v1.EnvVar{Name: "ELASTIC_RESIZE", Value: "true"},
        )
    }

    _, err := rm.kubeClient.BatchV1().Jobs(job.Namespace).Create(ctx, k8sJob, metav1.CreateOptions{})
    return err
}

// buildK8sJob 构建 K8s Job
func (rm *RecoveryManager) buildK8sJob(job *ElasticJob, checkpoint *Checkpoint) *batchv1.Job {
    gpus := job.Status.AllocatedGPUs
    if gpus == 0 {
        gpus = job.MinGPUs
    }

    parallelism := gpus // 每个 GPU 一个 Pod
    completions := parallelism

    return &batchv1.Job{
        ObjectMeta: metav1.ObjectMeta{
            Name:      job.Name,
            Namespace: job.Namespace,
            Labels: map[string]string{
                "elastic.gpu/job-name": job.Name,
                "elastic.gpu/queue":    job.Queue,
                "elastic.gpu/priority": string(job.Priority),
            },
            Annotations: map[string]string{
                "elastic.gpu/preemptible": fmt.Sprintf("%v", job.ElasticConfig.Preemptible),
            },
        },
        Spec: batchv1.JobSpec{
            Parallelism: &parallelism,
            Completions: &completions,
            Template: v1.PodTemplateSpec{
                ObjectMeta: metav1.ObjectMeta{
                    Labels: map[string]string{
                        "elastic.gpu/job-name": job.Name,
                    },
                },
                Spec: v1.PodSpec{
                    SchedulerName: "gpu-scheduler",
                    Containers: []v1.Container{
                        {
                            Name:  "main",
                            Image: "training-image:latest",
                            Resources: v1.ResourceRequirements{
                                Limits: v1.ResourceList{
                                    "nvidia.com/gpu": *resource.NewQuantity(1, resource.DecimalSI),
                                },
                                Requests: v1.ResourceList{
                                    "nvidia.com/gpu": *resource.NewQuantity(1, resource.DecimalSI),
                                },
                            },
                        },
                    },
                    RestartPolicy: v1.RestartPolicyNever,
                },
            },
        },
    }
}

// getAvailableGPUs 获取可用 GPU 数量
func (rm *RecoveryManager) getAvailableGPUs(ctx context.Context) int32 {
    // 从集群获取可用 GPU
    // 简化实现
    return 4
}

// updateJobStatus 更新作业状态
func (rm *RecoveryManager) updateJobStatus(ctx context.Context, job *ElasticJob) error {
    // 更新 CRD 状态
    return nil
}

4. 弹性伸缩实现

4.1 自动伸缩控制器

// pkg/elastic/autoscaler.go
package elastic

import (
    "context"
    "sync"
    "time"
)

// GPUAutoscaler GPU 自动伸缩器
type GPUAutoscaler struct {
    mu sync.RWMutex
    // 伸缩目标
    targets map[string]*ScaleTarget
    // 度量收集器
    metricsCollector MetricsCollector
    // 伸缩执行器
    scaleExecutor ScaleExecutor
    // 配置
    config *AutoscalerConfig
}

// ScaleTarget 伸缩目标
type ScaleTarget struct {
    Job       *ElasticJob
    CurrentGPUs int32
    DesiredGPUs int32
    // 上次伸缩时间
    LastScaleTime time.Time
    // 伸缩方向
    ScaleDirection ScaleDirection
    // 度量历史
    MetricsHistory []MetricsSample
}

// ScaleDirection 伸缩方向
type ScaleDirection string

const (
    ScaleUp   ScaleDirection = "up"
    ScaleDown ScaleDirection = "down"
    ScaleNone ScaleDirection = "none"
)

// MetricsSample 度量样本
type MetricsSample struct {
    Timestamp      time.Time
    GPUUtilization float64
    MemoryUsage    float64
    Throughput     float64
    QueueLength    int32
}

// MetricsCollector 度量收集器接口
type MetricsCollector interface {
    CollectMetrics(ctx context.Context, job *ElasticJob) (*MetricsSample, error)
}

// ScaleExecutor 伸缩执行器接口
type ScaleExecutor interface {
    ScaleUp(ctx context.Context, job *ElasticJob, delta int32) error
    ScaleDown(ctx context.Context, job *ElasticJob, delta int32) error
}

// AutoscalerConfig 自动伸缩配置
type AutoscalerConfig struct {
    // 扩容阈值(GPU 利用率)
    ScaleUpThreshold float64
    // 缩容阈值
    ScaleDownThreshold float64
    // 冷却时间
    CooldownPeriod time.Duration
    // 度量采集间隔
    MetricsInterval time.Duration
    // 稳定窗口
    StabilizationWindow time.Duration
    // 最大伸缩步长
    MaxScaleStep int32
}

// NewGPUAutoscaler 创建自动伸缩器
func NewGPUAutoscaler(metricsCollector MetricsCollector,
    scaleExecutor ScaleExecutor, config *AutoscalerConfig) *GPUAutoscaler {

    if config == nil {
        config = &AutoscalerConfig{
            ScaleUpThreshold:    0.9,
            ScaleDownThreshold:  0.3,
            CooldownPeriod:      5 * time.Minute,
            MetricsInterval:     30 * time.Second,
            StabilizationWindow: 3 * time.Minute,
            MaxScaleStep:        4,
        }
    }

    return &GPUAutoscaler{
        targets:          make(map[string]*ScaleTarget),
        metricsCollector: metricsCollector,
        scaleExecutor:    scaleExecutor,
        config:           config,
    }
}

// Register 注册伸缩目标
func (a *GPUAutoscaler) Register(job *ElasticJob) {
    a.mu.Lock()
    defer a.mu.Unlock()

    key := job.Namespace + "/" + job.Name
    a.targets[key] = &ScaleTarget{
        Job:         job,
        CurrentGPUs: job.Status.AllocatedGPUs,
        DesiredGPUs: job.Status.AllocatedGPUs,
    }
}

// Unregister 注销伸缩目标
func (a *GPUAutoscaler) Unregister(namespace, name string) {
    a.mu.Lock()
    defer a.mu.Unlock()
    delete(a.targets, namespace+"/"+name)
}

// Run 运行自动伸缩循环
func (a *GPUAutoscaler) Run(ctx context.Context) {
    ticker := time.NewTicker(a.config.MetricsInterval)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-ticker.C:
            a.reconcile(ctx)
        }
    }
}

// reconcile 协调循环
func (a *GPUAutoscaler) reconcile(ctx context.Context) {
    a.mu.Lock()
    targets := make([]*ScaleTarget, 0, len(a.targets))
    for _, t := range a.targets {
        targets = append(targets, t)
    }
    a.mu.Unlock()

    for _, target := range targets {
        a.processTarget(ctx, target)
    }
}

// processTarget 处理单个目标
func (a *GPUAutoscaler) processTarget(ctx context.Context, target *ScaleTarget) {
    job := target.Job

    // 检查是否启用弹性
    if !job.ElasticConfig.Enabled {
        return
    }

    // 收集度量
    metrics, err := a.metricsCollector.CollectMetrics(ctx, job)
    if err != nil {
        return
    }

    // 添加到历史
    target.MetricsHistory = append(target.MetricsHistory, *metrics)

    // 保留最近的样本
    windowSamples := int(a.config.StabilizationWindow / a.config.MetricsInterval)
    if len(target.MetricsHistory) > windowSamples {
        target.MetricsHistory = target.MetricsHistory[len(target.MetricsHistory)-windowSamples:]
    }

    // 计算伸缩决策
    decision := a.calculateDecision(target)

    // 执行伸缩
    if decision.Direction != ScaleNone {
        a.executeScale(ctx, target, decision)
    }
}

// ScaleDecision 伸缩决策
type ScaleDecision struct {
    Direction ScaleDirection
    Delta     int32
    Reason    string
}

// calculateDecision 计算伸缩决策
func (a *GPUAutoscaler) calculateDecision(target *ScaleTarget) *ScaleDecision {
    job := target.Job

    // 检查冷却期
    if time.Since(target.LastScaleTime) < a.config.CooldownPeriod {
        return &ScaleDecision{Direction: ScaleNone, Reason: "in cooldown period"}
    }

    // 计算平均利用率
    avgUtilization := a.calculateAverageUtilization(target.MetricsHistory)

    // 扩容判断
    if avgUtilization > a.config.ScaleUpThreshold {
        if target.CurrentGPUs >= job.MaxGPUs {
            return &ScaleDecision{Direction: ScaleNone, Reason: "at max GPUs"}
        }

        delta := a.calculateScaleUpDelta(target, avgUtilization)
        return &ScaleDecision{
            Direction: ScaleUp,
            Delta:     delta,
            Reason:    fmt.Sprintf("high utilization: %.2f%%", avgUtilization*100),
        }
    }

    // 缩容判断
    if avgUtilization < a.config.ScaleDownThreshold {
        if target.CurrentGPUs <= job.MinGPUs {
            return &ScaleDecision{Direction: ScaleNone, Reason: "at min GPUs"}
        }

        delta := a.calculateScaleDownDelta(target, avgUtilization)
        return &ScaleDecision{
            Direction: ScaleDown,
            Delta:     delta,
            Reason:    fmt.Sprintf("low utilization: %.2f%%", avgUtilization*100),
        }
    }

    return &ScaleDecision{Direction: ScaleNone, Reason: "utilization in normal range"}
}

// calculateAverageUtilization 计算平均利用率
func (a *GPUAutoscaler) calculateAverageUtilization(history []MetricsSample) float64 {
    if len(history) == 0 {
        return 0
    }

    var sum float64
    for _, sample := range history {
        sum += sample.GPUUtilization
    }
    return sum / float64(len(history))
}

// calculateScaleUpDelta 计算扩容数量
func (a *GPUAutoscaler) calculateScaleUpDelta(target *ScaleTarget, utilization float64) int32 {
    job := target.Job
    policy := job.ElasticConfig.ScaleUpPolicy

    var delta int32

    switch policy.Type {
    case ScalePolicyStep:
        delta = policy.Step
    case ScalePolicyPercent:
        delta = target.CurrentGPUs * policy.Percent / 100
        if delta < 1 {
            delta = 1
        }
    default:
        delta = 1
    }

    // 限制最大步长
    if delta > a.config.MaxScaleStep {
        delta = a.config.MaxScaleStep
    }

    // 不超过最大值
    if target.CurrentGPUs+delta > job.MaxGPUs {
        delta = job.MaxGPUs - target.CurrentGPUs
    }

    return delta
}

// calculateScaleDownDelta 计算缩容数量
func (a *GPUAutoscaler) calculateScaleDownDelta(target *ScaleTarget, utilization float64) int32 {
    job := target.Job
    policy := job.ElasticConfig.ScaleDownPolicy

    var delta int32

    switch policy.Type {
    case ScalePolicyStep:
        delta = policy.Step
    case ScalePolicyPercent:
        delta = target.CurrentGPUs * policy.Percent / 100
        if delta < 1 {
            delta = 1
        }
    default:
        delta = 1
    }

    // 限制最大步长
    if delta > a.config.MaxScaleStep {
        delta = a.config.MaxScaleStep
    }

    // 不低于最小值
    if target.CurrentGPUs-delta < job.MinGPUs {
        delta = target.CurrentGPUs - job.MinGPUs
    }

    return delta
}

// executeScale 执行伸缩
func (a *GPUAutoscaler) executeScale(ctx context.Context, target *ScaleTarget,
    decision *ScaleDecision) {

    var err error

    switch decision.Direction {
    case ScaleUp:
        err = a.scaleExecutor.ScaleUp(ctx, target.Job, decision.Delta)
    case ScaleDown:
        err = a.scaleExecutor.ScaleDown(ctx, target.Job, decision.Delta)
    }

    if err != nil {
        return
    }

    // 更新状态
    a.mu.Lock()
    defer a.mu.Unlock()

    switch decision.Direction {
    case ScaleUp:
        target.CurrentGPUs += decision.Delta
    case ScaleDown:
        target.CurrentGPUs -= decision.Delta
    }
    target.LastScaleTime = time.Now()
    target.ScaleDirection = decision.Direction
}

4.2 弹性训练支持

# elastic_training.py
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.etcd_server import EtcdRendezvousBackend
import os
import signal

class ElasticTrainer:
    """弹性训练器"""

    def __init__(self, model, optimizer, checkpoint_dir):
        self.model = model
        self.optimizer = optimizer
        self.checkpoint_dir = checkpoint_dir
        self.step = 0
        self.epoch = 0

        # 注册信号处理
        signal.signal(signal.SIGTERM, self._handle_preemption)
        signal.signal(signal.SIGUSR1, self._handle_resize)

        # 初始化分布式
        self._init_distributed()

    def _init_distributed(self):
        """初始化分布式环境"""
        if not dist.is_initialized():
            # 使用弹性启动
            dist.init_process_group(
                backend='nccl',
                init_method='env://'
            )

        self.world_size = dist.get_world_size()
        self.rank = dist.get_rank()
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))

        # 设置设备
        torch.cuda.set_device(self.local_rank)

        # 包装模型
        self.model = torch.nn.parallel.DistributedDataParallel(
            self.model.cuda(),
            device_ids=[self.local_rank]
        )

        print(f"Initialized rank {self.rank}/{self.world_size}")

    def _handle_preemption(self, signum, frame):
        """处理抢占信号"""
        print(f"Rank {self.rank}: Received preemption signal")
        self.save_checkpoint("preempt")
        dist.destroy_process_group()
        exit(0)

    def _handle_resize(self, signum, frame):
        """处理伸缩信号"""
        print(f"Rank {self.rank}: Received resize signal")
        self.save_checkpoint("resize")
        # 重新初始化分布式
        dist.destroy_process_group()
        self._init_distributed()
        self.load_checkpoint()

    def save_checkpoint(self, reason="periodic"):
        """保存检查点"""
        if self.rank != 0:
            return

        checkpoint = {
            'step': self.step,
            'epoch': self.epoch,
            'model_state_dict': self.model.module.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'world_size': self.world_size,
            'reason': reason
        }

        path = os.path.join(
            self.checkpoint_dir,
            f"checkpoint_{self.step}.pt"
        )
        torch.save(checkpoint, path)

        # 创建最新检查点链接
        latest_path = os.path.join(self.checkpoint_dir, "latest.pt")
        if os.path.exists(latest_path):
            os.remove(latest_path)
        os.symlink(path, latest_path)

        print(f"Saved checkpoint at step {self.step}, reason: {reason}")

    def load_checkpoint(self):
        """加载检查点"""
        latest_path = os.path.join(self.checkpoint_dir, "latest.pt")
        if not os.path.exists(latest_path):
            return False

        checkpoint = torch.load(latest_path, map_location='cuda')

        self.step = checkpoint['step']
        self.epoch = checkpoint['epoch']
        self.model.module.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # 处理 world_size 变化
        old_world_size = checkpoint['world_size']
        if old_world_size != self.world_size:
            print(f"World size changed: {old_world_size} -> {self.world_size}")
            # 重新调整 optimizer 状态
            self._adjust_optimizer_state(old_world_size)

        print(f"Loaded checkpoint from step {self.step}")
        return True

    def _adjust_optimizer_state(self, old_world_size):
        """调整 optimizer 状态适应新的 world_size"""
        # 对于 Adam 等有 momentum 的优化器,可能需要重置
        scale = old_world_size / self.world_size
        for param_group in self.optimizer.param_groups:
            param_group['lr'] *= scale

    def train_step(self, batch):
        """单步训练"""
        self.model.train()

        inputs, targets = batch
        inputs = inputs.cuda()
        targets = targets.cuda()

        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, targets)
        loss.backward()

        # 同步梯度
        self.optimizer.step()

        self.step += 1

        return loss.item()

    def train_epoch(self, dataloader, save_interval=100):
        """训练一个 epoch"""
        for batch_idx, batch in enumerate(dataloader):
            loss = self.train_step(batch)

            # 定期保存检查点
            if self.step % save_interval == 0:
                self.save_checkpoint("periodic")

            if self.rank == 0 and batch_idx % 10 == 0:
                print(f"Step {self.step}, Loss: {loss:.4f}")

        self.epoch += 1


def main():
    """弹性训练主函数"""
    # 配置
    checkpoint_dir = os.environ.get('CHECKPOINT_DIR', '/checkpoints')

    # 创建模型和优化器
    model = YourModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 创建弹性训练器
    trainer = ElasticTrainer(model, optimizer, checkpoint_dir)

    # 尝试恢复检查点
    trainer.load_checkpoint()

    # 创建数据加载器(需要使用 DistributedSampler)
    dataset = YourDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=trainer.world_size,
        rank=trainer.rank
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,
        num_workers=4
    )

    # 训练
    for epoch in range(trainer.epoch, 100):
        sampler.set_epoch(epoch)
        trainer.train_epoch(dataloader)

    print("Training completed!")


if __name__ == '__main__':
    main()

5. Gang Scheduling

5.1 Gang 调度器

// pkg/elastic/gang_scheduler.go
package elastic

import (
    "context"
    "fmt"
    "sync"
    "time"

    v1 "k8s.io/api/core/v1"
)

// GangScheduler Gang 调度器
type GangScheduler struct {
    mu sync.RWMutex
    // 待处理的 Gang
    pendingGangs map[string]*Gang
    // GPU 状态管理
    gpuStateManager *GPUStateManager
    // 配置
    config *GangConfig
}

// Gang 一组需要同时调度的 Pod
type Gang struct {
    Name        string
    Namespace   string
    // Pod 模板列表
    Members     []*GangMember
    // 最小成员数(用于弹性 Gang)
    MinMembers  int
    // 创建时间
    CreateTime  time.Time
    // 超时时间
    Timeout     time.Duration
    // 状态
    Status      GangStatus
}

// GangMember Gang 成员
type GangMember struct {
    Name       string
    GPUs       int32
    NodeName   string // 调度后填充
    Scheduled  bool
}

// GangStatus Gang 状态
type GangStatus string

const (
    GangPending   GangStatus = "Pending"
    GangScheduled GangStatus = "Scheduled"
    GangPartial   GangStatus = "Partial"  // 部分调度(弹性 Gang)
    GangTimeout   GangStatus = "Timeout"
    GangFailed    GangStatus = "Failed"
)

// GangConfig Gang 配置
type GangConfig struct {
    // 默认超时
    DefaultTimeout time.Duration
    // 是否允许部分调度
    AllowPartial bool
    // 调度间隔
    ScheduleInterval time.Duration
}

// NewGangScheduler 创建 Gang 调度器
func NewGangScheduler(gpuStateManager *GPUStateManager, config *GangConfig) *GangScheduler {
    if config == nil {
        config = &GangConfig{
            DefaultTimeout:   5 * time.Minute,
            AllowPartial:     false,
            ScheduleInterval: 10 * time.Second,
        }
    }

    return &GangScheduler{
        pendingGangs:    make(map[string]*Gang),
        gpuStateManager: gpuStateManager,
        config:          config,
    }
}

// Submit 提交 Gang
func (gs *GangScheduler) Submit(gang *Gang) error {
    gs.mu.Lock()
    defer gs.mu.Unlock()

    key := gang.Namespace + "/" + gang.Name
    if _, exists := gs.pendingGangs[key]; exists {
        return fmt.Errorf("gang %s already exists", key)
    }

    gang.CreateTime = time.Now()
    if gang.Timeout == 0 {
        gang.Timeout = gs.config.DefaultTimeout
    }
    gang.Status = GangPending

    gs.pendingGangs[key] = gang
    return nil
}

// Run 运行调度循环
func (gs *GangScheduler) Run(ctx context.Context) {
    ticker := time.NewTicker(gs.config.ScheduleInterval)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-ticker.C:
            gs.scheduleGangs(ctx)
        }
    }
}

// scheduleGangs 调度所有待处理的 Gang
func (gs *GangScheduler) scheduleGangs(ctx context.Context) {
    gs.mu.Lock()
    gangs := make([]*Gang, 0, len(gs.pendingGangs))
    for _, g := range gs.pendingGangs {
        gangs = append(gangs, g)
    }
    gs.mu.Unlock()

    for _, gang := range gangs {
        // 检查超时
        if time.Since(gang.CreateTime) > gang.Timeout {
            gs.handleTimeout(gang)
            continue
        }

        // 尝试调度
        if gs.tryScheduleGang(ctx, gang) {
            gs.removeFromPending(gang)
        }
    }
}

// tryScheduleGang 尝试调度 Gang
func (gs *GangScheduler) tryScheduleGang(ctx context.Context, gang *Gang) bool {
    // 计算总 GPU 需求
    totalGPUs := int32(0)
    for _, member := range gang.Members {
        totalGPUs += member.GPUs
    }

    // 获取集群可用资源
    nodeGPUs := gs.getAvailableGPUsPerNode()

    // 尝试分配
    allocation := gs.allocateGang(gang, nodeGPUs)
    if allocation == nil {
        // 尝试部分调度
        if gs.config.AllowPartial && gang.MinMembers < len(gang.Members) {
            allocation = gs.allocatePartialGang(gang, nodeGPUs)
        }
    }

    if allocation == nil {
        return false
    }

    // 执行调度
    return gs.executeAllocation(ctx, gang, allocation)
}

// GangAllocation Gang 分配结果
type GangAllocation struct {
    // 成员到节点的映射
    MemberToNode map[int]string
    // 每个节点分配的 GPU
    NodeGPUs map[string]int32
}

// allocateGang 分配 Gang
func (gs *GangScheduler) allocateGang(gang *Gang, nodeGPUs map[string]int32) *GangAllocation {
    allocation := &GangAllocation{
        MemberToNode: make(map[int]string),
        NodeGPUs:     make(map[string]int32),
    }

    // 复制可用资源
    available := make(map[string]int32)
    for node, gpus := range nodeGPUs {
        available[node] = gpus
    }

    // 按 GPU 需求排序成员(大的先分配)
    members := make([]struct {
        idx    int
        member *GangMember
    }, len(gang.Members))
    for i, m := range gang.Members {
        members[i] = struct {
            idx    int
            member *GangMember
        }{i, m}
    }
    sort.Slice(members, func(i, j int) bool {
        return members[i].member.GPUs > members[j].member.GPUs
    })

    // 贪心分配
    for _, m := range members {
        allocated := false
        for node, gpus := range available {
            if gpus >= m.member.GPUs {
                allocation.MemberToNode[m.idx] = node
                allocation.NodeGPUs[node] += m.member.GPUs
                available[node] -= m.member.GPUs
                allocated = true
                break
            }
        }
        if !allocated {
            return nil // 无法完全分配
        }
    }

    return allocation
}

// allocatePartialGang 部分分配 Gang
func (gs *GangScheduler) allocatePartialGang(gang *Gang, nodeGPUs map[string]int32) *GangAllocation {
    allocation := &GangAllocation{
        MemberToNode: make(map[int]string),
        NodeGPUs:     make(map[string]int32),
    }

    available := make(map[string]int32)
    for node, gpus := range nodeGPUs {
        available[node] = gpus
    }

    allocatedCount := 0
    for i, member := range gang.Members {
        for node, gpus := range available {
            if gpus >= member.GPUs {
                allocation.MemberToNode[i] = node
                allocation.NodeGPUs[node] += member.GPUs
                available[node] -= member.GPUs
                allocatedCount++
                break
            }
        }
    }

    // 检查是否满足最小成员数
    if allocatedCount < gang.MinMembers {
        return nil
    }

    return allocation
}

// executeAllocation 执行分配
func (gs *GangScheduler) executeAllocation(ctx context.Context,
    gang *Gang, allocation *GangAllocation) bool {

    // 更新成员状态
    for idx, node := range allocation.MemberToNode {
        gang.Members[idx].NodeName = node
        gang.Members[idx].Scheduled = true
    }

    // 判断是完全调度还是部分调度
    if len(allocation.MemberToNode) == len(gang.Members) {
        gang.Status = GangScheduled
    } else {
        gang.Status = GangPartial
    }

    // 实际创建 Pod
    // ...

    return true
}

// handleTimeout 处理超时
func (gs *GangScheduler) handleTimeout(gang *Gang) {
    gang.Status = GangTimeout
    gs.removeFromPending(gang)
    // 发送事件通知
}

// removeFromPending 从待处理列表移除
func (gs *GangScheduler) removeFromPending(gang *Gang) {
    gs.mu.Lock()
    defer gs.mu.Unlock()
    delete(gs.pendingGangs, gang.Namespace+"/"+gang.Name)
}

// getAvailableGPUsPerNode 获取每个节点的可用 GPU
func (gs *GangScheduler) getAvailableGPUsPerNode() map[string]int32 {
    result := make(map[string]int32)
    // 从 GPU 状态管理器获取
    return result
}

6. 配额管理

6.1 资源配额

# gpu-quota.yaml
apiVersion: v1
kind: ResourceQuota
metadata:
  name: gpu-quota
  namespace: ml-team-a
spec:
  hard:
    requests.nvidia.com/gpu: "16"
    limits.nvidia.com/gpu: "16"
    persistentvolumeclaims: "10"
    requests.storage: "1Ti"
---
# 自定义 GPU 配额 CRD
apiVersion: gpu.elastic.io/v1alpha1
kind: GPUQuota
metadata:
  name: ml-team-a-quota
  namespace: ml-team-a
spec:
  # 硬限制
  hard:
    # 总 GPU 配额
    gpus: 16
    # A100 配额
    gpus.nvidia.com/a100: 8
    # V100 配额
    gpus.nvidia.com/v100: 8
    # MIG 配额
    mig-1g.5gb: 4
    mig-3g.40gb: 2
  # 借用限制
  borrowing:
    enabled: true
    maxBorrow: 8
    # 可借用的命名空间
    borrowFrom:
      - ml-team-b
      - ml-shared
  # 优先级配额
  priorityQuotas:
    production:
      gpus: 8
    high-priority:
      gpus: 4
    normal:
      gpus: 4
  # 抢占配置
  preemption:
    # 可被借用资源抢占
    allowBorrowedPreemption: true
    # 保护的最小资源
    protectedGPUs: 4

6.2 配额控制器

// pkg/elastic/quota_controller.go
package elastic

import (
    "context"
    "fmt"
    "sync"
)

// QuotaController 配额控制器
type QuotaController struct {
    mu sync.RWMutex
    // 命名空间配额
    quotas map[string]*GPUQuota
    // 当前使用量
    usage map[string]*GPUUsage
}

// GPUQuota GPU 配额
type GPUQuota struct {
    Namespace string
    // 硬限制
    Hard QuotaResources
    // 借用配置
    Borrowing *BorrowingConfig
    // 优先级配额
    PriorityQuotas map[PriorityClass]QuotaResources
    // 抢占配置
    Preemption *PreemptionConfig
}

// QuotaResources 配额资源
type QuotaResources struct {
    // 总 GPU
    GPUs int32
    // 按型号的 GPU
    GPUsByModel map[string]int32
    // MIG 配额
    MIG map[string]int32
}

// BorrowingConfig 借用配置
type BorrowingConfig struct {
    Enabled    bool
    MaxBorrow  int32
    BorrowFrom []string
}

// PreemptionConfig 抢占配置
type PreemptionConfig struct {
    AllowBorrowedPreemption bool
    ProtectedGPUs           int32
}

// GPUUsage GPU 使用量
type GPUUsage struct {
    Namespace string
    // 已使用
    Used QuotaResources
    // 已借出
    Lent int32
    // 已借入
    Borrowed int32
    // 按优先级使用量
    UsedByPriority map[PriorityClass]int32
}

// NewQuotaController 创建配额控制器
func NewQuotaController() *QuotaController {
    return &QuotaController{
        quotas: make(map[string]*GPUQuota),
        usage:  make(map[string]*GPUUsage),
    }
}

// SetQuota 设置配额
func (qc *QuotaController) SetQuota(quota *GPUQuota) {
    qc.mu.Lock()
    defer qc.mu.Unlock()
    qc.quotas[quota.Namespace] = quota
}

// CheckQuota 检查配额
func (qc *QuotaController) CheckQuota(namespace string, request QuotaRequest) error {
    qc.mu.RLock()
    defer qc.mu.RUnlock()

    quota, ok := qc.quotas[namespace]
    if !ok {
        return nil // 无配额限制
    }

    usage := qc.usage[namespace]
    if usage == nil {
        usage = &GPUUsage{Namespace: namespace}
    }

    // 检查总 GPU 配额
    available := quota.Hard.GPUs - usage.Used.GPUs
    if request.GPUs > available {
        // 尝试借用
        if quota.Borrowing != nil && quota.Borrowing.Enabled {
            borrowed, err := qc.tryBorrow(namespace, request.GPUs-available)
            if err != nil {
                return fmt.Errorf("quota exceeded and cannot borrow: %v", err)
            }
            // 标记借用
            usage.Borrowed += borrowed
        } else {
            return fmt.Errorf("GPU quota exceeded: requested %d, available %d",
                request.GPUs, available)
        }
    }

    // 检查型号配额
    if request.GPUModel != "" {
        modelQuota, hasModelQuota := quota.Hard.GPUsByModel[request.GPUModel]
        if hasModelQuota {
            modelUsed := usage.Used.GPUsByModel[request.GPUModel]
            if request.GPUs > modelQuota-modelUsed {
                return fmt.Errorf("GPU model %s quota exceeded", request.GPUModel)
            }
        }
    }

    // 检查优先级配额
    if request.Priority != "" {
        priorityQuota, hasPriorityQuota := quota.PriorityQuotas[request.Priority]
        if hasPriorityQuota {
            priorityUsed := usage.UsedByPriority[request.Priority]
            if request.GPUs > priorityQuota.GPUs-priorityUsed {
                return fmt.Errorf("priority %s GPU quota exceeded", request.Priority)
            }
        }
    }

    return nil
}

// QuotaRequest 配额请求
type QuotaRequest struct {
    GPUs      int32
    GPUModel  string
    MIGProfile string
    Priority  PriorityClass
}

// AllocateQuota 分配配额
func (qc *QuotaController) AllocateQuota(namespace string, request QuotaRequest) error {
    if err := qc.CheckQuota(namespace, request); err != nil {
        return err
    }

    qc.mu.Lock()
    defer qc.mu.Unlock()

    usage := qc.usage[namespace]
    if usage == nil {
        usage = &GPUUsage{
            Namespace:      namespace,
            Used:           QuotaResources{GPUsByModel: make(map[string]int32)},
            UsedByPriority: make(map[PriorityClass]int32),
        }
        qc.usage[namespace] = usage
    }

    // 更新使用量
    usage.Used.GPUs += request.GPUs
    if request.GPUModel != "" {
        usage.Used.GPUsByModel[request.GPUModel] += request.GPUs
    }
    if request.Priority != "" {
        usage.UsedByPriority[request.Priority] += request.GPUs
    }

    return nil
}

// ReleaseQuota 释放配额
func (qc *QuotaController) ReleaseQuota(namespace string, request QuotaRequest) {
    qc.mu.Lock()
    defer qc.mu.Unlock()

    usage := qc.usage[namespace]
    if usage == nil {
        return
    }

    usage.Used.GPUs -= request.GPUs
    if request.GPUModel != "" {
        usage.Used.GPUsByModel[request.GPUModel] -= request.GPUs
    }
    if request.Priority != "" {
        usage.UsedByPriority[request.Priority] -= request.GPUs
    }

    // 释放借用的资源
    if usage.Borrowed > 0 {
        released := min32(usage.Borrowed, request.GPUs)
        usage.Borrowed -= released
        qc.returnBorrowed(namespace, released)
    }
}

// tryBorrow 尝试借用资源
func (qc *QuotaController) tryBorrow(namespace string, needed int32) (int32, error) {
    quota := qc.quotas[namespace]
    if quota.Borrowing == nil || !quota.Borrowing.Enabled {
        return 0, fmt.Errorf("borrowing not enabled")
    }

    usage := qc.usage[namespace]
    currentBorrowed := int32(0)
    if usage != nil {
        currentBorrowed = usage.Borrowed
    }

    // 检查借用限制
    if currentBorrowed+needed > quota.Borrowing.MaxBorrow {
        return 0, fmt.Errorf("would exceed max borrow limit")
    }

    // 从配置的命名空间借用
    var borrowed int32
    for _, lenderNS := range quota.Borrowing.BorrowFrom {
        lenderQuota := qc.quotas[lenderNS]
        if lenderQuota == nil {
            continue
        }

        lenderUsage := qc.usage[lenderNS]
        if lenderUsage == nil {
            lenderUsage = &GPUUsage{Namespace: lenderNS}
        }

        available := lenderQuota.Hard.GPUs - lenderUsage.Used.GPUs - lenderUsage.Lent
        if available > 0 {
            toBorrow := min32(needed-borrowed, available)
            lenderUsage.Lent += toBorrow
            borrowed += toBorrow

            if borrowed >= needed {
                break
            }
        }
    }

    if borrowed < needed {
        return 0, fmt.Errorf("cannot borrow enough GPUs")
    }

    return borrowed, nil
}

// returnBorrowed 归还借用的资源
func (qc *QuotaController) returnBorrowed(namespace string, amount int32) {
    quota := qc.quotas[namespace]
    if quota.Borrowing == nil {
        return
    }

    // 归还给借出者
    remaining := amount
    for _, lenderNS := range quota.Borrowing.BorrowFrom {
        lenderUsage := qc.usage[lenderNS]
        if lenderUsage == nil || lenderUsage.Lent == 0 {
            continue
        }

        toReturn := min32(remaining, lenderUsage.Lent)
        lenderUsage.Lent -= toReturn
        remaining -= toReturn

        if remaining == 0 {
            break
        }
    }
}

// GetQuotaStatus 获取配额状态
func (qc *QuotaController) GetQuotaStatus(namespace string) *QuotaStatus {
    qc.mu.RLock()
    defer qc.mu.RUnlock()

    quota := qc.quotas[namespace]
    usage := qc.usage[namespace]

    if quota == nil {
        return nil
    }

    if usage == nil {
        usage = &GPUUsage{Namespace: namespace}
    }

    return &QuotaStatus{
        Namespace: namespace,
        Hard:      quota.Hard,
        Used:      usage.Used,
        Borrowed:  usage.Borrowed,
        Lent:      usage.Lent,
    }
}

// QuotaStatus 配额状态
type QuotaStatus struct {
    Namespace string
    Hard      QuotaResources
    Used      QuotaResources
    Borrowed  int32
    Lent      int32
}

func min32(a, b int32) int32 {
    if a < b {
        return a
    }
    return b
}

7. 最佳实践

7.1 弹性调度策略选择

场景推荐策略优先级说明
生产推理不可抢占Production保证 SLA
大规模训练可抢占 + 检查点High支持恢复
实验开发弹性伸缩Normal资源共享
批量任务尽力而为BestEffort填充空闲
临时任务竞价实例Spot成本优化

7.2 配额规划建议

# 团队配额规划示例
teams:
  - name: ml-platform
    # 平台团队:保证资源 + 借用能力
    quota:
      hard: 32
      borrowable: 16
      priority: production

  - name: research
    # 研究团队:弹性资源
    quota:
      hard: 16
      minGuaranteed: 8
      priority: high-priority

  - name: experiments
    # 实验团队:共享池
    quota:
      hard: 8
      priority: normal
      preemptible: true

7.3 监控指标

# prometheus-rules.yaml
groups:
  - name: elastic-gpu
    rules:
      # 抢占率
      - record: gpu_preemption_rate
        expr: |
          sum(rate(elastic_gpu_preemptions_total[1h])) /
          sum(rate(elastic_gpu_allocations_total[1h]))

      # 配额利用率
      - record: gpu_quota_utilization
        expr: |
          sum(elastic_gpu_used) by (namespace) /
          sum(elastic_gpu_quota) by (namespace)

      # 等待时间
      - record: gpu_scheduling_wait_time_p99
        expr: |
          histogram_quantile(0.99,
            sum(rate(elastic_gpu_wait_seconds_bucket[5m])) by (le, queue))

      # 弹性伸缩事件
      - alert: HighScaleUpRate
        expr: sum(rate(elastic_gpu_scale_up_total[1h])) > 10
        for: 30m
        labels:
          severity: warning

总结

本章深入讲解了弹性 GPU 调度的核心技术:

  1. 优先级调度:多级优先级队列和公平调度算法
  2. 抢占式调度:多种抢占策略和恢复机制
  3. 弹性伸缩:基于度量的自动扩缩容
  4. Gang Scheduling:分布式训练的整组调度
  5. 配额管理:多租户资源隔离和借用机制

弹性调度是实现 GPU 资源高效利用的关键,通过合理的优先级和配额设计,可以在保证关键业务 SLA 的同时,最大化集群资源利用率。

下一章我们将进入 AI 训练平台 部分,讲解分布式训练框架的设计与实现。

Prev
拓扑感知调度