训练任务调度
概述
大规模 AI 训练任务的调度是一个复杂的系统工程问题。不同于传统的微服务调度,训练任务具有长时运行、资源密集、Gang 调度、弹性伸缩等特殊需求。本章深入讲解训练任务调度系统的设计与实现,包括调度策略、队列管理、公平共享、抢占恢复等核心机制。
1. 训练任务调度架构
1.1 整体架构
┌──────────────────────────────────────────────────────────────────────────────┐
│ Training Job Scheduling Architecture │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────────┐ │
│ │ Job Submission Layer │ │
│ │ │ │
│ │ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │ │
│ │ │ CLI/SDK │ │ Web UI │ │ Jupyter Hub │ │ │
│ │ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │ │
│ └──────────┼──────────────────┼──────────────────┼─────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐ │
│ │ Job Controller Layer │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────────────────────────┐ │ │
│ │ │ Training Job Controller │ │ │
│ │ │ │ │ │
│ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │ │
│ │ │ │ PyTorchJob │ │ TFJob │ │ MPIJob │ │ │ │
│ │ │ │ Controller │ │ Controller │ │ Controller │ │ │ │
│ │ │ └──────────────┘ └──────────────┘ └──────────────┘ │ │ │
│ │ └─────────────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐ │
│ │ Scheduling Layer │ │
│ │ │ │
│ │ ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ │ │
│ │ │ Queue Manager │ │ Gang Scheduler │ │Priority Manager│ │ │
│ │ └────────┬───────┘ └────────┬───────┘ └────────┬───────┘ │ │
│ │ │ │ │ │ │
│ │ ▼ ▼ ▼ │ │
│ │ ┌─────────────────────────────────────────────────────────────────┐ │ │
│ │ │ Volcano/Kueue Scheduler │ │ │
│ │ │ │ │ │
│ │ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ │
│ │ │ │ Fair │ │ Capacity │ │ Binpack │ │ Topology │ │ │ │
│ │ │ │ Share │ │ Planning │ │ Strategy │ │ Aware │ │ │ │
│ │ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │ │
│ │ └─────────────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐ │
│ │ Resource Layer │ │
│ │ │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ GPU Pool │ │ CPU Pool │ │ Memory Pool │ │ Storage Pool│ │ │
│ │ │ │ │ │ │ │ │ │ │ │
│ │ │ A100 x 64 │ │ 1024 cores │ │ 8 TiB │ │ 100 TiB │ │ │
│ │ │ V100 x 32 │ │ │ │ │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────┘
1.2 核心数据结构
// pkg/scheduler/types.go
package scheduler
import (
"time"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)
// TrainingJob 训练任务定义
type TrainingJob struct {
// 基本信息
Name string `json:"name"`
Namespace string `json:"namespace"`
UID string `json:"uid"`
Labels map[string]string `json:"labels"`
// 任务类型
JobType JobType `json:"jobType"`
// 资源需求
Resources ResourceRequirements `json:"resources"`
// 调度配置
SchedulingConfig SchedulingConfig `json:"schedulingConfig"`
// 弹性配置
ElasticConfig *ElasticConfig `json:"elasticConfig,omitempty"`
// 状态
Status TrainingJobStatus `json:"status"`
// 时间戳
CreationTime time.Time `json:"creationTime"`
StartTime *time.Time `json:"startTime,omitempty"`
EndTime *time.Time `json:"endTime,omitempty"`
}
// JobType 任务类型
type JobType string
const (
JobTypePyTorch JobType = "PyTorchJob"
JobTypeTensorFlow JobType = "TFJob"
JobTypeMPI JobType = "MPIJob"
JobTypeXGBoost JobType = "XGBoostJob"
JobTypePaddlePaddle JobType = "PaddleJob"
)
// ResourceRequirements 资源需求
type ResourceRequirements struct {
// 任务级别需求
MinReplicas int32 `json:"minReplicas"`
MaxReplicas int32 `json:"maxReplicas"`
// 每个 replica 的需求
ReplicaResources ReplicaResources `json:"replicaResources"`
// GPU 类型要求
GPUType string `json:"gpuType,omitempty"`
// 拓扑约束
TopologyConstraint string `json:"topologyConstraint,omitempty"`
}
// ReplicaResources 单个副本资源
type ReplicaResources struct {
GPU int32 `json:"gpu"`
CPU resource.Quantity `json:"cpu"`
Memory resource.Quantity `json:"memory"`
}
// SchedulingConfig 调度配置
type SchedulingConfig struct {
// 队列名
Queue string `json:"queue"`
// 优先级
Priority int32 `json:"priority"`
// 是否需要 Gang 调度
GangScheduling bool `json:"gangScheduling"`
// 最大等待时间
MaxWaitTime time.Duration `json:"maxWaitTime"`
// 抢占策略
PreemptionPolicy PreemptionPolicy `json:"preemptionPolicy"`
// 调度器名称
SchedulerName string `json:"schedulerName"`
}
// PreemptionPolicy 抢占策略
type PreemptionPolicy string
const (
PreemptionPolicyNever PreemptionPolicy = "Never"
PreemptionPolicyPreemptLower PreemptionPolicy = "PreemptLower"
PreemptionPolicyPreemptAlways PreemptionPolicy = "PreemptAlways"
)
// ElasticConfig 弹性配置
type ElasticConfig struct {
// 是否启用弹性
Enabled bool `json:"enabled"`
// 最小 worker 数
MinWorkers int32 `json:"minWorkers"`
// 最大 worker 数
MaxWorkers int32 `json:"maxWorkers"`
// 伸缩策略
ScalePolicy ScalePolicy `json:"scalePolicy"`
}
// ScalePolicy 伸缩策略
type ScalePolicy struct {
// 扩容冷却时间
ScaleUpCooldown time.Duration `json:"scaleUpCooldown"`
// 缩容冷却时间
ScaleDownCooldown time.Duration `json:"scaleDownCooldown"`
// 扩容步长
ScaleUpStep int32 `json:"scaleUpStep"`
// 缩容步长
ScaleDownStep int32 `json:"scaleDownStep"`
}
// TrainingJobStatus 任务状态
type TrainingJobStatus struct {
// 当前阶段
Phase JobPhase `json:"phase"`
// 分配的资源
AllocatedResources AllocatedResources `json:"allocatedResources"`
// 副本状态
ReplicaStatuses map[string]ReplicaStatus `json:"replicaStatuses"`
// 条件
Conditions []JobCondition `json:"conditions"`
// 历史事件
Events []JobEvent `json:"events"`
}
// JobPhase 任务阶段
type JobPhase string
const (
JobPhasePending JobPhase = "Pending"
JobPhaseQueued JobPhase = "Queued"
JobPhaseRunning JobPhase = "Running"
JobPhaseSucceeded JobPhase = "Succeeded"
JobPhaseFailed JobPhase = "Failed"
JobPhaseSuspended JobPhase = "Suspended"
)
// AllocatedResources 已分配资源
type AllocatedResources struct {
GPUs int32 `json:"gpus"`
CPU resource.Quantity `json:"cpu"`
Memory resource.Quantity `json:"memory"`
Nodes []string `json:"nodes"`
}
// ReplicaStatus 副本状态
type ReplicaStatus struct {
Running int32 `json:"running"`
Succeeded int32 `json:"succeeded"`
Failed int32 `json:"failed"`
Pending int32 `json:"pending"`
}
// JobCondition 任务条件
type JobCondition struct {
Type string `json:"type"`
Status corev1.ConditionStatus `json:"status"`
LastTransitionTime time.Time `json:"lastTransitionTime"`
Reason string `json:"reason"`
Message string `json:"message"`
}
// JobEvent 任务事件
type JobEvent struct {
Type string `json:"type"`
Reason string `json:"reason"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
2. 队列管理系统
2.1 多级队列实现
// pkg/scheduler/queue/manager.go
package queue
import (
"context"
"fmt"
"sync"
"time"
)
// QueueSpec 队列规格
type QueueSpec struct {
Name string `json:"name"`
// 父队列(用于层级队列)
Parent string `json:"parent,omitempty"`
// 权重(用于公平调度)
Weight int32 `json:"weight"`
// 资源配额
Quota ResourceQuota `json:"quota"`
// 借用配置
Borrowing *BorrowingSpec `json:"borrowing,omitempty"`
// 抢占配置
Preemption *PreemptionSpec `json:"preemption,omitempty"`
// 准入控制
AdmissionControl *AdmissionSpec `json:"admissionControl,omitempty"`
}
// ResourceQuota 资源配额
type ResourceQuota struct {
// 硬限制
Hard ResourceList `json:"hard"`
// 软限制(可以借用超过)
Soft ResourceList `json:"soft,omitempty"`
// 保证资源(不可被借用)
Guaranteed ResourceList `json:"guaranteed,omitempty"`
}
// ResourceList 资源列表
type ResourceList struct {
GPU int32 `json:"gpu"`
CPU string `json:"cpu"`
Memory string `json:"memory"`
}
// BorrowingSpec 借用配置
type BorrowingSpec struct {
// 是否允许借用
Enabled bool `json:"enabled"`
// 可借用的队列
AllowedQueues []string `json:"allowedQueues,omitempty"`
// 最大借用比例
MaxBorrowRatio float64 `json:"maxBorrowRatio"`
}
// PreemptionSpec 抢占配置
type PreemptionSpec struct {
// 是否允许被抢占
AllowPreemption bool `json:"allowPreemption"`
// 允许抢占的队列
AllowPreemptFrom []string `json:"allowPreemptFrom,omitempty"`
// 抢占优雅期
GracePeriod time.Duration `json:"gracePeriod"`
}
// AdmissionSpec 准入控制配置
type AdmissionSpec struct {
// 最大作业数
MaxJobs int32 `json:"maxJobs,omitempty"`
// 最大 GPU 数
MaxGPUsPerJob int32 `json:"maxGPUsPerJob,omitempty"`
// 允许的优先级范围
PriorityRange *PriorityRange `json:"priorityRange,omitempty"`
}
// PriorityRange 优先级范围
type PriorityRange struct {
Min int32 `json:"min"`
Max int32 `json:"max"`
}
// Queue 队列实例
type Queue struct {
Spec QueueSpec
Status QueueStatus
// 子队列
children []*Queue
parent *Queue
// 作业列表
jobs []*TrainingJob
jobIndex map[string]int
mu sync.RWMutex
}
// QueueStatus 队列状态
type QueueStatus struct {
// 已使用资源
Used ResourceList `json:"used"`
// 借入资源
Borrowed ResourceList `json:"borrowed"`
// 借出资源
Lent ResourceList `json:"lent"`
// 排队作业数
PendingJobs int32 `json:"pendingJobs"`
// 运行作业数
RunningJobs int32 `json:"runningJobs"`
}
// QueueManager 队列管理器
type QueueManager struct {
queues map[string]*Queue
rootQueues []*Queue
mu sync.RWMutex
}
// NewQueueManager 创建队列管理器
func NewQueueManager() *QueueManager {
return &QueueManager{
queues: make(map[string]*Queue),
rootQueues: make([]*Queue, 0),
}
}
// CreateQueue 创建队列
func (m *QueueManager) CreateQueue(spec QueueSpec) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.queues[spec.Name]; exists {
return fmt.Errorf("queue %s already exists", spec.Name)
}
queue := &Queue{
Spec: spec,
Status: QueueStatus{},
children: make([]*Queue, 0),
jobs: make([]*TrainingJob, 0),
jobIndex: make(map[string]int),
}
// 处理层级关系
if spec.Parent != "" {
parent, ok := m.queues[spec.Parent]
if !ok {
return fmt.Errorf("parent queue %s not found", spec.Parent)
}
queue.parent = parent
parent.children = append(parent.children, queue)
} else {
m.rootQueues = append(m.rootQueues, queue)
}
m.queues[spec.Name] = queue
return nil
}
// EnqueueJob 将作业加入队列
func (m *QueueManager) EnqueueJob(job *TrainingJob) error {
m.mu.Lock()
defer m.mu.Unlock()
queueName := job.SchedulingConfig.Queue
queue, ok := m.queues[queueName]
if !ok {
return fmt.Errorf("queue %s not found", queueName)
}
// 检查准入控制
if err := m.checkAdmission(queue, job); err != nil {
return fmt.Errorf("admission check failed: %v", err)
}
// 加入队列
queue.mu.Lock()
defer queue.mu.Unlock()
jobKey := job.Namespace + "/" + job.Name
if _, exists := queue.jobIndex[jobKey]; exists {
return fmt.Errorf("job %s already in queue", jobKey)
}
queue.jobs = append(queue.jobs, job)
queue.jobIndex[jobKey] = len(queue.jobs) - 1
queue.Status.PendingJobs++
// 按优先级排序
m.sortJobsByPriority(queue)
return nil
}
// DequeueJob 从队列取出作业
func (m *QueueManager) DequeueJob(queueName string) *TrainingJob {
m.mu.Lock()
defer m.mu.Unlock()
queue, ok := m.queues[queueName]
if !ok {
return nil
}
queue.mu.Lock()
defer queue.mu.Unlock()
if len(queue.jobs) == 0 {
return nil
}
// 取出最高优先级的作业
job := queue.jobs[0]
queue.jobs = queue.jobs[1:]
// 更新索引
delete(queue.jobIndex, job.Namespace+"/"+job.Name)
for i, j := range queue.jobs {
queue.jobIndex[j.Namespace+"/"+j.Name] = i
}
queue.Status.PendingJobs--
return job
}
// checkAdmission 检查准入控制
func (m *QueueManager) checkAdmission(queue *Queue, job *TrainingJob) error {
if queue.Spec.AdmissionControl == nil {
return nil
}
ac := queue.Spec.AdmissionControl
// 检查最大作业数
if ac.MaxJobs > 0 {
totalJobs := queue.Status.PendingJobs + queue.Status.RunningJobs
if totalJobs >= ac.MaxJobs {
return fmt.Errorf("queue job limit reached: %d", ac.MaxJobs)
}
}
// 检查单作业 GPU 限制
if ac.MaxGPUsPerJob > 0 {
requestedGPUs := job.Resources.ReplicaResources.GPU * job.Resources.MaxReplicas
if requestedGPUs > ac.MaxGPUsPerJob {
return fmt.Errorf("job GPU request %d exceeds limit %d",
requestedGPUs, ac.MaxGPUsPerJob)
}
}
// 检查优先级范围
if ac.PriorityRange != nil {
if job.SchedulingConfig.Priority < ac.PriorityRange.Min ||
job.SchedulingConfig.Priority > ac.PriorityRange.Max {
return fmt.Errorf("job priority %d out of allowed range [%d, %d]",
job.SchedulingConfig.Priority,
ac.PriorityRange.Min, ac.PriorityRange.Max)
}
}
return nil
}
// sortJobsByPriority 按优先级排序
func (m *QueueManager) sortJobsByPriority(queue *Queue) {
// 使用稳定排序保持相同优先级的 FIFO 顺序
sort.SliceStable(queue.jobs, func(i, j int) bool {
return queue.jobs[i].SchedulingConfig.Priority >
queue.jobs[j].SchedulingConfig.Priority
})
// 更新索引
for i, job := range queue.jobs {
queue.jobIndex[job.Namespace+"/"+job.Name] = i
}
}
// SelectNextJob 选择下一个要调度的作业(公平调度)
func (m *QueueManager) SelectNextJob(availableResources ResourceList) *TrainingJob {
m.mu.RLock()
defer m.mu.RUnlock()
// 使用加权公平队列算法
var selectedJob *TrainingJob
var selectedQueue *Queue
bestScore := float64(-1)
for _, queue := range m.rootQueues {
job, score := m.selectFromQueueTree(queue, availableResources)
if job != nil && score > bestScore {
bestScore = score
selectedJob = job
selectedQueue = queue
}
}
if selectedJob != nil {
// 从队列中移除
m.removeJobFromQueue(selectedQueue, selectedJob)
}
return selectedJob
}
// selectFromQueueTree 从队列树中选择作业
func (m *QueueManager) selectFromQueueTree(queue *Queue, available ResourceList) (*TrainingJob, float64) {
queue.mu.RLock()
defer queue.mu.RUnlock()
// 递归检查子队列
var bestJob *TrainingJob
bestScore := float64(-1)
for _, child := range queue.children {
job, score := m.selectFromQueueTree(child, available)
if job != nil && score > bestScore {
bestScore = score
bestJob = job
}
}
// 检查当前队列的作业
for _, job := range queue.jobs {
if m.canSchedule(job, available) {
score := m.calculateFairScore(queue, job)
if score > bestScore {
bestScore = score
bestJob = job
}
}
}
return bestJob, bestScore
}
// calculateFairScore 计算公平调度分数
func (m *QueueManager) calculateFairScore(queue *Queue, job *TrainingJob) float64 {
// 公平份额 = 权重 / (已使用 + 1)
// 这样使用越少的队列分数越高
usedGPUs := float64(queue.Status.Used.GPU)
weight := float64(queue.Spec.Weight)
fairShare := weight / (usedGPUs + 1)
// 优先级加成
priorityBonus := float64(job.SchedulingConfig.Priority) / 1000000.0
// 等待时间加成(防止饥饿)
waitTime := time.Since(job.CreationTime).Minutes()
waitBonus := waitTime / 60.0 // 每小时加 1 分
return fairShare + priorityBonus + waitBonus
}
// canSchedule 检查是否可以调度
func (m *QueueManager) canSchedule(job *TrainingJob, available ResourceList) bool {
requiredGPUs := job.Resources.ReplicaResources.GPU * job.Resources.MinReplicas
return available.GPU >= requiredGPUs
}
// removeJobFromQueue 从队列移除作业
func (m *QueueManager) removeJobFromQueue(queue *Queue, job *TrainingJob) {
queue.mu.Lock()
defer queue.mu.Unlock()
jobKey := job.Namespace + "/" + job.Name
idx, ok := queue.jobIndex[jobKey]
if !ok {
return
}
queue.jobs = append(queue.jobs[:idx], queue.jobs[idx+1:]...)
delete(queue.jobIndex, jobKey)
// 更新索引
for i := idx; i < len(queue.jobs); i++ {
queue.jobIndex[queue.jobs[i].Namespace+"/"+queue.jobs[i].Name] = i
}
queue.Status.PendingJobs--
}
2.2 Kueue 集成
# kueue-resources.yaml
apiVersion: kueue.x-k8s.io/v1beta1
kind: ResourceFlavor
metadata:
name: gpu-a100
spec:
nodeLabels:
nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: ResourceFlavor
metadata:
name: gpu-v100
spec:
nodeLabels:
nvidia.com/gpu.product: "Tesla-V100-SXM2-32GB"
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: ClusterQueue
metadata:
name: training-cluster-queue
spec:
namespaceSelector: {}
resourceGroups:
- coveredResources: ["cpu", "memory", "nvidia.com/gpu"]
flavors:
- name: gpu-a100
resources:
- name: "cpu"
nominalQuota: 256
- name: "memory"
nominalQuota: 2Ti
- name: "nvidia.com/gpu"
nominalQuota: 64
borrowingLimit: 16
- name: gpu-v100
resources:
- name: "cpu"
nominalQuota: 128
- name: "memory"
nominalQuota: 1Ti
- name: "nvidia.com/gpu"
nominalQuota: 32
cohort: ai-platform
preemption:
reclaimWithinCohort: Any
withinClusterQueue: LowerPriority
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: LocalQueue
metadata:
name: ml-research-queue
namespace: ml-research
spec:
clusterQueue: training-cluster-queue
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: LocalQueue
metadata:
name: ml-production-queue
namespace: ml-production
spec:
clusterQueue: training-cluster-queue
3. Gang 调度实现
3.1 Gang 调度器
// pkg/scheduler/gang/scheduler.go
package gang
import (
"context"
"fmt"
"sync"
"time"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// GangScheduler Gang 调度器
type GangScheduler struct {
mu sync.RWMutex
// 待调度的 Gang
pendingGangs map[string]*PodGroup
// 资源管理
resourceManager *ResourceManager
// 配置
config *GangConfig
}
// PodGroup Pod 组(Gang)
type PodGroup struct {
// 基本信息
Name string
Namespace string
UID string
// 成员 Pod
Members []*PodMember
// 最小成员数(用于弹性 Gang)
MinMember int32
// 调度策略
SchedulePolicy SchedulePolicy
// 创建时间
CreationTime time.Time
// 调度超时
ScheduleTimeout time.Duration
// 状态
Status PodGroupStatus
}
// PodMember Gang 成员
type PodMember struct {
Name string
Namespace string
Pod *corev1.Pod
Resources corev1.ResourceRequirements
NodeName string // 调度后填充
Status MemberStatus
}
// MemberStatus 成员状态
type MemberStatus string
const (
MemberPending MemberStatus = "Pending"
MemberScheduled MemberStatus = "Scheduled"
MemberRunning MemberStatus = "Running"
MemberFailed MemberStatus = "Failed"
)
// SchedulePolicy 调度策略
type SchedulePolicy struct {
// 是否严格 Gang(必须所有成员同时调度)
Strict bool
// 拓扑约束
TopologyKey string
// 亲和性
Affinity *corev1.Affinity
// 容忍
Tolerations []corev1.Toleration
}
// PodGroupStatus PodGroup 状态
type PodGroupStatus struct {
Phase PodGroupPhase
Scheduled int32
Running int32
Succeeded int32
Failed int32
Conditions []PodGroupCondition
}
// PodGroupPhase PodGroup 阶段
type PodGroupPhase string
const (
PodGroupPending PodGroupPhase = "Pending"
PodGroupScheduled PodGroupPhase = "Scheduled"
PodGroupRunning PodGroupPhase = "Running"
PodGroupSucceeded PodGroupPhase = "Succeeded"
PodGroupFailed PodGroupPhase = "Failed"
PodGroupTimeout PodGroupPhase = "Timeout"
)
// PodGroupCondition 条件
type PodGroupCondition struct {
Type string
Status corev1.ConditionStatus
LastTransitionTime metav1.Time
Reason string
Message string
}
// GangConfig 配置
type GangConfig struct {
// 默认超时时间
DefaultTimeout time.Duration
// 调度间隔
ScheduleInterval time.Duration
// 是否允许部分调度
AllowPartialSchedule bool
}
// NewGangScheduler 创建 Gang 调度器
func NewGangScheduler(rm *ResourceManager, config *GangConfig) *GangScheduler {
if config == nil {
config = &GangConfig{
DefaultTimeout: 10 * time.Minute,
ScheduleInterval: 5 * time.Second,
AllowPartialSchedule: false,
}
}
return &GangScheduler{
pendingGangs: make(map[string]*PodGroup),
resourceManager: rm,
config: config,
}
}
// AddPodGroup 添加 PodGroup
func (s *GangScheduler) AddPodGroup(pg *PodGroup) error {
s.mu.Lock()
defer s.mu.Unlock()
key := pg.Namespace + "/" + pg.Name
if _, exists := s.pendingGangs[key]; exists {
return fmt.Errorf("pod group %s already exists", key)
}
if pg.ScheduleTimeout == 0 {
pg.ScheduleTimeout = s.config.DefaultTimeout
}
pg.CreationTime = time.Now()
pg.Status.Phase = PodGroupPending
s.pendingGangs[key] = pg
return nil
}
// Run 运行调度循环
func (s *GangScheduler) Run(ctx context.Context) {
ticker := time.NewTicker(s.config.ScheduleInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.scheduleAll(ctx)
}
}
}
// scheduleAll 调度所有待处理的 Gang
func (s *GangScheduler) scheduleAll(ctx context.Context) {
s.mu.Lock()
gangs := make([]*PodGroup, 0, len(s.pendingGangs))
for _, pg := range s.pendingGangs {
gangs = append(gangs, pg)
}
s.mu.Unlock()
// 按优先级排序
sort.Slice(gangs, func(i, j int) bool {
return gangs[i].CreationTime.Before(gangs[j].CreationTime)
})
for _, pg := range gangs {
// 检查超时
if time.Since(pg.CreationTime) > pg.ScheduleTimeout {
s.handleTimeout(pg)
continue
}
// 尝试调度
if s.trySchedule(ctx, pg) {
s.markScheduled(pg)
}
}
}
// trySchedule 尝试调度 PodGroup
func (s *GangScheduler) trySchedule(ctx context.Context, pg *PodGroup) bool {
// 收集资源需求
totalRequirements := s.calculateTotalRequirements(pg)
// 检查是否有足够资源
allocation, err := s.resourceManager.TryAllocate(totalRequirements)
if err != nil {
return false
}
// 分配 Pod 到节点
if !s.assignPodsToNodes(pg, allocation) {
s.resourceManager.ReleaseAllocation(allocation)
return false
}
// 提交分配
return s.commitAllocation(ctx, pg, allocation)
}
// calculateTotalRequirements 计算总资源需求
func (s *GangScheduler) calculateTotalRequirements(pg *PodGroup) *ResourceRequirements {
total := &ResourceRequirements{
GPU: 0,
CPU: resource.Quantity{},
Memory: resource.Quantity{},
}
for _, member := range pg.Members {
if member.Status == MemberPending {
if gpu, ok := member.Resources.Limits["nvidia.com/gpu"]; ok {
total.GPU += int32(gpu.Value())
}
if cpu, ok := member.Resources.Requests[corev1.ResourceCPU]; ok {
total.CPU.Add(cpu)
}
if mem, ok := member.Resources.Requests[corev1.ResourceMemory]; ok {
total.Memory.Add(mem)
}
}
}
return total
}
// assignPodsToNodes 分配 Pod 到节点
func (s *GangScheduler) assignPodsToNodes(pg *PodGroup, allocation *Allocation) bool {
// 获取可用节点
nodes := allocation.GetAllocatedNodes()
// 使用拓扑感知分配
if pg.SchedulePolicy.TopologyKey != "" {
return s.assignWithTopology(pg, nodes)
}
// 简单分配:贪心算法
nodeIndex := 0
for _, member := range pg.Members {
if member.Status != MemberPending {
continue
}
if nodeIndex >= len(nodes) {
return false
}
// 检查节点是否有足够资源
node := nodes[nodeIndex]
if s.canFitOnNode(member, node) {
member.NodeName = node.Name
nodeIndex++
} else {
// 尝试下一个节点
found := false
for i := nodeIndex + 1; i < len(nodes); i++ {
if s.canFitOnNode(member, nodes[i]) {
member.NodeName = nodes[i].Name
found = true
break
}
}
if !found {
return false
}
}
}
return true
}
// assignWithTopology 拓扑感知分配
func (s *GangScheduler) assignWithTopology(pg *PodGroup, nodes []*NodeInfo) bool {
topologyKey := pg.SchedulePolicy.TopologyKey
// 按拓扑域分组节点
topologyGroups := make(map[string][]*NodeInfo)
for _, node := range nodes {
domain := node.Labels[topologyKey]
topologyGroups[domain] = append(topologyGroups[domain], node)
}
// 尝试在同一拓扑域内分配所有 Pod
for domain, domainNodes := range topologyGroups {
if s.tryAssignToDomain(pg, domainNodes) {
return true
}
}
// 如果单个域不够,尝试跨域分配
if !pg.SchedulePolicy.Strict {
return s.assignAcrossDomains(pg, nodes)
}
return false
}
// tryAssignToDomain 尝试在单个拓扑域分配
func (s *GangScheduler) tryAssignToDomain(pg *PodGroup, nodes []*NodeInfo) bool {
// 检查域内资源是否足够
totalCapacity := s.calculateDomainCapacity(nodes)
requirements := s.calculateTotalRequirements(pg)
if totalCapacity.GPU < requirements.GPU {
return false
}
// 分配
nodeIndex := 0
for _, member := range pg.Members {
if member.Status != MemberPending {
continue
}
for nodeIndex < len(nodes) {
if s.canFitOnNode(member, nodes[nodeIndex]) {
member.NodeName = nodes[nodeIndex].Name
break
}
nodeIndex++
}
if nodeIndex >= len(nodes) {
// 回滚
for _, m := range pg.Members {
m.NodeName = ""
}
return false
}
}
return true
}
// commitAllocation 提交分配
func (s *GangScheduler) commitAllocation(ctx context.Context, pg *PodGroup,
allocation *Allocation) bool {
// 创建绑定
for _, member := range pg.Members {
if member.NodeName == "" {
continue
}
binding := &corev1.Binding{
ObjectMeta: metav1.ObjectMeta{
Name: member.Name,
Namespace: member.Namespace,
},
Target: corev1.ObjectReference{
Kind: "Node",
Name: member.NodeName,
},
}
// 执行绑定
if err := s.bindPod(ctx, binding); err != nil {
// 绑定失败,回滚
s.resourceManager.ReleaseAllocation(allocation)
return false
}
member.Status = MemberScheduled
}
return true
}
// handleTimeout 处理超时
func (s *GangScheduler) handleTimeout(pg *PodGroup) {
s.mu.Lock()
defer s.mu.Unlock()
pg.Status.Phase = PodGroupTimeout
pg.Status.Conditions = append(pg.Status.Conditions, PodGroupCondition{
Type: "SchedulingTimeout",
Status: corev1.ConditionTrue,
LastTransitionTime: metav1.Now(),
Reason: "Timeout",
Message: fmt.Sprintf("Failed to schedule within %v", pg.ScheduleTimeout),
})
// 从待处理列表移除
delete(s.pendingGangs, pg.Namespace+"/"+pg.Name)
}
// markScheduled 标记已调度
func (s *GangScheduler) markScheduled(pg *PodGroup) {
s.mu.Lock()
defer s.mu.Unlock()
pg.Status.Phase = PodGroupScheduled
pg.Status.Scheduled = int32(len(pg.Members))
pg.Status.Conditions = append(pg.Status.Conditions, PodGroupCondition{
Type: "Scheduled",
Status: corev1.ConditionTrue,
LastTransitionTime: metav1.Now(),
Reason: "AllMembersScheduled",
Message: "All pod group members have been scheduled",
})
delete(s.pendingGangs, pg.Namespace+"/"+pg.Name)
}
// Helper functions
func (s *GangScheduler) canFitOnNode(member *PodMember, node *NodeInfo) bool {
// 检查 GPU
if gpu, ok := member.Resources.Limits["nvidia.com/gpu"]; ok {
if node.AvailableGPU < int32(gpu.Value()) {
return false
}
}
// 检查 CPU 和内存...
return true
}
func (s *GangScheduler) calculateDomainCapacity(nodes []*NodeInfo) *ResourceCapacity {
capacity := &ResourceCapacity{}
for _, node := range nodes {
capacity.GPU += node.AvailableGPU
}
return capacity
}
func (s *GangScheduler) bindPod(ctx context.Context, binding *corev1.Binding) error {
// 实际调用 Kubernetes API
return nil
}
3.2 Volcano 集成
# volcano-job.yaml
apiVersion: batch.volcano.sh/v1alpha1
kind: Job
metadata:
name: distributed-training
namespace: ml-workloads
spec:
schedulerName: volcano
minAvailable: 4 # Gang 调度:至少 4 个 Pod 同时调度
queue: training-queue
policies:
- event: PodEvicted
action: RestartJob
- event: PodFailed
action: RestartJob
plugins:
ssh: []
svc: []
env: []
maxRetry: 3
ttlSecondsAfterFinished: 3600
tasks:
- name: master
replicas: 1
template:
spec:
containers:
- name: pytorch
image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
command:
- /bin/bash
- -c
- |
python -m torch.distributed.run \
--nproc_per_node=8 \
--nnodes=$VC_WORKER_NUM \
--node_rank=$VC_TASK_INDEX \
--master_addr=$VC_MASTER_ADDR \
--master_port=29500 \
train.py
resources:
limits:
nvidia.com/gpu: 8
requests:
nvidia.com/gpu: 8
cpu: "32"
memory: 256Gi
volumeMounts:
- name: shm
mountPath: /dev/shm
volumes:
- name: shm
emptyDir:
medium: Memory
sizeLimit: 64Gi
restartPolicy: OnFailure
- name: worker
replicas: 3
template:
spec:
containers:
- name: pytorch
image: pytorch/pytorch:2.0-cuda12.1-cudnn8-runtime
command:
- /bin/bash
- -c
- |
python -m torch.distributed.run \
--nproc_per_node=8 \
--nnodes=$VC_WORKER_NUM \
--node_rank=$VC_TASK_INDEX \
--master_addr=$VC_MASTER_ADDR \
--master_port=29500 \
train.py
resources:
limits:
nvidia.com/gpu: 8
requests:
nvidia.com/gpu: 8
cpu: "32"
memory: 256Gi
volumeMounts:
- name: shm
mountPath: /dev/shm
volumes:
- name: shm
emptyDir:
medium: Memory
sizeLimit: 64Gi
restartPolicy: OnFailure
---
apiVersion: scheduling.volcano.sh/v1beta1
kind: Queue
metadata:
name: training-queue
spec:
weight: 10
capability:
cpu: "512"
memory: "2Ti"
nvidia.com/gpu: "64"
reclaimable: true
guarantee:
resource:
cpu: "128"
memory: "512Gi"
nvidia.com/gpu: "16"
4. 公平调度算法
4.1 Dominant Resource Fairness (DRF)
// pkg/scheduler/fairshare/drf.go
package fairshare
import (
"sort"
"sync"
)
// DRFScheduler Dominant Resource Fairness 调度器
type DRFScheduler struct {
mu sync.RWMutex
// 队列
queues map[string]*DRFQueue
// 总资源
totalResources Resources
// 已分配资源
allocatedResources Resources
}
// DRFQueue DRF 队列
type DRFQueue struct {
Name string
Weight float64
// 已分配资源
Allocated Resources
// 主导份额(最大资源占用比)
DominantShare float64
// 等待的作业
PendingJobs []*TrainingJob
}
// Resources 资源
type Resources struct {
GPU float64
CPU float64
Memory float64
}
// NewDRFScheduler 创建 DRF 调度器
func NewDRFScheduler(totalResources Resources) *DRFScheduler {
return &DRFScheduler{
queues: make(map[string]*DRFQueue),
totalResources: totalResources,
}
}
// AddQueue 添加队列
func (s *DRFScheduler) AddQueue(name string, weight float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.queues[name] = &DRFQueue{
Name: name,
Weight: weight,
PendingJobs: make([]*TrainingJob, 0),
}
}
// SubmitJob 提交作业
func (s *DRFScheduler) SubmitJob(queueName string, job *TrainingJob) {
s.mu.Lock()
defer s.mu.Unlock()
queue, ok := s.queues[queueName]
if !ok {
return
}
queue.PendingJobs = append(queue.PendingJobs, job)
}
// Schedule 调度一个作业
func (s *DRFScheduler) Schedule() *ScheduleResult {
s.mu.Lock()
defer s.mu.Unlock()
// 计算每个队列的主导份额
for _, queue := range s.queues {
queue.DominantShare = s.calculateDominantShare(queue)
}
// 选择主导份额最小的队列(加权后)
var selectedQueue *DRFQueue
minWeightedShare := float64(1e9)
for _, queue := range s.queues {
if len(queue.PendingJobs) == 0 {
continue
}
weightedShare := queue.DominantShare / queue.Weight
if weightedShare < minWeightedShare {
minWeightedShare = weightedShare
selectedQueue = queue
}
}
if selectedQueue == nil {
return nil
}
// 从选中的队列取出作业
job := selectedQueue.PendingJobs[0]
// 检查资源是否足够
jobResources := s.getJobResources(job)
if !s.hasEnoughResources(jobResources) {
return nil
}
// 分配资源
selectedQueue.PendingJobs = selectedQueue.PendingJobs[1:]
selectedQueue.Allocated = s.addResources(selectedQueue.Allocated, jobResources)
s.allocatedResources = s.addResources(s.allocatedResources, jobResources)
// 更新主导份额
selectedQueue.DominantShare = s.calculateDominantShare(selectedQueue)
return &ScheduleResult{
Job: job,
Queue: selectedQueue.Name,
Resources: jobResources,
}
}
// calculateDominantShare 计算主导份额
func (s *DRFScheduler) calculateDominantShare(queue *DRFQueue) float64 {
if s.totalResources.GPU == 0 && s.totalResources.CPU == 0 &&
s.totalResources.Memory == 0 {
return 0
}
// 计算每种资源的份额
gpuShare := float64(0)
if s.totalResources.GPU > 0 {
gpuShare = queue.Allocated.GPU / s.totalResources.GPU
}
cpuShare := float64(0)
if s.totalResources.CPU > 0 {
cpuShare = queue.Allocated.CPU / s.totalResources.CPU
}
memShare := float64(0)
if s.totalResources.Memory > 0 {
memShare = queue.Allocated.Memory / s.totalResources.Memory
}
// 返回最大份额(主导资源)
return max(gpuShare, max(cpuShare, memShare))
}
// hasEnoughResources 检查是否有足够资源
func (s *DRFScheduler) hasEnoughResources(required Resources) bool {
available := s.subtractResources(s.totalResources, s.allocatedResources)
return available.GPU >= required.GPU &&
available.CPU >= required.CPU &&
available.Memory >= required.Memory
}
// getJobResources 获取作业资源需求
func (s *DRFScheduler) getJobResources(job *TrainingJob) Resources {
replicas := float64(job.Resources.MinReplicas)
return Resources{
GPU: float64(job.Resources.ReplicaResources.GPU) * replicas,
CPU: job.Resources.ReplicaResources.CPU.AsApproximateFloat64() * replicas,
Memory: job.Resources.ReplicaResources.Memory.AsApproximateFloat64() * replicas,
}
}
// addResources 资源相加
func (s *DRFScheduler) addResources(a, b Resources) Resources {
return Resources{
GPU: a.GPU + b.GPU,
CPU: a.CPU + b.CPU,
Memory: a.Memory + b.Memory,
}
}
// subtractResources 资源相减
func (s *DRFScheduler) subtractResources(a, b Resources) Resources {
return Resources{
GPU: a.GPU - b.GPU,
CPU: a.CPU - b.CPU,
Memory: a.Memory - b.Memory,
}
}
// ScheduleResult 调度结果
type ScheduleResult struct {
Job *TrainingJob
Queue string
Resources Resources
}
func max(a, b float64) float64 {
if a > b {
return a
}
return b
}
4.2 层级公平调度
// pkg/scheduler/fairshare/hierarchical.go
package fairshare
import (
"container/heap"
)
// HierarchicalFairScheduler 层级公平调度器
type HierarchicalFairScheduler struct {
// 根队列
root *HierarchicalQueue
// 所有队列映射
allQueues map[string]*HierarchicalQueue
// 总资源
totalResources Resources
}
// HierarchicalQueue 层级队列
type HierarchicalQueue struct {
Name string
Parent *HierarchicalQueue
Children []*HierarchicalQueue
// 权重
Weight float64
// 最小保证
MinShare Resources
// 最大限制
MaxShare Resources
// 已分配
Allocated Resources
// 公平份额
FairShare Resources
// 作业
Jobs *JobHeap
}
// JobHeap 作业优先级堆
type JobHeap []*TrainingJob
func (h JobHeap) Len() int { return len(h) }
func (h JobHeap) Less(i, j int) bool {
return h[i].SchedulingConfig.Priority > h[j].SchedulingConfig.Priority
}
func (h JobHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *JobHeap) Push(x interface{}) { *h = append(*h, x.(*TrainingJob)) }
func (h *JobHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// NewHierarchicalFairScheduler 创建层级公平调度器
func NewHierarchicalFairScheduler(totalResources Resources) *HierarchicalFairScheduler {
root := &HierarchicalQueue{
Name: "root",
Weight: 1.0,
MaxShare: totalResources,
Jobs: &JobHeap{},
}
heap.Init(root.Jobs)
return &HierarchicalFairScheduler{
root: root,
allQueues: map[string]*HierarchicalQueue{"root": root},
totalResources: totalResources,
}
}
// AddQueue 添加队列
func (s *HierarchicalFairScheduler) AddQueue(name, parentName string,
weight float64, minShare, maxShare Resources) error {
parent, ok := s.allQueues[parentName]
if !ok {
return fmt.Errorf("parent queue %s not found", parentName)
}
queue := &HierarchicalQueue{
Name: name,
Parent: parent,
Weight: weight,
MinShare: minShare,
MaxShare: maxShare,
Jobs: &JobHeap{},
}
heap.Init(queue.Jobs)
parent.Children = append(parent.Children, queue)
s.allQueues[name] = queue
return nil
}
// CalculateFairShares 计算公平份额
func (s *HierarchicalFairScheduler) CalculateFairShares() {
// 自顶向下计算公平份额
s.calculateFairShareRecursive(s.root, s.totalResources)
}
// calculateFairShareRecursive 递归计算公平份额
func (s *HierarchicalFairScheduler) calculateFairShareRecursive(
queue *HierarchicalQueue, available Resources) {
if len(queue.Children) == 0 {
// 叶子队列:公平份额 = min(可用, 最大限制)
queue.FairShare = s.minResources(available, queue.MaxShare)
return
}
// 计算子队列的需求
demands := make([]Resources, len(queue.Children))
totalWeight := float64(0)
for i, child := range queue.Children {
demands[i] = s.calculateDemand(child)
totalWeight += child.Weight
}
// 分配资源给子队列
remaining := available
for i, child := range queue.Children {
// 首先分配最小保证
minAlloc := s.minResources(child.MinShare, remaining)
remaining = s.subtractResources(remaining, minAlloc)
}
// 按权重分配剩余资源
for i, child := range queue.Children {
weightRatio := child.Weight / totalWeight
share := s.multiplyResources(remaining, weightRatio)
// 加上最小保证
share = s.addResources(share, child.MinShare)
// 不超过需求
share = s.minResources(share, demands[i])
// 不超过最大限制
share = s.minResources(share, child.MaxShare)
// 递归计算子队列
s.calculateFairShareRecursive(child, share)
}
}
// calculateDemand 计算队列需求
func (s *HierarchicalFairScheduler) calculateDemand(queue *HierarchicalQueue) Resources {
demand := Resources{}
// 已运行的作业
demand = s.addResources(demand, queue.Allocated)
// 等待的作业
for _, job := range *queue.Jobs {
jobRes := s.getJobResources(job)
demand = s.addResources(demand, jobRes)
}
// 子队列需求
for _, child := range queue.Children {
childDemand := s.calculateDemand(child)
demand = s.addResources(demand, childDemand)
}
return demand
}
// Schedule 调度作业
func (s *HierarchicalFairScheduler) Schedule() *ScheduleResult {
// 先计算公平份额
s.CalculateFairShares()
// 选择最需要资源的队列
queue := s.selectQueue(s.root)
if queue == nil || queue.Jobs.Len() == 0 {
return nil
}
// 取出最高优先级作业
job := heap.Pop(queue.Jobs).(*TrainingJob)
jobRes := s.getJobResources(job)
// 检查资源
if !s.canAllocate(queue, jobRes) {
heap.Push(queue.Jobs, job)
return nil
}
// 分配资源
s.allocate(queue, jobRes)
return &ScheduleResult{
Job: job,
Queue: queue.Name,
Resources: jobRes,
}
}
// selectQueue 选择队列
func (s *HierarchicalFairScheduler) selectQueue(queue *HierarchicalQueue) *HierarchicalQueue {
if len(queue.Children) == 0 {
return queue
}
// 选择最低于公平份额的子队列
var selected *HierarchicalQueue
maxDeficit := float64(-1e9)
for _, child := range queue.Children {
deficit := s.calculateDeficit(child)
if deficit > maxDeficit {
maxDeficit = deficit
selected = child
}
}
if selected == nil {
return nil
}
return s.selectQueue(selected)
}
// calculateDeficit 计算资源缺口
func (s *HierarchicalFairScheduler) calculateDeficit(queue *HierarchicalQueue) float64 {
// 缺口 = 公平份额 - 已分配
gpuDeficit := queue.FairShare.GPU - queue.Allocated.GPU
cpuDeficit := queue.FairShare.CPU - queue.Allocated.CPU
memDeficit := queue.FairShare.Memory - queue.Allocated.Memory
// 返回主导资源的缺口
return max(gpuDeficit, max(cpuDeficit, memDeficit))
}
// canAllocate 检查是否可分配
func (s *HierarchicalFairScheduler) canAllocate(queue *HierarchicalQueue, res Resources) bool {
// 检查队列限制
newAlloc := s.addResources(queue.Allocated, res)
if newAlloc.GPU > queue.MaxShare.GPU ||
newAlloc.CPU > queue.MaxShare.CPU ||
newAlloc.Memory > queue.MaxShare.Memory {
return false
}
// 检查父队列
if queue.Parent != nil {
return s.canAllocate(queue.Parent, res)
}
return true
}
// allocate 分配资源
func (s *HierarchicalFairScheduler) allocate(queue *HierarchicalQueue, res Resources) {
queue.Allocated = s.addResources(queue.Allocated, res)
if queue.Parent != nil {
s.allocate(queue.Parent, res)
}
}
// 辅助函数
func (s *HierarchicalFairScheduler) minResources(a, b Resources) Resources {
return Resources{
GPU: min(a.GPU, b.GPU),
CPU: min(a.CPU, b.CPU),
Memory: min(a.Memory, b.Memory),
}
}
func (s *HierarchicalFairScheduler) multiplyResources(r Resources, factor float64) Resources {
return Resources{
GPU: r.GPU * factor,
CPU: r.CPU * factor,
Memory: r.Memory * factor,
}
}
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
5. 抢占与恢复
5.1 抢占调度器
// pkg/scheduler/preemption/scheduler.go
package preemption
import (
"context"
"sort"
"time"
)
// PreemptionScheduler 抢占调度器
type PreemptionScheduler struct {
// 运行中的作业
runningJobs map[string]*RunningJob
// 抢占策略
strategy PreemptionStrategy
// 检查点管理器
checkpointManager CheckpointManager
}
// RunningJob 运行中的作业
type RunningJob struct {
Job *TrainingJob
StartTime time.Time
Resources Resources
Nodes []string
Preemptible bool
Priority int32
LastCheckpoint time.Time
}
// PreemptionStrategy 抢占策略
type PreemptionStrategy interface {
// SelectVictims 选择被抢占的作业
SelectVictims(running []*RunningJob, required Resources) []*RunningJob
// ShouldPreempt 是否应该抢占
ShouldPreempt(preemptor *TrainingJob, victim *RunningJob) bool
}
// PriorityBasedStrategy 基于优先级的抢占策略
type PriorityBasedStrategy struct {
// 最小优先级差
MinPriorityDiff int32
// 最小运行时间保护
MinRuntime time.Duration
// 最大被抢占作业数
MaxVictims int
}
// SelectVictims 选择被抢占者
func (s *PriorityBasedStrategy) SelectVictims(running []*RunningJob,
required Resources) []*RunningJob {
// 过滤可抢占的作业
candidates := make([]*RunningJob, 0)
for _, job := range running {
if job.Preemptible && time.Since(job.StartTime) > s.MinRuntime {
candidates = append(candidates, job)
}
}
// 按优先级升序排序(低优先级先被抢占)
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].Priority < candidates[j].Priority
})
// 选择足够的作业以释放所需资源
victims := make([]*RunningJob, 0)
freed := Resources{}
for _, job := range candidates {
if len(victims) >= s.MaxVictims {
break
}
victims = append(victims, job)
freed = addResources(freed, job.Resources)
// 检查是否已经够了
if freed.GPU >= required.GPU &&
freed.CPU >= required.CPU &&
freed.Memory >= required.Memory {
break
}
}
// 检查释放的资源是否足够
if freed.GPU < required.GPU || freed.CPU < required.CPU ||
freed.Memory < required.Memory {
return nil
}
return victims
}
// ShouldPreempt 判断是否应该抢占
func (s *PriorityBasedStrategy) ShouldPreempt(preemptor *TrainingJob,
victim *RunningJob) bool {
// 优先级差必须足够大
if preemptor.SchedulingConfig.Priority - victim.Priority < s.MinPriorityDiff {
return false
}
// 被抢占者必须已运行足够长时间
if time.Since(victim.StartTime) < s.MinRuntime {
return false
}
return true
}
// ExecutePreemption 执行抢占
func (s *PreemptionScheduler) ExecutePreemption(ctx context.Context,
preemptor *TrainingJob, victims []*RunningJob) error {
for _, victim := range victims {
// 1. 发送抢占通知(给作业机会保存检查点)
if err := s.notifyPreemption(ctx, victim); err != nil {
return err
}
// 2. 等待检查点保存
if err := s.waitForCheckpoint(ctx, victim, 30*time.Second); err != nil {
// 超时,强制抢占
}
// 3. 终止作业
if err := s.terminateJob(ctx, victim); err != nil {
return err
}
// 4. 记录抢占事件
s.recordPreemptionEvent(victim, preemptor)
}
return nil
}
// notifyPreemption 通知即将被抢占
func (s *PreemptionScheduler) notifyPreemption(ctx context.Context,
victim *RunningJob) error {
// 发送 SIGTERM 信号
// 或通过 annotation/label 通知
return nil
}
// waitForCheckpoint 等待检查点保存
func (s *PreemptionScheduler) waitForCheckpoint(ctx context.Context,
victim *RunningJob, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
// 检查是否有新检查点
checkpoint, err := s.checkpointManager.GetLatest(victim.Job)
if err == nil && checkpoint.Timestamp.After(victim.LastCheckpoint) {
victim.LastCheckpoint = checkpoint.Timestamp
return nil
}
time.Sleep(time.Second)
}
return fmt.Errorf("checkpoint timeout")
}
// terminateJob 终止作业
func (s *PreemptionScheduler) terminateJob(ctx context.Context,
victim *RunningJob) error {
// 删除所有 Pod
// 更新作业状态
return nil
}
// RecoverPreemptedJob 恢复被抢占的作业
func (s *PreemptionScheduler) RecoverPreemptedJob(ctx context.Context,
job *TrainingJob) error {
// 获取最新检查点
checkpoint, err := s.checkpointManager.GetLatest(job)
if err != nil {
// 没有检查点,从头开始
return s.restartJob(ctx, job)
}
// 从检查点恢复
return s.resumeFromCheckpoint(ctx, job, checkpoint)
}
// restartJob 重新启动作业
func (s *PreemptionScheduler) restartJob(ctx context.Context,
job *TrainingJob) error {
// 重新创建 Pod
return nil
}
// resumeFromCheckpoint 从检查点恢复
func (s *PreemptionScheduler) resumeFromCheckpoint(ctx context.Context,
job *TrainingJob, checkpoint *Checkpoint) error {
// 创建 Pod 并配置从检查点恢复
return nil
}
5.2 检查点管理
// pkg/scheduler/checkpoint/manager.go
package checkpoint
import (
"context"
"fmt"
"os"
"path/filepath"
"sort"
"time"
)
// CheckpointManager 检查点管理器
type CheckpointManager struct {
// 存储后端
storage CheckpointStorage
// 配置
config CheckpointConfig
}
// CheckpointStorage 检查点存储接口
type CheckpointStorage interface {
// Save 保存检查点
Save(ctx context.Context, checkpoint *Checkpoint, data []byte) error
// Load 加载检查点
Load(ctx context.Context, checkpoint *Checkpoint) ([]byte, error)
// List 列出检查点
List(ctx context.Context, jobKey string) ([]*Checkpoint, error)
// Delete 删除检查点
Delete(ctx context.Context, checkpoint *Checkpoint) error
}
// CheckpointConfig 配置
type CheckpointConfig struct {
// 检查点保存路径
BasePath string
// 保留的检查点数量
MaxCheckpoints int
// 检查点间隔
Interval time.Duration
}
// Checkpoint 检查点
type Checkpoint struct {
// 唯一标识
ID string `json:"id"`
// 作业标识
JobKey string `json:"jobKey"`
// 时间戳
Timestamp time.Time `json:"timestamp"`
// 训练步数
Step int64 `json:"step"`
// Epoch
Epoch int32 `json:"epoch"`
// 文件路径
Path string `json:"path"`
// 大小(字节)
Size int64 `json:"size"`
// 元数据
Metadata map[string]string `json:"metadata"`
}
// NewCheckpointManager 创建检查点管理器
func NewCheckpointManager(storage CheckpointStorage, config CheckpointConfig) *CheckpointManager {
return &CheckpointManager{
storage: storage,
config: config,
}
}
// SaveCheckpoint 保存检查点
func (m *CheckpointManager) SaveCheckpoint(ctx context.Context,
job *TrainingJob, step int64, epoch int32, data []byte) (*Checkpoint, error) {
jobKey := job.Namespace + "/" + job.Name
checkpointID := fmt.Sprintf("%s-%d", jobKey, time.Now().UnixNano())
checkpoint := &Checkpoint{
ID: checkpointID,
JobKey: jobKey,
Timestamp: time.Now(),
Step: step,
Epoch: epoch,
Path: filepath.Join(m.config.BasePath, jobKey, checkpointID),
Size: int64(len(data)),
Metadata: make(map[string]string),
}
// 保存到存储
if err := m.storage.Save(ctx, checkpoint, data); err != nil {
return nil, err
}
// 清理旧检查点
if err := m.cleanupOldCheckpoints(ctx, jobKey); err != nil {
// 日志记录,但不影响保存
}
return checkpoint, nil
}
// GetLatest 获取最新检查点
func (m *CheckpointManager) GetLatest(job *TrainingJob) (*Checkpoint, error) {
jobKey := job.Namespace + "/" + job.Name
checkpoints, err := m.storage.List(context.Background(), jobKey)
if err != nil {
return nil, err
}
if len(checkpoints) == 0 {
return nil, fmt.Errorf("no checkpoint found for job %s", jobKey)
}
// 按时间排序
sort.Slice(checkpoints, func(i, j int) bool {
return checkpoints[i].Timestamp.After(checkpoints[j].Timestamp)
})
return checkpoints[0], nil
}
// LoadCheckpoint 加载检查点
func (m *CheckpointManager) LoadCheckpoint(ctx context.Context,
checkpoint *Checkpoint) ([]byte, error) {
return m.storage.Load(ctx, checkpoint)
}
// cleanupOldCheckpoints 清理旧检查点
func (m *CheckpointManager) cleanupOldCheckpoints(ctx context.Context,
jobKey string) error {
checkpoints, err := m.storage.List(ctx, jobKey)
if err != nil {
return err
}
if len(checkpoints) <= m.config.MaxCheckpoints {
return nil
}
// 按时间排序
sort.Slice(checkpoints, func(i, j int) bool {
return checkpoints[i].Timestamp.After(checkpoints[j].Timestamp)
})
// 删除多余的检查点
for i := m.config.MaxCheckpoints; i < len(checkpoints); i++ {
if err := m.storage.Delete(ctx, checkpoints[i]); err != nil {
return err
}
}
return nil
}
// S3CheckpointStorage S3 存储实现
type S3CheckpointStorage struct {
bucket string
client *s3.Client
}
// Save 保存到 S3
func (s *S3CheckpointStorage) Save(ctx context.Context,
checkpoint *Checkpoint, data []byte) error {
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &s.bucket,
Key: &checkpoint.Path,
Body: bytes.NewReader(data),
})
return err
}
// Load 从 S3 加载
func (s *S3CheckpointStorage) Load(ctx context.Context,
checkpoint *Checkpoint) ([]byte, error) {
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &checkpoint.Path,
})
if err != nil {
return nil, err
}
defer result.Body.Close()
return io.ReadAll(result.Body)
}
// List 列出检查点
func (s *S3CheckpointStorage) List(ctx context.Context,
jobKey string) ([]*Checkpoint, error) {
prefix := jobKey + "/"
result, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &s.bucket,
Prefix: &prefix,
})
if err != nil {
return nil, err
}
checkpoints := make([]*Checkpoint, 0, len(result.Contents))
for _, obj := range result.Contents {
checkpoint := &Checkpoint{
ID: *obj.Key,
JobKey: jobKey,
Path: *obj.Key,
Size: *obj.Size,
Timestamp: *obj.LastModified,
}
checkpoints = append(checkpoints, checkpoint)
}
return checkpoints, nil
}
// Delete 删除检查点
func (s *S3CheckpointStorage) Delete(ctx context.Context,
checkpoint *Checkpoint) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &s.bucket,
Key: &checkpoint.Path,
})
return err
}
6. 监控与可观测性
6.1 调度指标
// pkg/scheduler/metrics/metrics.go
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// 调度延迟
SchedulingLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "training_job_scheduling_duration_seconds",
Help: "Time taken to schedule a training job",
Buckets: prometheus.ExponentialBuckets(0.1, 2, 15),
},
[]string{"queue", "job_type", "result"},
)
// 队列长度
QueueLength = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "training_job_queue_length",
Help: "Number of jobs waiting in queue",
},
[]string{"queue", "priority"},
)
// 资源利用率
ResourceUtilization = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "training_cluster_resource_utilization",
Help: "Resource utilization ratio",
},
[]string{"resource_type", "queue"},
)
// 抢占计数
PreemptionCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "training_job_preemptions_total",
Help: "Total number of job preemptions",
},
[]string{"preemptor_queue", "victim_queue", "reason"},
)
// 作业完成时间
JobCompletionTime = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "training_job_completion_seconds",
Help: "Time taken to complete a training job",
Buckets: prometheus.ExponentialBuckets(60, 2, 15), // 1分钟到约1天
},
[]string{"queue", "job_type", "status"},
)
// 公平份额偏差
FairShareDeviation = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "training_queue_fairshare_deviation",
Help: "Deviation from fair share",
},
[]string{"queue"},
)
// Gang 调度成功率
GangSchedulingSuccess = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "training_gang_scheduling_total",
Help: "Total gang scheduling attempts",
},
[]string{"result", "timeout"},
)
// 检查点保存
CheckpointSaves = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "training_checkpoint_saves_total",
Help: "Total checkpoint saves",
},
[]string{"job", "reason", "status"},
)
)
// RecordSchedulingLatency 记录调度延迟
func RecordSchedulingLatency(queue, jobType, result string, duration float64) {
SchedulingLatency.WithLabelValues(queue, jobType, result).Observe(duration)
}
// UpdateQueueLength 更新队列长度
func UpdateQueueLength(queue string, priority int32, length int) {
QueueLength.WithLabelValues(queue, fmt.Sprintf("%d", priority)).Set(float64(length))
}
// UpdateResourceUtilization 更新资源利用率
func UpdateResourceUtilization(resourceType, queue string, ratio float64) {
ResourceUtilization.WithLabelValues(resourceType, queue).Set(ratio)
}
// RecordPreemption 记录抢占事件
func RecordPreemption(preemptorQueue, victimQueue, reason string) {
PreemptionCounter.WithLabelValues(preemptorQueue, victimQueue, reason).Inc()
}
6.2 Grafana Dashboard
{
"dashboard": {
"title": "Training Job Scheduler Dashboard",
"panels": [
{
"title": "Queue Length by Priority",
"type": "timeseries",
"targets": [
{
"expr": "sum(training_job_queue_length) by (queue, priority)",
"legendFormat": "{{queue}} - P{{priority}}"
}
]
},
{
"title": "Scheduling Latency (P99)",
"type": "gauge",
"targets": [
{
"expr": "histogram_quantile(0.99, sum(rate(training_job_scheduling_duration_seconds_bucket[5m])) by (le))"
}
],
"options": {
"maxValue": 60,
"thresholds": [
{"value": 0, "color": "green"},
{"value": 10, "color": "yellow"},
{"value": 30, "color": "red"}
]
}
},
{
"title": "Resource Utilization",
"type": "heatmap",
"targets": [
{
"expr": "training_cluster_resource_utilization",
"format": "heatmap"
}
]
},
{
"title": "Fair Share Deviation",
"type": "bargauge",
"targets": [
{
"expr": "training_queue_fairshare_deviation",
"legendFormat": "{{queue}}"
}
]
},
{
"title": "Preemption Events",
"type": "timeseries",
"targets": [
{
"expr": "sum(rate(training_job_preemptions_total[1h])) by (victim_queue)",
"legendFormat": "{{victim_queue}}"
}
]
},
{
"title": "Job Completion Time Distribution",
"type": "histogram",
"targets": [
{
"expr": "sum(rate(training_job_completion_seconds_bucket[1h])) by (le, status)",
"format": "heatmap"
}
]
}
]
}
}
总结
本章深入讲解了训练任务调度系统的核心技术:
- 队列管理:多级优先级队列、层级队列、Kueue 集成
- Gang 调度:确保分布式训练任务整组调度、Volcano 集成
- 公平调度:DRF 算法、层级公平共享
- 抢占机制:优先级抢占、检查点恢复
- 可观测性:调度指标、Grafana 监控
高效的训练任务调度是 AI 平台的核心能力,直接影响集群资源利用率和用户体验。
下一章我们将探讨 模型存储与管理,讲解如何高效管理训练过程中产生的模型文件和检查点。