超参数优化
概述
超参数优化(Hyperparameter Optimization, HPO)是机器学习工作流中的关键环节。对于大模型训练,一次实验可能需要数天甚至数周时间,选择合适的超参数配置对于节省计算资源和提升模型效果至关重要。本章深入讲解超参数优化的算法原理、系统设计和实践方案。
超参数优化基础
问题形式化
┌─────────────────────────────────────────────────────────────────┐
│ 超参数优化问题定义 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 目标:找到最优超参数配置 x* 使得目标函数 f(x) 最小化 │
│ │
│ x* = argmin f(x) │
│ x ∈ X │
│ │
│ 其中: │
│ • x:超参数配置向量 │
│ • X:超参数搜索空间 │
│ • f(x):目标函数(如验证损失) │
│ │
│ 挑战: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 1. 黑盒优化:f(x) 没有解析形式,只能通过训练评估 │ │
│ │ 2. 评估昂贵:每次评估需要完整训练,耗时数小时到数天 │ │
│ │ 3. 噪声:训练结果受随机性影响 │ │
│ │ 4. 混合空间:包含连续、离散、条件参数 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 超参数类型: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 连续:learning_rate ∈ [1e-6, 1e-2] │ │
│ │ • 离散:num_layers ∈ {6, 12, 24, 32} │ │
│ │ • 类别:optimizer ∈ {adam, sgd, adamw} │ │
│ │ • 条件:如果 optimizer=sgd,则 momentum ∈ [0, 1] │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
搜索空间定义
// search_space.go
package hpo
import (
"fmt"
"math"
"math/rand"
)
// SearchSpace 搜索空间
type SearchSpace struct {
Parameters map[string]Parameter
Conditions []Condition
}
// Parameter 参数定义
type Parameter interface {
Sample(rng *rand.Rand) interface{}
Contains(value interface{}) bool
Type() ParameterType
Name() string
}
// ParameterType 参数类型
type ParameterType string
const (
ParamTypeContinuous ParameterType = "continuous"
ParamTypeDiscrete ParameterType = "discrete"
ParamTypeCategorical ParameterType = "categorical"
)
// ContinuousParameter 连续参数
type ContinuousParameter struct {
name string
low float64
high float64
logScale bool // 对数尺度
}
func NewContinuousParameter(name string, low, high float64, logScale bool) *ContinuousParameter {
return &ContinuousParameter{
name: name,
low: low,
high: high,
logScale: logScale,
}
}
func (p *ContinuousParameter) Sample(rng *rand.Rand) interface{} {
if p.logScale {
// 对数尺度采样
logLow := math.Log(p.low)
logHigh := math.Log(p.high)
return math.Exp(logLow + rng.Float64()*(logHigh-logLow))
}
return p.low + rng.Float64()*(p.high-p.low)
}
func (p *ContinuousParameter) Contains(value interface{}) bool {
v, ok := value.(float64)
if !ok {
return false
}
return v >= p.low && v <= p.high
}
func (p *ContinuousParameter) Type() ParameterType { return ParamTypeContinuous }
func (p *ContinuousParameter) Name() string { return p.name }
func (p *ContinuousParameter) Low() float64 { return p.low }
func (p *ContinuousParameter) High() float64 { return p.high }
func (p *ContinuousParameter) LogScale() bool { return p.logScale }
// DiscreteParameter 离散参数
type DiscreteParameter struct {
name string
low int
high int
step int
}
func NewDiscreteParameter(name string, low, high, step int) *DiscreteParameter {
return &DiscreteParameter{
name: name,
low: low,
high: high,
step: step,
}
}
func (p *DiscreteParameter) Sample(rng *rand.Rand) interface{} {
numSteps := (p.high - p.low) / p.step
return p.low + rng.Intn(numSteps+1)*p.step
}
func (p *DiscreteParameter) Contains(value interface{}) bool {
v, ok := value.(int)
if !ok {
return false
}
return v >= p.low && v <= p.high && (v-p.low)%p.step == 0
}
func (p *DiscreteParameter) Type() ParameterType { return ParamTypeDiscrete }
func (p *DiscreteParameter) Name() string { return p.name }
// CategoricalParameter 类别参数
type CategoricalParameter struct {
name string
choices []interface{}
}
func NewCategoricalParameter(name string, choices []interface{}) *CategoricalParameter {
return &CategoricalParameter{
name: name,
choices: choices,
}
}
func (p *CategoricalParameter) Sample(rng *rand.Rand) interface{} {
return p.choices[rng.Intn(len(p.choices))]
}
func (p *CategoricalParameter) Contains(value interface{}) bool {
for _, c := range p.choices {
if c == value {
return true
}
}
return false
}
func (p *CategoricalParameter) Type() ParameterType { return ParamTypeCategorical }
func (p *CategoricalParameter) Name() string { return p.name }
func (p *CategoricalParameter) Choices() []interface{} { return p.choices }
// Condition 条件依赖
type Condition struct {
Child string // 子参数
Parent string // 父参数
Values []interface{} // 父参数值(满足时子参数有效)
}
// Config 配置
type Config map[string]interface{}
// SearchSpaceBuilder 搜索空间构建器
type SearchSpaceBuilder struct {
space *SearchSpace
}
func NewSearchSpaceBuilder() *SearchSpaceBuilder {
return &SearchSpaceBuilder{
space: &SearchSpace{
Parameters: make(map[string]Parameter),
Conditions: []Condition{},
},
}
}
func (b *SearchSpaceBuilder) AddContinuous(name string, low, high float64, logScale bool) *SearchSpaceBuilder {
b.space.Parameters[name] = NewContinuousParameter(name, low, high, logScale)
return b
}
func (b *SearchSpaceBuilder) AddDiscrete(name string, low, high, step int) *SearchSpaceBuilder {
b.space.Parameters[name] = NewDiscreteParameter(name, low, high, step)
return b
}
func (b *SearchSpaceBuilder) AddCategorical(name string, choices []interface{}) *SearchSpaceBuilder {
b.space.Parameters[name] = NewCategoricalParameter(name, choices)
return b
}
func (b *SearchSpaceBuilder) AddCondition(child, parent string, values []interface{}) *SearchSpaceBuilder {
b.space.Conditions = append(b.space.Conditions, Condition{
Child: child,
Parent: parent,
Values: values,
})
return b
}
func (b *SearchSpaceBuilder) Build() *SearchSpace {
return b.space
}
// Sample 从搜索空间采样配置
func (s *SearchSpace) Sample(rng *rand.Rand) Config {
config := make(Config)
// 首先采样无条件参数
for name, param := range s.Parameters {
if !s.hasCondition(name) {
config[name] = param.Sample(rng)
}
}
// 然后采样有条件参数
for _, cond := range s.Conditions {
parentValue := config[cond.Parent]
if s.conditionSatisfied(parentValue, cond.Values) {
config[cond.Child] = s.Parameters[cond.Child].Sample(rng)
}
}
return config
}
func (s *SearchSpace) hasCondition(paramName string) bool {
for _, cond := range s.Conditions {
if cond.Child == paramName {
return true
}
}
return false
}
func (s *SearchSpace) conditionSatisfied(value interface{}, allowedValues []interface{}) bool {
for _, v := range allowedValues {
if v == value {
return true
}
}
return false
}
// 创建 LLM 训练的搜索空间示例
func CreateLLMSearchSpace() *SearchSpace {
return NewSearchSpaceBuilder().
// 学习率(对数尺度)
AddContinuous("learning_rate", 1e-6, 1e-3, true).
// 批量大小
AddDiscrete("batch_size", 8, 128, 8).
// 热身步数
AddDiscrete("warmup_steps", 100, 5000, 100).
// 权重衰减
AddContinuous("weight_decay", 0.0, 0.3, false).
// 优化器
AddCategorical("optimizer", []interface{}{"adam", "adamw", "sgd"}).
// 学习率调度器
AddCategorical("lr_scheduler", []interface{}{"linear", "cosine", "constant"}).
// Dropout
AddContinuous("dropout", 0.0, 0.5, false).
// 梯度裁剪
AddContinuous("gradient_clip", 0.1, 10.0, true).
// SGD 专用:动量(条件参数)
AddContinuous("momentum", 0.0, 0.99, false).
AddCondition("momentum", "optimizer", []interface{}{"sgd"}).
Build()
}
搜索算法实现
网格搜索与随机搜索
// grid_random_search.go
package hpo
import (
"context"
"math/rand"
"sync"
)
// GridSearch 网格搜索
type GridSearch struct {
space *SearchSpace
gridPoints map[string][]interface{}
}
func NewGridSearch(space *SearchSpace, gridPoints map[string][]interface{}) *GridSearch {
return &GridSearch{
space: space,
gridPoints: gridPoints,
}
}
// GenerateConfigs 生成所有网格配置
func (g *GridSearch) GenerateConfigs() []Config {
// 计算笛卡尔积
paramNames := make([]string, 0, len(g.gridPoints))
for name := range g.gridPoints {
paramNames = append(paramNames, name)
}
var configs []Config
g.cartesianProduct(paramNames, 0, make(Config), &configs)
return configs
}
func (g *GridSearch) cartesianProduct(names []string, idx int, current Config, result *[]Config) {
if idx == len(names) {
configCopy := make(Config)
for k, v := range current {
configCopy[k] = v
}
*result = append(*result, configCopy)
return
}
name := names[idx]
for _, value := range g.gridPoints[name] {
current[name] = value
g.cartesianProduct(names, idx+1, current, result)
}
}
// RandomSearch 随机搜索
type RandomSearch struct {
space *SearchSpace
rng *rand.Rand
numTrials int
}
func NewRandomSearch(space *SearchSpace, numTrials int, seed int64) *RandomSearch {
return &RandomSearch{
space: space,
rng: rand.New(rand.NewSource(seed)),
numTrials: numTrials,
}
}
// GenerateConfigs 生成随机配置
func (r *RandomSearch) GenerateConfigs() []Config {
configs := make([]Config, r.numTrials)
for i := 0; i < r.numTrials; i++ {
configs[i] = r.space.Sample(r.rng)
}
return configs
}
// Suggest 建议下一个配置
func (r *RandomSearch) Suggest() Config {
return r.space.Sample(r.rng)
}
贝叶斯优化
// bayesian_optimization.go
package hpo
import (
"context"
"math"
"math/rand"
"sort"
)
// BayesianOptimization 贝叶斯优化
type BayesianOptimization struct {
space *SearchSpace
rng *rand.Rand
surrogate SurrogateModel
acquisition AcquisitionFunction
observations []Observation
numInitSamples int
explorationRate float64
}
// Observation 观测值
type Observation struct {
Config Config
Value float64 // 目标函数值(越小越好)
}
// SurrogateModel 代理模型接口
type SurrogateModel interface {
Fit(observations []Observation) error
Predict(config Config) (mean, std float64)
}
// AcquisitionFunction 采集函数接口
type AcquisitionFunction interface {
Evaluate(mean, std, bestValue float64) float64
}
// NewBayesianOptimization 创建贝叶斯优化器
func NewBayesianOptimization(
space *SearchSpace,
seed int64,
numInitSamples int,
) *BayesianOptimization {
return &BayesianOptimization{
space: space,
rng: rand.New(rand.NewSource(seed)),
surrogate: NewGaussianProcess(),
acquisition: NewExpectedImprovement(),
observations: []Observation{},
numInitSamples: numInitSamples,
explorationRate: 0.1,
}
}
// Suggest 建议下一个配置
func (bo *BayesianOptimization) Suggest() Config {
// 如果观测数量不足,使用随机采样
if len(bo.observations) < bo.numInitSamples {
return bo.space.Sample(bo.rng)
}
// 探索 vs 利用
if bo.rng.Float64() < bo.explorationRate {
return bo.space.Sample(bo.rng)
}
// 拟合代理模型
bo.surrogate.Fit(bo.observations)
// 找到当前最优值
bestValue := bo.getBestValue()
// 使用采集函数优化找到下一个配置
return bo.optimizeAcquisition(bestValue)
}
// Observe 记录观测
func (bo *BayesianOptimization) Observe(config Config, value float64) {
bo.observations = append(bo.observations, Observation{
Config: config,
Value: value,
})
}
// GetBest 获取最优配置
func (bo *BayesianOptimization) GetBest() (Config, float64) {
if len(bo.observations) == 0 {
return nil, math.Inf(1)
}
best := bo.observations[0]
for _, obs := range bo.observations[1:] {
if obs.Value < best.Value {
best = obs
}
}
return best.Config, best.Value
}
func (bo *BayesianOptimization) getBestValue() float64 {
best := math.Inf(1)
for _, obs := range bo.observations {
if obs.Value < best {
best = obs.Value
}
}
return best
}
func (bo *BayesianOptimization) optimizeAcquisition(bestValue float64) Config {
// 使用随机搜索优化采集函数
numCandidates := 1000
bestConfig := bo.space.Sample(bo.rng)
bestAcq := math.Inf(-1)
for i := 0; i < numCandidates; i++ {
config := bo.space.Sample(bo.rng)
mean, std := bo.surrogate.Predict(config)
acq := bo.acquisition.Evaluate(mean, std, bestValue)
if acq > bestAcq {
bestAcq = acq
bestConfig = config
}
}
return bestConfig
}
// GaussianProcess 高斯过程代理模型
type GaussianProcess struct {
observations []Observation
lengthScale float64
noiseVar float64
K [][]float64 // 核矩阵
Kinv [][]float64 // 核矩阵逆
alpha []float64 // K^{-1} * y
}
func NewGaussianProcess() *GaussianProcess {
return &GaussianProcess{
lengthScale: 1.0,
noiseVar: 1e-6,
}
}
func (gp *GaussianProcess) Fit(observations []Observation) error {
gp.observations = observations
n := len(observations)
// 构建核矩阵
gp.K = make([][]float64, n)
for i := 0; i < n; i++ {
gp.K[i] = make([]float64, n)
for j := 0; j < n; j++ {
gp.K[i][j] = gp.kernel(observations[i].Config, observations[j].Config)
if i == j {
gp.K[i][j] += gp.noiseVar
}
}
}
// 计算逆矩阵和 alpha
gp.Kinv = gp.invertMatrix(gp.K)
y := make([]float64, n)
for i, obs := range observations {
y[i] = obs.Value
}
gp.alpha = gp.matVecMul(gp.Kinv, y)
return nil
}
func (gp *GaussianProcess) Predict(config Config) (mean, std float64) {
if len(gp.observations) == 0 {
return 0, 1.0
}
n := len(gp.observations)
// 计算 k_* (新点与所有观测点的核)
kStar := make([]float64, n)
for i, obs := range gp.observations {
kStar[i] = gp.kernel(config, obs.Config)
}
// 预测均值: k_*^T * alpha
for i := 0; i < n; i++ {
mean += kStar[i] * gp.alpha[i]
}
// 预测方差: k_** - k_*^T * K^{-1} * k_*
kStarStar := gp.kernel(config, config)
v := gp.matVecMul(gp.Kinv, kStar)
var variance float64 = kStarStar
for i := 0; i < n; i++ {
variance -= kStar[i] * v[i]
}
if variance < 0 {
variance = 0
}
std = math.Sqrt(variance)
return mean, std
}
// RBF 核函数
func (gp *GaussianProcess) kernel(x1, x2 Config) float64 {
// 计算欧几里得距离(简化版,实际需要处理不同类型参数)
var dist float64
for key := range x1 {
v1, ok1 := x1[key].(float64)
v2, ok2 := x2[key].(float64)
if ok1 && ok2 {
diff := (v1 - v2) / gp.lengthScale
dist += diff * diff
}
}
return math.Exp(-0.5 * dist)
}
func (gp *GaussianProcess) invertMatrix(A [][]float64) [][]float64 {
// 简化的矩阵求逆(实际应使用数值稳定的方法)
n := len(A)
inv := make([][]float64, n)
for i := range inv {
inv[i] = make([]float64, n)
inv[i][i] = 1.0 / A[i][i] // 对角线近似
}
return inv
}
func (gp *GaussianProcess) matVecMul(A [][]float64, x []float64) []float64 {
n := len(A)
result := make([]float64, n)
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
result[i] += A[i][j] * x[j]
}
}
return result
}
// ExpectedImprovement 期望改进采集函数
type ExpectedImprovement struct {
xi float64 // 探索参数
}
func NewExpectedImprovement() *ExpectedImprovement {
return &ExpectedImprovement{xi: 0.01}
}
func (ei *ExpectedImprovement) Evaluate(mean, std, bestValue float64) float64 {
if std == 0 {
return 0
}
z := (bestValue - mean - ei.xi) / std
cdf := 0.5 * (1 + math.Erf(z/math.Sqrt2))
pdf := math.Exp(-z*z/2) / math.Sqrt(2*math.Pi)
return (bestValue - mean - ei.xi)*cdf + std*pdf
}
TPE (Tree-structured Parzen Estimator)
// tpe.go
package hpo
import (
"math"
"math/rand"
"sort"
)
// TPE Tree-structured Parzen Estimator
type TPE struct {
space *SearchSpace
rng *rand.Rand
observations []Observation
gamma float64 // 分位数
numEICandidates int
numInitSamples int
}
func NewTPE(space *SearchSpace, seed int64) *TPE {
return &TPE{
space: space,
rng: rand.New(rand.NewSource(seed)),
observations: []Observation{},
gamma: 0.25, // 取最好的 25% 作为 l(x)
numEICandidates: 24,
numInitSamples: 10,
}
}
// Suggest 建议下一个配置
func (t *TPE) Suggest() Config {
if len(t.observations) < t.numInitSamples {
return t.space.Sample(t.rng)
}
// 按目标值排序
sorted := make([]Observation, len(t.observations))
copy(sorted, t.observations)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Value < sorted[j].Value
})
// 分割为好的和差的
splitIdx := int(float64(len(sorted)) * t.gamma)
if splitIdx < 1 {
splitIdx = 1
}
good := sorted[:splitIdx]
bad := sorted[splitIdx:]
// 采样候选配置并计算 EI
bestConfig := t.space.Sample(t.rng)
bestEI := math.Inf(-1)
for i := 0; i < t.numEICandidates; i++ {
// 从 l(x) (好的分布) 采样
config := t.sampleFromGood(good)
// 计算 EI = l(x) / g(x)
lx := t.estimateDensity(config, good)
gx := t.estimateDensity(config, bad)
var ei float64
if gx > 0 {
ei = lx / gx
}
if ei > bestEI {
bestEI = ei
bestConfig = config
}
}
return bestConfig
}
// Observe 记录观测
func (t *TPE) Observe(config Config, value float64) {
t.observations = append(t.observations, Observation{
Config: config,
Value: value,
})
}
// sampleFromGood 从好的观测中采样
func (t *TPE) sampleFromGood(good []Observation) Config {
config := make(Config)
for name, param := range t.space.Parameters {
switch p := param.(type) {
case *ContinuousParameter:
// 使用 KDE 采样
values := make([]float64, len(good))
for i, obs := range good {
values[i] = obs.Config[name].(float64)
}
config[name] = t.sampleKDE(values, p.Low(), p.High())
case *CategoricalParameter:
// 使用频率采样
counts := make(map[interface{}]int)
for _, obs := range good {
counts[obs.Config[name]]++
}
config[name] = t.sampleCategorical(counts, p.Choices())
case *DiscreteParameter:
values := make([]int, len(good))
for i, obs := range good {
values[i] = obs.Config[name].(int)
}
config[name] = t.sampleDiscreteKDE(values, p.low, p.high)
}
}
return config
}
// estimateDensity 估计密度
func (t *TPE) estimateDensity(config Config, observations []Observation) float64 {
density := 1.0
for name, param := range t.space.Parameters {
switch p := param.(type) {
case *ContinuousParameter:
values := make([]float64, len(observations))
for i, obs := range observations {
values[i] = obs.Config[name].(float64)
}
density *= t.kdeProb(config[name].(float64), values, p.Low(), p.High())
case *CategoricalParameter:
counts := make(map[interface{}]int)
for _, obs := range observations {
counts[obs.Config[name]]++
}
density *= t.categoricalProb(config[name], counts, len(p.Choices()))
}
}
return density
}
// KDE 概率估计
func (t *TPE) kdeProb(x float64, samples []float64, low, high float64) float64 {
if len(samples) == 0 {
return 1.0 / (high - low) // 均匀分布
}
// Scott's rule for bandwidth
std := t.std(samples)
h := 1.06 * std * math.Pow(float64(len(samples)), -0.2)
if h == 0 {
h = 0.1 * (high - low)
}
// 计算 KDE
prob := 0.0
for _, s := range samples {
prob += math.Exp(-0.5*math.Pow((x-s)/h, 2)) / (h * math.Sqrt(2*math.Pi))
}
prob /= float64(len(samples))
return prob
}
func (t *TPE) sampleKDE(samples []float64, low, high float64) float64 {
if len(samples) == 0 {
return low + t.rng.Float64()*(high-low)
}
// 选择一个样本
base := samples[t.rng.Intn(len(samples))]
// 添加噪声
std := t.std(samples)
h := 1.06 * std * math.Pow(float64(len(samples)), -0.2)
if h == 0 {
h = 0.1 * (high - low)
}
sample := base + t.rng.NormFloat64()*h
// 裁剪到边界
if sample < low {
sample = low
}
if sample > high {
sample = high
}
return sample
}
func (t *TPE) sampleDiscreteKDE(samples []int, low, high int) int {
if len(samples) == 0 {
return low + t.rng.Intn(high-low+1)
}
// 简化:均匀采样
return samples[t.rng.Intn(len(samples))]
}
func (t *TPE) categoricalProb(value interface{}, counts map[interface{}]int, numChoices int) float64 {
total := 0
for _, c := range counts {
total += c
}
if total == 0 {
return 1.0 / float64(numChoices)
}
count := counts[value]
// 添加拉普拉斯平滑
return (float64(count) + 1) / (float64(total) + float64(numChoices))
}
func (t *TPE) sampleCategorical(counts map[interface{}]int, choices []interface{}) interface{} {
if len(counts) == 0 {
return choices[t.rng.Intn(len(choices))]
}
// 根据频率采样
total := float64(len(choices)) // 拉普拉斯平滑
for _, c := range counts {
total += float64(c)
}
r := t.rng.Float64() * total
cumSum := 0.0
for _, choice := range choices {
count := float64(counts[choice]) + 1 // +1 平滑
cumSum += count
if r < cumSum {
return choice
}
}
return choices[len(choices)-1]
}
func (t *TPE) std(values []float64) float64 {
if len(values) < 2 {
return 0
}
mean := 0.0
for _, v := range values {
mean += v
}
mean /= float64(len(values))
sumSq := 0.0
for _, v := range values {
sumSq += (v - mean) * (v - mean)
}
return math.Sqrt(sumSq / float64(len(values)-1))
}
Hyperband 和 ASHA
// hyperband.go
package hpo
import (
"context"
"math"
"sort"
)
// Hyperband Hyperband 调度器
type Hyperband struct {
space *SearchSpace
sampler Sampler
maxResource int // 最大资源(如 epochs)
reductionFactor float64 // 减少因子(通常为 3)
minResource int // 最小资源
}
// Sampler 采样器接口
type Sampler interface {
Suggest() Config
Observe(config Config, value float64)
}
// Bracket 括号
type Bracket struct {
Configs []Config
Resources []int
Results []float64
}
func NewHyperband(space *SearchSpace, sampler Sampler, maxResource int) *Hyperband {
return &Hyperband{
space: space,
sampler: sampler,
maxResource: maxResource,
reductionFactor: 3,
minResource: 1,
}
}
// Run 运行 Hyperband
func (h *Hyperband) Run(ctx context.Context, objective ObjectiveFunc) (Config, float64, error) {
// 计算 s_max
sMax := int(math.Log(float64(h.maxResource)/float64(h.minResource)) / math.Log(h.reductionFactor))
var bestConfig Config
bestValue := math.Inf(1)
// 外循环:从 s_max 到 0
for s := sMax; s >= 0; s-- {
// 计算初始配置数量和资源
n := int(math.Ceil(float64(sMax+1) / float64(s+1) * math.Pow(h.reductionFactor, float64(s))))
r := int(float64(h.maxResource) * math.Pow(h.reductionFactor, float64(-s)))
// 采样初始配置
configs := make([]Config, n)
for i := 0; i < n; i++ {
configs[i] = h.sampler.Suggest()
}
// 逐步淘汰
for i := 0; i <= s; i++ {
// 当前资源
resource := r * int(math.Pow(h.reductionFactor, float64(i)))
nConfigs := int(float64(n) * math.Pow(h.reductionFactor, float64(-i)))
// 评估所有配置
type result struct {
config Config
value float64
}
results := make([]result, len(configs))
for j, config := range configs {
select {
case <-ctx.Done():
return bestConfig, bestValue, ctx.Err()
default:
}
value, err := objective(ctx, config, resource)
if err != nil {
results[j] = result{config, math.Inf(1)}
continue
}
results[j] = result{config, value}
h.sampler.Observe(config, value)
if value < bestValue {
bestValue = value
bestConfig = config
}
}
// 选择最好的 n/eta 个
sort.Slice(results, func(a, b int) bool {
return results[a].value < results[b].value
})
nextN := max(1, int(float64(nConfigs)/h.reductionFactor))
configs = make([]Config, nextN)
for j := 0; j < nextN; j++ {
configs[j] = results[j].config
}
}
}
return bestConfig, bestValue, nil
}
// ASHA Asynchronous Successive Halving Algorithm
type ASHA struct {
space *SearchSpace
sampler Sampler
maxResource int
reductionFactor float64
minResource int
brackets map[int]*ASHABracket // rung -> bracket
promotionRule string // "sync" or "async"
}
// ASHABracket ASHA 括号
type ASHABracket struct {
Rung int
Configs []ASHAConfig
Promoted map[string]bool
}
// ASHAConfig ASHA 配置
type ASHAConfig struct {
ID string
Config Config
Values map[int]float64 // resource -> value
}
func NewASHA(space *SearchSpace, sampler Sampler, maxResource int) *ASHA {
return &ASHA{
space: space,
sampler: sampler,
maxResource: maxResource,
reductionFactor: 4,
minResource: 1,
brackets: make(map[int]*ASHABracket),
promotionRule: "async",
}
}
// GetNextConfig 获取下一个要评估的配置和资源
func (a *ASHA) GetNextConfig() (Config, int, string) {
// 检查是否有可晋升的配置
for rung := 0; ; rung++ {
resource := a.minResource * int(math.Pow(a.reductionFactor, float64(rung)))
if resource > a.maxResource {
break
}
bracket := a.brackets[rung]
if bracket == nil {
continue
}
// 检查是否有配置可以晋升
promotable := a.findPromotable(bracket)
if len(promotable) > 0 {
// 返回第一个可晋升的配置
config := promotable[0]
nextRung := rung + 1
nextResource := a.minResource * int(math.Pow(a.reductionFactor, float64(nextRung)))
if nextResource <= a.maxResource {
bracket.Promoted[config.ID] = true
return config.Config, nextResource, config.ID
}
}
}
// 没有可晋升的,采样新配置
newConfig := a.sampler.Suggest()
configID := generateConfigID()
return newConfig, a.minResource, configID
}
// Report 报告评估结果
func (a *ASHA) Report(configID string, config Config, resource int, value float64) {
// 找到对应的 rung
rung := int(math.Log(float64(resource)/float64(a.minResource)) / math.Log(a.reductionFactor))
// 初始化 bracket
if a.brackets[rung] == nil {
a.brackets[rung] = &ASHABracket{
Rung: rung,
Configs: []ASHAConfig{},
Promoted: make(map[string]bool),
}
}
bracket := a.brackets[rung]
// 查找或创建配置
found := false
for i := range bracket.Configs {
if bracket.Configs[i].ID == configID {
bracket.Configs[i].Values[resource] = value
found = true
break
}
}
if !found {
bracket.Configs = append(bracket.Configs, ASHAConfig{
ID: configID,
Config: config,
Values: map[int]float64{resource: value},
})
}
// 通知采样器
a.sampler.Observe(config, value)
}
// findPromotable 找到可晋升的配置
func (a *ASHA) findPromotable(bracket *ASHABracket) []ASHAConfig {
resource := a.minResource * int(math.Pow(a.reductionFactor, float64(bracket.Rung)))
// 收集已完成的配置
completed := make([]ASHAConfig, 0)
for _, config := range bracket.Configs {
if _, ok := config.Values[resource]; ok && !bracket.Promoted[config.ID] {
completed = append(completed, config)
}
}
// 计算晋升数量
numPromote := int(float64(len(completed)) / a.reductionFactor)
if numPromote < 1 {
return nil
}
// 按值排序
sort.Slice(completed, func(i, j int) bool {
return completed[i].Values[resource] < completed[j].Values[resource]
})
return completed[:numPromote]
}
// GetBest 获取最优配置
func (a *ASHA) GetBest() (Config, float64) {
var bestConfig Config
bestValue := math.Inf(1)
for _, bracket := range a.brackets {
for _, config := range bracket.Configs {
for _, value := range config.Values {
if value < bestValue {
bestValue = value
bestConfig = config.Config
}
}
}
}
return bestConfig, bestValue
}
// ObjectiveFunc 目标函数
type ObjectiveFunc func(ctx context.Context, config Config, resource int) (float64, error)
func max(a, b int) int {
if a > b {
return a
}
return b
}
func generateConfigID() string {
// 生成唯一 ID
return ""
}
HPO 服务实现
HPO 服务架构
┌─────────────────────────────────────────────────────────────────┐
│ HPO 服务架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ HPO Controller │ │
│ │ • Study 管理 │ │
│ │ • Trial 调度 │ │
│ │ • 资源分配 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
│ │ Sampler │ │ Scheduler │ │ Pruner │ │
│ │ Service │ │ Service │ │ Service │ │
│ │ │ │ │ │ │ │
│ │ • Random │ │ • FIFO │ │ • Median │ │
│ │ • TPE │ │ • Priority│ │ • SHA │ │
│ │ • BO │ │ • Fair │ │ • ASHA │ │
│ └───────────┘ └───────────┘ └───────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Trial Executor │ │
│ │ • Kubernetes Job │ │
│ │ • 分布式训练支持 │ │
│ │ • 检查点恢复 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
HPO 控制器
// hpo_controller.go
package hpo
import (
"context"
"fmt"
"sync"
"time"
)
// Study 研究(HPO 任务)
type Study struct {
ID string `json:"id"`
Name string `json:"name"`
Objective string `json:"objective"` // minimize, maximize
SearchSpace *SearchSpace `json:"search_space"`
Algorithm string `json:"algorithm"` // random, tpe, bayesian, hyperband
Status StudyStatus `json:"status"`
// 配置
MaxTrials int `json:"max_trials"`
MaxDuration time.Duration `json:"max_duration"`
Parallelism int `json:"parallelism"`
// Hyperband/ASHA 配置
MaxResource int `json:"max_resource,omitempty"`
ReductionFactor float64 `json:"reduction_factor,omitempty"`
// 运行时状态
Trials []*Trial `json:"trials"`
BestTrial *Trial `json:"best_trial"`
StartTime *time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time"`
CreatedAt time.Time `json:"created_at"`
}
type StudyStatus string
const (
StudyStatusPending StudyStatus = "pending"
StudyStatusRunning StudyStatus = "running"
StudyStatusCompleted StudyStatus = "completed"
StudyStatusFailed StudyStatus = "failed"
StudyStatusStopped StudyStatus = "stopped"
)
// Trial 试验
type Trial struct {
ID string `json:"id"`
StudyID string `json:"study_id"`
Number int `json:"number"`
Config Config `json:"config"`
Status TrialStatus `json:"status"`
Value *float64 `json:"value,omitempty"` // 最终指标值
Resource int `json:"resource"` // 使用的资源(如 epochs)
// 中间结果
IntermediateValues []IntermediateValue `json:"intermediate_values"`
// 执行信息
JobName string `json:"job_name"`
StartTime *time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time"`
Error string `json:"error,omitempty"`
}
type TrialStatus string
const (
TrialStatusPending TrialStatus = "pending"
TrialStatusRunning TrialStatus = "running"
TrialStatusCompleted TrialStatus = "completed"
TrialStatusFailed TrialStatus = "failed"
TrialStatusPruned TrialStatus = "pruned"
)
// IntermediateValue 中间结果
type IntermediateValue struct {
Step int `json:"step"`
Value float64 `json:"value"`
Time time.Time `json:"time"`
}
// HPOController HPO 控制器
type HPOController struct {
// 存储
studyRepo StudyRepository
trialRepo TrialRepository
// 执行器
executor TrialExecutor
// 算法实例
samplers map[string]Sampler
pruners map[string]Pruner
// 运行时状态
activeStudies map[string]*studyRunner
mu sync.RWMutex
}
type studyRunner struct {
study *Study
sampler Sampler
pruner Pruner
cancel context.CancelFunc
}
// Pruner 剪枝器接口
type Pruner interface {
ShouldPrune(trial *Trial) bool
}
// TrialExecutor 试验执行器接口
type TrialExecutor interface {
Execute(ctx context.Context, trial *Trial, study *Study) error
Stop(trial *Trial) error
}
// NewHPOController 创建 HPO 控制器
func NewHPOController(
studyRepo StudyRepository,
trialRepo TrialRepository,
executor TrialExecutor,
) *HPOController {
return &HPOController{
studyRepo: studyRepo,
trialRepo: trialRepo,
executor: executor,
samplers: make(map[string]Sampler),
pruners: make(map[string]Pruner),
activeStudies: make(map[string]*studyRunner),
}
}
// CreateStudy 创建研究
func (c *HPOController) CreateStudy(ctx context.Context, study *Study) error {
study.ID = generateStudyID()
study.Status = StudyStatusPending
study.Trials = []*Trial{}
study.CreatedAt = time.Now()
if err := c.studyRepo.Create(ctx, study); err != nil {
return fmt.Errorf("create study: %w", err)
}
return nil
}
// StartStudy 启动研究
func (c *HPOController) StartStudy(ctx context.Context, studyID string) error {
study, err := c.studyRepo.Get(ctx, studyID)
if err != nil {
return err
}
// 创建采样器
sampler := c.createSampler(study)
// 创建剪枝器
pruner := c.createPruner(study)
// 创建运行上下文
runCtx, cancel := context.WithCancel(ctx)
runner := &studyRunner{
study: study,
sampler: sampler,
pruner: pruner,
cancel: cancel,
}
c.mu.Lock()
c.activeStudies[studyID] = runner
c.mu.Unlock()
// 更新状态
now := time.Now()
study.Status = StudyStatusRunning
study.StartTime = &now
c.studyRepo.Update(ctx, study)
// 启动调度循环
go c.runStudy(runCtx, runner)
return nil
}
// runStudy 运行研究
func (c *HPOController) runStudy(ctx context.Context, runner *studyRunner) {
study := runner.study
defer func() {
c.mu.Lock()
delete(c.activeStudies, study.ID)
c.mu.Unlock()
now := time.Now()
study.EndTime = &now
if study.Status == StudyStatusRunning {
study.Status = StudyStatusCompleted
}
c.studyRepo.Update(context.Background(), study)
}()
// 控制并发
sem := make(chan struct{}, study.Parallelism)
var wg sync.WaitGroup
trialNumber := 0
for {
select {
case <-ctx.Done():
wg.Wait()
return
default:
}
// 检查终止条件
if trialNumber >= study.MaxTrials {
break
}
if study.MaxDuration > 0 && time.Since(*study.StartTime) > study.MaxDuration {
break
}
// 获取并发槽位
sem <- struct{}{}
trialNumber++
wg.Add(1)
go func(num int) {
defer wg.Done()
defer func() { <-sem }()
c.runTrial(ctx, runner, num)
}(trialNumber)
}
wg.Wait()
}
// runTrial 运行试验
func (c *HPOController) runTrial(ctx context.Context, runner *studyRunner, number int) {
study := runner.study
// 获取配置
config := runner.sampler.Suggest()
// 创建试验
trial := &Trial{
ID: generateTrialID(),
StudyID: study.ID,
Number: number,
Config: config,
Status: TrialStatusPending,
Resource: study.MaxResource,
}
if err := c.trialRepo.Create(ctx, trial); err != nil {
return
}
now := time.Now()
trial.Status = TrialStatusRunning
trial.StartTime = &now
c.trialRepo.Update(ctx, trial)
// 执行试验
err := c.executor.Execute(ctx, trial, study)
endTime := time.Now()
trial.EndTime = &endTime
if err != nil {
trial.Status = TrialStatusFailed
trial.Error = err.Error()
} else if trial.Status != TrialStatusPruned {
trial.Status = TrialStatusCompleted
}
c.trialRepo.Update(ctx, trial)
// 通知采样器
if trial.Value != nil {
value := *trial.Value
if study.Objective == "maximize" {
value = -value // 转换为最小化
}
runner.sampler.Observe(config, value)
}
// 更新最佳试验
c.updateBestTrial(ctx, study, trial)
}
// ReportIntermediateValue 报告中间结果
func (c *HPOController) ReportIntermediateValue(
ctx context.Context,
trialID string,
step int,
value float64,
) (bool, error) { // 返回是否应该继续
trial, err := c.trialRepo.Get(ctx, trialID)
if err != nil {
return false, err
}
// 记录中间结果
trial.IntermediateValues = append(trial.IntermediateValues, IntermediateValue{
Step: step,
Value: value,
Time: time.Now(),
})
c.trialRepo.Update(ctx, trial)
// 检查是否应该剪枝
c.mu.RLock()
runner, ok := c.activeStudies[trial.StudyID]
c.mu.RUnlock()
if ok && runner.pruner != nil {
if runner.pruner.ShouldPrune(trial) {
trial.Status = TrialStatusPruned
c.trialRepo.Update(ctx, trial)
return false, nil
}
}
return true, nil
}
// CompleteTrial 完成试验
func (c *HPOController) CompleteTrial(ctx context.Context, trialID string, value float64) error {
trial, err := c.trialRepo.Get(ctx, trialID)
if err != nil {
return err
}
trial.Value = &value
trial.Status = TrialStatusCompleted
now := time.Now()
trial.EndTime = &now
return c.trialRepo.Update(ctx, trial)
}
// updateBestTrial 更新最佳试验
func (c *HPOController) updateBestTrial(ctx context.Context, study *Study, trial *Trial) {
if trial.Value == nil {
return
}
if study.BestTrial == nil || isBetter(*trial.Value, *study.BestTrial.Value, study.Objective) {
study.BestTrial = trial
c.studyRepo.Update(ctx, study)
}
}
func isBetter(new, old float64, objective string) bool {
if objective == "maximize" {
return new > old
}
return new < old
}
// createSampler 创建采样器
func (c *HPOController) createSampler(study *Study) Sampler {
switch study.Algorithm {
case "random":
return NewRandomSearch(study.SearchSpace, 0, time.Now().UnixNano())
case "tpe":
return NewTPE(study.SearchSpace, time.Now().UnixNano())
case "bayesian":
return NewBayesianOptimization(study.SearchSpace, time.Now().UnixNano(), 10)
default:
return NewRandomSearch(study.SearchSpace, 0, time.Now().UnixNano())
}
}
// createPruner 创建剪枝器
func (c *HPOController) createPruner(study *Study) Pruner {
return NewMedianPruner(5, 0.5)
}
func generateStudyID() string {
return fmt.Sprintf("study_%d", time.Now().UnixNano())
}
func generateTrialID() string {
return fmt.Sprintf("trial_%d", time.Now().UnixNano())
}
中位数剪枝器
// median_pruner.go
package hpo
import (
"sort"
)
// MedianPruner 中位数剪枝器
type MedianPruner struct {
minTrials int // 最少完成的试验数
percentile float64 // 剪枝分位数
minSteps int // 最小步数后才开始剪枝
}
func NewMedianPruner(minTrials int, percentile float64) *MedianPruner {
return &MedianPruner{
minTrials: minTrials,
percentile: percentile,
minSteps: 5,
}
}
func (p *MedianPruner) ShouldPrune(trial *Trial) bool {
if len(trial.IntermediateValues) < p.minSteps {
return false
}
// 获取当前步数和值
current := trial.IntermediateValues[len(trial.IntermediateValues)-1]
// 需要获取其他试验的同步数据进行比较
// 这里简化处理,实际需要从存储获取
return false
}
// ShouldPruneWithHistory 带历史数据的剪枝判断
func (p *MedianPruner) ShouldPruneWithHistory(trial *Trial, completedTrials []*Trial) bool {
if len(trial.IntermediateValues) < p.minSteps {
return false
}
if len(completedTrials) < p.minTrials {
return false
}
// 获取当前步数和值
currentStep := trial.IntermediateValues[len(trial.IntermediateValues)-1].Step
currentValue := trial.IntermediateValues[len(trial.IntermediateValues)-1].Value
// 收集同一步数的历史值
var historyValues []float64
for _, t := range completedTrials {
for _, iv := range t.IntermediateValues {
if iv.Step == currentStep {
historyValues = append(historyValues, iv.Value)
break
}
}
}
if len(historyValues) < p.minTrials {
return false
}
// 计算分位数
sort.Float64s(historyValues)
threshold := percentile(historyValues, p.percentile)
// 如果当前值差于分位数,剪枝
return currentValue > threshold
}
func percentile(sorted []float64, p float64) float64 {
if len(sorted) == 0 {
return 0
}
idx := p * float64(len(sorted)-1)
lower := int(idx)
upper := lower + 1
if upper >= len(sorted) {
return sorted[len(sorted)-1]
}
frac := idx - float64(lower)
return sorted[lower]*(1-frac) + sorted[upper]*frac
}
Kubernetes 集成
HPO CRD
# hpo-crd.yaml
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
name: studies.hpo.ai.io
spec:
group: hpo.ai.io
versions:
- name: v1
served: true
storage: true
schema:
openAPIV3Schema:
type: object
properties:
spec:
type: object
properties:
objective:
type: string
enum: [minimize, maximize]
algorithm:
type: string
enum: [random, tpe, bayesian, hyperband, asha]
maxTrials:
type: integer
parallelism:
type: integer
maxDuration:
type: string
searchSpace:
type: object
x-kubernetes-preserve-unknown-fields: true
trialTemplate:
type: object
properties:
image:
type: string
command:
type: array
items:
type: string
resources:
type: object
x-kubernetes-preserve-unknown-fields: true
metricsCollector:
type: object
properties:
type:
type: string
source:
type: string
status:
type: object
properties:
phase:
type: string
trials:
type: integer
completedTrials:
type: integer
bestTrial:
type: object
properties:
name:
type: string
value:
type: number
config:
type: object
x-kubernetes-preserve-unknown-fields: true
subresources:
status: {}
scope: Namespaced
names:
plural: studies
singular: study
kind: Study
shortNames:
- hpo
---
# 示例 Study
apiVersion: hpo.ai.io/v1
kind: Study
metadata:
name: llm-lr-optimization
namespace: ai-training
spec:
objective: minimize
algorithm: tpe
maxTrials: 50
parallelism: 4
maxDuration: "48h"
searchSpace:
parameters:
learning_rate:
type: continuous
low: 1e-6
high: 1e-3
logScale: true
batch_size:
type: discrete
low: 8
high: 64
step: 8
warmup_ratio:
type: continuous
low: 0.0
high: 0.2
weight_decay:
type: continuous
low: 0.0
high: 0.3
optimizer:
type: categorical
choices: ["adamw", "adam", "sgd"]
trialTemplate:
image: training:v1.0
command:
- python
- train.py
- --learning-rate=${learning_rate}
- --batch-size=${batch_size}
- --warmup-ratio=${warmup_ratio}
- --weight-decay=${weight_decay}
- --optimizer=${optimizer}
resources:
limits:
nvidia.com/gpu: 8
memory: "128Gi"
requests:
cpu: "16"
metricsCollector:
type: prometheus
source: "pod_annotation"
试验执行器
// k8s_executor.go
package hpo
import (
"context"
"fmt"
"strings"
"time"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
// K8sTrialExecutor Kubernetes 试验执行器
type K8sTrialExecutor struct {
client kubernetes.Interface
namespace string
}
func NewK8sTrialExecutor(client kubernetes.Interface, namespace string) *K8sTrialExecutor {
return &K8sTrialExecutor{
client: client,
namespace: namespace,
}
}
// Execute 执行试验
func (e *K8sTrialExecutor) Execute(ctx context.Context, trial *Trial, study *Study) error {
// 构建 Job
job := e.buildJob(trial, study)
// 创建 Job
_, err := e.client.BatchV1().Jobs(e.namespace).Create(ctx, job, metav1.CreateOptions{})
if err != nil {
return fmt.Errorf("create job: %w", err)
}
trial.JobName = job.Name
// 等待完成
return e.waitForCompletion(ctx, trial, study)
}
// buildJob 构建 Job
func (e *K8sTrialExecutor) buildJob(trial *Trial, study *Study) *batchv1.Job {
jobName := fmt.Sprintf("trial-%s", trial.ID)
// 替换命令中的参数占位符
command := make([]string, len(study.TrialTemplate.Command))
for i, cmd := range study.TrialTemplate.Command {
command[i] = e.substituteParams(cmd, trial.Config)
}
// 构建环境变量
env := []corev1.EnvVar{
{Name: "TRIAL_ID", Value: trial.ID},
{Name: "STUDY_ID", Value: study.ID},
{Name: "HPO_REPORT_URL", Value: e.getReportURL()},
}
// 添加超参数为环境变量
for key, value := range trial.Config {
env = append(env, corev1.EnvVar{
Name: fmt.Sprintf("HP_%s", strings.ToUpper(key)),
Value: fmt.Sprintf("%v", value),
})
}
backoffLimit := int32(0)
return &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: jobName,
Namespace: e.namespace,
Labels: map[string]string{
"hpo.ai.io/study": study.ID,
"hpo.ai.io/trial": trial.ID,
},
Annotations: map[string]string{
"hpo.ai.io/config": serializeConfig(trial.Config),
},
},
Spec: batchv1.JobSpec{
BackoffLimit: &backoffLimit,
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"hpo.ai.io/study": study.ID,
"hpo.ai.io/trial": trial.ID,
},
},
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicyNever,
Containers: []corev1.Container{
{
Name: "trial",
Image: study.TrialTemplate.Image,
Command: command,
Env: env,
Resources: corev1.ResourceRequirements{
Limits: study.TrialTemplate.Resources.Limits,
Requests: study.TrialTemplate.Resources.Requests,
},
},
},
},
},
},
}
}
func (e *K8sTrialExecutor) substituteParams(template string, config Config) string {
result := template
for key, value := range config {
placeholder := fmt.Sprintf("${%s}", key)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", value))
}
return result
}
func (e *K8sTrialExecutor) waitForCompletion(ctx context.Context, trial *Trial, study *Study) error {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
job, err := e.client.BatchV1().Jobs(e.namespace).Get(ctx, trial.JobName, metav1.GetOptions{})
if err != nil {
continue
}
if job.Status.Succeeded > 0 {
return nil
}
if job.Status.Failed > 0 {
return fmt.Errorf("job failed")
}
}
}
}
// Stop 停止试验
func (e *K8sTrialExecutor) Stop(trial *Trial) error {
propagation := metav1.DeletePropagationBackground
return e.client.BatchV1().Jobs(e.namespace).Delete(
context.Background(),
trial.JobName,
metav1.DeleteOptions{
PropagationPolicy: &propagation,
},
)
}
func (e *K8sTrialExecutor) getReportURL() string {
return "http://hpo-controller:8080/api/v1/report"
}
func serializeConfig(config Config) string {
// JSON 序列化
return ""
}
Python 客户端
# hpo_client.py
import os
import requests
from typing import Any, Dict, Optional, Callable
from dataclasses import dataclass
import time
@dataclass
class TrialContext:
"""试验上下文"""
trial_id: str
study_id: str
config: Dict[str, Any]
report_url: str
class HPOClient:
"""HPO 客户端"""
def __init__(self, server_url: Optional[str] = None):
self.server_url = server_url or os.environ.get("HPO_REPORT_URL", "http://localhost:8080")
self.trial_id = os.environ.get("TRIAL_ID")
self.study_id = os.environ.get("STUDY_ID")
self._config = None
@property
def config(self) -> Dict[str, Any]:
"""获取当前试验的超参数配置"""
if self._config is None:
self._config = self._load_config()
return self._config
def _load_config(self) -> Dict[str, Any]:
"""从环境变量加载配置"""
config = {}
for key, value in os.environ.items():
if key.startswith("HP_"):
param_name = key[3:].lower()
config[param_name] = self._parse_value(value)
return config
def _parse_value(self, value: str) -> Any:
"""解析参数值"""
# 尝试转换为数值
try:
if '.' in value:
return float(value)
return int(value)
except ValueError:
return value
def report_intermediate(self, step: int, value: float) -> bool:
"""报告中间结果,返回是否应该继续训练"""
if not self.trial_id:
return True
try:
response = requests.post(
f"{self.server_url}/api/v1/trials/{self.trial_id}/intermediate",
json={
"step": step,
"value": value,
}
)
result = response.json()
return result.get("should_continue", True)
except Exception as e:
print(f"Failed to report intermediate value: {e}")
return True
def report_final(self, value: float) -> None:
"""报告最终结果"""
if not self.trial_id:
return
try:
requests.post(
f"{self.server_url}/api/v1/trials/{self.trial_id}/complete",
json={"value": value}
)
except Exception as e:
print(f"Failed to report final value: {e}")
def suggest(self, param_name: str) -> Any:
"""获取建议的参数值"""
return self.config.get(param_name)
# 便捷函数
_client: Optional[HPOClient] = None
def get_client() -> HPOClient:
global _client
if _client is None:
_client = HPOClient()
return _client
def suggest(param_name: str) -> Any:
"""获取建议的参数值"""
return get_client().suggest(param_name)
def report_intermediate(step: int, value: float) -> bool:
"""报告中间结果"""
return get_client().report_intermediate(step, value)
def report_final(value: float) -> None:
"""报告最终结果"""
get_client().report_final(value)
# 使用示例
def train_with_hpo():
"""集成 HPO 的训练示例"""
import hpo_client as hpo
# 获取超参数
learning_rate = hpo.suggest("learning_rate")
batch_size = hpo.suggest("batch_size")
warmup_ratio = hpo.suggest("warmup_ratio")
print(f"Training with lr={learning_rate}, bs={batch_size}, warmup={warmup_ratio}")
# 模拟训练
for epoch in range(10):
train_loss = train_epoch(...)
eval_loss = evaluate(...)
# 报告中间结果
should_continue = hpo.report_intermediate(epoch, eval_loss)
if not should_continue:
print(f"Trial pruned at epoch {epoch}")
break
# 报告最终结果
final_loss = evaluate(...)
hpo.report_final(final_loss)
最佳实践
HPO 策略选择
┌─────────────────────────────────────────────────────────────────┐
│ HPO 策略选择指南 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 根据场景选择算法: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 预算充足 + 参数少 (<10) → 网格搜索 │ │
│ │ │ │
│ │ 预算有限 + 参数少 (<10) → 贝叶斯优化 │ │
│ │ │ │
│ │ 预算有限 + 参数多 (10-50) → TPE │ │
│ │ │ │
│ │ 单次评估昂贵 + 需要早停 → ASHA / Hyperband │ │
│ │ │ │
│ │ 大规模并行 + 分布式环境 → 异步 ASHA │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 搜索空间设计: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 学习率:使用对数尺度 [1e-6, 1e-2] │ │
│ │ • 批量大小:使用 2 的幂次 [8, 16, 32, 64, 128] │ │
│ │ • Dropout:线性尺度 [0, 0.5] │ │
│ │ • 层数/隐藏维度:离散值或少量选项 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 效率优化: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 使用早停避免浪费资源 │ │
│ │ • 先在小数据集/少epoch上筛选 │ │
│ │ • 利用迁移学习:从小模型迁移好的超参数 │ │
│ │ • 设置合理的并行度(通常 4-16) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
小结
本章详细介绍了超参数优化系统的设计与实现:
- 基础概念:搜索空间定义、参数类型、条件依赖
- 搜索算法:网格搜索、随机搜索、贝叶斯优化、TPE
- 早停策略:Hyperband、ASHA、中位数剪枝
- 服务架构:HPO 控制器、试验执行器、剪枝器
- Kubernetes 集成:Study CRD、试验 Job 管理
- Python 客户端:训练代码集成方式
通过本章的学习,你应该能够:
- 理解各种 HPO 算法的原理和适用场景
- 设计合理的超参数搜索空间
- 在 Kubernetes 上部署 HPO 服务
- 将 HPO 集成到训练代码中
下一章我们将进入 推理服务 部分,讲解如何高效地部署和服务 AI 模型。