HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

训练任务调度

概述

大规模 AI 训练任务的调度是一个复杂的系统工程问题。不同于传统的微服务调度,训练任务具有长时运行、资源密集、Gang 调度、弹性伸缩等特殊需求。本章深入讲解训练任务调度系统的设计与实现,包括调度策略、队列管理、公平共享、抢占恢复等核心机制。

1. 训练任务调度架构

1.1 整体架构

┌──────────────────────────────────────────────────────────────────────────────┐
│                      Training Job Scheduling Architecture                     │
├──────────────────────────────────────────────────────────────────────────────┤
│                                                                               │
│  ┌─────────────────────────────────────────────────────────────────────────┐ │
│  │                           Job Submission Layer                          │ │
│  │                                                                         │ │
│  │  ┌───────────────┐  ┌───────────────┐  ┌───────────────┐               │ │
│  │  │   CLI/SDK     │  │   Web UI      │  │  Jupyter Hub  │               │ │
│  │  └───────┬───────┘  └───────┬───────┘  └───────┬───────┘               │ │
│  └──────────┼──────────────────┼──────────────────┼─────────────────────────┘ │
│             │                  │                  │                           │
│             ▼                  ▼                  ▼                           │
│  ┌─────────────────────────────────────────────────────────────────────────┐ │
│  │                         Job Controller Layer                            │ │
│  │                                                                         │ │
│  │  ┌─────────────────────────────────────────────────────────────────┐   │ │
│  │  │                    Training Job Controller                       │   │ │
│  │  │                                                                  │   │ │
│  │  │  ┌──────────────┐ ┌──────────────┐ ┌──────────────┐             │   │ │
│  │  │  │ PyTorchJob   │ │   TFJob      │ │   MPIJob     │             │   │ │
│  │  │  │ Controller   │ │ Controller   │ │ Controller   │             │   │ │
│  │  │  └──────────────┘ └──────────────┘ └──────────────┘             │   │ │
│  │  └─────────────────────────────────────────────────────────────────┘   │ │
│  └─────────────────────────────────────────────────────────────────────────┘ │
│                                      │                                        │
│                                      ▼                                        │
│  ┌─────────────────────────────────────────────────────────────────────────┐ │
│  │                         Scheduling Layer                                │ │
│  │                                                                         │ │
│  │  ┌────────────────┐  ┌────────────────┐  ┌────────────────┐            │ │
│  │  │  Queue Manager │  │ Gang Scheduler │  │Priority Manager│            │ │
│  │  └────────┬───────┘  └────────┬───────┘  └────────┬───────┘            │ │
│  │           │                   │                   │                     │ │
│  │           ▼                   ▼                   ▼                     │ │
│  │  ┌─────────────────────────────────────────────────────────────────┐   │ │
│  │  │                    Volcano/Kueue Scheduler                       │   │ │
│  │  │                                                                  │   │ │
│  │  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐            │   │ │
│  │  │  │ Fair     │ │ Capacity │ │ Binpack  │ │ Topology │            │   │ │
│  │  │  │ Share    │ │ Planning │ │ Strategy │ │ Aware    │            │   │ │
│  │  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘            │   │ │
│  │  └─────────────────────────────────────────────────────────────────┘   │ │
│  └─────────────────────────────────────────────────────────────────────────┘ │
│                                      │                                        │
│                                      ▼                                        │
│  ┌─────────────────────────────────────────────────────────────────────────┐ │
│  │                         Resource Layer                                  │ │
│  │                                                                         │ │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐    │ │
│  │  │  GPU Pool   │  │  CPU Pool   │  │ Memory Pool │  │ Storage Pool│    │ │
│  │  │             │  │             │  │             │  │             │    │ │
│  │  │ A100 x 64   │  │ 1024 cores  │  │   8 TiB     │  │  100 TiB    │    │ │
│  │  │ V100 x 32   │  │             │  │             │  │             │    │ │
│  │  └─────────────┘  └─────────────┘  └─────────────┘  └─────────────┘    │ │
│  └─────────────────────────────────────────────────────────────────────────┘ │
│                                                                               │
└──────────────────────────────────────────────────────────────────────────────┘

1.2 核心数据结构

// pkg/scheduler/types.go
package scheduler

import (
    "time"

    corev1 "k8s.io/api/core/v1"
    "k8s.io/apimachinery/pkg/api/resource"
)

// TrainingJob 训练任务定义
type TrainingJob struct {
    // 基本信息
    Name      string            `json:"name"`
    Namespace string            `json:"namespace"`
    UID       string            `json:"uid"`
    Labels    map[string]string `json:"labels"`

    // 任务类型
    JobType JobType `json:"jobType"`

    // 资源需求
    Resources ResourceRequirements `json:"resources"`

    // 调度配置
    SchedulingConfig SchedulingConfig `json:"schedulingConfig"`

    // 弹性配置
    ElasticConfig *ElasticConfig `json:"elasticConfig,omitempty"`

    // 状态
    Status TrainingJobStatus `json:"status"`

    // 时间戳
    CreationTime time.Time  `json:"creationTime"`
    StartTime    *time.Time `json:"startTime,omitempty"`
    EndTime      *time.Time `json:"endTime,omitempty"`
}

// JobType 任务类型
type JobType string

const (
    JobTypePyTorch    JobType = "PyTorchJob"
    JobTypeTensorFlow JobType = "TFJob"
    JobTypeMPI        JobType = "MPIJob"
    JobTypeXGBoost    JobType = "XGBoostJob"
    JobTypePaddlePaddle JobType = "PaddleJob"
)

// ResourceRequirements 资源需求
type ResourceRequirements struct {
    // 任务级别需求
    MinReplicas int32 `json:"minReplicas"`
    MaxReplicas int32 `json:"maxReplicas"`

    // 每个 replica 的需求
    ReplicaResources ReplicaResources `json:"replicaResources"`

    // GPU 类型要求
    GPUType string `json:"gpuType,omitempty"`

    // 拓扑约束
    TopologyConstraint string `json:"topologyConstraint,omitempty"`
}

// ReplicaResources 单个副本资源
type ReplicaResources struct {
    GPU    int32              `json:"gpu"`
    CPU    resource.Quantity  `json:"cpu"`
    Memory resource.Quantity  `json:"memory"`
}

// SchedulingConfig 调度配置
type SchedulingConfig struct {
    // 队列名
    Queue string `json:"queue"`

    // 优先级
    Priority int32 `json:"priority"`

    // 是否需要 Gang 调度
    GangScheduling bool `json:"gangScheduling"`

    // 最大等待时间
    MaxWaitTime time.Duration `json:"maxWaitTime"`

    // 抢占策略
    PreemptionPolicy PreemptionPolicy `json:"preemptionPolicy"`

    // 调度器名称
    SchedulerName string `json:"schedulerName"`
}

// PreemptionPolicy 抢占策略
type PreemptionPolicy string

const (
    PreemptionPolicyNever         PreemptionPolicy = "Never"
    PreemptionPolicyPreemptLower  PreemptionPolicy = "PreemptLower"
    PreemptionPolicyPreemptAlways PreemptionPolicy = "PreemptAlways"
)

// ElasticConfig 弹性配置
type ElasticConfig struct {
    // 是否启用弹性
    Enabled bool `json:"enabled"`

    // 最小 worker 数
    MinWorkers int32 `json:"minWorkers"`

    // 最大 worker 数
    MaxWorkers int32 `json:"maxWorkers"`

    // 伸缩策略
    ScalePolicy ScalePolicy `json:"scalePolicy"`
}

// ScalePolicy 伸缩策略
type ScalePolicy struct {
    // 扩容冷却时间
    ScaleUpCooldown time.Duration `json:"scaleUpCooldown"`

    // 缩容冷却时间
    ScaleDownCooldown time.Duration `json:"scaleDownCooldown"`

    // 扩容步长
    ScaleUpStep int32 `json:"scaleUpStep"`

    // 缩容步长
    ScaleDownStep int32 `json:"scaleDownStep"`
}

// TrainingJobStatus 任务状态
type TrainingJobStatus struct {
    // 当前阶段
    Phase JobPhase `json:"phase"`

    // 分配的资源
    AllocatedResources AllocatedResources `json:"allocatedResources"`

    // 副本状态
    ReplicaStatuses map[string]ReplicaStatus `json:"replicaStatuses"`

    // 条件
    Conditions []JobCondition `json:"conditions"`

    // 历史事件
    Events []JobEvent `json:"events"`
}

// JobPhase 任务阶段
type JobPhase string

const (
    JobPhasePending   JobPhase = "Pending"
    JobPhaseQueued    JobPhase = "Queued"
    JobPhaseRunning   JobPhase = "Running"
    JobPhaseSucceeded JobPhase = "Succeeded"
    JobPhaseFailed    JobPhase = "Failed"
    JobPhaseSuspended JobPhase = "Suspended"
)

// AllocatedResources 已分配资源
type AllocatedResources struct {
    GPUs     int32             `json:"gpus"`
    CPU      resource.Quantity `json:"cpu"`
    Memory   resource.Quantity `json:"memory"`
    Nodes    []string          `json:"nodes"`
}

// ReplicaStatus 副本状态
type ReplicaStatus struct {
    Running   int32 `json:"running"`
    Succeeded int32 `json:"succeeded"`
    Failed    int32 `json:"failed"`
    Pending   int32 `json:"pending"`
}

// JobCondition 任务条件
type JobCondition struct {
    Type               string             `json:"type"`
    Status             corev1.ConditionStatus `json:"status"`
    LastTransitionTime time.Time          `json:"lastTransitionTime"`
    Reason             string             `json:"reason"`
    Message            string             `json:"message"`
}

// JobEvent 任务事件
type JobEvent struct {
    Type      string    `json:"type"`
    Reason    string    `json:"reason"`
    Message   string    `json:"message"`
    Timestamp time.Time `json:"timestamp"`
}

2. 队列管理系统

2.1 多级队列实现

// pkg/scheduler/queue/manager.go
package queue

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// QueueSpec 队列规格
type QueueSpec struct {
    Name string `json:"name"`

    // 父队列(用于层级队列)
    Parent string `json:"parent,omitempty"`

    // 权重(用于公平调度)
    Weight int32 `json:"weight"`

    // 资源配额
    Quota ResourceQuota `json:"quota"`

    // 借用配置
    Borrowing *BorrowingSpec `json:"borrowing,omitempty"`

    // 抢占配置
    Preemption *PreemptionSpec `json:"preemption,omitempty"`

    // 准入控制
    AdmissionControl *AdmissionSpec `json:"admissionControl,omitempty"`
}

// ResourceQuota 资源配额
type ResourceQuota struct {
    // 硬限制
    Hard ResourceList `json:"hard"`

    // 软限制(可以借用超过)
    Soft ResourceList `json:"soft,omitempty"`

    // 保证资源(不可被借用)
    Guaranteed ResourceList `json:"guaranteed,omitempty"`
}

// ResourceList 资源列表
type ResourceList struct {
    GPU    int32  `json:"gpu"`
    CPU    string `json:"cpu"`
    Memory string `json:"memory"`
}

// BorrowingSpec 借用配置
type BorrowingSpec struct {
    // 是否允许借用
    Enabled bool `json:"enabled"`

    // 可借用的队列
    AllowedQueues []string `json:"allowedQueues,omitempty"`

    // 最大借用比例
    MaxBorrowRatio float64 `json:"maxBorrowRatio"`
}

// PreemptionSpec 抢占配置
type PreemptionSpec struct {
    // 是否允许被抢占
    AllowPreemption bool `json:"allowPreemption"`

    // 允许抢占的队列
    AllowPreemptFrom []string `json:"allowPreemptFrom,omitempty"`

    // 抢占优雅期
    GracePeriod time.Duration `json:"gracePeriod"`
}

// AdmissionSpec 准入控制配置
type AdmissionSpec struct {
    // 最大作业数
    MaxJobs int32 `json:"maxJobs,omitempty"`

    // 最大 GPU 数
    MaxGPUsPerJob int32 `json:"maxGPUsPerJob,omitempty"`

    // 允许的优先级范围
    PriorityRange *PriorityRange `json:"priorityRange,omitempty"`
}

// PriorityRange 优先级范围
type PriorityRange struct {
    Min int32 `json:"min"`
    Max int32 `json:"max"`
}

// Queue 队列实例
type Queue struct {
    Spec   QueueSpec
    Status QueueStatus

    // 子队列
    children []*Queue
    parent   *Queue

    // 作业列表
    jobs     []*TrainingJob
    jobIndex map[string]int

    mu sync.RWMutex
}

// QueueStatus 队列状态
type QueueStatus struct {
    // 已使用资源
    Used ResourceList `json:"used"`

    // 借入资源
    Borrowed ResourceList `json:"borrowed"`

    // 借出资源
    Lent ResourceList `json:"lent"`

    // 排队作业数
    PendingJobs int32 `json:"pendingJobs"`

    // 运行作业数
    RunningJobs int32 `json:"runningJobs"`
}

// QueueManager 队列管理器
type QueueManager struct {
    queues     map[string]*Queue
    rootQueues []*Queue
    mu         sync.RWMutex
}

// NewQueueManager 创建队列管理器
func NewQueueManager() *QueueManager {
    return &QueueManager{
        queues:     make(map[string]*Queue),
        rootQueues: make([]*Queue, 0),
    }
}

// CreateQueue 创建队列
func (m *QueueManager) CreateQueue(spec QueueSpec) error {
    m.mu.Lock()
    defer m.mu.Unlock()

    if _, exists := m.queues[spec.Name]; exists {
        return fmt.Errorf("queue %s already exists", spec.Name)
    }

    queue := &Queue{
        Spec:     spec,
        Status:   QueueStatus{},
        children: make([]*Queue, 0),
        jobs:     make([]*TrainingJob, 0),
        jobIndex: make(map[string]int),
    }

    // 处理层级关系
    if spec.Parent != "" {
        parent, ok := m.queues[spec.Parent]
        if !ok {
            return fmt.Errorf("parent queue %s not found", spec.Parent)
        }
        queue.parent = parent
        parent.children = append(parent.children, queue)
    } else {
        m.rootQueues = append(m.rootQueues, queue)
    }

    m.queues[spec.Name] = queue
    return nil
}

// EnqueueJob 将作业加入队列
func (m *QueueManager) EnqueueJob(job *TrainingJob) error {
    m.mu.Lock()
    defer m.mu.Unlock()

    queueName := job.SchedulingConfig.Queue
    queue, ok := m.queues[queueName]
    if !ok {
        return fmt.Errorf("queue %s not found", queueName)
    }

    // 检查准入控制
    if err := m.checkAdmission(queue, job); err != nil {
        return fmt.Errorf("admission check failed: %v", err)
    }

    // 加入队列
    queue.mu.Lock()
    defer queue.mu.Unlock()

    jobKey := job.Namespace + "/" + job.Name
    if _, exists := queue.jobIndex[jobKey]; exists {
        return fmt.Errorf("job %s already in queue", jobKey)
    }

    queue.jobs = append(queue.jobs, job)
    queue.jobIndex[jobKey] = len(queue.jobs) - 1
    queue.Status.PendingJobs++

    // 按优先级排序
    m.sortJobsByPriority(queue)

    return nil
}

// DequeueJob 从队列取出作业
func (m *QueueManager) DequeueJob(queueName string) *TrainingJob {
    m.mu.Lock()
    defer m.mu.Unlock()

    queue, ok := m.queues[queueName]
    if !ok {
        return nil
    }

    queue.mu.Lock()
    defer queue.mu.Unlock()

    if len(queue.jobs) == 0 {
        return nil
    }

    // 取出最高优先级的作业
    job := queue.jobs[0]
    queue.jobs = queue.jobs[1:]

    // 更新索引
    delete(queue.jobIndex, job.Namespace+"/"+job.Name)
    for i, j := range queue.jobs {
        queue.jobIndex[j.Namespace+"/"+j.Name] = i
    }

    queue.Status.PendingJobs--
    return job
}

// checkAdmission 检查准入控制
func (m *QueueManager) checkAdmission(queue *Queue, job *TrainingJob) error {
    if queue.Spec.AdmissionControl == nil {
        return nil
    }

    ac := queue.Spec.AdmissionControl

    // 检查最大作业数
    if ac.MaxJobs > 0 {
        totalJobs := queue.Status.PendingJobs + queue.Status.RunningJobs
        if totalJobs >= ac.MaxJobs {
            return fmt.Errorf("queue job limit reached: %d", ac.MaxJobs)
        }
    }

    // 检查单作业 GPU 限制
    if ac.MaxGPUsPerJob > 0 {
        requestedGPUs := job.Resources.ReplicaResources.GPU * job.Resources.MaxReplicas
        if requestedGPUs > ac.MaxGPUsPerJob {
            return fmt.Errorf("job GPU request %d exceeds limit %d",
                requestedGPUs, ac.MaxGPUsPerJob)
        }
    }

    // 检查优先级范围
    if ac.PriorityRange != nil {
        if job.SchedulingConfig.Priority < ac.PriorityRange.Min ||
            job.SchedulingConfig.Priority > ac.PriorityRange.Max {
            return fmt.Errorf("job priority %d out of allowed range [%d, %d]",
                job.SchedulingConfig.Priority,
                ac.PriorityRange.Min, ac.PriorityRange.Max)
        }
    }

    return nil
}

// sortJobsByPriority 按优先级排序
func (m *QueueManager) sortJobsByPriority(queue *Queue) {
    // 使用稳定排序保持相同优先级的 FIFO 顺序
    sort.SliceStable(queue.jobs, func(i, j int) bool {
        return queue.jobs[i].SchedulingConfig.Priority >
            queue.jobs[j].SchedulingConfig.Priority
    })

    // 更新索引
    for i, job := range queue.jobs {
        queue.jobIndex[job.Namespace+"/"+job.Name] = i
    }
}

// SelectNextJob 选择下一个要调度的作业(公平调度)
func (m *QueueManager) SelectNextJob(availableResources ResourceList) *TrainingJob {
    m.mu.RLock()
    defer m.mu.RUnlock()

    // 使用加权公平队列算法
    var selectedJob *TrainingJob
    var selectedQueue *Queue
    bestScore := float64(-1)

    for _, queue := range m.rootQueues {
        job, score := m.selectFromQueueTree(queue, availableResources)
        if job != nil && score > bestScore {
            bestScore = score
            selectedJob = job
            selectedQueue = queue
        }
    }

    if selectedJob != nil {
        // 从队列中移除
        m.removeJobFromQueue(selectedQueue, selectedJob)
    }

    return selectedJob
}

// selectFromQueueTree 从队列树中选择作业
func (m *QueueManager) selectFromQueueTree(queue *Queue, available ResourceList) (*TrainingJob, float64) {
    queue.mu.RLock()
    defer queue.mu.RUnlock()

    // 递归检查子队列
    var bestJob *TrainingJob
    bestScore := float64(-1)

    for _, child := range queue.children {
        job, score := m.selectFromQueueTree(child, available)
        if job != nil && score > bestScore {
            bestScore = score
            bestJob = job
        }
    }

    // 检查当前队列的作业
    for _, job := range queue.jobs {
        if m.canSchedule(job, available) {
            score := m.calculateFairScore(queue, job)
            if score > bestScore {
                bestScore = score
                bestJob = job
            }
        }
    }

    return bestJob, bestScore
}

// calculateFairScore 计算公平调度分数
func (m *QueueManager) calculateFairScore(queue *Queue, job *TrainingJob) float64 {
    // 公平份额 = 权重 / (已使用 + 1)
    // 这样使用越少的队列分数越高
    usedGPUs := float64(queue.Status.Used.GPU)
    weight := float64(queue.Spec.Weight)

    fairShare := weight / (usedGPUs + 1)

    // 优先级加成
    priorityBonus := float64(job.SchedulingConfig.Priority) / 1000000.0

    // 等待时间加成(防止饥饿)
    waitTime := time.Since(job.CreationTime).Minutes()
    waitBonus := waitTime / 60.0 // 每小时加 1 分

    return fairShare + priorityBonus + waitBonus
}

// canSchedule 检查是否可以调度
func (m *QueueManager) canSchedule(job *TrainingJob, available ResourceList) bool {
    requiredGPUs := job.Resources.ReplicaResources.GPU * job.Resources.MinReplicas
    return available.GPU >= requiredGPUs
}

// removeJobFromQueue 从队列移除作业
func (m *QueueManager) removeJobFromQueue(queue *Queue, job *TrainingJob) {
    queue.mu.Lock()
    defer queue.mu.Unlock()

    jobKey := job.Namespace + "/" + job.Name
    idx, ok := queue.jobIndex[jobKey]
    if !ok {
        return
    }

    queue.jobs = append(queue.jobs[:idx], queue.jobs[idx+1:]...)
    delete(queue.jobIndex, jobKey)

    // 更新索引
    for i := idx; i < len(queue.jobs); i++ {
        queue.jobIndex[queue.jobs[i].Namespace+"/"+queue.jobs[i].Name] = i
    }

    queue.Status.PendingJobs--
}

2.2 Kueue 集成

# kueue-resources.yaml
apiVersion: kueue.x-k8s.io/v1beta1
kind: ResourceFlavor
metadata:
  name: gpu-a100
spec:
  nodeLabels:
    nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"
  tolerations:
    - key: "nvidia.com/gpu"
      operator: "Exists"
      effect: "NoSchedule"
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: ResourceFlavor
metadata:
  name: gpu-v100
spec:
  nodeLabels:
    nvidia.com/gpu.product: "Tesla-V100-SXM2-32GB"
  tolerations:
    - key: "nvidia.com/gpu"
      operator: "Exists"
      effect: "NoSchedule"
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: ClusterQueue
metadata:
  name: training-cluster-queue
spec:
  namespaceSelector: {}
  resourceGroups:
    - coveredResources: ["cpu", "memory", "nvidia.com/gpu"]
      flavors:
        - name: gpu-a100
          resources:
            - name: "cpu"
              nominalQuota: 256
            - name: "memory"
              nominalQuota: 2Ti
            - name: "nvidia.com/gpu"
              nominalQuota: 64
              borrowingLimit: 16
        - name: gpu-v100
          resources:
            - name: "cpu"
              nominalQuota: 128
            - name: "memory"
              nominalQuota: 1Ti
            - name: "nvidia.com/gpu"
              nominalQuota: 32
  cohort: ai-platform
  preemption:
    reclaimWithinCohort: Any
    withinClusterQueue: LowerPriority
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: LocalQueue
metadata:
  name: ml-research-queue
  namespace: ml-research
spec:
  clusterQueue: training-cluster-queue
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: LocalQueue
metadata:
  name: ml-production-queue
  namespace: ml-production
spec:
  clusterQueue: training-cluster-queue

3. Gang 调度实现

3.1 Gang 调度器

// pkg/scheduler/gang/scheduler.go
package gang

import (
    "context"
    "fmt"
    "sync"
    "time"

    corev1 "k8s.io/api/core/v1"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

// GangScheduler Gang 调度器
type GangScheduler struct {
    mu sync.RWMutex

    // 待调度的 Gang
    pendingGangs map[string]*PodGroup

    // 资源管理
    resourceManager *ResourceManager

    // 配置
    config *GangConfig
}

// PodGroup Pod 组(Gang)
type PodGroup struct {
    // 基本信息
    Name      string
    Namespace string
    UID       string

    // 成员 Pod
    Members []*PodMember

    // 最小成员数(用于弹性 Gang)
    MinMember int32

    // 调度策略
    SchedulePolicy SchedulePolicy

    // 创建时间
    CreationTime time.Time

    // 调度超时
    ScheduleTimeout time.Duration

    // 状态
    Status PodGroupStatus
}

// PodMember Gang 成员
type PodMember struct {
    Name      string
    Namespace string
    Pod       *corev1.Pod
    Resources corev1.ResourceRequirements
    NodeName  string // 调度后填充
    Status    MemberStatus
}

// MemberStatus 成员状态
type MemberStatus string

const (
    MemberPending   MemberStatus = "Pending"
    MemberScheduled MemberStatus = "Scheduled"
    MemberRunning   MemberStatus = "Running"
    MemberFailed    MemberStatus = "Failed"
)

// SchedulePolicy 调度策略
type SchedulePolicy struct {
    // 是否严格 Gang(必须所有成员同时调度)
    Strict bool

    // 拓扑约束
    TopologyKey string

    // 亲和性
    Affinity *corev1.Affinity

    // 容忍
    Tolerations []corev1.Toleration
}

// PodGroupStatus PodGroup 状态
type PodGroupStatus struct {
    Phase      PodGroupPhase
    Scheduled  int32
    Running    int32
    Succeeded  int32
    Failed     int32
    Conditions []PodGroupCondition
}

// PodGroupPhase PodGroup 阶段
type PodGroupPhase string

const (
    PodGroupPending   PodGroupPhase = "Pending"
    PodGroupScheduled PodGroupPhase = "Scheduled"
    PodGroupRunning   PodGroupPhase = "Running"
    PodGroupSucceeded PodGroupPhase = "Succeeded"
    PodGroupFailed    PodGroupPhase = "Failed"
    PodGroupTimeout   PodGroupPhase = "Timeout"
)

// PodGroupCondition 条件
type PodGroupCondition struct {
    Type               string
    Status             corev1.ConditionStatus
    LastTransitionTime metav1.Time
    Reason             string
    Message            string
}

// GangConfig 配置
type GangConfig struct {
    // 默认超时时间
    DefaultTimeout time.Duration

    // 调度间隔
    ScheduleInterval time.Duration

    // 是否允许部分调度
    AllowPartialSchedule bool
}

// NewGangScheduler 创建 Gang 调度器
func NewGangScheduler(rm *ResourceManager, config *GangConfig) *GangScheduler {
    if config == nil {
        config = &GangConfig{
            DefaultTimeout:       10 * time.Minute,
            ScheduleInterval:     5 * time.Second,
            AllowPartialSchedule: false,
        }
    }

    return &GangScheduler{
        pendingGangs:    make(map[string]*PodGroup),
        resourceManager: rm,
        config:          config,
    }
}

// AddPodGroup 添加 PodGroup
func (s *GangScheduler) AddPodGroup(pg *PodGroup) error {
    s.mu.Lock()
    defer s.mu.Unlock()

    key := pg.Namespace + "/" + pg.Name
    if _, exists := s.pendingGangs[key]; exists {
        return fmt.Errorf("pod group %s already exists", key)
    }

    if pg.ScheduleTimeout == 0 {
        pg.ScheduleTimeout = s.config.DefaultTimeout
    }
    pg.CreationTime = time.Now()
    pg.Status.Phase = PodGroupPending

    s.pendingGangs[key] = pg
    return nil
}

// Run 运行调度循环
func (s *GangScheduler) Run(ctx context.Context) {
    ticker := time.NewTicker(s.config.ScheduleInterval)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-ticker.C:
            s.scheduleAll(ctx)
        }
    }
}

// scheduleAll 调度所有待处理的 Gang
func (s *GangScheduler) scheduleAll(ctx context.Context) {
    s.mu.Lock()
    gangs := make([]*PodGroup, 0, len(s.pendingGangs))
    for _, pg := range s.pendingGangs {
        gangs = append(gangs, pg)
    }
    s.mu.Unlock()

    // 按优先级排序
    sort.Slice(gangs, func(i, j int) bool {
        return gangs[i].CreationTime.Before(gangs[j].CreationTime)
    })

    for _, pg := range gangs {
        // 检查超时
        if time.Since(pg.CreationTime) > pg.ScheduleTimeout {
            s.handleTimeout(pg)
            continue
        }

        // 尝试调度
        if s.trySchedule(ctx, pg) {
            s.markScheduled(pg)
        }
    }
}

// trySchedule 尝试调度 PodGroup
func (s *GangScheduler) trySchedule(ctx context.Context, pg *PodGroup) bool {
    // 收集资源需求
    totalRequirements := s.calculateTotalRequirements(pg)

    // 检查是否有足够资源
    allocation, err := s.resourceManager.TryAllocate(totalRequirements)
    if err != nil {
        return false
    }

    // 分配 Pod 到节点
    if !s.assignPodsToNodes(pg, allocation) {
        s.resourceManager.ReleaseAllocation(allocation)
        return false
    }

    // 提交分配
    return s.commitAllocation(ctx, pg, allocation)
}

// calculateTotalRequirements 计算总资源需求
func (s *GangScheduler) calculateTotalRequirements(pg *PodGroup) *ResourceRequirements {
    total := &ResourceRequirements{
        GPU:    0,
        CPU:    resource.Quantity{},
        Memory: resource.Quantity{},
    }

    for _, member := range pg.Members {
        if member.Status == MemberPending {
            if gpu, ok := member.Resources.Limits["nvidia.com/gpu"]; ok {
                total.GPU += int32(gpu.Value())
            }
            if cpu, ok := member.Resources.Requests[corev1.ResourceCPU]; ok {
                total.CPU.Add(cpu)
            }
            if mem, ok := member.Resources.Requests[corev1.ResourceMemory]; ok {
                total.Memory.Add(mem)
            }
        }
    }

    return total
}

// assignPodsToNodes 分配 Pod 到节点
func (s *GangScheduler) assignPodsToNodes(pg *PodGroup, allocation *Allocation) bool {
    // 获取可用节点
    nodes := allocation.GetAllocatedNodes()

    // 使用拓扑感知分配
    if pg.SchedulePolicy.TopologyKey != "" {
        return s.assignWithTopology(pg, nodes)
    }

    // 简单分配:贪心算法
    nodeIndex := 0
    for _, member := range pg.Members {
        if member.Status != MemberPending {
            continue
        }

        if nodeIndex >= len(nodes) {
            return false
        }

        // 检查节点是否有足够资源
        node := nodes[nodeIndex]
        if s.canFitOnNode(member, node) {
            member.NodeName = node.Name
            nodeIndex++
        } else {
            // 尝试下一个节点
            found := false
            for i := nodeIndex + 1; i < len(nodes); i++ {
                if s.canFitOnNode(member, nodes[i]) {
                    member.NodeName = nodes[i].Name
                    found = true
                    break
                }
            }
            if !found {
                return false
            }
        }
    }

    return true
}

// assignWithTopology 拓扑感知分配
func (s *GangScheduler) assignWithTopology(pg *PodGroup, nodes []*NodeInfo) bool {
    topologyKey := pg.SchedulePolicy.TopologyKey

    // 按拓扑域分组节点
    topologyGroups := make(map[string][]*NodeInfo)
    for _, node := range nodes {
        domain := node.Labels[topologyKey]
        topologyGroups[domain] = append(topologyGroups[domain], node)
    }

    // 尝试在同一拓扑域内分配所有 Pod
    for domain, domainNodes := range topologyGroups {
        if s.tryAssignToDomain(pg, domainNodes) {
            return true
        }
    }

    // 如果单个域不够,尝试跨域分配
    if !pg.SchedulePolicy.Strict {
        return s.assignAcrossDomains(pg, nodes)
    }

    return false
}

// tryAssignToDomain 尝试在单个拓扑域分配
func (s *GangScheduler) tryAssignToDomain(pg *PodGroup, nodes []*NodeInfo) bool {
    // 检查域内资源是否足够
    totalCapacity := s.calculateDomainCapacity(nodes)
    requirements := s.calculateTotalRequirements(pg)

    if totalCapacity.GPU < requirements.GPU {
        return false
    }

    // 分配
    nodeIndex := 0
    for _, member := range pg.Members {
        if member.Status != MemberPending {
            continue
        }

        for nodeIndex < len(nodes) {
            if s.canFitOnNode(member, nodes[nodeIndex]) {
                member.NodeName = nodes[nodeIndex].Name
                break
            }
            nodeIndex++
        }

        if nodeIndex >= len(nodes) {
            // 回滚
            for _, m := range pg.Members {
                m.NodeName = ""
            }
            return false
        }
    }

    return true
}

// commitAllocation 提交分配
func (s *GangScheduler) commitAllocation(ctx context.Context, pg *PodGroup,
    allocation *Allocation) bool {

    // 创建绑定
    for _, member := range pg.Members {
        if member.NodeName == "" {
            continue
        }

        binding := &corev1.Binding{
            ObjectMeta: metav1.ObjectMeta{
                Name:      member.Name,
                Namespace: member.Namespace,
            },
            Target: corev1.ObjectReference{
                Kind: "Node",
                Name: member.NodeName,
            },
        }

        // 执行绑定
        if err := s.bindPod(ctx, binding); err != nil {
            // 绑定失败,回滚
            s.resourceManager.ReleaseAllocation(allocation)
            return false
        }

        member.Status = MemberScheduled
    }

    return true
}

// handleTimeout 处理超时
func (s *GangScheduler) handleTimeout(pg *PodGroup) {
    s.mu.Lock()
    defer s.mu.Unlock()

    pg.Status.Phase = PodGroupTimeout
    pg.Status.Conditions = append(pg.Status.Conditions, PodGroupCondition{
        Type:               "SchedulingTimeout",
        Status:             corev1.ConditionTrue,
        LastTransitionTime: metav1.Now(),
        Reason:             "Timeout",
        Message:            fmt.Sprintf("Failed to schedule within %v", pg.ScheduleTimeout),
    })

    // 从待处理列表移除
    delete(s.pendingGangs, pg.Namespace+"/"+pg.Name)
}

// markScheduled 标记已调度
func (s *GangScheduler) markScheduled(pg *PodGroup) {
    s.mu.Lock()
    defer s.mu.Unlock()

    pg.Status.Phase = PodGroupScheduled
    pg.Status.Scheduled = int32(len(pg.Members))
    pg.Status.Conditions = append(pg.Status.Conditions, PodGroupCondition{
        Type:               "Scheduled",
        Status:             corev1.ConditionTrue,
        LastTransitionTime: metav1.Now(),
        Reason:             "AllMembersScheduled",
        Message:            "All pod group members have been scheduled",
    })

    delete(s.pendingGangs, pg.Namespace+"/"+pg.Name)
}

// Helper functions
func (s *GangScheduler) canFitOnNode(member *PodMember, node *NodeInfo) bool {
    // 检查 GPU
    if gpu, ok := member.Resources.Limits["nvidia.com/gpu"]; ok {
        if node.AvailableGPU < int32(gpu.Value()) {
            return false
        }
    }
    // 检查 CPU 和内存...
    return true
}

func (s *GangScheduler) calculateDomainCapacity(nodes []*NodeInfo) *ResourceCapacity {
    capacity := &ResourceCapacity{}
    for _, node := range nodes {
        capacity.GPU += node.AvailableGPU
    }
    return capacity
}

func (s *GangScheduler) bindPod(ctx context.Context, binding *corev1.Binding) error {
    // 实际调用 Kubernetes API
    return nil
}

3.2 Volcano 集成

# volcano-job.yaml
apiVersion: batch.volcano.sh/v1alpha1
kind: Job
metadata:
  name: distributed-training
  namespace: ml-workloads
spec:
  schedulerName: volcano
  minAvailable: 4  # Gang 调度:至少 4 个 Pod 同时调度
  queue: training-queue
  policies:
    - event: PodEvicted
      action: RestartJob
    - event: PodFailed
      action: RestartJob
  plugins:
    ssh: []
    svc: []
    env: []
  maxRetry: 3
  ttlSecondsAfterFinished: 3600
  tasks:
    - name: master
      replicas: 1
      template:
        spec:
          containers:
            - name: pytorch
              image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
              command:
                - /bin/bash
                - -c
                - |
                  python -m torch.distributed.run \
                    --nproc_per_node=8 \
                    --nnodes=$VC_WORKER_NUM \
                    --node_rank=$VC_TASK_INDEX \
                    --master_addr=$VC_MASTER_ADDR \
                    --master_port=29500 \
                    train.py
              resources:
                limits:
                  nvidia.com/gpu: 8
                requests:
                  nvidia.com/gpu: 8
                  cpu: "32"
                  memory: 256Gi
              volumeMounts:
                - name: shm
                  mountPath: /dev/shm
          volumes:
            - name: shm
              emptyDir:
                medium: Memory
                sizeLimit: 64Gi
          restartPolicy: OnFailure

    - name: worker
      replicas: 3
      template:
        spec:
          containers:
            - name: pytorch
              image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
              command:
                - /bin/bash
                - -c
                - |
                  python -m torch.distributed.run \
                    --nproc_per_node=8 \
                    --nnodes=$VC_WORKER_NUM \
                    --node_rank=$VC_TASK_INDEX \
                    --master_addr=$VC_MASTER_ADDR \
                    --master_port=29500 \
                    train.py
              resources:
                limits:
                  nvidia.com/gpu: 8
                requests:
                  nvidia.com/gpu: 8
                  cpu: "32"
                  memory: 256Gi
              volumeMounts:
                - name: shm
                  mountPath: /dev/shm
          volumes:
            - name: shm
              emptyDir:
                medium: Memory
                sizeLimit: 64Gi
          restartPolicy: OnFailure
---
apiVersion: scheduling.volcano.sh/v1beta1
kind: Queue
metadata:
  name: training-queue
spec:
  weight: 10
  capability:
    cpu: "512"
    memory: "2Ti"
    nvidia.com/gpu: "64"
  reclaimable: true
  guarantee:
    resource:
      cpu: "128"
      memory: "512Gi"
      nvidia.com/gpu: "16"

4. 公平调度算法

4.1 Dominant Resource Fairness (DRF)

// pkg/scheduler/fairshare/drf.go
package fairshare

import (
    "sort"
    "sync"
)

// DRFScheduler Dominant Resource Fairness 调度器
type DRFScheduler struct {
    mu sync.RWMutex

    // 队列
    queues map[string]*DRFQueue

    // 总资源
    totalResources Resources

    // 已分配资源
    allocatedResources Resources
}

// DRFQueue DRF 队列
type DRFQueue struct {
    Name   string
    Weight float64

    // 已分配资源
    Allocated Resources

    // 主导份额(最大资源占用比)
    DominantShare float64

    // 等待的作业
    PendingJobs []*TrainingJob
}

// Resources 资源
type Resources struct {
    GPU    float64
    CPU    float64
    Memory float64
}

// NewDRFScheduler 创建 DRF 调度器
func NewDRFScheduler(totalResources Resources) *DRFScheduler {
    return &DRFScheduler{
        queues:         make(map[string]*DRFQueue),
        totalResources: totalResources,
    }
}

// AddQueue 添加队列
func (s *DRFScheduler) AddQueue(name string, weight float64) {
    s.mu.Lock()
    defer s.mu.Unlock()

    s.queues[name] = &DRFQueue{
        Name:        name,
        Weight:      weight,
        PendingJobs: make([]*TrainingJob, 0),
    }
}

// SubmitJob 提交作业
func (s *DRFScheduler) SubmitJob(queueName string, job *TrainingJob) {
    s.mu.Lock()
    defer s.mu.Unlock()

    queue, ok := s.queues[queueName]
    if !ok {
        return
    }

    queue.PendingJobs = append(queue.PendingJobs, job)
}

// Schedule 调度一个作业
func (s *DRFScheduler) Schedule() *ScheduleResult {
    s.mu.Lock()
    defer s.mu.Unlock()

    // 计算每个队列的主导份额
    for _, queue := range s.queues {
        queue.DominantShare = s.calculateDominantShare(queue)
    }

    // 选择主导份额最小的队列(加权后)
    var selectedQueue *DRFQueue
    minWeightedShare := float64(1e9)

    for _, queue := range s.queues {
        if len(queue.PendingJobs) == 0 {
            continue
        }

        weightedShare := queue.DominantShare / queue.Weight
        if weightedShare < minWeightedShare {
            minWeightedShare = weightedShare
            selectedQueue = queue
        }
    }

    if selectedQueue == nil {
        return nil
    }

    // 从选中的队列取出作业
    job := selectedQueue.PendingJobs[0]

    // 检查资源是否足够
    jobResources := s.getJobResources(job)
    if !s.hasEnoughResources(jobResources) {
        return nil
    }

    // 分配资源
    selectedQueue.PendingJobs = selectedQueue.PendingJobs[1:]
    selectedQueue.Allocated = s.addResources(selectedQueue.Allocated, jobResources)
    s.allocatedResources = s.addResources(s.allocatedResources, jobResources)

    // 更新主导份额
    selectedQueue.DominantShare = s.calculateDominantShare(selectedQueue)

    return &ScheduleResult{
        Job:       job,
        Queue:     selectedQueue.Name,
        Resources: jobResources,
    }
}

// calculateDominantShare 计算主导份额
func (s *DRFScheduler) calculateDominantShare(queue *DRFQueue) float64 {
    if s.totalResources.GPU == 0 && s.totalResources.CPU == 0 &&
       s.totalResources.Memory == 0 {
        return 0
    }

    // 计算每种资源的份额
    gpuShare := float64(0)
    if s.totalResources.GPU > 0 {
        gpuShare = queue.Allocated.GPU / s.totalResources.GPU
    }

    cpuShare := float64(0)
    if s.totalResources.CPU > 0 {
        cpuShare = queue.Allocated.CPU / s.totalResources.CPU
    }

    memShare := float64(0)
    if s.totalResources.Memory > 0 {
        memShare = queue.Allocated.Memory / s.totalResources.Memory
    }

    // 返回最大份额(主导资源)
    return max(gpuShare, max(cpuShare, memShare))
}

// hasEnoughResources 检查是否有足够资源
func (s *DRFScheduler) hasEnoughResources(required Resources) bool {
    available := s.subtractResources(s.totalResources, s.allocatedResources)
    return available.GPU >= required.GPU &&
        available.CPU >= required.CPU &&
        available.Memory >= required.Memory
}

// getJobResources 获取作业资源需求
func (s *DRFScheduler) getJobResources(job *TrainingJob) Resources {
    replicas := float64(job.Resources.MinReplicas)
    return Resources{
        GPU:    float64(job.Resources.ReplicaResources.GPU) * replicas,
        CPU:    job.Resources.ReplicaResources.CPU.AsApproximateFloat64() * replicas,
        Memory: job.Resources.ReplicaResources.Memory.AsApproximateFloat64() * replicas,
    }
}

// addResources 资源相加
func (s *DRFScheduler) addResources(a, b Resources) Resources {
    return Resources{
        GPU:    a.GPU + b.GPU,
        CPU:    a.CPU + b.CPU,
        Memory: a.Memory + b.Memory,
    }
}

// subtractResources 资源相减
func (s *DRFScheduler) subtractResources(a, b Resources) Resources {
    return Resources{
        GPU:    a.GPU - b.GPU,
        CPU:    a.CPU - b.CPU,
        Memory: a.Memory - b.Memory,
    }
}

// ScheduleResult 调度结果
type ScheduleResult struct {
    Job       *TrainingJob
    Queue     string
    Resources Resources
}

func max(a, b float64) float64 {
    if a > b {
        return a
    }
    return b
}

4.2 层级公平调度

// pkg/scheduler/fairshare/hierarchical.go
package fairshare

import (
    "container/heap"
)

// HierarchicalFairScheduler 层级公平调度器
type HierarchicalFairScheduler struct {
    // 根队列
    root *HierarchicalQueue

    // 所有队列映射
    allQueues map[string]*HierarchicalQueue

    // 总资源
    totalResources Resources
}

// HierarchicalQueue 层级队列
type HierarchicalQueue struct {
    Name     string
    Parent   *HierarchicalQueue
    Children []*HierarchicalQueue

    // 权重
    Weight float64

    // 最小保证
    MinShare Resources

    // 最大限制
    MaxShare Resources

    // 已分配
    Allocated Resources

    // 公平份额
    FairShare Resources

    // 作业
    Jobs *JobHeap
}

// JobHeap 作业优先级堆
type JobHeap []*TrainingJob

func (h JobHeap) Len() int           { return len(h) }
func (h JobHeap) Less(i, j int) bool {
    return h[i].SchedulingConfig.Priority > h[j].SchedulingConfig.Priority
}
func (h JobHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *JobHeap) Push(x interface{}) { *h = append(*h, x.(*TrainingJob)) }
func (h *JobHeap) Pop() interface{} {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
}

// NewHierarchicalFairScheduler 创建层级公平调度器
func NewHierarchicalFairScheduler(totalResources Resources) *HierarchicalFairScheduler {
    root := &HierarchicalQueue{
        Name:      "root",
        Weight:    1.0,
        MaxShare:  totalResources,
        Jobs:      &JobHeap{},
    }
    heap.Init(root.Jobs)

    return &HierarchicalFairScheduler{
        root:           root,
        allQueues:      map[string]*HierarchicalQueue{"root": root},
        totalResources: totalResources,
    }
}

// AddQueue 添加队列
func (s *HierarchicalFairScheduler) AddQueue(name, parentName string,
    weight float64, minShare, maxShare Resources) error {

    parent, ok := s.allQueues[parentName]
    if !ok {
        return fmt.Errorf("parent queue %s not found", parentName)
    }

    queue := &HierarchicalQueue{
        Name:     name,
        Parent:   parent,
        Weight:   weight,
        MinShare: minShare,
        MaxShare: maxShare,
        Jobs:     &JobHeap{},
    }
    heap.Init(queue.Jobs)

    parent.Children = append(parent.Children, queue)
    s.allQueues[name] = queue

    return nil
}

// CalculateFairShares 计算公平份额
func (s *HierarchicalFairScheduler) CalculateFairShares() {
    // 自顶向下计算公平份额
    s.calculateFairShareRecursive(s.root, s.totalResources)
}

// calculateFairShareRecursive 递归计算公平份额
func (s *HierarchicalFairScheduler) calculateFairShareRecursive(
    queue *HierarchicalQueue, available Resources) {

    if len(queue.Children) == 0 {
        // 叶子队列:公平份额 = min(可用, 最大限制)
        queue.FairShare = s.minResources(available, queue.MaxShare)
        return
    }

    // 计算子队列的需求
    demands := make([]Resources, len(queue.Children))
    totalWeight := float64(0)
    for i, child := range queue.Children {
        demands[i] = s.calculateDemand(child)
        totalWeight += child.Weight
    }

    // 分配资源给子队列
    remaining := available
    for i, child := range queue.Children {
        // 首先分配最小保证
        minAlloc := s.minResources(child.MinShare, remaining)
        remaining = s.subtractResources(remaining, minAlloc)
    }

    // 按权重分配剩余资源
    for i, child := range queue.Children {
        weightRatio := child.Weight / totalWeight
        share := s.multiplyResources(remaining, weightRatio)

        // 加上最小保证
        share = s.addResources(share, child.MinShare)

        // 不超过需求
        share = s.minResources(share, demands[i])

        // 不超过最大限制
        share = s.minResources(share, child.MaxShare)

        // 递归计算子队列
        s.calculateFairShareRecursive(child, share)
    }
}

// calculateDemand 计算队列需求
func (s *HierarchicalFairScheduler) calculateDemand(queue *HierarchicalQueue) Resources {
    demand := Resources{}

    // 已运行的作业
    demand = s.addResources(demand, queue.Allocated)

    // 等待的作业
    for _, job := range *queue.Jobs {
        jobRes := s.getJobResources(job)
        demand = s.addResources(demand, jobRes)
    }

    // 子队列需求
    for _, child := range queue.Children {
        childDemand := s.calculateDemand(child)
        demand = s.addResources(demand, childDemand)
    }

    return demand
}

// Schedule 调度作业
func (s *HierarchicalFairScheduler) Schedule() *ScheduleResult {
    // 先计算公平份额
    s.CalculateFairShares()

    // 选择最需要资源的队列
    queue := s.selectQueue(s.root)
    if queue == nil || queue.Jobs.Len() == 0 {
        return nil
    }

    // 取出最高优先级作业
    job := heap.Pop(queue.Jobs).(*TrainingJob)
    jobRes := s.getJobResources(job)

    // 检查资源
    if !s.canAllocate(queue, jobRes) {
        heap.Push(queue.Jobs, job)
        return nil
    }

    // 分配资源
    s.allocate(queue, jobRes)

    return &ScheduleResult{
        Job:       job,
        Queue:     queue.Name,
        Resources: jobRes,
    }
}

// selectQueue 选择队列
func (s *HierarchicalFairScheduler) selectQueue(queue *HierarchicalQueue) *HierarchicalQueue {
    if len(queue.Children) == 0 {
        return queue
    }

    // 选择最低于公平份额的子队列
    var selected *HierarchicalQueue
    maxDeficit := float64(-1e9)

    for _, child := range queue.Children {
        deficit := s.calculateDeficit(child)
        if deficit > maxDeficit {
            maxDeficit = deficit
            selected = child
        }
    }

    if selected == nil {
        return nil
    }

    return s.selectQueue(selected)
}

// calculateDeficit 计算资源缺口
func (s *HierarchicalFairScheduler) calculateDeficit(queue *HierarchicalQueue) float64 {
    // 缺口 = 公平份额 - 已分配
    gpuDeficit := queue.FairShare.GPU - queue.Allocated.GPU
    cpuDeficit := queue.FairShare.CPU - queue.Allocated.CPU
    memDeficit := queue.FairShare.Memory - queue.Allocated.Memory

    // 返回主导资源的缺口
    return max(gpuDeficit, max(cpuDeficit, memDeficit))
}

// canAllocate 检查是否可分配
func (s *HierarchicalFairScheduler) canAllocate(queue *HierarchicalQueue, res Resources) bool {
    // 检查队列限制
    newAlloc := s.addResources(queue.Allocated, res)
    if newAlloc.GPU > queue.MaxShare.GPU ||
        newAlloc.CPU > queue.MaxShare.CPU ||
        newAlloc.Memory > queue.MaxShare.Memory {
        return false
    }

    // 检查父队列
    if queue.Parent != nil {
        return s.canAllocate(queue.Parent, res)
    }

    return true
}

// allocate 分配资源
func (s *HierarchicalFairScheduler) allocate(queue *HierarchicalQueue, res Resources) {
    queue.Allocated = s.addResources(queue.Allocated, res)
    if queue.Parent != nil {
        s.allocate(queue.Parent, res)
    }
}

// 辅助函数
func (s *HierarchicalFairScheduler) minResources(a, b Resources) Resources {
    return Resources{
        GPU:    min(a.GPU, b.GPU),
        CPU:    min(a.CPU, b.CPU),
        Memory: min(a.Memory, b.Memory),
    }
}

func (s *HierarchicalFairScheduler) multiplyResources(r Resources, factor float64) Resources {
    return Resources{
        GPU:    r.GPU * factor,
        CPU:    r.CPU * factor,
        Memory: r.Memory * factor,
    }
}

func min(a, b float64) float64 {
    if a < b {
        return a
    }
    return b
}

5. 抢占与恢复

5.1 抢占调度器

// pkg/scheduler/preemption/scheduler.go
package preemption

import (
    "context"
    "sort"
    "time"
)

// PreemptionScheduler 抢占调度器
type PreemptionScheduler struct {
    // 运行中的作业
    runningJobs map[string]*RunningJob

    // 抢占策略
    strategy PreemptionStrategy

    // 检查点管理器
    checkpointManager CheckpointManager
}

// RunningJob 运行中的作业
type RunningJob struct {
    Job          *TrainingJob
    StartTime    time.Time
    Resources    Resources
    Nodes        []string
    Preemptible  bool
    Priority     int32
    LastCheckpoint time.Time
}

// PreemptionStrategy 抢占策略
type PreemptionStrategy interface {
    // SelectVictims 选择被抢占的作业
    SelectVictims(running []*RunningJob, required Resources) []*RunningJob

    // ShouldPreempt 是否应该抢占
    ShouldPreempt(preemptor *TrainingJob, victim *RunningJob) bool
}

// PriorityBasedStrategy 基于优先级的抢占策略
type PriorityBasedStrategy struct {
    // 最小优先级差
    MinPriorityDiff int32

    // 最小运行时间保护
    MinRuntime time.Duration

    // 最大被抢占作业数
    MaxVictims int
}

// SelectVictims 选择被抢占者
func (s *PriorityBasedStrategy) SelectVictims(running []*RunningJob,
    required Resources) []*RunningJob {

    // 过滤可抢占的作业
    candidates := make([]*RunningJob, 0)
    for _, job := range running {
        if job.Preemptible && time.Since(job.StartTime) > s.MinRuntime {
            candidates = append(candidates, job)
        }
    }

    // 按优先级升序排序(低优先级先被抢占)
    sort.Slice(candidates, func(i, j int) bool {
        return candidates[i].Priority < candidates[j].Priority
    })

    // 选择足够的作业以释放所需资源
    victims := make([]*RunningJob, 0)
    freed := Resources{}

    for _, job := range candidates {
        if len(victims) >= s.MaxVictims {
            break
        }

        victims = append(victims, job)
        freed = addResources(freed, job.Resources)

        // 检查是否已经够了
        if freed.GPU >= required.GPU &&
           freed.CPU >= required.CPU &&
           freed.Memory >= required.Memory {
            break
        }
    }

    // 检查释放的资源是否足够
    if freed.GPU < required.GPU || freed.CPU < required.CPU ||
       freed.Memory < required.Memory {
        return nil
    }

    return victims
}

// ShouldPreempt 判断是否应该抢占
func (s *PriorityBasedStrategy) ShouldPreempt(preemptor *TrainingJob,
    victim *RunningJob) bool {

    // 优先级差必须足够大
    if preemptor.SchedulingConfig.Priority - victim.Priority < s.MinPriorityDiff {
        return false
    }

    // 被抢占者必须已运行足够长时间
    if time.Since(victim.StartTime) < s.MinRuntime {
        return false
    }

    return true
}

// ExecutePreemption 执行抢占
func (s *PreemptionScheduler) ExecutePreemption(ctx context.Context,
    preemptor *TrainingJob, victims []*RunningJob) error {

    for _, victim := range victims {
        // 1. 发送抢占通知(给作业机会保存检查点)
        if err := s.notifyPreemption(ctx, victim); err != nil {
            return err
        }

        // 2. 等待检查点保存
        if err := s.waitForCheckpoint(ctx, victim, 30*time.Second); err != nil {
            // 超时,强制抢占
        }

        // 3. 终止作业
        if err := s.terminateJob(ctx, victim); err != nil {
            return err
        }

        // 4. 记录抢占事件
        s.recordPreemptionEvent(victim, preemptor)
    }

    return nil
}

// notifyPreemption 通知即将被抢占
func (s *PreemptionScheduler) notifyPreemption(ctx context.Context,
    victim *RunningJob) error {

    // 发送 SIGTERM 信号
    // 或通过 annotation/label 通知
    return nil
}

// waitForCheckpoint 等待检查点保存
func (s *PreemptionScheduler) waitForCheckpoint(ctx context.Context,
    victim *RunningJob, timeout time.Duration) error {

    deadline := time.Now().Add(timeout)

    for time.Now().Before(deadline) {
        // 检查是否有新检查点
        checkpoint, err := s.checkpointManager.GetLatest(victim.Job)
        if err == nil && checkpoint.Timestamp.After(victim.LastCheckpoint) {
            victim.LastCheckpoint = checkpoint.Timestamp
            return nil
        }

        time.Sleep(time.Second)
    }

    return fmt.Errorf("checkpoint timeout")
}

// terminateJob 终止作业
func (s *PreemptionScheduler) terminateJob(ctx context.Context,
    victim *RunningJob) error {

    // 删除所有 Pod
    // 更新作业状态
    return nil
}

// RecoverPreemptedJob 恢复被抢占的作业
func (s *PreemptionScheduler) RecoverPreemptedJob(ctx context.Context,
    job *TrainingJob) error {

    // 获取最新检查点
    checkpoint, err := s.checkpointManager.GetLatest(job)
    if err != nil {
        // 没有检查点,从头开始
        return s.restartJob(ctx, job)
    }

    // 从检查点恢复
    return s.resumeFromCheckpoint(ctx, job, checkpoint)
}

// restartJob 重新启动作业
func (s *PreemptionScheduler) restartJob(ctx context.Context,
    job *TrainingJob) error {

    // 重新创建 Pod
    return nil
}

// resumeFromCheckpoint 从检查点恢复
func (s *PreemptionScheduler) resumeFromCheckpoint(ctx context.Context,
    job *TrainingJob, checkpoint *Checkpoint) error {

    // 创建 Pod 并配置从检查点恢复
    return nil
}

5.2 检查点管理

// pkg/scheduler/checkpoint/manager.go
package checkpoint

import (
    "context"
    "fmt"
    "os"
    "path/filepath"
    "sort"
    "time"
)

// CheckpointManager 检查点管理器
type CheckpointManager struct {
    // 存储后端
    storage CheckpointStorage

    // 配置
    config CheckpointConfig
}

// CheckpointStorage 检查点存储接口
type CheckpointStorage interface {
    // Save 保存检查点
    Save(ctx context.Context, checkpoint *Checkpoint, data []byte) error

    // Load 加载检查点
    Load(ctx context.Context, checkpoint *Checkpoint) ([]byte, error)

    // List 列出检查点
    List(ctx context.Context, jobKey string) ([]*Checkpoint, error)

    // Delete 删除检查点
    Delete(ctx context.Context, checkpoint *Checkpoint) error
}

// CheckpointConfig 配置
type CheckpointConfig struct {
    // 检查点保存路径
    BasePath string

    // 保留的检查点数量
    MaxCheckpoints int

    // 检查点间隔
    Interval time.Duration
}

// Checkpoint 检查点
type Checkpoint struct {
    // 唯一标识
    ID string `json:"id"`

    // 作业标识
    JobKey string `json:"jobKey"`

    // 时间戳
    Timestamp time.Time `json:"timestamp"`

    // 训练步数
    Step int64 `json:"step"`

    // Epoch
    Epoch int32 `json:"epoch"`

    // 文件路径
    Path string `json:"path"`

    // 大小(字节)
    Size int64 `json:"size"`

    // 元数据
    Metadata map[string]string `json:"metadata"`
}

// NewCheckpointManager 创建检查点管理器
func NewCheckpointManager(storage CheckpointStorage, config CheckpointConfig) *CheckpointManager {
    return &CheckpointManager{
        storage: storage,
        config:  config,
    }
}

// SaveCheckpoint 保存检查点
func (m *CheckpointManager) SaveCheckpoint(ctx context.Context,
    job *TrainingJob, step int64, epoch int32, data []byte) (*Checkpoint, error) {

    jobKey := job.Namespace + "/" + job.Name
    checkpointID := fmt.Sprintf("%s-%d", jobKey, time.Now().UnixNano())

    checkpoint := &Checkpoint{
        ID:        checkpointID,
        JobKey:    jobKey,
        Timestamp: time.Now(),
        Step:      step,
        Epoch:     epoch,
        Path:      filepath.Join(m.config.BasePath, jobKey, checkpointID),
        Size:      int64(len(data)),
        Metadata:  make(map[string]string),
    }

    // 保存到存储
    if err := m.storage.Save(ctx, checkpoint, data); err != nil {
        return nil, err
    }

    // 清理旧检查点
    if err := m.cleanupOldCheckpoints(ctx, jobKey); err != nil {
        // 日志记录,但不影响保存
    }

    return checkpoint, nil
}

// GetLatest 获取最新检查点
func (m *CheckpointManager) GetLatest(job *TrainingJob) (*Checkpoint, error) {
    jobKey := job.Namespace + "/" + job.Name
    checkpoints, err := m.storage.List(context.Background(), jobKey)
    if err != nil {
        return nil, err
    }

    if len(checkpoints) == 0 {
        return nil, fmt.Errorf("no checkpoint found for job %s", jobKey)
    }

    // 按时间排序
    sort.Slice(checkpoints, func(i, j int) bool {
        return checkpoints[i].Timestamp.After(checkpoints[j].Timestamp)
    })

    return checkpoints[0], nil
}

// LoadCheckpoint 加载检查点
func (m *CheckpointManager) LoadCheckpoint(ctx context.Context,
    checkpoint *Checkpoint) ([]byte, error) {

    return m.storage.Load(ctx, checkpoint)
}

// cleanupOldCheckpoints 清理旧检查点
func (m *CheckpointManager) cleanupOldCheckpoints(ctx context.Context,
    jobKey string) error {

    checkpoints, err := m.storage.List(ctx, jobKey)
    if err != nil {
        return err
    }

    if len(checkpoints) <= m.config.MaxCheckpoints {
        return nil
    }

    // 按时间排序
    sort.Slice(checkpoints, func(i, j int) bool {
        return checkpoints[i].Timestamp.After(checkpoints[j].Timestamp)
    })

    // 删除多余的检查点
    for i := m.config.MaxCheckpoints; i < len(checkpoints); i++ {
        if err := m.storage.Delete(ctx, checkpoints[i]); err != nil {
            return err
        }
    }

    return nil
}

// S3CheckpointStorage S3 存储实现
type S3CheckpointStorage struct {
    bucket string
    client *s3.Client
}

// Save 保存到 S3
func (s *S3CheckpointStorage) Save(ctx context.Context,
    checkpoint *Checkpoint, data []byte) error {

    _, err := s.client.PutObject(ctx, &s3.PutObjectInput{
        Bucket: &s.bucket,
        Key:    &checkpoint.Path,
        Body:   bytes.NewReader(data),
    })

    return err
}

// Load 从 S3 加载
func (s *S3CheckpointStorage) Load(ctx context.Context,
    checkpoint *Checkpoint) ([]byte, error) {

    result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
        Bucket: &s.bucket,
        Key:    &checkpoint.Path,
    })
    if err != nil {
        return nil, err
    }
    defer result.Body.Close()

    return io.ReadAll(result.Body)
}

// List 列出检查点
func (s *S3CheckpointStorage) List(ctx context.Context,
    jobKey string) ([]*Checkpoint, error) {

    prefix := jobKey + "/"
    result, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
        Bucket: &s.bucket,
        Prefix: &prefix,
    })
    if err != nil {
        return nil, err
    }

    checkpoints := make([]*Checkpoint, 0, len(result.Contents))
    for _, obj := range result.Contents {
        checkpoint := &Checkpoint{
            ID:        *obj.Key,
            JobKey:    jobKey,
            Path:      *obj.Key,
            Size:      *obj.Size,
            Timestamp: *obj.LastModified,
        }
        checkpoints = append(checkpoints, checkpoint)
    }

    return checkpoints, nil
}

// Delete 删除检查点
func (s *S3CheckpointStorage) Delete(ctx context.Context,
    checkpoint *Checkpoint) error {

    _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
        Bucket: &s.bucket,
        Key:    &checkpoint.Path,
    })

    return err
}

6. 监控与可观测性

6.1 调度指标

// pkg/scheduler/metrics/metrics.go
package metrics

import (
    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

var (
    // 调度延迟
    SchedulingLatency = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "training_job_scheduling_duration_seconds",
            Help:    "Time taken to schedule a training job",
            Buckets: prometheus.ExponentialBuckets(0.1, 2, 15),
        },
        []string{"queue", "job_type", "result"},
    )

    // 队列长度
    QueueLength = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "training_job_queue_length",
            Help: "Number of jobs waiting in queue",
        },
        []string{"queue", "priority"},
    )

    // 资源利用率
    ResourceUtilization = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "training_cluster_resource_utilization",
            Help: "Resource utilization ratio",
        },
        []string{"resource_type", "queue"},
    )

    // 抢占计数
    PreemptionCounter = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "training_job_preemptions_total",
            Help: "Total number of job preemptions",
        },
        []string{"preemptor_queue", "victim_queue", "reason"},
    )

    // 作业完成时间
    JobCompletionTime = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "training_job_completion_seconds",
            Help:    "Time taken to complete a training job",
            Buckets: prometheus.ExponentialBuckets(60, 2, 15), // 1分钟到约1天
        },
        []string{"queue", "job_type", "status"},
    )

    // 公平份额偏差
    FairShareDeviation = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "training_queue_fairshare_deviation",
            Help: "Deviation from fair share",
        },
        []string{"queue"},
    )

    // Gang 调度成功率
    GangSchedulingSuccess = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "training_gang_scheduling_total",
            Help: "Total gang scheduling attempts",
        },
        []string{"result", "timeout"},
    )

    // 检查点保存
    CheckpointSaves = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "training_checkpoint_saves_total",
            Help: "Total checkpoint saves",
        },
        []string{"job", "reason", "status"},
    )
)

// RecordSchedulingLatency 记录调度延迟
func RecordSchedulingLatency(queue, jobType, result string, duration float64) {
    SchedulingLatency.WithLabelValues(queue, jobType, result).Observe(duration)
}

// UpdateQueueLength 更新队列长度
func UpdateQueueLength(queue string, priority int32, length int) {
    QueueLength.WithLabelValues(queue, fmt.Sprintf("%d", priority)).Set(float64(length))
}

// UpdateResourceUtilization 更新资源利用率
func UpdateResourceUtilization(resourceType, queue string, ratio float64) {
    ResourceUtilization.WithLabelValues(resourceType, queue).Set(ratio)
}

// RecordPreemption 记录抢占事件
func RecordPreemption(preemptorQueue, victimQueue, reason string) {
    PreemptionCounter.WithLabelValues(preemptorQueue, victimQueue, reason).Inc()
}

6.2 Grafana Dashboard

{
  "dashboard": {
    "title": "Training Job Scheduler Dashboard",
    "panels": [
      {
        "title": "Queue Length by Priority",
        "type": "timeseries",
        "targets": [
          {
            "expr": "sum(training_job_queue_length) by (queue, priority)",
            "legendFormat": "{{queue}} - P{{priority}}"
          }
        ]
      },
      {
        "title": "Scheduling Latency (P99)",
        "type": "gauge",
        "targets": [
          {
            "expr": "histogram_quantile(0.99, sum(rate(training_job_scheduling_duration_seconds_bucket[5m])) by (le))"
          }
        ],
        "options": {
          "maxValue": 60,
          "thresholds": [
            {"value": 0, "color": "green"},
            {"value": 10, "color": "yellow"},
            {"value": 30, "color": "red"}
          ]
        }
      },
      {
        "title": "Resource Utilization",
        "type": "heatmap",
        "targets": [
          {
            "expr": "training_cluster_resource_utilization",
            "format": "heatmap"
          }
        ]
      },
      {
        "title": "Fair Share Deviation",
        "type": "bargauge",
        "targets": [
          {
            "expr": "training_queue_fairshare_deviation",
            "legendFormat": "{{queue}}"
          }
        ]
      },
      {
        "title": "Preemption Events",
        "type": "timeseries",
        "targets": [
          {
            "expr": "sum(rate(training_job_preemptions_total[1h])) by (victim_queue)",
            "legendFormat": "{{victim_queue}}"
          }
        ]
      },
      {
        "title": "Job Completion Time Distribution",
        "type": "histogram",
        "targets": [
          {
            "expr": "sum(rate(training_job_completion_seconds_bucket[1h])) by (le, status)",
            "format": "heatmap"
          }
        ]
      }
    ]
  }
}

总结

本章深入讲解了训练任务调度系统的核心技术:

  1. 队列管理:多级优先级队列、层级队列、Kueue 集成
  2. Gang 调度:确保分布式训练任务整组调度、Volcano 集成
  3. 公平调度:DRF 算法、层级公平共享
  4. 抢占机制:优先级抢占、检查点恢复
  5. 可观测性:调度指标、Grafana 监控

高效的训练任务调度是 AI 平台的核心能力,直接影响集群资源利用率和用户体验。

下一章我们将探讨 模型存储与管理,讲解如何高效管理训练过程中产生的模型文件和检查点。

Prev
分布式训练框架
Next
模型存储与管理