HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

超参数优化

概述

超参数优化(Hyperparameter Optimization, HPO)是机器学习工作流中的关键环节。对于大模型训练,一次实验可能需要数天甚至数周时间,选择合适的超参数配置对于节省计算资源和提升模型效果至关重要。本章深入讲解超参数优化的算法原理、系统设计和实践方案。

超参数优化基础

问题形式化

┌─────────────────────────────────────────────────────────────────┐
│                    超参数优化问题定义                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  目标:找到最优超参数配置 x* 使得目标函数 f(x) 最小化           │
│                                                                  │
│         x* = argmin f(x)                                        │
│             x ∈ X                                               │
│                                                                  │
│  其中:                                                          │
│  • x:超参数配置向量                                             │
│  • X:超参数搜索空间                                             │
│  • f(x):目标函数(如验证损失)                                  │
│                                                                  │
│  挑战:                                                          │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ 1. 黑盒优化:f(x) 没有解析形式,只能通过训练评估          │   │
│  │ 2. 评估昂贵:每次评估需要完整训练,耗时数小时到数天        │   │
│  │ 3. 噪声:训练结果受随机性影响                              │   │
│  │ 4. 混合空间:包含连续、离散、条件参数                       │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  超参数类型:                                                    │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 连续:learning_rate ∈ [1e-6, 1e-2]                      │   │
│  │ • 离散:num_layers ∈ {6, 12, 24, 32}                      │   │
│  │ • 类别:optimizer ∈ {adam, sgd, adamw}                    │   │
│  │ • 条件:如果 optimizer=sgd,则 momentum ∈ [0, 1]          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

搜索空间定义

// search_space.go
package hpo

import (
    "fmt"
    "math"
    "math/rand"
)

// SearchSpace 搜索空间
type SearchSpace struct {
    Parameters map[string]Parameter
    Conditions []Condition
}

// Parameter 参数定义
type Parameter interface {
    Sample(rng *rand.Rand) interface{}
    Contains(value interface{}) bool
    Type() ParameterType
    Name() string
}

// ParameterType 参数类型
type ParameterType string

const (
    ParamTypeContinuous  ParameterType = "continuous"
    ParamTypeDiscrete    ParameterType = "discrete"
    ParamTypeCategorical ParameterType = "categorical"
)

// ContinuousParameter 连续参数
type ContinuousParameter struct {
    name     string
    low      float64
    high     float64
    logScale bool // 对数尺度
}

func NewContinuousParameter(name string, low, high float64, logScale bool) *ContinuousParameter {
    return &ContinuousParameter{
        name:     name,
        low:      low,
        high:     high,
        logScale: logScale,
    }
}

func (p *ContinuousParameter) Sample(rng *rand.Rand) interface{} {
    if p.logScale {
        // 对数尺度采样
        logLow := math.Log(p.low)
        logHigh := math.Log(p.high)
        return math.Exp(logLow + rng.Float64()*(logHigh-logLow))
    }
    return p.low + rng.Float64()*(p.high-p.low)
}

func (p *ContinuousParameter) Contains(value interface{}) bool {
    v, ok := value.(float64)
    if !ok {
        return false
    }
    return v >= p.low && v <= p.high
}

func (p *ContinuousParameter) Type() ParameterType { return ParamTypeContinuous }
func (p *ContinuousParameter) Name() string        { return p.name }
func (p *ContinuousParameter) Low() float64        { return p.low }
func (p *ContinuousParameter) High() float64       { return p.high }
func (p *ContinuousParameter) LogScale() bool      { return p.logScale }

// DiscreteParameter 离散参数
type DiscreteParameter struct {
    name   string
    low    int
    high   int
    step   int
}

func NewDiscreteParameter(name string, low, high, step int) *DiscreteParameter {
    return &DiscreteParameter{
        name: name,
        low:  low,
        high: high,
        step: step,
    }
}

func (p *DiscreteParameter) Sample(rng *rand.Rand) interface{} {
    numSteps := (p.high - p.low) / p.step
    return p.low + rng.Intn(numSteps+1)*p.step
}

func (p *DiscreteParameter) Contains(value interface{}) bool {
    v, ok := value.(int)
    if !ok {
        return false
    }
    return v >= p.low && v <= p.high && (v-p.low)%p.step == 0
}

func (p *DiscreteParameter) Type() ParameterType { return ParamTypeDiscrete }
func (p *DiscreteParameter) Name() string        { return p.name }

// CategoricalParameter 类别参数
type CategoricalParameter struct {
    name    string
    choices []interface{}
}

func NewCategoricalParameter(name string, choices []interface{}) *CategoricalParameter {
    return &CategoricalParameter{
        name:    name,
        choices: choices,
    }
}

func (p *CategoricalParameter) Sample(rng *rand.Rand) interface{} {
    return p.choices[rng.Intn(len(p.choices))]
}

func (p *CategoricalParameter) Contains(value interface{}) bool {
    for _, c := range p.choices {
        if c == value {
            return true
        }
    }
    return false
}

func (p *CategoricalParameter) Type() ParameterType  { return ParamTypeCategorical }
func (p *CategoricalParameter) Name() string         { return p.name }
func (p *CategoricalParameter) Choices() []interface{} { return p.choices }

// Condition 条件依赖
type Condition struct {
    Child    string      // 子参数
    Parent   string      // 父参数
    Values   []interface{} // 父参数值(满足时子参数有效)
}

// Config 配置
type Config map[string]interface{}

// SearchSpaceBuilder 搜索空间构建器
type SearchSpaceBuilder struct {
    space *SearchSpace
}

func NewSearchSpaceBuilder() *SearchSpaceBuilder {
    return &SearchSpaceBuilder{
        space: &SearchSpace{
            Parameters: make(map[string]Parameter),
            Conditions: []Condition{},
        },
    }
}

func (b *SearchSpaceBuilder) AddContinuous(name string, low, high float64, logScale bool) *SearchSpaceBuilder {
    b.space.Parameters[name] = NewContinuousParameter(name, low, high, logScale)
    return b
}

func (b *SearchSpaceBuilder) AddDiscrete(name string, low, high, step int) *SearchSpaceBuilder {
    b.space.Parameters[name] = NewDiscreteParameter(name, low, high, step)
    return b
}

func (b *SearchSpaceBuilder) AddCategorical(name string, choices []interface{}) *SearchSpaceBuilder {
    b.space.Parameters[name] = NewCategoricalParameter(name, choices)
    return b
}

func (b *SearchSpaceBuilder) AddCondition(child, parent string, values []interface{}) *SearchSpaceBuilder {
    b.space.Conditions = append(b.space.Conditions, Condition{
        Child:  child,
        Parent: parent,
        Values: values,
    })
    return b
}

func (b *SearchSpaceBuilder) Build() *SearchSpace {
    return b.space
}

// Sample 从搜索空间采样配置
func (s *SearchSpace) Sample(rng *rand.Rand) Config {
    config := make(Config)

    // 首先采样无条件参数
    for name, param := range s.Parameters {
        if !s.hasCondition(name) {
            config[name] = param.Sample(rng)
        }
    }

    // 然后采样有条件参数
    for _, cond := range s.Conditions {
        parentValue := config[cond.Parent]
        if s.conditionSatisfied(parentValue, cond.Values) {
            config[cond.Child] = s.Parameters[cond.Child].Sample(rng)
        }
    }

    return config
}

func (s *SearchSpace) hasCondition(paramName string) bool {
    for _, cond := range s.Conditions {
        if cond.Child == paramName {
            return true
        }
    }
    return false
}

func (s *SearchSpace) conditionSatisfied(value interface{}, allowedValues []interface{}) bool {
    for _, v := range allowedValues {
        if v == value {
            return true
        }
    }
    return false
}

// 创建 LLM 训练的搜索空间示例
func CreateLLMSearchSpace() *SearchSpace {
    return NewSearchSpaceBuilder().
        // 学习率(对数尺度)
        AddContinuous("learning_rate", 1e-6, 1e-3, true).
        // 批量大小
        AddDiscrete("batch_size", 8, 128, 8).
        // 热身步数
        AddDiscrete("warmup_steps", 100, 5000, 100).
        // 权重衰减
        AddContinuous("weight_decay", 0.0, 0.3, false).
        // 优化器
        AddCategorical("optimizer", []interface{}{"adam", "adamw", "sgd"}).
        // 学习率调度器
        AddCategorical("lr_scheduler", []interface{}{"linear", "cosine", "constant"}).
        // Dropout
        AddContinuous("dropout", 0.0, 0.5, false).
        // 梯度裁剪
        AddContinuous("gradient_clip", 0.1, 10.0, true).
        // SGD 专用:动量(条件参数)
        AddContinuous("momentum", 0.0, 0.99, false).
        AddCondition("momentum", "optimizer", []interface{}{"sgd"}).
        Build()
}

搜索算法实现

网格搜索与随机搜索

// grid_random_search.go
package hpo

import (
    "context"
    "math/rand"
    "sync"
)

// GridSearch 网格搜索
type GridSearch struct {
    space      *SearchSpace
    gridPoints map[string][]interface{}
}

func NewGridSearch(space *SearchSpace, gridPoints map[string][]interface{}) *GridSearch {
    return &GridSearch{
        space:      space,
        gridPoints: gridPoints,
    }
}

// GenerateConfigs 生成所有网格配置
func (g *GridSearch) GenerateConfigs() []Config {
    // 计算笛卡尔积
    paramNames := make([]string, 0, len(g.gridPoints))
    for name := range g.gridPoints {
        paramNames = append(paramNames, name)
    }

    var configs []Config
    g.cartesianProduct(paramNames, 0, make(Config), &configs)
    return configs
}

func (g *GridSearch) cartesianProduct(names []string, idx int, current Config, result *[]Config) {
    if idx == len(names) {
        configCopy := make(Config)
        for k, v := range current {
            configCopy[k] = v
        }
        *result = append(*result, configCopy)
        return
    }

    name := names[idx]
    for _, value := range g.gridPoints[name] {
        current[name] = value
        g.cartesianProduct(names, idx+1, current, result)
    }
}

// RandomSearch 随机搜索
type RandomSearch struct {
    space    *SearchSpace
    rng      *rand.Rand
    numTrials int
}

func NewRandomSearch(space *SearchSpace, numTrials int, seed int64) *RandomSearch {
    return &RandomSearch{
        space:     space,
        rng:       rand.New(rand.NewSource(seed)),
        numTrials: numTrials,
    }
}

// GenerateConfigs 生成随机配置
func (r *RandomSearch) GenerateConfigs() []Config {
    configs := make([]Config, r.numTrials)
    for i := 0; i < r.numTrials; i++ {
        configs[i] = r.space.Sample(r.rng)
    }
    return configs
}

// Suggest 建议下一个配置
func (r *RandomSearch) Suggest() Config {
    return r.space.Sample(r.rng)
}

贝叶斯优化

// bayesian_optimization.go
package hpo

import (
    "context"
    "math"
    "math/rand"
    "sort"
)

// BayesianOptimization 贝叶斯优化
type BayesianOptimization struct {
    space           *SearchSpace
    rng             *rand.Rand
    surrogate       SurrogateModel
    acquisition     AcquisitionFunction
    observations    []Observation
    numInitSamples  int
    explorationRate float64
}

// Observation 观测值
type Observation struct {
    Config Config
    Value  float64 // 目标函数值(越小越好)
}

// SurrogateModel 代理模型接口
type SurrogateModel interface {
    Fit(observations []Observation) error
    Predict(config Config) (mean, std float64)
}

// AcquisitionFunction 采集函数接口
type AcquisitionFunction interface {
    Evaluate(mean, std, bestValue float64) float64
}

// NewBayesianOptimization 创建贝叶斯优化器
func NewBayesianOptimization(
    space *SearchSpace,
    seed int64,
    numInitSamples int,
) *BayesianOptimization {
    return &BayesianOptimization{
        space:           space,
        rng:             rand.New(rand.NewSource(seed)),
        surrogate:       NewGaussianProcess(),
        acquisition:     NewExpectedImprovement(),
        observations:    []Observation{},
        numInitSamples:  numInitSamples,
        explorationRate: 0.1,
    }
}

// Suggest 建议下一个配置
func (bo *BayesianOptimization) Suggest() Config {
    // 如果观测数量不足,使用随机采样
    if len(bo.observations) < bo.numInitSamples {
        return bo.space.Sample(bo.rng)
    }

    // 探索 vs 利用
    if bo.rng.Float64() < bo.explorationRate {
        return bo.space.Sample(bo.rng)
    }

    // 拟合代理模型
    bo.surrogate.Fit(bo.observations)

    // 找到当前最优值
    bestValue := bo.getBestValue()

    // 使用采集函数优化找到下一个配置
    return bo.optimizeAcquisition(bestValue)
}

// Observe 记录观测
func (bo *BayesianOptimization) Observe(config Config, value float64) {
    bo.observations = append(bo.observations, Observation{
        Config: config,
        Value:  value,
    })
}

// GetBest 获取最优配置
func (bo *BayesianOptimization) GetBest() (Config, float64) {
    if len(bo.observations) == 0 {
        return nil, math.Inf(1)
    }

    best := bo.observations[0]
    for _, obs := range bo.observations[1:] {
        if obs.Value < best.Value {
            best = obs
        }
    }

    return best.Config, best.Value
}

func (bo *BayesianOptimization) getBestValue() float64 {
    best := math.Inf(1)
    for _, obs := range bo.observations {
        if obs.Value < best {
            best = obs.Value
        }
    }
    return best
}

func (bo *BayesianOptimization) optimizeAcquisition(bestValue float64) Config {
    // 使用随机搜索优化采集函数
    numCandidates := 1000
    bestConfig := bo.space.Sample(bo.rng)
    bestAcq := math.Inf(-1)

    for i := 0; i < numCandidates; i++ {
        config := bo.space.Sample(bo.rng)
        mean, std := bo.surrogate.Predict(config)
        acq := bo.acquisition.Evaluate(mean, std, bestValue)

        if acq > bestAcq {
            bestAcq = acq
            bestConfig = config
        }
    }

    return bestConfig
}

// GaussianProcess 高斯过程代理模型
type GaussianProcess struct {
    observations []Observation
    lengthScale  float64
    noiseVar     float64
    K            [][]float64 // 核矩阵
    Kinv         [][]float64 // 核矩阵逆
    alpha        []float64   // K^{-1} * y
}

func NewGaussianProcess() *GaussianProcess {
    return &GaussianProcess{
        lengthScale: 1.0,
        noiseVar:    1e-6,
    }
}

func (gp *GaussianProcess) Fit(observations []Observation) error {
    gp.observations = observations
    n := len(observations)

    // 构建核矩阵
    gp.K = make([][]float64, n)
    for i := 0; i < n; i++ {
        gp.K[i] = make([]float64, n)
        for j := 0; j < n; j++ {
            gp.K[i][j] = gp.kernel(observations[i].Config, observations[j].Config)
            if i == j {
                gp.K[i][j] += gp.noiseVar
            }
        }
    }

    // 计算逆矩阵和 alpha
    gp.Kinv = gp.invertMatrix(gp.K)

    y := make([]float64, n)
    for i, obs := range observations {
        y[i] = obs.Value
    }

    gp.alpha = gp.matVecMul(gp.Kinv, y)

    return nil
}

func (gp *GaussianProcess) Predict(config Config) (mean, std float64) {
    if len(gp.observations) == 0 {
        return 0, 1.0
    }

    n := len(gp.observations)

    // 计算 k_* (新点与所有观测点的核)
    kStar := make([]float64, n)
    for i, obs := range gp.observations {
        kStar[i] = gp.kernel(config, obs.Config)
    }

    // 预测均值: k_*^T * alpha
    for i := 0; i < n; i++ {
        mean += kStar[i] * gp.alpha[i]
    }

    // 预测方差: k_** - k_*^T * K^{-1} * k_*
    kStarStar := gp.kernel(config, config)
    v := gp.matVecMul(gp.Kinv, kStar)
    var variance float64 = kStarStar
    for i := 0; i < n; i++ {
        variance -= kStar[i] * v[i]
    }

    if variance < 0 {
        variance = 0
    }
    std = math.Sqrt(variance)

    return mean, std
}

// RBF 核函数
func (gp *GaussianProcess) kernel(x1, x2 Config) float64 {
    // 计算欧几里得距离(简化版,实际需要处理不同类型参数)
    var dist float64
    for key := range x1 {
        v1, ok1 := x1[key].(float64)
        v2, ok2 := x2[key].(float64)
        if ok1 && ok2 {
            diff := (v1 - v2) / gp.lengthScale
            dist += diff * diff
        }
    }
    return math.Exp(-0.5 * dist)
}

func (gp *GaussianProcess) invertMatrix(A [][]float64) [][]float64 {
    // 简化的矩阵求逆(实际应使用数值稳定的方法)
    n := len(A)
    inv := make([][]float64, n)
    for i := range inv {
        inv[i] = make([]float64, n)
        inv[i][i] = 1.0 / A[i][i] // 对角线近似
    }
    return inv
}

func (gp *GaussianProcess) matVecMul(A [][]float64, x []float64) []float64 {
    n := len(A)
    result := make([]float64, n)
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            result[i] += A[i][j] * x[j]
        }
    }
    return result
}

// ExpectedImprovement 期望改进采集函数
type ExpectedImprovement struct {
    xi float64 // 探索参数
}

func NewExpectedImprovement() *ExpectedImprovement {
    return &ExpectedImprovement{xi: 0.01}
}

func (ei *ExpectedImprovement) Evaluate(mean, std, bestValue float64) float64 {
    if std == 0 {
        return 0
    }

    z := (bestValue - mean - ei.xi) / std
    cdf := 0.5 * (1 + math.Erf(z/math.Sqrt2))
    pdf := math.Exp(-z*z/2) / math.Sqrt(2*math.Pi)

    return (bestValue - mean - ei.xi)*cdf + std*pdf
}

TPE (Tree-structured Parzen Estimator)

// tpe.go
package hpo

import (
    "math"
    "math/rand"
    "sort"
)

// TPE Tree-structured Parzen Estimator
type TPE struct {
    space        *SearchSpace
    rng          *rand.Rand
    observations []Observation
    gamma        float64 // 分位数
    numEICandidates int
    numInitSamples  int
}

func NewTPE(space *SearchSpace, seed int64) *TPE {
    return &TPE{
        space:           space,
        rng:             rand.New(rand.NewSource(seed)),
        observations:    []Observation{},
        gamma:           0.25, // 取最好的 25% 作为 l(x)
        numEICandidates: 24,
        numInitSamples:  10,
    }
}

// Suggest 建议下一个配置
func (t *TPE) Suggest() Config {
    if len(t.observations) < t.numInitSamples {
        return t.space.Sample(t.rng)
    }

    // 按目标值排序
    sorted := make([]Observation, len(t.observations))
    copy(sorted, t.observations)
    sort.Slice(sorted, func(i, j int) bool {
        return sorted[i].Value < sorted[j].Value
    })

    // 分割为好的和差的
    splitIdx := int(float64(len(sorted)) * t.gamma)
    if splitIdx < 1 {
        splitIdx = 1
    }

    good := sorted[:splitIdx]
    bad := sorted[splitIdx:]

    // 采样候选配置并计算 EI
    bestConfig := t.space.Sample(t.rng)
    bestEI := math.Inf(-1)

    for i := 0; i < t.numEICandidates; i++ {
        // 从 l(x) (好的分布) 采样
        config := t.sampleFromGood(good)

        // 计算 EI = l(x) / g(x)
        lx := t.estimateDensity(config, good)
        gx := t.estimateDensity(config, bad)

        var ei float64
        if gx > 0 {
            ei = lx / gx
        }

        if ei > bestEI {
            bestEI = ei
            bestConfig = config
        }
    }

    return bestConfig
}

// Observe 记录观测
func (t *TPE) Observe(config Config, value float64) {
    t.observations = append(t.observations, Observation{
        Config: config,
        Value:  value,
    })
}

// sampleFromGood 从好的观测中采样
func (t *TPE) sampleFromGood(good []Observation) Config {
    config := make(Config)

    for name, param := range t.space.Parameters {
        switch p := param.(type) {
        case *ContinuousParameter:
            // 使用 KDE 采样
            values := make([]float64, len(good))
            for i, obs := range good {
                values[i] = obs.Config[name].(float64)
            }
            config[name] = t.sampleKDE(values, p.Low(), p.High())

        case *CategoricalParameter:
            // 使用频率采样
            counts := make(map[interface{}]int)
            for _, obs := range good {
                counts[obs.Config[name]]++
            }
            config[name] = t.sampleCategorical(counts, p.Choices())

        case *DiscreteParameter:
            values := make([]int, len(good))
            for i, obs := range good {
                values[i] = obs.Config[name].(int)
            }
            config[name] = t.sampleDiscreteKDE(values, p.low, p.high)
        }
    }

    return config
}

// estimateDensity 估计密度
func (t *TPE) estimateDensity(config Config, observations []Observation) float64 {
    density := 1.0

    for name, param := range t.space.Parameters {
        switch p := param.(type) {
        case *ContinuousParameter:
            values := make([]float64, len(observations))
            for i, obs := range observations {
                values[i] = obs.Config[name].(float64)
            }
            density *= t.kdeProb(config[name].(float64), values, p.Low(), p.High())

        case *CategoricalParameter:
            counts := make(map[interface{}]int)
            for _, obs := range observations {
                counts[obs.Config[name]]++
            }
            density *= t.categoricalProb(config[name], counts, len(p.Choices()))
        }
    }

    return density
}

// KDE 概率估计
func (t *TPE) kdeProb(x float64, samples []float64, low, high float64) float64 {
    if len(samples) == 0 {
        return 1.0 / (high - low) // 均匀分布
    }

    // Scott's rule for bandwidth
    std := t.std(samples)
    h := 1.06 * std * math.Pow(float64(len(samples)), -0.2)
    if h == 0 {
        h = 0.1 * (high - low)
    }

    // 计算 KDE
    prob := 0.0
    for _, s := range samples {
        prob += math.Exp(-0.5*math.Pow((x-s)/h, 2)) / (h * math.Sqrt(2*math.Pi))
    }
    prob /= float64(len(samples))

    return prob
}

func (t *TPE) sampleKDE(samples []float64, low, high float64) float64 {
    if len(samples) == 0 {
        return low + t.rng.Float64()*(high-low)
    }

    // 选择一个样本
    base := samples[t.rng.Intn(len(samples))]

    // 添加噪声
    std := t.std(samples)
    h := 1.06 * std * math.Pow(float64(len(samples)), -0.2)
    if h == 0 {
        h = 0.1 * (high - low)
    }

    sample := base + t.rng.NormFloat64()*h

    // 裁剪到边界
    if sample < low {
        sample = low
    }
    if sample > high {
        sample = high
    }

    return sample
}

func (t *TPE) sampleDiscreteKDE(samples []int, low, high int) int {
    if len(samples) == 0 {
        return low + t.rng.Intn(high-low+1)
    }

    // 简化:均匀采样
    return samples[t.rng.Intn(len(samples))]
}

func (t *TPE) categoricalProb(value interface{}, counts map[interface{}]int, numChoices int) float64 {
    total := 0
    for _, c := range counts {
        total += c
    }

    if total == 0 {
        return 1.0 / float64(numChoices)
    }

    count := counts[value]
    // 添加拉普拉斯平滑
    return (float64(count) + 1) / (float64(total) + float64(numChoices))
}

func (t *TPE) sampleCategorical(counts map[interface{}]int, choices []interface{}) interface{} {
    if len(counts) == 0 {
        return choices[t.rng.Intn(len(choices))]
    }

    // 根据频率采样
    total := float64(len(choices)) // 拉普拉斯平滑
    for _, c := range counts {
        total += float64(c)
    }

    r := t.rng.Float64() * total
    cumSum := 0.0

    for _, choice := range choices {
        count := float64(counts[choice]) + 1 // +1 平滑
        cumSum += count
        if r < cumSum {
            return choice
        }
    }

    return choices[len(choices)-1]
}

func (t *TPE) std(values []float64) float64 {
    if len(values) < 2 {
        return 0
    }

    mean := 0.0
    for _, v := range values {
        mean += v
    }
    mean /= float64(len(values))

    sumSq := 0.0
    for _, v := range values {
        sumSq += (v - mean) * (v - mean)
    }

    return math.Sqrt(sumSq / float64(len(values)-1))
}

Hyperband 和 ASHA

// hyperband.go
package hpo

import (
    "context"
    "math"
    "sort"
)

// Hyperband Hyperband 调度器
type Hyperband struct {
    space           *SearchSpace
    sampler         Sampler
    maxResource     int     // 最大资源(如 epochs)
    reductionFactor float64 // 减少因子(通常为 3)
    minResource     int     // 最小资源
}

// Sampler 采样器接口
type Sampler interface {
    Suggest() Config
    Observe(config Config, value float64)
}

// Bracket 括号
type Bracket struct {
    Configs   []Config
    Resources []int
    Results   []float64
}

func NewHyperband(space *SearchSpace, sampler Sampler, maxResource int) *Hyperband {
    return &Hyperband{
        space:           space,
        sampler:         sampler,
        maxResource:     maxResource,
        reductionFactor: 3,
        minResource:     1,
    }
}

// Run 运行 Hyperband
func (h *Hyperband) Run(ctx context.Context, objective ObjectiveFunc) (Config, float64, error) {
    // 计算 s_max
    sMax := int(math.Log(float64(h.maxResource)/float64(h.minResource)) / math.Log(h.reductionFactor))

    var bestConfig Config
    bestValue := math.Inf(1)

    // 外循环:从 s_max 到 0
    for s := sMax; s >= 0; s-- {
        // 计算初始配置数量和资源
        n := int(math.Ceil(float64(sMax+1) / float64(s+1) * math.Pow(h.reductionFactor, float64(s))))
        r := int(float64(h.maxResource) * math.Pow(h.reductionFactor, float64(-s)))

        // 采样初始配置
        configs := make([]Config, n)
        for i := 0; i < n; i++ {
            configs[i] = h.sampler.Suggest()
        }

        // 逐步淘汰
        for i := 0; i <= s; i++ {
            // 当前资源
            resource := r * int(math.Pow(h.reductionFactor, float64(i)))
            nConfigs := int(float64(n) * math.Pow(h.reductionFactor, float64(-i)))

            // 评估所有配置
            type result struct {
                config Config
                value  float64
            }
            results := make([]result, len(configs))

            for j, config := range configs {
                select {
                case <-ctx.Done():
                    return bestConfig, bestValue, ctx.Err()
                default:
                }

                value, err := objective(ctx, config, resource)
                if err != nil {
                    results[j] = result{config, math.Inf(1)}
                    continue
                }

                results[j] = result{config, value}
                h.sampler.Observe(config, value)

                if value < bestValue {
                    bestValue = value
                    bestConfig = config
                }
            }

            // 选择最好的 n/eta 个
            sort.Slice(results, func(a, b int) bool {
                return results[a].value < results[b].value
            })

            nextN := max(1, int(float64(nConfigs)/h.reductionFactor))
            configs = make([]Config, nextN)
            for j := 0; j < nextN; j++ {
                configs[j] = results[j].config
            }
        }
    }

    return bestConfig, bestValue, nil
}

// ASHA Asynchronous Successive Halving Algorithm
type ASHA struct {
    space           *SearchSpace
    sampler         Sampler
    maxResource     int
    reductionFactor float64
    minResource     int
    brackets        map[int]*ASHABracket // rung -> bracket
    promotionRule   string               // "sync" or "async"
}

// ASHABracket ASHA 括号
type ASHABracket struct {
    Rung      int
    Configs   []ASHAConfig
    Promoted  map[string]bool
}

// ASHAConfig ASHA 配置
type ASHAConfig struct {
    ID     string
    Config Config
    Values map[int]float64 // resource -> value
}

func NewASHA(space *SearchSpace, sampler Sampler, maxResource int) *ASHA {
    return &ASHA{
        space:           space,
        sampler:         sampler,
        maxResource:     maxResource,
        reductionFactor: 4,
        minResource:     1,
        brackets:        make(map[int]*ASHABracket),
        promotionRule:   "async",
    }
}

// GetNextConfig 获取下一个要评估的配置和资源
func (a *ASHA) GetNextConfig() (Config, int, string) {
    // 检查是否有可晋升的配置
    for rung := 0; ; rung++ {
        resource := a.minResource * int(math.Pow(a.reductionFactor, float64(rung)))
        if resource > a.maxResource {
            break
        }

        bracket := a.brackets[rung]
        if bracket == nil {
            continue
        }

        // 检查是否有配置可以晋升
        promotable := a.findPromotable(bracket)
        if len(promotable) > 0 {
            // 返回第一个可晋升的配置
            config := promotable[0]
            nextRung := rung + 1
            nextResource := a.minResource * int(math.Pow(a.reductionFactor, float64(nextRung)))

            if nextResource <= a.maxResource {
                bracket.Promoted[config.ID] = true
                return config.Config, nextResource, config.ID
            }
        }
    }

    // 没有可晋升的,采样新配置
    newConfig := a.sampler.Suggest()
    configID := generateConfigID()

    return newConfig, a.minResource, configID
}

// Report 报告评估结果
func (a *ASHA) Report(configID string, config Config, resource int, value float64) {
    // 找到对应的 rung
    rung := int(math.Log(float64(resource)/float64(a.minResource)) / math.Log(a.reductionFactor))

    // 初始化 bracket
    if a.brackets[rung] == nil {
        a.brackets[rung] = &ASHABracket{
            Rung:     rung,
            Configs:  []ASHAConfig{},
            Promoted: make(map[string]bool),
        }
    }

    bracket := a.brackets[rung]

    // 查找或创建配置
    found := false
    for i := range bracket.Configs {
        if bracket.Configs[i].ID == configID {
            bracket.Configs[i].Values[resource] = value
            found = true
            break
        }
    }

    if !found {
        bracket.Configs = append(bracket.Configs, ASHAConfig{
            ID:     configID,
            Config: config,
            Values: map[int]float64{resource: value},
        })
    }

    // 通知采样器
    a.sampler.Observe(config, value)
}

// findPromotable 找到可晋升的配置
func (a *ASHA) findPromotable(bracket *ASHABracket) []ASHAConfig {
    resource := a.minResource * int(math.Pow(a.reductionFactor, float64(bracket.Rung)))

    // 收集已完成的配置
    completed := make([]ASHAConfig, 0)
    for _, config := range bracket.Configs {
        if _, ok := config.Values[resource]; ok && !bracket.Promoted[config.ID] {
            completed = append(completed, config)
        }
    }

    // 计算晋升数量
    numPromote := int(float64(len(completed)) / a.reductionFactor)
    if numPromote < 1 {
        return nil
    }

    // 按值排序
    sort.Slice(completed, func(i, j int) bool {
        return completed[i].Values[resource] < completed[j].Values[resource]
    })

    return completed[:numPromote]
}

// GetBest 获取最优配置
func (a *ASHA) GetBest() (Config, float64) {
    var bestConfig Config
    bestValue := math.Inf(1)

    for _, bracket := range a.brackets {
        for _, config := range bracket.Configs {
            for _, value := range config.Values {
                if value < bestValue {
                    bestValue = value
                    bestConfig = config.Config
                }
            }
        }
    }

    return bestConfig, bestValue
}

// ObjectiveFunc 目标函数
type ObjectiveFunc func(ctx context.Context, config Config, resource int) (float64, error)

func max(a, b int) int {
    if a > b {
        return a
    }
    return b
}

func generateConfigID() string {
    // 生成唯一 ID
    return ""
}

HPO 服务实现

HPO 服务架构

┌─────────────────────────────────────────────────────────────────┐
│                    HPO 服务架构                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                     HPO Controller                       │    │
│  │  • Study 管理                                            │    │
│  │  • Trial 调度                                            │    │
│  │  • 资源分配                                              │    │
│  └─────────────────────────────────────────────────────────┘    │
│                            │                                     │
│          ┌─────────────────┼─────────────────┐                  │
│          │                 │                 │                   │
│          ▼                 ▼                 ▼                   │
│  ┌───────────┐     ┌───────────┐     ┌───────────┐             │
│  │  Sampler  │     │ Scheduler │     │  Pruner   │             │
│  │  Service  │     │  Service  │     │  Service  │             │
│  │           │     │           │     │           │             │
│  │ • Random  │     │ • FIFO    │     │ • Median  │             │
│  │ • TPE     │     │ • Priority│     │ • SHA     │             │
│  │ • BO      │     │ • Fair    │     │ • ASHA    │             │
│  └───────────┘     └───────────┘     └───────────┘             │
│                            │                                     │
│                            ▼                                     │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    Trial Executor                        │    │
│  │  • Kubernetes Job                                        │    │
│  │  • 分布式训练支持                                         │    │
│  │  • 检查点恢复                                            │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

HPO 控制器

// hpo_controller.go
package hpo

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// Study 研究(HPO 任务)
type Study struct {
    ID          string        `json:"id"`
    Name        string        `json:"name"`
    Objective   string        `json:"objective"` // minimize, maximize
    SearchSpace *SearchSpace  `json:"search_space"`
    Algorithm   string        `json:"algorithm"` // random, tpe, bayesian, hyperband
    Status      StudyStatus   `json:"status"`

    // 配置
    MaxTrials    int           `json:"max_trials"`
    MaxDuration  time.Duration `json:"max_duration"`
    Parallelism  int           `json:"parallelism"`

    // Hyperband/ASHA 配置
    MaxResource      int     `json:"max_resource,omitempty"`
    ReductionFactor  float64 `json:"reduction_factor,omitempty"`

    // 运行时状态
    Trials       []*Trial      `json:"trials"`
    BestTrial    *Trial        `json:"best_trial"`
    StartTime    *time.Time    `json:"start_time"`
    EndTime      *time.Time    `json:"end_time"`

    CreatedAt    time.Time     `json:"created_at"`
}

type StudyStatus string

const (
    StudyStatusPending   StudyStatus = "pending"
    StudyStatusRunning   StudyStatus = "running"
    StudyStatusCompleted StudyStatus = "completed"
    StudyStatusFailed    StudyStatus = "failed"
    StudyStatusStopped   StudyStatus = "stopped"
)

// Trial 试验
type Trial struct {
    ID        string       `json:"id"`
    StudyID   string       `json:"study_id"`
    Number    int          `json:"number"`
    Config    Config       `json:"config"`
    Status    TrialStatus  `json:"status"`
    Value     *float64     `json:"value,omitempty"` // 最终指标值
    Resource  int          `json:"resource"`         // 使用的资源(如 epochs)

    // 中间结果
    IntermediateValues []IntermediateValue `json:"intermediate_values"`

    // 执行信息
    JobName   string     `json:"job_name"`
    StartTime *time.Time `json:"start_time"`
    EndTime   *time.Time `json:"end_time"`
    Error     string     `json:"error,omitempty"`
}

type TrialStatus string

const (
    TrialStatusPending   TrialStatus = "pending"
    TrialStatusRunning   TrialStatus = "running"
    TrialStatusCompleted TrialStatus = "completed"
    TrialStatusFailed    TrialStatus = "failed"
    TrialStatusPruned    TrialStatus = "pruned"
)

// IntermediateValue 中间结果
type IntermediateValue struct {
    Step  int       `json:"step"`
    Value float64   `json:"value"`
    Time  time.Time `json:"time"`
}

// HPOController HPO 控制器
type HPOController struct {
    // 存储
    studyRepo StudyRepository
    trialRepo TrialRepository

    // 执行器
    executor TrialExecutor

    // 算法实例
    samplers map[string]Sampler
    pruners  map[string]Pruner

    // 运行时状态
    activeStudies map[string]*studyRunner
    mu            sync.RWMutex
}

type studyRunner struct {
    study   *Study
    sampler Sampler
    pruner  Pruner
    cancel  context.CancelFunc
}

// Pruner 剪枝器接口
type Pruner interface {
    ShouldPrune(trial *Trial) bool
}

// TrialExecutor 试验执行器接口
type TrialExecutor interface {
    Execute(ctx context.Context, trial *Trial, study *Study) error
    Stop(trial *Trial) error
}

// NewHPOController 创建 HPO 控制器
func NewHPOController(
    studyRepo StudyRepository,
    trialRepo TrialRepository,
    executor TrialExecutor,
) *HPOController {
    return &HPOController{
        studyRepo:     studyRepo,
        trialRepo:     trialRepo,
        executor:      executor,
        samplers:      make(map[string]Sampler),
        pruners:       make(map[string]Pruner),
        activeStudies: make(map[string]*studyRunner),
    }
}

// CreateStudy 创建研究
func (c *HPOController) CreateStudy(ctx context.Context, study *Study) error {
    study.ID = generateStudyID()
    study.Status = StudyStatusPending
    study.Trials = []*Trial{}
    study.CreatedAt = time.Now()

    if err := c.studyRepo.Create(ctx, study); err != nil {
        return fmt.Errorf("create study: %w", err)
    }

    return nil
}

// StartStudy 启动研究
func (c *HPOController) StartStudy(ctx context.Context, studyID string) error {
    study, err := c.studyRepo.Get(ctx, studyID)
    if err != nil {
        return err
    }

    // 创建采样器
    sampler := c.createSampler(study)

    // 创建剪枝器
    pruner := c.createPruner(study)

    // 创建运行上下文
    runCtx, cancel := context.WithCancel(ctx)

    runner := &studyRunner{
        study:   study,
        sampler: sampler,
        pruner:  pruner,
        cancel:  cancel,
    }

    c.mu.Lock()
    c.activeStudies[studyID] = runner
    c.mu.Unlock()

    // 更新状态
    now := time.Now()
    study.Status = StudyStatusRunning
    study.StartTime = &now
    c.studyRepo.Update(ctx, study)

    // 启动调度循环
    go c.runStudy(runCtx, runner)

    return nil
}

// runStudy 运行研究
func (c *HPOController) runStudy(ctx context.Context, runner *studyRunner) {
    study := runner.study

    defer func() {
        c.mu.Lock()
        delete(c.activeStudies, study.ID)
        c.mu.Unlock()

        now := time.Now()
        study.EndTime = &now
        if study.Status == StudyStatusRunning {
            study.Status = StudyStatusCompleted
        }
        c.studyRepo.Update(context.Background(), study)
    }()

    // 控制并发
    sem := make(chan struct{}, study.Parallelism)
    var wg sync.WaitGroup

    trialNumber := 0

    for {
        select {
        case <-ctx.Done():
            wg.Wait()
            return
        default:
        }

        // 检查终止条件
        if trialNumber >= study.MaxTrials {
            break
        }

        if study.MaxDuration > 0 && time.Since(*study.StartTime) > study.MaxDuration {
            break
        }

        // 获取并发槽位
        sem <- struct{}{}
        trialNumber++

        wg.Add(1)
        go func(num int) {
            defer wg.Done()
            defer func() { <-sem }()

            c.runTrial(ctx, runner, num)
        }(trialNumber)
    }

    wg.Wait()
}

// runTrial 运行试验
func (c *HPOController) runTrial(ctx context.Context, runner *studyRunner, number int) {
    study := runner.study

    // 获取配置
    config := runner.sampler.Suggest()

    // 创建试验
    trial := &Trial{
        ID:       generateTrialID(),
        StudyID:  study.ID,
        Number:   number,
        Config:   config,
        Status:   TrialStatusPending,
        Resource: study.MaxResource,
    }

    if err := c.trialRepo.Create(ctx, trial); err != nil {
        return
    }

    now := time.Now()
    trial.Status = TrialStatusRunning
    trial.StartTime = &now
    c.trialRepo.Update(ctx, trial)

    // 执行试验
    err := c.executor.Execute(ctx, trial, study)

    endTime := time.Now()
    trial.EndTime = &endTime

    if err != nil {
        trial.Status = TrialStatusFailed
        trial.Error = err.Error()
    } else if trial.Status != TrialStatusPruned {
        trial.Status = TrialStatusCompleted
    }

    c.trialRepo.Update(ctx, trial)

    // 通知采样器
    if trial.Value != nil {
        value := *trial.Value
        if study.Objective == "maximize" {
            value = -value // 转换为最小化
        }
        runner.sampler.Observe(config, value)
    }

    // 更新最佳试验
    c.updateBestTrial(ctx, study, trial)
}

// ReportIntermediateValue 报告中间结果
func (c *HPOController) ReportIntermediateValue(
    ctx context.Context,
    trialID string,
    step int,
    value float64,
) (bool, error) { // 返回是否应该继续

    trial, err := c.trialRepo.Get(ctx, trialID)
    if err != nil {
        return false, err
    }

    // 记录中间结果
    trial.IntermediateValues = append(trial.IntermediateValues, IntermediateValue{
        Step:  step,
        Value: value,
        Time:  time.Now(),
    })
    c.trialRepo.Update(ctx, trial)

    // 检查是否应该剪枝
    c.mu.RLock()
    runner, ok := c.activeStudies[trial.StudyID]
    c.mu.RUnlock()

    if ok && runner.pruner != nil {
        if runner.pruner.ShouldPrune(trial) {
            trial.Status = TrialStatusPruned
            c.trialRepo.Update(ctx, trial)
            return false, nil
        }
    }

    return true, nil
}

// CompleteTrial 完成试验
func (c *HPOController) CompleteTrial(ctx context.Context, trialID string, value float64) error {
    trial, err := c.trialRepo.Get(ctx, trialID)
    if err != nil {
        return err
    }

    trial.Value = &value
    trial.Status = TrialStatusCompleted
    now := time.Now()
    trial.EndTime = &now

    return c.trialRepo.Update(ctx, trial)
}

// updateBestTrial 更新最佳试验
func (c *HPOController) updateBestTrial(ctx context.Context, study *Study, trial *Trial) {
    if trial.Value == nil {
        return
    }

    if study.BestTrial == nil || isBetter(*trial.Value, *study.BestTrial.Value, study.Objective) {
        study.BestTrial = trial
        c.studyRepo.Update(ctx, study)
    }
}

func isBetter(new, old float64, objective string) bool {
    if objective == "maximize" {
        return new > old
    }
    return new < old
}

// createSampler 创建采样器
func (c *HPOController) createSampler(study *Study) Sampler {
    switch study.Algorithm {
    case "random":
        return NewRandomSearch(study.SearchSpace, 0, time.Now().UnixNano())
    case "tpe":
        return NewTPE(study.SearchSpace, time.Now().UnixNano())
    case "bayesian":
        return NewBayesianOptimization(study.SearchSpace, time.Now().UnixNano(), 10)
    default:
        return NewRandomSearch(study.SearchSpace, 0, time.Now().UnixNano())
    }
}

// createPruner 创建剪枝器
func (c *HPOController) createPruner(study *Study) Pruner {
    return NewMedianPruner(5, 0.5)
}

func generateStudyID() string {
    return fmt.Sprintf("study_%d", time.Now().UnixNano())
}

func generateTrialID() string {
    return fmt.Sprintf("trial_%d", time.Now().UnixNano())
}

中位数剪枝器

// median_pruner.go
package hpo

import (
    "sort"
)

// MedianPruner 中位数剪枝器
type MedianPruner struct {
    minTrials   int     // 最少完成的试验数
    percentile  float64 // 剪枝分位数
    minSteps    int     // 最小步数后才开始剪枝
}

func NewMedianPruner(minTrials int, percentile float64) *MedianPruner {
    return &MedianPruner{
        minTrials:  minTrials,
        percentile: percentile,
        minSteps:   5,
    }
}

func (p *MedianPruner) ShouldPrune(trial *Trial) bool {
    if len(trial.IntermediateValues) < p.minSteps {
        return false
    }

    // 获取当前步数和值
    current := trial.IntermediateValues[len(trial.IntermediateValues)-1]

    // 需要获取其他试验的同步数据进行比较
    // 这里简化处理,实际需要从存储获取
    return false
}

// ShouldPruneWithHistory 带历史数据的剪枝判断
func (p *MedianPruner) ShouldPruneWithHistory(trial *Trial, completedTrials []*Trial) bool {
    if len(trial.IntermediateValues) < p.minSteps {
        return false
    }

    if len(completedTrials) < p.minTrials {
        return false
    }

    // 获取当前步数和值
    currentStep := trial.IntermediateValues[len(trial.IntermediateValues)-1].Step
    currentValue := trial.IntermediateValues[len(trial.IntermediateValues)-1].Value

    // 收集同一步数的历史值
    var historyValues []float64
    for _, t := range completedTrials {
        for _, iv := range t.IntermediateValues {
            if iv.Step == currentStep {
                historyValues = append(historyValues, iv.Value)
                break
            }
        }
    }

    if len(historyValues) < p.minTrials {
        return false
    }

    // 计算分位数
    sort.Float64s(historyValues)
    threshold := percentile(historyValues, p.percentile)

    // 如果当前值差于分位数,剪枝
    return currentValue > threshold
}

func percentile(sorted []float64, p float64) float64 {
    if len(sorted) == 0 {
        return 0
    }

    idx := p * float64(len(sorted)-1)
    lower := int(idx)
    upper := lower + 1

    if upper >= len(sorted) {
        return sorted[len(sorted)-1]
    }

    frac := idx - float64(lower)
    return sorted[lower]*(1-frac) + sorted[upper]*frac
}

Kubernetes 集成

HPO CRD

# hpo-crd.yaml
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
  name: studies.hpo.ai.io
spec:
  group: hpo.ai.io
  versions:
  - name: v1
    served: true
    storage: true
    schema:
      openAPIV3Schema:
        type: object
        properties:
          spec:
            type: object
            properties:
              objective:
                type: string
                enum: [minimize, maximize]
              algorithm:
                type: string
                enum: [random, tpe, bayesian, hyperband, asha]
              maxTrials:
                type: integer
              parallelism:
                type: integer
              maxDuration:
                type: string
              searchSpace:
                type: object
                x-kubernetes-preserve-unknown-fields: true
              trialTemplate:
                type: object
                properties:
                  image:
                    type: string
                  command:
                    type: array
                    items:
                      type: string
                  resources:
                    type: object
                    x-kubernetes-preserve-unknown-fields: true
              metricsCollector:
                type: object
                properties:
                  type:
                    type: string
                  source:
                    type: string
          status:
            type: object
            properties:
              phase:
                type: string
              trials:
                type: integer
              completedTrials:
                type: integer
              bestTrial:
                type: object
                properties:
                  name:
                    type: string
                  value:
                    type: number
                  config:
                    type: object
                    x-kubernetes-preserve-unknown-fields: true
    subresources:
      status: {}
  scope: Namespaced
  names:
    plural: studies
    singular: study
    kind: Study
    shortNames:
    - hpo

---
# 示例 Study
apiVersion: hpo.ai.io/v1
kind: Study
metadata:
  name: llm-lr-optimization
  namespace: ai-training
spec:
  objective: minimize
  algorithm: tpe
  maxTrials: 50
  parallelism: 4
  maxDuration: "48h"

  searchSpace:
    parameters:
      learning_rate:
        type: continuous
        low: 1e-6
        high: 1e-3
        logScale: true
      batch_size:
        type: discrete
        low: 8
        high: 64
        step: 8
      warmup_ratio:
        type: continuous
        low: 0.0
        high: 0.2
      weight_decay:
        type: continuous
        low: 0.0
        high: 0.3
      optimizer:
        type: categorical
        choices: ["adamw", "adam", "sgd"]

  trialTemplate:
    image: training:v1.0
    command:
    - python
    - train.py
    - --learning-rate=${learning_rate}
    - --batch-size=${batch_size}
    - --warmup-ratio=${warmup_ratio}
    - --weight-decay=${weight_decay}
    - --optimizer=${optimizer}
    resources:
      limits:
        nvidia.com/gpu: 8
        memory: "128Gi"
      requests:
        cpu: "16"

  metricsCollector:
    type: prometheus
    source: "pod_annotation"

试验执行器

// k8s_executor.go
package hpo

import (
    "context"
    "fmt"
    "strings"
    "time"

    batchv1 "k8s.io/api/batch/v1"
    corev1 "k8s.io/api/core/v1"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    "k8s.io/client-go/kubernetes"
)

// K8sTrialExecutor Kubernetes 试验执行器
type K8sTrialExecutor struct {
    client    kubernetes.Interface
    namespace string
}

func NewK8sTrialExecutor(client kubernetes.Interface, namespace string) *K8sTrialExecutor {
    return &K8sTrialExecutor{
        client:    client,
        namespace: namespace,
    }
}

// Execute 执行试验
func (e *K8sTrialExecutor) Execute(ctx context.Context, trial *Trial, study *Study) error {
    // 构建 Job
    job := e.buildJob(trial, study)

    // 创建 Job
    _, err := e.client.BatchV1().Jobs(e.namespace).Create(ctx, job, metav1.CreateOptions{})
    if err != nil {
        return fmt.Errorf("create job: %w", err)
    }

    trial.JobName = job.Name

    // 等待完成
    return e.waitForCompletion(ctx, trial, study)
}

// buildJob 构建 Job
func (e *K8sTrialExecutor) buildJob(trial *Trial, study *Study) *batchv1.Job {
    jobName := fmt.Sprintf("trial-%s", trial.ID)

    // 替换命令中的参数占位符
    command := make([]string, len(study.TrialTemplate.Command))
    for i, cmd := range study.TrialTemplate.Command {
        command[i] = e.substituteParams(cmd, trial.Config)
    }

    // 构建环境变量
    env := []corev1.EnvVar{
        {Name: "TRIAL_ID", Value: trial.ID},
        {Name: "STUDY_ID", Value: study.ID},
        {Name: "HPO_REPORT_URL", Value: e.getReportURL()},
    }

    // 添加超参数为环境变量
    for key, value := range trial.Config {
        env = append(env, corev1.EnvVar{
            Name:  fmt.Sprintf("HP_%s", strings.ToUpper(key)),
            Value: fmt.Sprintf("%v", value),
        })
    }

    backoffLimit := int32(0)

    return &batchv1.Job{
        ObjectMeta: metav1.ObjectMeta{
            Name:      jobName,
            Namespace: e.namespace,
            Labels: map[string]string{
                "hpo.ai.io/study": study.ID,
                "hpo.ai.io/trial": trial.ID,
            },
            Annotations: map[string]string{
                "hpo.ai.io/config": serializeConfig(trial.Config),
            },
        },
        Spec: batchv1.JobSpec{
            BackoffLimit: &backoffLimit,
            Template: corev1.PodTemplateSpec{
                ObjectMeta: metav1.ObjectMeta{
                    Labels: map[string]string{
                        "hpo.ai.io/study": study.ID,
                        "hpo.ai.io/trial": trial.ID,
                    },
                },
                Spec: corev1.PodSpec{
                    RestartPolicy: corev1.RestartPolicyNever,
                    Containers: []corev1.Container{
                        {
                            Name:    "trial",
                            Image:   study.TrialTemplate.Image,
                            Command: command,
                            Env:     env,
                            Resources: corev1.ResourceRequirements{
                                Limits:   study.TrialTemplate.Resources.Limits,
                                Requests: study.TrialTemplate.Resources.Requests,
                            },
                        },
                    },
                },
            },
        },
    }
}

func (e *K8sTrialExecutor) substituteParams(template string, config Config) string {
    result := template
    for key, value := range config {
        placeholder := fmt.Sprintf("${%s}", key)
        result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", value))
    }
    return result
}

func (e *K8sTrialExecutor) waitForCompletion(ctx context.Context, trial *Trial, study *Study) error {
    ticker := time.NewTicker(10 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return ctx.Err()
        case <-ticker.C:
            job, err := e.client.BatchV1().Jobs(e.namespace).Get(ctx, trial.JobName, metav1.GetOptions{})
            if err != nil {
                continue
            }

            if job.Status.Succeeded > 0 {
                return nil
            }

            if job.Status.Failed > 0 {
                return fmt.Errorf("job failed")
            }
        }
    }
}

// Stop 停止试验
func (e *K8sTrialExecutor) Stop(trial *Trial) error {
    propagation := metav1.DeletePropagationBackground
    return e.client.BatchV1().Jobs(e.namespace).Delete(
        context.Background(),
        trial.JobName,
        metav1.DeleteOptions{
            PropagationPolicy: &propagation,
        },
    )
}

func (e *K8sTrialExecutor) getReportURL() string {
    return "http://hpo-controller:8080/api/v1/report"
}

func serializeConfig(config Config) string {
    // JSON 序列化
    return ""
}

Python 客户端

# hpo_client.py
import os
import requests
from typing import Any, Dict, Optional, Callable
from dataclasses import dataclass
import time

@dataclass
class TrialContext:
    """试验上下文"""
    trial_id: str
    study_id: str
    config: Dict[str, Any]
    report_url: str

class HPOClient:
    """HPO 客户端"""

    def __init__(self, server_url: Optional[str] = None):
        self.server_url = server_url or os.environ.get("HPO_REPORT_URL", "http://localhost:8080")
        self.trial_id = os.environ.get("TRIAL_ID")
        self.study_id = os.environ.get("STUDY_ID")
        self._config = None

    @property
    def config(self) -> Dict[str, Any]:
        """获取当前试验的超参数配置"""
        if self._config is None:
            self._config = self._load_config()
        return self._config

    def _load_config(self) -> Dict[str, Any]:
        """从环境变量加载配置"""
        config = {}
        for key, value in os.environ.items():
            if key.startswith("HP_"):
                param_name = key[3:].lower()
                config[param_name] = self._parse_value(value)
        return config

    def _parse_value(self, value: str) -> Any:
        """解析参数值"""
        # 尝试转换为数值
        try:
            if '.' in value:
                return float(value)
            return int(value)
        except ValueError:
            return value

    def report_intermediate(self, step: int, value: float) -> bool:
        """报告中间结果,返回是否应该继续训练"""
        if not self.trial_id:
            return True

        try:
            response = requests.post(
                f"{self.server_url}/api/v1/trials/{self.trial_id}/intermediate",
                json={
                    "step": step,
                    "value": value,
                }
            )
            result = response.json()
            return result.get("should_continue", True)
        except Exception as e:
            print(f"Failed to report intermediate value: {e}")
            return True

    def report_final(self, value: float) -> None:
        """报告最终结果"""
        if not self.trial_id:
            return

        try:
            requests.post(
                f"{self.server_url}/api/v1/trials/{self.trial_id}/complete",
                json={"value": value}
            )
        except Exception as e:
            print(f"Failed to report final value: {e}")

    def suggest(self, param_name: str) -> Any:
        """获取建议的参数值"""
        return self.config.get(param_name)


# 便捷函数
_client: Optional[HPOClient] = None

def get_client() -> HPOClient:
    global _client
    if _client is None:
        _client = HPOClient()
    return _client

def suggest(param_name: str) -> Any:
    """获取建议的参数值"""
    return get_client().suggest(param_name)

def report_intermediate(step: int, value: float) -> bool:
    """报告中间结果"""
    return get_client().report_intermediate(step, value)

def report_final(value: float) -> None:
    """报告最终结果"""
    get_client().report_final(value)


# 使用示例
def train_with_hpo():
    """集成 HPO 的训练示例"""
    import hpo_client as hpo

    # 获取超参数
    learning_rate = hpo.suggest("learning_rate")
    batch_size = hpo.suggest("batch_size")
    warmup_ratio = hpo.suggest("warmup_ratio")

    print(f"Training with lr={learning_rate}, bs={batch_size}, warmup={warmup_ratio}")

    # 模拟训练
    for epoch in range(10):
        train_loss = train_epoch(...)
        eval_loss = evaluate(...)

        # 报告中间结果
        should_continue = hpo.report_intermediate(epoch, eval_loss)

        if not should_continue:
            print(f"Trial pruned at epoch {epoch}")
            break

    # 报告最终结果
    final_loss = evaluate(...)
    hpo.report_final(final_loss)

最佳实践

HPO 策略选择

┌─────────────────────────────────────────────────────────────────┐
│                    HPO 策略选择指南                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  根据场景选择算法:                                               │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │                                                           │   │
│  │  预算充足 + 参数少 (<10)          → 网格搜索              │   │
│  │                                                           │   │
│  │  预算有限 + 参数少 (<10)          → 贝叶斯优化            │   │
│  │                                                           │   │
│  │  预算有限 + 参数多 (10-50)        → TPE                   │   │
│  │                                                           │   │
│  │  单次评估昂贵 + 需要早停          → ASHA / Hyperband      │   │
│  │                                                           │   │
│  │  大规模并行 + 分布式环境          → 异步 ASHA             │   │
│  │                                                           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  搜索空间设计:                                                   │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 学习率:使用对数尺度 [1e-6, 1e-2]                       │   │
│  │ • 批量大小:使用 2 的幂次 [8, 16, 32, 64, 128]            │   │
│  │ • Dropout:线性尺度 [0, 0.5]                              │   │
│  │ • 层数/隐藏维度:离散值或少量选项                          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  效率优化:                                                       │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 使用早停避免浪费资源                                     │   │
│  │ • 先在小数据集/少epoch上筛选                              │   │
│  │ • 利用迁移学习:从小模型迁移好的超参数                    │   │
│  │ • 设置合理的并行度(通常 4-16)                           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

小结

本章详细介绍了超参数优化系统的设计与实现:

  1. 基础概念:搜索空间定义、参数类型、条件依赖
  2. 搜索算法:网格搜索、随机搜索、贝叶斯优化、TPE
  3. 早停策略:Hyperband、ASHA、中位数剪枝
  4. 服务架构:HPO 控制器、试验执行器、剪枝器
  5. Kubernetes 集成:Study CRD、试验 Job 管理
  6. Python 客户端:训练代码集成方式

通过本章的学习,你应该能够:

  • 理解各种 HPO 算法的原理和适用场景
  • 设计合理的超参数搜索空间
  • 在 Kubernetes 上部署 HPO 服务
  • 将 HPO 集成到训练代码中

下一章我们将进入 推理服务 部分,讲解如何高效地部署和服务 AI 模型。

Prev
实验管理