HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

实验管理

概述

机器学习实验管理是 AI 平台的核心能力之一。一个典型的大模型训练项目可能涉及数百次实验,每次实验都有不同的超参数配置、数据集版本、代码版本。如何系统化地追踪、比较和复现这些实验,是提升研发效率的关键。本章深入讲解实验管理系统的设计与实现。

实验管理核心概念

实验层次结构

┌─────────────────────────────────────────────────────────────────┐
│                    实验管理层次结构                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                      Project                             │    │
│  │              (项目:如 LLM-7B-Training)                  │    │
│  │  ┌───────────────────────────────────────────────────┐  │    │
│  │  │                  Experiment                        │  │    │
│  │  │           (实验:如 learning-rate-sweep)           │  │    │
│  │  │  ┌─────────────────────────────────────────────┐  │  │    │
│  │  │  │                   Run                        │  │  │    │
│  │  │  │        (运行:如 run-20241201-001)           │  │  │    │
│  │  │  │                                              │  │  │    │
│  │  │  │  • Parameters (超参数)                       │  │  │    │
│  │  │  │  • Metrics (指标)                            │  │  │    │
│  │  │  │  • Artifacts (产物)                          │  │  │    │
│  │  │  │  • Source (代码版本)                         │  │  │    │
│  │  │  │  • Environment (环境)                        │  │  │    │
│  │  │  │                                              │  │  │    │
│  │  │  └─────────────────────────────────────────────┘  │  │    │
│  │  └───────────────────────────────────────────────────┘  │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
│  关键关系:                                                       │
│  • 一个 Project 包含多个 Experiment                             │
│  • 一个 Experiment 包含多个 Run                                 │
│  • 每个 Run 记录完整的实验上下文                                 │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

核心数据模型

// experiment_models.go
package experiment

import (
    "time"
)

// Project 项目
type Project struct {
    ID          string            `json:"id" db:"id"`
    Name        string            `json:"name" db:"name"`
    Description string            `json:"description" db:"description"`
    Owner       string            `json:"owner" db:"owner"`
    Team        string            `json:"team" db:"team"`
    Tags        []string          `json:"tags" db:"tags"`
    Settings    ProjectSettings   `json:"settings" db:"settings"`
    CreatedAt   time.Time         `json:"created_at" db:"created_at"`
    UpdatedAt   time.Time         `json:"updated_at" db:"updated_at"`
}

// ProjectSettings 项目设置
type ProjectSettings struct {
    // 默认实验配置
    DefaultExperimentConfig map[string]interface{} `json:"default_experiment_config"`
    // 指标可视化配置
    MetricDisplayConfig     MetricDisplayConfig    `json:"metric_display_config"`
    // 归档策略
    ArchivePolicy           ArchivePolicy          `json:"archive_policy"`
}

// MetricDisplayConfig 指标显示配置
type MetricDisplayConfig struct {
    PrimaryMetrics   []string `json:"primary_metrics"`   // 主要指标
    HigherIsBetter   []string `json:"higher_is_better"`  // 越高越好的指标
    ChartTypes       map[string]string `json:"chart_types"` // 图表类型
}

// ArchivePolicy 归档策略
type ArchivePolicy struct {
    AutoArchiveDays int  `json:"auto_archive_days"` // 自动归档天数
    KeepBestRuns    int  `json:"keep_best_runs"`    // 保留最佳运行数
    KeepRecentRuns  int  `json:"keep_recent_runs"`  // 保留最近运行数
}

// Experiment 实验
type Experiment struct {
    ID          string           `json:"id" db:"id"`
    ProjectID   string           `json:"project_id" db:"project_id"`
    Name        string           `json:"name" db:"name"`
    Description string           `json:"description" db:"description"`
    Hypothesis  string           `json:"hypothesis" db:"hypothesis"` // 实验假设
    Status      ExperimentStatus `json:"status" db:"status"`
    Tags        []string         `json:"tags" db:"tags"`
    CreatedBy   string           `json:"created_by" db:"created_by"`
    CreatedAt   time.Time        `json:"created_at" db:"created_at"`
    UpdatedAt   time.Time        `json:"updated_at" db:"updated_at"`
    ArchivedAt  *time.Time       `json:"archived_at" db:"archived_at"`
}

// ExperimentStatus 实验状态
type ExperimentStatus string

const (
    ExperimentStatusActive    ExperimentStatus = "active"
    ExperimentStatusCompleted ExperimentStatus = "completed"
    ExperimentStatusArchived  ExperimentStatus = "archived"
)

// Run 运行
type Run struct {
    ID           string        `json:"id" db:"id"`
    ExperimentID string        `json:"experiment_id" db:"experiment_id"`
    Name         string        `json:"name" db:"name"`
    Status       RunStatus     `json:"status" db:"status"`

    // 来源追踪
    Source       RunSource     `json:"source" db:"source"`

    // 配置
    Parameters   Parameters    `json:"parameters" db:"parameters"`
    Environment  Environment   `json:"environment" db:"environment"`

    // 运行时信息
    StartTime    *time.Time    `json:"start_time" db:"start_time"`
    EndTime      *time.Time    `json:"end_time" db:"end_time"`
    Duration     int64         `json:"duration" db:"duration"` // 秒

    // 资源使用
    Resources    ResourceUsage `json:"resources" db:"resources"`

    // 元信息
    Tags         []string      `json:"tags" db:"tags"`
    Notes        string        `json:"notes" db:"notes"`
    CreatedBy    string        `json:"created_by" db:"created_by"`
    CreatedAt    time.Time     `json:"created_at" db:"created_at"`
}

// RunStatus 运行状态
type RunStatus string

const (
    RunStatusPending   RunStatus = "pending"
    RunStatusRunning   RunStatus = "running"
    RunStatusCompleted RunStatus = "completed"
    RunStatusFailed    RunStatus = "failed"
    RunStatusKilled    RunStatus = "killed"
)

// RunSource 运行来源
type RunSource struct {
    // Git 信息
    GitRepoURL    string `json:"git_repo_url"`
    GitCommit     string `json:"git_commit"`
    GitBranch     string `json:"git_branch"`
    GitTag        string `json:"git_tag"`
    GitDirty      bool   `json:"git_dirty"` // 是否有未提交更改

    // 代码入口
    EntryPoint    string `json:"entry_point"`    // 执行脚本
    SourceType    string `json:"source_type"`    // "local", "git", "notebook"

    // 代码快照(小项目)
    CodeSnapshot  string `json:"code_snapshot"`  // 代码归档路径
}

// Parameters 参数
type Parameters struct {
    // 模型参数
    ModelConfig   map[string]interface{} `json:"model_config"`

    // 训练参数
    TrainingConfig map[string]interface{} `json:"training_config"`

    // 数据参数
    DataConfig    map[string]interface{} `json:"data_config"`

    // 原始参数(全部展开)
    Flat          map[string]interface{} `json:"flat"`
}

// Environment 环境信息
type Environment struct {
    // Python 环境
    PythonVersion    string            `json:"python_version"`
    PipPackages      map[string]string `json:"pip_packages"`
    CondaEnvironment string            `json:"conda_environment"`

    // 硬件信息
    GPUType          string   `json:"gpu_type"`
    GPUCount         int      `json:"gpu_count"`
    GPUMemory        int64    `json:"gpu_memory"`
    CPUCount         int      `json:"cpu_count"`
    MemoryGB         int      `json:"memory_gb"`

    // 系统信息
    Platform         string   `json:"platform"`
    Hostname         string   `json:"hostname"`
    DockerImage      string   `json:"docker_image"`

    // 框架版本
    FrameworkVersions map[string]string `json:"framework_versions"`
}

// ResourceUsage 资源使用
type ResourceUsage struct {
    GPUHours     float64 `json:"gpu_hours"`
    CPUHours     float64 `json:"cpu_hours"`
    MemoryGBHours float64 `json:"memory_gb_hours"`
    StorageGB    float64 `json:"storage_gb"`
    Cost         float64 `json:"cost"` // 预估成本
}

// Metric 指标
type Metric struct {
    RunID     string    `json:"run_id" db:"run_id"`
    Key       string    `json:"key" db:"key"`
    Value     float64   `json:"value" db:"value"`
    Step      int64     `json:"step" db:"step"`
    Timestamp time.Time `json:"timestamp" db:"timestamp"`
    Context   string    `json:"context" db:"context"` // train, eval, test
}

// Artifact 产物
type Artifact struct {
    ID        string    `json:"id" db:"id"`
    RunID     string    `json:"run_id" db:"run_id"`
    Name      string    `json:"name" db:"name"`
    Type      string    `json:"type" db:"type"` // model, data, figure, etc.
    Path      string    `json:"path" db:"path"`
    Size      int64     `json:"size" db:"size"`
    Checksum  string    `json:"checksum" db:"checksum"`
    Metadata  map[string]interface{} `json:"metadata" db:"metadata"`
    CreatedAt time.Time `json:"created_at" db:"created_at"`
}

实验追踪系统

追踪客户端

// tracking_client.go
package experiment

import (
    "bytes"
    "context"
    "encoding/json"
    "fmt"
    "net/http"
    "os"
    "os/exec"
    "runtime"
    "strings"
    "sync"
    "time"
)

// TrackingClient 追踪客户端
type TrackingClient struct {
    serverURL    string
    httpClient   *http.Client

    // 当前运行上下文
    activeRun    *Run
    runMu        sync.RWMutex

    // 批量写入缓冲
    metricBuffer []*Metric
    bufferMu     sync.Mutex
    flushTicker  *time.Ticker

    // 配置
    config       ClientConfig
}

// ClientConfig 客户端配置
type ClientConfig struct {
    BatchSize      int           // 批量大小
    FlushInterval  time.Duration // 刷新间隔
    AutoLogEnv     bool          // 自动记录环境
    AutoLogGit     bool          // 自动记录 Git 信息
}

// NewTrackingClient 创建追踪客户端
func NewTrackingClient(serverURL string, config ClientConfig) *TrackingClient {
    client := &TrackingClient{
        serverURL:    serverURL,
        httpClient:   &http.Client{Timeout: 30 * time.Second},
        metricBuffer: make([]*Metric, 0, config.BatchSize),
        config:       config,
    }

    // 启动定时刷新
    client.flushTicker = time.NewTicker(config.FlushInterval)
    go client.flushLoop()

    return client
}

// StartRun 开始运行
func (c *TrackingClient) StartRun(
    ctx context.Context,
    experimentID string,
    runName string,
    params map[string]interface{},
) (*Run, error) {

    // 创建运行
    run := &Run{
        ExperimentID: experimentID,
        Name:         runName,
        Status:       RunStatusRunning,
        Parameters: Parameters{
            Flat: params,
        },
        CreatedAt: time.Now(),
    }

    now := time.Now()
    run.StartTime = &now

    // 自动收集信息
    if c.config.AutoLogEnv {
        run.Environment = c.collectEnvironment()
    }

    if c.config.AutoLogGit {
        run.Source = c.collectGitInfo()
    }

    // 发送到服务器
    resp, err := c.post(ctx, "/api/v1/runs", run)
    if err != nil {
        return nil, fmt.Errorf("create run: %w", err)
    }

    if err := json.NewDecoder(resp.Body).Decode(run); err != nil {
        return nil, fmt.Errorf("decode response: %w", err)
    }

    // 设置为活跃运行
    c.runMu.Lock()
    c.activeRun = run
    c.runMu.Unlock()

    return run, nil
}

// LogParam 记录参数
func (c *TrackingClient) LogParam(key string, value interface{}) error {
    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    return c.LogParams(map[string]interface{}{key: value})
}

// LogParams 批量记录参数
func (c *TrackingClient) LogParams(params map[string]interface{}) error {
    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    _, err := c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/params", run.ID), params)
    return err
}

// LogMetric 记录指标
func (c *TrackingClient) LogMetric(key string, value float64, step int64) error {
    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    metric := &Metric{
        RunID:     run.ID,
        Key:       key,
        Value:     value,
        Step:      step,
        Timestamp: time.Now(),
    }

    // 添加到缓冲
    c.bufferMu.Lock()
    c.metricBuffer = append(c.metricBuffer, metric)

    // 如果达到批量大小,立即刷新
    if len(c.metricBuffer) >= c.config.BatchSize {
        go c.flush()
    }
    c.bufferMu.Unlock()

    return nil
}

// LogMetrics 批量记录指标
func (c *TrackingClient) LogMetrics(metrics map[string]float64, step int64) error {
    for key, value := range metrics {
        if err := c.LogMetric(key, value, step); err != nil {
            return err
        }
    }
    return nil
}

// LogArtifact 记录产物
func (c *TrackingClient) LogArtifact(localPath string, artifactPath string) error {
    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    // 读取文件
    data, err := os.ReadFile(localPath)
    if err != nil {
        return fmt.Errorf("read file: %w", err)
    }

    // 上传
    url := fmt.Sprintf("%s/api/v1/runs/%s/artifacts?path=%s", c.serverURL, run.ID, artifactPath)
    req, err := http.NewRequest("PUT", url, bytes.NewReader(data))
    if err != nil {
        return err
    }

    resp, err := c.httpClient.Do(req)
    if err != nil {
        return err
    }
    defer resp.Body.Close()

    if resp.StatusCode != http.StatusOK {
        return fmt.Errorf("upload failed: %s", resp.Status)
    }

    return nil
}

// LogModel 记录模型
func (c *TrackingClient) LogModel(modelPath string, modelName string, metadata map[string]interface{}) error {
    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    // 创建模型产物记录
    artifact := &Artifact{
        RunID:    run.ID,
        Name:     modelName,
        Type:     "model",
        Path:     modelPath,
        Metadata: metadata,
    }

    _, err := c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/models", run.ID), artifact)
    return err
}

// SetTag 设置标签
func (c *TrackingClient) SetTag(key, value string) error {
    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    _, err := c.post(context.Background(),
        fmt.Sprintf("/api/v1/runs/%s/tags", run.ID),
        map[string]string{key: value})
    return err
}

// EndRun 结束运行
func (c *TrackingClient) EndRun(status RunStatus) error {
    c.runMu.Lock()
    run := c.activeRun
    c.activeRun = nil
    c.runMu.Unlock()

    if run == nil {
        return fmt.Errorf("no active run")
    }

    // 刷新剩余指标
    c.flush()

    // 更新运行状态
    now := time.Now()
    run.EndTime = &now
    run.Status = status
    run.Duration = int64(now.Sub(*run.StartTime).Seconds())

    _, err := c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/end", run.ID), run)
    return err
}

// flush 刷新指标缓冲
func (c *TrackingClient) flush() {
    c.bufferMu.Lock()
    if len(c.metricBuffer) == 0 {
        c.bufferMu.Unlock()
        return
    }

    metrics := c.metricBuffer
    c.metricBuffer = make([]*Metric, 0, c.config.BatchSize)
    c.bufferMu.Unlock()

    c.runMu.RLock()
    run := c.activeRun
    c.runMu.RUnlock()

    if run == nil {
        return
    }

    // 批量发送
    c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/metrics/batch", run.ID), metrics)
}

func (c *TrackingClient) flushLoop() {
    for range c.flushTicker.C {
        c.flush()
    }
}

// collectEnvironment 收集环境信息
func (c *TrackingClient) collectEnvironment() Environment {
    env := Environment{
        PythonVersion:     getPythonVersion(),
        PipPackages:       getPipPackages(),
        Platform:          runtime.GOOS,
        CPUCount:          runtime.NumCPU(),
        FrameworkVersions: make(map[string]string),
    }

    // 获取 GPU 信息
    if gpuInfo := getGPUInfo(); gpuInfo != nil {
        env.GPUType = gpuInfo.Type
        env.GPUCount = gpuInfo.Count
        env.GPUMemory = gpuInfo.Memory
    }

    // 获取主机名
    env.Hostname, _ = os.Hostname()

    // 获取 Docker 镜像(如果在容器中)
    env.DockerImage = os.Getenv("DOCKER_IMAGE")

    return env
}

// collectGitInfo 收集 Git 信息
func (c *TrackingClient) collectGitInfo() RunSource {
    source := RunSource{
        SourceType: "git",
    }

    // 获取远程仓库 URL
    if out, err := exec.Command("git", "remote", "get-url", "origin").Output(); err == nil {
        source.GitRepoURL = strings.TrimSpace(string(out))
    }

    // 获取当前提交
    if out, err := exec.Command("git", "rev-parse", "HEAD").Output(); err == nil {
        source.GitCommit = strings.TrimSpace(string(out))
    }

    // 获取当前分支
    if out, err := exec.Command("git", "rev-parse", "--abbrev-ref", "HEAD").Output(); err == nil {
        source.GitBranch = strings.TrimSpace(string(out))
    }

    // 检查是否有未提交更改
    if out, err := exec.Command("git", "status", "--porcelain").Output(); err == nil {
        source.GitDirty = len(out) > 0
    }

    return source
}

func (c *TrackingClient) post(ctx context.Context, path string, body interface{}) (*http.Response, error) {
    data, err := json.Marshal(body)
    if err != nil {
        return nil, err
    }

    req, err := http.NewRequestWithContext(ctx, "POST", c.serverURL+path, bytes.NewReader(data))
    if err != nil {
        return nil, err
    }
    req.Header.Set("Content-Type", "application/json")

    return c.httpClient.Do(req)
}

// 辅助函数
func getPythonVersion() string {
    out, err := exec.Command("python", "--version").Output()
    if err != nil {
        return ""
    }
    return strings.TrimSpace(string(out))
}

func getPipPackages() map[string]string {
    packages := make(map[string]string)
    out, err := exec.Command("pip", "freeze").Output()
    if err != nil {
        return packages
    }

    for _, line := range strings.Split(string(out), "\n") {
        parts := strings.Split(line, "==")
        if len(parts) == 2 {
            packages[parts[0]] = parts[1]
        }
    }
    return packages
}

type GPUInfo struct {
    Type   string
    Count  int
    Memory int64
}

func getGPUInfo() *GPUInfo {
    // 使用 nvidia-smi 获取 GPU 信息
    out, err := exec.Command("nvidia-smi", "--query-gpu=name,count,memory.total", "--format=csv,noheader,nounits").Output()
    if err != nil {
        return nil
    }

    lines := strings.Split(strings.TrimSpace(string(out)), "\n")
    if len(lines) == 0 {
        return nil
    }

    // 解析第一行
    parts := strings.Split(lines[0], ", ")
    if len(parts) < 3 {
        return nil
    }

    info := &GPUInfo{
        Type:  parts[0],
        Count: len(lines),
    }

    fmt.Sscanf(parts[2], "%d", &info.Memory)
    return info
}

Python SDK

# experiment_sdk.py
import os
import json
import time
import hashlib
import threading
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
import requests

@dataclass
class RunConfig:
    """运行配置"""
    experiment_id: str
    run_name: Optional[str] = None
    tags: Dict[str, str] = field(default_factory=dict)
    description: str = ""

class ExperimentTracker:
    """实验追踪器"""

    def __init__(
        self,
        tracking_uri: str,
        batch_size: int = 100,
        flush_interval: float = 5.0,
        auto_log_env: bool = True,
        auto_log_git: bool = True,
    ):
        self.tracking_uri = tracking_uri
        self.batch_size = batch_size
        self.flush_interval = flush_interval
        self.auto_log_env = auto_log_env
        self.auto_log_git = auto_log_git

        self._active_run: Optional[Dict] = None
        self._metric_buffer: List[Dict] = []
        self._buffer_lock = threading.Lock()
        self._flush_thread: Optional[threading.Thread] = None
        self._stop_flush = threading.Event()

    def start_run(
        self,
        config: RunConfig,
        params: Optional[Dict[str, Any]] = None,
    ) -> Dict:
        """开始运行"""
        run_data = {
            "experiment_id": config.experiment_id,
            "name": config.run_name or f"run-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
            "status": "running",
            "tags": config.tags,
            "description": config.description,
            "start_time": datetime.now().isoformat(),
        }

        # 自动收集环境信息
        if self.auto_log_env:
            run_data["environment"] = self._collect_environment()

        # 自动收集 Git 信息
        if self.auto_log_git:
            run_data["source"] = self._collect_git_info()

        # 创建运行
        response = requests.post(
            f"{self.tracking_uri}/api/v1/runs",
            json=run_data,
        )
        response.raise_for_status()

        self._active_run = response.json()

        # 记录参数
        if params:
            self.log_params(params)

        # 启动后台刷新线程
        self._start_flush_thread()

        return self._active_run

    def log_param(self, key: str, value: Any) -> None:
        """记录单个参数"""
        self.log_params({key: value})

    def log_params(self, params: Dict[str, Any]) -> None:
        """记录多个参数"""
        if not self._active_run:
            raise RuntimeError("No active run")

        # 扁平化嵌套参数
        flat_params = self._flatten_params(params)

        requests.post(
            f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/params",
            json=flat_params,
        )

    def log_metric(
        self,
        key: str,
        value: float,
        step: Optional[int] = None,
        context: str = "train",
    ) -> None:
        """记录指标"""
        if not self._active_run:
            raise RuntimeError("No active run")

        metric = {
            "run_id": self._active_run["id"],
            "key": key,
            "value": float(value),
            "step": step or 0,
            "timestamp": datetime.now().isoformat(),
            "context": context,
        }

        with self._buffer_lock:
            self._metric_buffer.append(metric)
            if len(self._metric_buffer) >= self.batch_size:
                self._flush_metrics()

    def log_metrics(
        self,
        metrics: Dict[str, float],
        step: Optional[int] = None,
        context: str = "train",
    ) -> None:
        """记录多个指标"""
        for key, value in metrics.items():
            self.log_metric(key, value, step, context)

    def log_artifact(
        self,
        local_path: str,
        artifact_path: Optional[str] = None,
        artifact_type: str = "file",
    ) -> None:
        """记录产物"""
        if not self._active_run:
            raise RuntimeError("No active run")

        artifact_path = artifact_path or os.path.basename(local_path)

        with open(local_path, "rb") as f:
            files = {"file": (artifact_path, f)}
            requests.post(
                f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/artifacts",
                files=files,
                data={"type": artifact_type},
            )

    def log_model(
        self,
        model: Any,
        model_name: str,
        signature: Optional[Dict] = None,
        metadata: Optional[Dict] = None,
    ) -> None:
        """记录模型"""
        if not self._active_run:
            raise RuntimeError("No active run")

        # 保存模型到临时文件
        import tempfile
        with tempfile.TemporaryDirectory() as tmpdir:
            model_path = os.path.join(tmpdir, model_name)

            # 根据模型类型保存
            if hasattr(model, "save_pretrained"):
                # HuggingFace 模型
                model.save_pretrained(model_path)
            elif hasattr(model, "state_dict"):
                # PyTorch 模型
                import torch
                torch.save(model.state_dict(), model_path)
            else:
                raise ValueError(f"Unsupported model type: {type(model)}")

            # 上传模型
            self._upload_directory(model_path, f"models/{model_name}")

        # 记录模型元数据
        model_info = {
            "name": model_name,
            "signature": signature,
            "metadata": metadata or {},
            "framework": self._detect_framework(model),
        }

        requests.post(
            f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/models",
            json=model_info,
        )

    def log_figure(
        self,
        figure: Any,
        name: str,
        format: str = "png",
    ) -> None:
        """记录图表"""
        import tempfile
        with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as f:
            if hasattr(figure, "savefig"):
                # Matplotlib figure
                figure.savefig(f.name, format=format, bbox_inches="tight")
            else:
                raise ValueError(f"Unsupported figure type: {type(figure)}")

            self.log_artifact(f.name, f"figures/{name}.{format}", "figure")

        os.unlink(f.name)

    def set_tag(self, key: str, value: str) -> None:
        """设置标签"""
        if not self._active_run:
            raise RuntimeError("No active run")

        requests.post(
            f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/tags",
            json={key: value},
        )

    def end_run(self, status: str = "completed") -> None:
        """结束运行"""
        if not self._active_run:
            return

        # 停止刷新线程
        self._stop_flush.set()
        if self._flush_thread:
            self._flush_thread.join()

        # 刷新剩余指标
        self._flush_metrics()

        # 更新运行状态
        requests.post(
            f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/end",
            json={
                "status": status,
                "end_time": datetime.now().isoformat(),
            },
        )

        self._active_run = None

    def _flush_metrics(self) -> None:
        """刷新指标缓冲"""
        with self._buffer_lock:
            if not self._metric_buffer:
                return

            metrics = self._metric_buffer
            self._metric_buffer = []

        if not self._active_run:
            return

        try:
            requests.post(
                f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/metrics/batch",
                json=metrics,
            )
        except Exception as e:
            print(f"Failed to flush metrics: {e}")

    def _start_flush_thread(self) -> None:
        """启动后台刷新线程"""
        self._stop_flush.clear()

        def flush_loop():
            while not self._stop_flush.wait(self.flush_interval):
                self._flush_metrics()

        self._flush_thread = threading.Thread(target=flush_loop, daemon=True)
        self._flush_thread.start()

    def _flatten_params(
        self,
        params: Dict[str, Any],
        prefix: str = "",
    ) -> Dict[str, Any]:
        """扁平化嵌套参数"""
        flat = {}
        for key, value in params.items():
            full_key = f"{prefix}.{key}" if prefix else key
            if isinstance(value, dict):
                flat.update(self._flatten_params(value, full_key))
            else:
                flat[full_key] = value
        return flat

    def _collect_environment(self) -> Dict:
        """收集环境信息"""
        import sys
        import platform

        env = {
            "python_version": sys.version,
            "platform": platform.platform(),
            "hostname": platform.node(),
        }

        # 收集 pip 包
        try:
            import pkg_resources
            env["pip_packages"] = {
                pkg.key: pkg.version
                for pkg in pkg_resources.working_set
            }
        except:
            pass

        # 收集 GPU 信息
        try:
            import torch
            if torch.cuda.is_available():
                env["gpu_count"] = torch.cuda.device_count()
                env["gpu_type"] = torch.cuda.get_device_name(0)
                env["cuda_version"] = torch.version.cuda
        except:
            pass

        return env

    def _collect_git_info(self) -> Dict:
        """收集 Git 信息"""
        import subprocess

        info = {"source_type": "git"}

        try:
            info["git_repo_url"] = subprocess.check_output(
                ["git", "remote", "get-url", "origin"],
                stderr=subprocess.DEVNULL,
            ).decode().strip()
        except:
            pass

        try:
            info["git_commit"] = subprocess.check_output(
                ["git", "rev-parse", "HEAD"],
                stderr=subprocess.DEVNULL,
            ).decode().strip()
        except:
            pass

        try:
            info["git_branch"] = subprocess.check_output(
                ["git", "rev-parse", "--abbrev-ref", "HEAD"],
                stderr=subprocess.DEVNULL,
            ).decode().strip()
        except:
            pass

        try:
            status = subprocess.check_output(
                ["git", "status", "--porcelain"],
                stderr=subprocess.DEVNULL,
            ).decode().strip()
            info["git_dirty"] = len(status) > 0
        except:
            pass

        return info

    def _detect_framework(self, model: Any) -> str:
        """检测模型框架"""
        module = type(model).__module__
        if "torch" in module:
            return "pytorch"
        elif "tensorflow" in module or "keras" in module:
            return "tensorflow"
        elif "transformers" in module:
            return "transformers"
        return "unknown"

    def _upload_directory(self, local_dir: str, artifact_dir: str) -> None:
        """上传目录"""
        for root, dirs, files in os.walk(local_dir):
            for file in files:
                local_path = os.path.join(root, file)
                rel_path = os.path.relpath(local_path, local_dir)
                artifact_path = os.path.join(artifact_dir, rel_path)
                self.log_artifact(local_path, artifact_path)


# 上下文管理器
class ExperimentRun:
    """实验运行上下文管理器"""

    def __init__(
        self,
        tracker: ExperimentTracker,
        config: RunConfig,
        params: Optional[Dict[str, Any]] = None,
    ):
        self.tracker = tracker
        self.config = config
        self.params = params
        self.run: Optional[Dict] = None

    def __enter__(self) -> "ExperimentRun":
        self.run = self.tracker.start_run(self.config, self.params)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        status = "failed" if exc_type else "completed"
        self.tracker.end_run(status)

    def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None:
        self.tracker.log_metric(key, value, step)

    def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
        self.tracker.log_metrics(metrics, step)

    def log_param(self, key: str, value: Any) -> None:
        self.tracker.log_param(key, value)

    def log_params(self, params: Dict[str, Any]) -> None:
        self.tracker.log_params(params)


# 全局 tracker 实例
_global_tracker: Optional[ExperimentTracker] = None

def init(tracking_uri: str, **kwargs) -> ExperimentTracker:
    """初始化全局 tracker"""
    global _global_tracker
    _global_tracker = ExperimentTracker(tracking_uri, **kwargs)
    return _global_tracker

def start_run(
    experiment_id: str,
    run_name: Optional[str] = None,
    params: Optional[Dict[str, Any]] = None,
    **kwargs,
) -> Dict:
    """开始运行"""
    if not _global_tracker:
        raise RuntimeError("Tracker not initialized. Call init() first.")
    config = RunConfig(experiment_id=experiment_id, run_name=run_name, **kwargs)
    return _global_tracker.start_run(config, params)

def log_param(key: str, value: Any) -> None:
    """记录参数"""
    if _global_tracker:
        _global_tracker.log_param(key, value)

def log_params(params: Dict[str, Any]) -> None:
    """记录多个参数"""
    if _global_tracker:
        _global_tracker.log_params(params)

def log_metric(key: str, value: float, step: Optional[int] = None) -> None:
    """记录指标"""
    if _global_tracker:
        _global_tracker.log_metric(key, value, step)

def log_metrics(metrics: Dict[str, float], step: Optional[int] = None) -> None:
    """记录多个指标"""
    if _global_tracker:
        _global_tracker.log_metrics(metrics, step)

def log_artifact(local_path: str, artifact_path: Optional[str] = None) -> None:
    """记录产物"""
    if _global_tracker:
        _global_tracker.log_artifact(local_path, artifact_path)

def end_run(status: str = "completed") -> None:
    """结束运行"""
    if _global_tracker:
        _global_tracker.end_run(status)

使用示例

# training_with_tracking.py
import experiment_sdk as exp

# 初始化
exp.init("http://experiment-server:8080")

# 定义训练参数
params = {
    "model": {
        "name": "llama-7b",
        "hidden_size": 4096,
        "num_layers": 32,
        "num_heads": 32,
    },
    "training": {
        "learning_rate": 1e-4,
        "batch_size": 32,
        "epochs": 3,
        "warmup_steps": 1000,
        "weight_decay": 0.01,
    },
    "data": {
        "dataset": "openwebtext",
        "max_length": 2048,
    },
}

# 开始运行
run = exp.start_run(
    experiment_id="llm-pretraining",
    run_name="lr-1e4-bs32",
    params=params,
)

try:
    # 训练循环
    for epoch in range(params["training"]["epochs"]):
        for step, batch in enumerate(train_dataloader):
            # 训练步骤
            loss = train_step(model, batch)

            # 记录指标
            if step % 100 == 0:
                exp.log_metrics({
                    "train/loss": loss.item(),
                    "train/learning_rate": scheduler.get_last_lr()[0],
                    "train/epoch": epoch,
                }, step=global_step)

        # 评估
        eval_metrics = evaluate(model, eval_dataloader)
        exp.log_metrics({
            "eval/loss": eval_metrics["loss"],
            "eval/perplexity": eval_metrics["perplexity"],
        }, step=global_step)

        # 保存检查点
        checkpoint_path = f"checkpoints/epoch_{epoch}"
        save_checkpoint(model, checkpoint_path)
        exp.log_artifact(checkpoint_path, f"checkpoints/epoch_{epoch}")

    # 保存最终模型
    exp.log_model(model, "final_model", metadata={"epochs": params["training"]["epochs"]})

    exp.end_run("completed")

except Exception as e:
    exp.log_param("error", str(e))
    exp.end_run("failed")
    raise

实验服务端

服务架构

┌─────────────────────────────────────────────────────────────────┐
│                    实验管理服务架构                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                     API Gateway                          │    │
│  │                   (REST + gRPC)                          │    │
│  └─────────────────────────┬───────────────────────────────┘    │
│                            │                                     │
│  ┌─────────────┬───────────┴───────────┬─────────────┐          │
│  │             │                       │             │          │
│  ▼             ▼                       ▼             ▼          │
│  ┌─────┐   ┌─────────┐           ┌─────────┐   ┌─────────┐     │
│  │ Run │   │ Metric  │           │Artifact │   │ Query   │     │
│  │ Svc │   │ Service │           │ Service │   │ Service │     │
│  └──┬──┘   └────┬────┘           └────┬────┘   └────┬────┘     │
│     │           │                     │             │           │
│  ┌──▼───────────▼─────────────────────▼─────────────▼──┐       │
│  │                    Data Layer                        │       │
│  ├──────────────────────────────────────────────────────┤       │
│  │                                                      │       │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐           │       │
│  │  │PostgreSQL│  │TimescaleDB│  │   S3    │           │       │
│  │  │(Metadata)│  │ (Metrics) │  │(Artifact)│           │       │
│  │  └──────────┘  └──────────┘  └──────────┘           │       │
│  │                                                      │       │
│  └──────────────────────────────────────────────────────┘       │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

服务实现

// experiment_service.go
package service

import (
    "context"
    "fmt"
    "time"
)

// ExperimentService 实验服务
type ExperimentService struct {
    // 存储
    projectRepo    ProjectRepository
    experimentRepo ExperimentRepository
    runRepo        RunRepository
    metricRepo     MetricRepository
    artifactRepo   ArtifactRepository

    // 存储后端
    objectStorage  ObjectStorage

    // 缓存
    cache          Cache

    // 事件发布
    eventPublisher EventPublisher
}

// NewExperimentService 创建服务
func NewExperimentService(
    projectRepo ProjectRepository,
    experimentRepo ExperimentRepository,
    runRepo RunRepository,
    metricRepo MetricRepository,
    artifactRepo ArtifactRepository,
    objectStorage ObjectStorage,
    cache Cache,
    eventPublisher EventPublisher,
) *ExperimentService {
    return &ExperimentService{
        projectRepo:    projectRepo,
        experimentRepo: experimentRepo,
        runRepo:        runRepo,
        metricRepo:     metricRepo,
        artifactRepo:   artifactRepo,
        objectStorage:  objectStorage,
        cache:          cache,
        eventPublisher: eventPublisher,
    }
}

// CreateProject 创建项目
func (s *ExperimentService) CreateProject(ctx context.Context, project *Project) error {
    project.ID = generateID("proj")
    project.CreatedAt = time.Now()
    project.UpdatedAt = time.Now()

    if err := s.projectRepo.Create(ctx, project); err != nil {
        return fmt.Errorf("create project: %w", err)
    }

    s.eventPublisher.Publish(ctx, "project.created", project)
    return nil
}

// CreateExperiment 创建实验
func (s *ExperimentService) CreateExperiment(ctx context.Context, experiment *Experiment) error {
    // 验证项目存在
    if _, err := s.projectRepo.Get(ctx, experiment.ProjectID); err != nil {
        return fmt.Errorf("project not found: %w", err)
    }

    experiment.ID = generateID("exp")
    experiment.Status = ExperimentStatusActive
    experiment.CreatedAt = time.Now()
    experiment.UpdatedAt = time.Now()

    if err := s.experimentRepo.Create(ctx, experiment); err != nil {
        return fmt.Errorf("create experiment: %w", err)
    }

    s.eventPublisher.Publish(ctx, "experiment.created", experiment)
    return nil
}

// CreateRun 创建运行
func (s *ExperimentService) CreateRun(ctx context.Context, run *Run) error {
    // 验证实验存在
    if _, err := s.experimentRepo.Get(ctx, run.ExperimentID); err != nil {
        return fmt.Errorf("experiment not found: %w", err)
    }

    run.ID = generateID("run")
    run.Status = RunStatusRunning
    run.CreatedAt = time.Now()

    if err := s.runRepo.Create(ctx, run); err != nil {
        return fmt.Errorf("create run: %w", err)
    }

    s.eventPublisher.Publish(ctx, "run.started", run)
    return nil
}

// LogParams 记录参数
func (s *ExperimentService) LogParams(ctx context.Context, runID string, params map[string]interface{}) error {
    run, err := s.runRepo.Get(ctx, runID)
    if err != nil {
        return fmt.Errorf("run not found: %w", err)
    }

    // 合并参数
    if run.Parameters.Flat == nil {
        run.Parameters.Flat = make(map[string]interface{})
    }
    for k, v := range params {
        run.Parameters.Flat[k] = v
    }

    return s.runRepo.Update(ctx, run)
}

// LogMetrics 记录指标
func (s *ExperimentService) LogMetrics(ctx context.Context, metrics []*Metric) error {
    if len(metrics) == 0 {
        return nil
    }

    // 批量写入
    if err := s.metricRepo.BatchCreate(ctx, metrics); err != nil {
        return fmt.Errorf("batch create metrics: %w", err)
    }

    // 更新缓存中的最新指标
    runID := metrics[0].RunID
    latestMetrics := make(map[string]*Metric)
    for _, m := range metrics {
        key := fmt.Sprintf("%s:%s", m.Key, m.Context)
        if existing, ok := latestMetrics[key]; !ok || m.Step > existing.Step {
            latestMetrics[key] = m
        }
    }

    for _, m := range latestMetrics {
        cacheKey := fmt.Sprintf("run:%s:metric:%s:%s:latest", runID, m.Key, m.Context)
        s.cache.Set(ctx, cacheKey, m, time.Hour)
    }

    return nil
}

// LogArtifact 记录产物
func (s *ExperimentService) LogArtifact(ctx context.Context, runID string, name string, data []byte, artifactType string) error {
    run, err := s.runRepo.Get(ctx, runID)
    if err != nil {
        return fmt.Errorf("run not found: %w", err)
    }

    // 上传到对象存储
    storagePath := fmt.Sprintf("runs/%s/artifacts/%s", runID, name)
    if err := s.objectStorage.Upload(ctx, storagePath, data); err != nil {
        return fmt.Errorf("upload artifact: %w", err)
    }

    // 计算校验和
    checksum := calculateChecksum(data)

    // 创建产物记录
    artifact := &Artifact{
        ID:        generateID("art"),
        RunID:     run.ID,
        Name:      name,
        Type:      artifactType,
        Path:      storagePath,
        Size:      int64(len(data)),
        Checksum:  checksum,
        CreatedAt: time.Now(),
    }

    if err := s.artifactRepo.Create(ctx, artifact); err != nil {
        return fmt.Errorf("create artifact record: %w", err)
    }

    return nil
}

// EndRun 结束运行
func (s *ExperimentService) EndRun(ctx context.Context, runID string, status RunStatus) error {
    run, err := s.runRepo.Get(ctx, runID)
    if err != nil {
        return fmt.Errorf("run not found: %w", err)
    }

    now := time.Now()
    run.Status = status
    run.EndTime = &now

    if run.StartTime != nil {
        run.Duration = int64(now.Sub(*run.StartTime).Seconds())
    }

    // 计算资源使用
    run.Resources = s.calculateResourceUsage(ctx, run)

    if err := s.runRepo.Update(ctx, run); err != nil {
        return fmt.Errorf("update run: %w", err)
    }

    s.eventPublisher.Publish(ctx, "run.ended", run)
    return nil
}

// GetRunMetrics 获取运行指标
func (s *ExperimentService) GetRunMetrics(
    ctx context.Context,
    runID string,
    keys []string,
    stepRange *StepRange,
) ([]*Metric, error) {
    return s.metricRepo.Query(ctx, MetricQuery{
        RunID:     runID,
        Keys:      keys,
        StepRange: stepRange,
    })
}

// CompareRuns 比较运行
func (s *ExperimentService) CompareRuns(
    ctx context.Context,
    runIDs []string,
    metricKeys []string,
) (*RunComparison, error) {
    comparison := &RunComparison{
        Runs:    make([]*Run, 0, len(runIDs)),
        Metrics: make(map[string]map[string]float64),
        Params:  make(map[string]map[string]interface{}),
    }

    for _, runID := range runIDs {
        run, err := s.runRepo.Get(ctx, runID)
        if err != nil {
            return nil, err
        }
        comparison.Runs = append(comparison.Runs, run)
        comparison.Params[runID] = run.Parameters.Flat

        // 获取最终指标
        for _, key := range metricKeys {
            cacheKey := fmt.Sprintf("run:%s:metric:%s:train:latest", runID, key)
            var metric *Metric
            if err := s.cache.Get(ctx, cacheKey, &metric); err == nil && metric != nil {
                if comparison.Metrics[key] == nil {
                    comparison.Metrics[key] = make(map[string]float64)
                }
                comparison.Metrics[key][runID] = metric.Value
            }
        }
    }

    return comparison, nil
}

// SearchRuns 搜索运行
func (s *ExperimentService) SearchRuns(
    ctx context.Context,
    query RunSearchQuery,
) ([]*Run, int64, error) {
    return s.runRepo.Search(ctx, query)
}

// calculateResourceUsage 计算资源使用
func (s *ExperimentService) calculateResourceUsage(ctx context.Context, run *Run) ResourceUsage {
    usage := ResourceUsage{}

    if run.StartTime == nil || run.EndTime == nil {
        return usage
    }

    duration := run.EndTime.Sub(*run.StartTime)
    hours := duration.Hours()

    // GPU 小时
    usage.GPUHours = float64(run.Environment.GPUCount) * hours

    // CPU 小时
    usage.CPUHours = float64(run.Environment.CPUCount) * hours

    // 内存 GB 小时
    usage.MemoryGBHours = float64(run.Environment.MemoryGB) * hours

    // 预估成本(示例价格)
    gpuPrice := 2.0  // $/GPU/hour
    cpuPrice := 0.05 // $/CPU/hour
    memPrice := 0.01 // $/GB/hour

    usage.Cost = usage.GPUHours*gpuPrice + usage.CPUHours*cpuPrice + usage.MemoryGBHours*memPrice

    return usage
}

// RunSearchQuery 运行搜索查询
type RunSearchQuery struct {
    ExperimentID string
    Status       []RunStatus
    Tags         map[string]string
    ParamFilter  map[string]interface{}
    MetricFilter *MetricFilter
    TimeRange    *TimeRange
    OrderBy      string
    Limit        int
    Offset       int
}

// MetricFilter 指标过滤
type MetricFilter struct {
    Key      string
    Operator string // ">", "<", ">=", "<=", "="
    Value    float64
}

// TimeRange 时间范围
type TimeRange struct {
    Start time.Time
    End   time.Time
}

// StepRange 步数范围
type StepRange struct {
    Start int64
    End   int64
}

// RunComparison 运行比较结果
type RunComparison struct {
    Runs    []*Run
    Metrics map[string]map[string]float64      // metric -> runID -> value
    Params  map[string]map[string]interface{} // runID -> params
}

func generateID(prefix string) string {
    return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}

func calculateChecksum(data []byte) string {
    // MD5 或 SHA256
    return ""
}

指标存储优化

// metric_repository.go
package repository

import (
    "context"
    "fmt"
    "strings"
    "time"

    "github.com/jackc/pgx/v5/pgxpool"
)

// TimescaleMetricRepository 使用 TimescaleDB 的指标存储
type TimescaleMetricRepository struct {
    pool *pgxpool.Pool
}

func NewTimescaleMetricRepository(pool *pgxpool.Pool) *TimescaleMetricRepository {
    return &TimescaleMetricRepository{pool: pool}
}

// 初始化表结构
func (r *TimescaleMetricRepository) Init(ctx context.Context) error {
    // 创建 hypertable
    queries := []string{
        `CREATE TABLE IF NOT EXISTS metrics (
            run_id TEXT NOT NULL,
            key TEXT NOT NULL,
            value DOUBLE PRECISION NOT NULL,
            step BIGINT NOT NULL,
            context TEXT DEFAULT 'train',
            timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
        )`,

        // 转换为 hypertable(TimescaleDB 特性)
        `SELECT create_hypertable('metrics', 'timestamp',
            if_not_exists => TRUE,
            chunk_time_interval => INTERVAL '1 day'
        )`,

        // 创建索引
        `CREATE INDEX IF NOT EXISTS idx_metrics_run_id ON metrics (run_id)`,
        `CREATE INDEX IF NOT EXISTS idx_metrics_run_key ON metrics (run_id, key)`,
        `CREATE INDEX IF NOT EXISTS idx_metrics_run_key_step ON metrics (run_id, key, step DESC)`,

        // 启用压缩(旧数据压缩)
        `ALTER TABLE metrics SET (
            timescaledb.compress,
            timescaledb.compress_segmentby = 'run_id,key'
        )`,

        // 自动压缩策略(7天后压缩)
        `SELECT add_compression_policy('metrics', INTERVAL '7 days', if_not_exists => TRUE)`,

        // 数据保留策略(90天后删除)
        `SELECT add_retention_policy('metrics', INTERVAL '90 days', if_not_exists => TRUE)`,
    }

    for _, q := range queries {
        if _, err := r.pool.Exec(ctx, q); err != nil {
            // 忽略某些预期错误
            if !strings.Contains(err.Error(), "already exists") {
                return fmt.Errorf("init query failed: %w", err)
            }
        }
    }

    return nil
}

// BatchCreate 批量创建指标
func (r *TimescaleMetricRepository) BatchCreate(ctx context.Context, metrics []*Metric) error {
    if len(metrics) == 0 {
        return nil
    }

    // 使用 COPY 协议进行高效批量插入
    query := `INSERT INTO metrics (run_id, key, value, step, context, timestamp) VALUES `
    values := make([]interface{}, 0, len(metrics)*6)
    placeholders := make([]string, 0, len(metrics))

    for i, m := range metrics {
        base := i * 6
        placeholders = append(placeholders,
            fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d)",
                base+1, base+2, base+3, base+4, base+5, base+6))
        values = append(values, m.RunID, m.Key, m.Value, m.Step, m.Context, m.Timestamp)
    }

    query += strings.Join(placeholders, ", ")

    _, err := r.pool.Exec(ctx, query, values...)
    return err
}

// Query 查询指标
func (r *TimescaleMetricRepository) Query(ctx context.Context, query MetricQuery) ([]*Metric, error) {
    var conditions []string
    var args []interface{}
    argIdx := 1

    conditions = append(conditions, fmt.Sprintf("run_id = $%d", argIdx))
    args = append(args, query.RunID)
    argIdx++

    if len(query.Keys) > 0 {
        placeholders := make([]string, len(query.Keys))
        for i, key := range query.Keys {
            placeholders[i] = fmt.Sprintf("$%d", argIdx)
            args = append(args, key)
            argIdx++
        }
        conditions = append(conditions, fmt.Sprintf("key IN (%s)", strings.Join(placeholders, ",")))
    }

    if query.StepRange != nil {
        conditions = append(conditions, fmt.Sprintf("step >= $%d AND step <= $%d", argIdx, argIdx+1))
        args = append(args, query.StepRange.Start, query.StepRange.End)
        argIdx += 2
    }

    sql := fmt.Sprintf(`
        SELECT run_id, key, value, step, context, timestamp
        FROM metrics
        WHERE %s
        ORDER BY step ASC
    `, strings.Join(conditions, " AND "))

    rows, err := r.pool.Query(ctx, sql, args...)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    var metrics []*Metric
    for rows.Next() {
        m := &Metric{}
        if err := rows.Scan(&m.RunID, &m.Key, &m.Value, &m.Step, &m.Context, &m.Timestamp); err != nil {
            return nil, err
        }
        metrics = append(metrics, m)
    }

    return metrics, nil
}

// GetAggregated 获取聚合指标
func (r *TimescaleMetricRepository) GetAggregated(
    ctx context.Context,
    runID string,
    key string,
    aggregation string, // min, max, avg, last
    bucketSize int64, // 步数间隔
) ([]*AggregatedMetric, error) {

    var aggFunc string
    switch aggregation {
    case "min":
        aggFunc = "MIN(value)"
    case "max":
        aggFunc = "MAX(value)"
    case "avg":
        aggFunc = "AVG(value)"
    case "last":
        aggFunc = "last(value, step)"
    default:
        aggFunc = "AVG(value)"
    }

    sql := fmt.Sprintf(`
        SELECT
            (step / $3) * $3 as bucket_step,
            %s as value,
            COUNT(*) as count
        FROM metrics
        WHERE run_id = $1 AND key = $2
        GROUP BY bucket_step
        ORDER BY bucket_step ASC
    `, aggFunc)

    rows, err := r.pool.Query(ctx, sql, runID, key, bucketSize)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    var results []*AggregatedMetric
    for rows.Next() {
        m := &AggregatedMetric{}
        if err := rows.Scan(&m.Step, &m.Value, &m.Count); err != nil {
            return nil, err
        }
        results = append(results, m)
    }

    return results, nil
}

// AggregatedMetric 聚合指标
type AggregatedMetric struct {
    Step  int64
    Value float64
    Count int64
}

// MetricQuery 指标查询
type MetricQuery struct {
    RunID     string
    Keys      []string
    StepRange *StepRange
    Context   string
}

实验比较与分析

比较分析服务

// analysis_service.go
package service

import (
    "context"
    "math"
    "sort"
)

// AnalysisService 分析服务
type AnalysisService struct {
    experimentService *ExperimentService
    metricRepo        MetricRepository
}

// CompareExperiments 比较实验
func (s *AnalysisService) CompareExperiments(
    ctx context.Context,
    experimentIDs []string,
    metricKeys []string,
) (*ExperimentComparison, error) {

    comparison := &ExperimentComparison{
        Experiments: make([]*ExperimentSummary, 0, len(experimentIDs)),
    }

    for _, expID := range experimentIDs {
        summary, err := s.getExperimentSummary(ctx, expID, metricKeys)
        if err != nil {
            return nil, err
        }
        comparison.Experiments = append(comparison.Experiments, summary)
    }

    // 计算统计信息
    comparison.Statistics = s.calculateComparisonStatistics(comparison.Experiments, metricKeys)

    return comparison, nil
}

// ExperimentComparison 实验比较结果
type ExperimentComparison struct {
    Experiments []*ExperimentSummary
    Statistics  map[string]*MetricStatistics
}

// ExperimentSummary 实验摘要
type ExperimentSummary struct {
    Experiment *Experiment
    BestRun    *Run
    RunCount   int
    Metrics    map[string]*MetricSummary
}

// MetricSummary 指标摘要
type MetricSummary struct {
    Best     float64
    Worst    float64
    Mean     float64
    Std      float64
    Median   float64
    BestRun  string
}

// MetricStatistics 指标统计
type MetricStatistics struct {
    BestExperiment  string
    BestValue       float64
    Improvement     float64 // 相对于基准的提升
    Significance    float64 // 统计显著性
}

func (s *AnalysisService) getExperimentSummary(
    ctx context.Context,
    experimentID string,
    metricKeys []string,
) (*ExperimentSummary, error) {

    // 获取实验
    exp, err := s.experimentService.experimentRepo.Get(ctx, experimentID)
    if err != nil {
        return nil, err
    }

    // 获取所有运行
    runs, _, err := s.experimentService.SearchRuns(ctx, RunSearchQuery{
        ExperimentID: experimentID,
        Status:       []RunStatus{RunStatusCompleted},
    })
    if err != nil {
        return nil, err
    }

    summary := &ExperimentSummary{
        Experiment: exp,
        RunCount:   len(runs),
        Metrics:    make(map[string]*MetricSummary),
    }

    // 计算每个指标的统计信息
    for _, key := range metricKeys {
        values := make([]float64, 0, len(runs))
        var bestRun *Run
        var bestValue float64 = math.Inf(-1)

        for _, run := range runs {
            metrics, err := s.metricRepo.Query(ctx, MetricQuery{
                RunID: run.ID,
                Keys:  []string{key},
            })
            if err != nil || len(metrics) == 0 {
                continue
            }

            // 取最后一个值
            lastMetric := metrics[len(metrics)-1]
            values = append(values, lastMetric.Value)

            if lastMetric.Value > bestValue {
                bestValue = lastMetric.Value
                bestRun = run
            }
        }

        if len(values) > 0 {
            summary.Metrics[key] = &MetricSummary{
                Best:    maxFloat(values),
                Worst:   minFloat(values),
                Mean:    meanFloat(values),
                Std:     stdFloat(values),
                Median:  medianFloat(values),
                BestRun: bestRun.ID,
            }

            if summary.BestRun == nil || bestValue > summary.Metrics[metricKeys[0]].Best {
                summary.BestRun = bestRun
            }
        }
    }

    return summary, nil
}

func (s *AnalysisService) calculateComparisonStatistics(
    experiments []*ExperimentSummary,
    metricKeys []string,
) map[string]*MetricStatistics {

    stats := make(map[string]*MetricStatistics)

    for _, key := range metricKeys {
        var bestExp *ExperimentSummary
        var bestValue float64 = math.Inf(-1)

        for _, exp := range experiments {
            if m, ok := exp.Metrics[key]; ok && m.Best > bestValue {
                bestValue = m.Best
                bestExp = exp
            }
        }

        if bestExp != nil {
            stats[key] = &MetricStatistics{
                BestExperiment: bestExp.Experiment.ID,
                BestValue:      bestValue,
            }
        }
    }

    return stats
}

// ParameterImportance 参数重要性分析
func (s *AnalysisService) ParameterImportance(
    ctx context.Context,
    experimentID string,
    targetMetric string,
) (*ParameterImportanceResult, error) {

    // 获取所有完成的运行
    runs, _, err := s.experimentService.SearchRuns(ctx, RunSearchQuery{
        ExperimentID: experimentID,
        Status:       []RunStatus{RunStatusCompleted},
    })
    if err != nil {
        return nil, err
    }

    // 收集参数和指标
    type dataPoint struct {
        params map[string]interface{}
        metric float64
    }

    var data []dataPoint
    allParams := make(map[string]bool)

    for _, run := range runs {
        metrics, err := s.metricRepo.Query(ctx, MetricQuery{
            RunID: run.ID,
            Keys:  []string{targetMetric},
        })
        if err != nil || len(metrics) == 0 {
            continue
        }

        dp := dataPoint{
            params: run.Parameters.Flat,
            metric: metrics[len(metrics)-1].Value,
        }
        data = append(data, dp)

        for k := range run.Parameters.Flat {
            allParams[k] = true
        }
    }

    // 计算每个参数的重要性(使用简单的相关性分析)
    result := &ParameterImportanceResult{
        Importance: make(map[string]float64),
    }

    for param := range allParams {
        importance := s.calculateParameterCorrelation(data, param)
        result.Importance[param] = importance
    }

    // 排序
    type paramScore struct {
        name  string
        score float64
    }
    var scores []paramScore
    for name, score := range result.Importance {
        scores = append(scores, paramScore{name, score})
    }
    sort.Slice(scores, func(i, j int) bool {
        return math.Abs(scores[i].score) > math.Abs(scores[j].score)
    })

    result.Ranking = make([]string, len(scores))
    for i, s := range scores {
        result.Ranking[i] = s.name
    }

    return result, nil
}

// ParameterImportanceResult 参数重要性结果
type ParameterImportanceResult struct {
    Importance map[string]float64
    Ranking    []string
}

func (s *AnalysisService) calculateParameterCorrelation(
    data []dataPoint,
    param string,
) float64 {
    // 简化的相关性计算
    // 实际应使用更复杂的方法(如 Spearman 相关系数)

    var paramValues, metricValues []float64

    for _, dp := range data {
        if v, ok := dp.params[param]; ok {
            // 尝试转换为数值
            switch val := v.(type) {
            case float64:
                paramValues = append(paramValues, val)
                metricValues = append(metricValues, dp.metric)
            case int:
                paramValues = append(paramValues, float64(val))
                metricValues = append(metricValues, dp.metric)
            }
        }
    }

    if len(paramValues) < 3 {
        return 0
    }

    return pearsonCorrelation(paramValues, metricValues)
}

// 辅助统计函数
func maxFloat(values []float64) float64 {
    if len(values) == 0 {
        return 0
    }
    max := values[0]
    for _, v := range values[1:] {
        if v > max {
            max = v
        }
    }
    return max
}

func minFloat(values []float64) float64 {
    if len(values) == 0 {
        return 0
    }
    min := values[0]
    for _, v := range values[1:] {
        if v < min {
            min = v
        }
    }
    return min
}

func meanFloat(values []float64) float64 {
    if len(values) == 0 {
        return 0
    }
    sum := 0.0
    for _, v := range values {
        sum += v
    }
    return sum / float64(len(values))
}

func stdFloat(values []float64) float64 {
    if len(values) < 2 {
        return 0
    }
    mean := meanFloat(values)
    sumSq := 0.0
    for _, v := range values {
        sumSq += (v - mean) * (v - mean)
    }
    return math.Sqrt(sumSq / float64(len(values)-1))
}

func medianFloat(values []float64) float64 {
    if len(values) == 0 {
        return 0
    }
    sorted := make([]float64, len(values))
    copy(sorted, values)
    sort.Float64s(sorted)

    mid := len(sorted) / 2
    if len(sorted)%2 == 0 {
        return (sorted[mid-1] + sorted[mid]) / 2
    }
    return sorted[mid]
}

func pearsonCorrelation(x, y []float64) float64 {
    if len(x) != len(y) || len(x) == 0 {
        return 0
    }

    n := float64(len(x))
    sumX, sumY, sumXY, sumX2, sumY2 := 0.0, 0.0, 0.0, 0.0, 0.0

    for i := range x {
        sumX += x[i]
        sumY += y[i]
        sumXY += x[i] * y[i]
        sumX2 += x[i] * x[i]
        sumY2 += y[i] * y[i]
    }

    numerator := n*sumXY - sumX*sumY
    denominator := math.Sqrt((n*sumX2 - sumX*sumX) * (n*sumY2 - sumY*sumY))

    if denominator == 0 {
        return 0
    }

    return numerator / denominator
}

可视化组件

指标图表服务

// visualization_service.go
package service

import (
    "context"
    "encoding/json"
)

// VisualizationService 可视化服务
type VisualizationService struct {
    metricRepo MetricRepository
    runRepo    RunRepository
}

// ChartData 图表数据
type ChartData struct {
    Type    string                 `json:"type"` // line, scatter, bar, heatmap
    Title   string                 `json:"title"`
    XLabel  string                 `json:"x_label"`
    YLabel  string                 `json:"y_label"`
    Series  []ChartSeries          `json:"series"`
    Options map[string]interface{} `json:"options"`
}

// ChartSeries 图表系列
type ChartSeries struct {
    Name   string      `json:"name"`
    Data   [][]float64 `json:"data"` // [[x, y], ...]
    Color  string      `json:"color,omitempty"`
    Style  string      `json:"style,omitempty"` // solid, dashed
}

// GetMetricChart 获取指标图表
func (s *VisualizationService) GetMetricChart(
    ctx context.Context,
    runIDs []string,
    metricKey string,
    chartType string,
) (*ChartData, error) {

    chart := &ChartData{
        Type:   chartType,
        Title:  metricKey,
        XLabel: "Step",
        YLabel: metricKey,
        Series: make([]ChartSeries, 0, len(runIDs)),
    }

    colors := []string{"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"}

    for i, runID := range runIDs {
        run, err := s.runRepo.Get(ctx, runID)
        if err != nil {
            continue
        }

        metrics, err := s.metricRepo.Query(ctx, MetricQuery{
            RunID: runID,
            Keys:  []string{metricKey},
        })
        if err != nil {
            continue
        }

        series := ChartSeries{
            Name:  run.Name,
            Data:  make([][]float64, 0, len(metrics)),
            Color: colors[i%len(colors)],
        }

        for _, m := range metrics {
            series.Data = append(series.Data, []float64{float64(m.Step), m.Value})
        }

        chart.Series = append(chart.Series, series)
    }

    return chart, nil
}

// GetParallelCoordinatesData 获取平行坐标图数据
func (s *VisualizationService) GetParallelCoordinatesData(
    ctx context.Context,
    experimentID string,
    paramKeys []string,
    metricKey string,
) (*ParallelCoordinatesData, error) {

    runs, _, err := s.runRepo.Search(ctx, RunSearchQuery{
        ExperimentID: experimentID,
        Status:       []RunStatus{RunStatusCompleted},
    })
    if err != nil {
        return nil, err
    }

    data := &ParallelCoordinatesData{
        Dimensions: make([]Dimension, 0, len(paramKeys)+1),
        Data:       make([][]interface{}, 0, len(runs)),
    }

    // 添加参数维度
    for _, key := range paramKeys {
        data.Dimensions = append(data.Dimensions, Dimension{
            Name:  key,
            Type:  "param",
            Range: s.calculateRange(runs, key),
        })
    }

    // 添加指标维度
    data.Dimensions = append(data.Dimensions, Dimension{
        Name: metricKey,
        Type: "metric",
    })

    // 填充数据
    for _, run := range runs {
        row := make([]interface{}, 0, len(paramKeys)+1)

        for _, key := range paramKeys {
            if v, ok := run.Parameters.Flat[key]; ok {
                row = append(row, v)
            } else {
                row = append(row, nil)
            }
        }

        // 获取指标值
        metrics, _ := s.metricRepo.Query(ctx, MetricQuery{
            RunID: run.ID,
            Keys:  []string{metricKey},
        })
        if len(metrics) > 0 {
            row = append(row, metrics[len(metrics)-1].Value)
        } else {
            row = append(row, nil)
        }

        data.Data = append(data.Data, row)
    }

    return data, nil
}

// ParallelCoordinatesData 平行坐标图数据
type ParallelCoordinatesData struct {
    Dimensions []Dimension       `json:"dimensions"`
    Data       [][]interface{}   `json:"data"`
}

// Dimension 维度定义
type Dimension struct {
    Name  string    `json:"name"`
    Type  string    `json:"type"` // param, metric
    Range []float64 `json:"range,omitempty"`
}

func (s *VisualizationService) calculateRange(runs []*Run, paramKey string) []float64 {
    var min, max float64 = 1e10, -1e10
    for _, run := range runs {
        if v, ok := run.Parameters.Flat[paramKey]; ok {
            switch val := v.(type) {
            case float64:
                if val < min {
                    min = val
                }
                if val > max {
                    max = val
                }
            }
        }
    }
    return []float64{min, max}
}

// GetContourPlotData 获取等高线图数据(超参数与指标关系)
func (s *VisualizationService) GetContourPlotData(
    ctx context.Context,
    experimentID string,
    paramX, paramY string,
    metricKey string,
) (*ContourPlotData, error) {

    runs, _, err := s.runRepo.Search(ctx, RunSearchQuery{
        ExperimentID: experimentID,
        Status:       []RunStatus{RunStatusCompleted},
    })
    if err != nil {
        return nil, err
    }

    data := &ContourPlotData{
        XLabel: paramX,
        YLabel: paramY,
        ZLabel: metricKey,
        Points: make([]ContourPoint, 0, len(runs)),
    }

    for _, run := range runs {
        xVal, xOK := run.Parameters.Flat[paramX]
        yVal, yOK := run.Parameters.Flat[paramY]

        if !xOK || !yOK {
            continue
        }

        metrics, _ := s.metricRepo.Query(ctx, MetricQuery{
            RunID: run.ID,
            Keys:  []string{metricKey},
        })
        if len(metrics) == 0 {
            continue
        }

        x, _ := toFloat64(xVal)
        y, _ := toFloat64(yVal)
        z := metrics[len(metrics)-1].Value

        data.Points = append(data.Points, ContourPoint{
            X:     x,
            Y:     y,
            Z:     z,
            RunID: run.ID,
        })
    }

    return data, nil
}

// ContourPlotData 等高线图数据
type ContourPlotData struct {
    XLabel string         `json:"x_label"`
    YLabel string         `json:"y_label"`
    ZLabel string         `json:"z_label"`
    Points []ContourPoint `json:"points"`
}

// ContourPoint 数据点
type ContourPoint struct {
    X     float64 `json:"x"`
    Y     float64 `json:"y"`
    Z     float64 `json:"z"`
    RunID string  `json:"run_id"`
}

func toFloat64(v interface{}) (float64, bool) {
    switch val := v.(type) {
    case float64:
        return val, true
    case int:
        return float64(val), true
    case int64:
        return float64(val), true
    default:
        return 0, false
    }
}

Kubernetes 集成

实验 CRD

# experiment-crd.yaml
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
  name: experiments.ai.platform.io
spec:
  group: ai.platform.io
  versions:
  - name: v1
    served: true
    storage: true
    schema:
      openAPIV3Schema:
        type: object
        properties:
          spec:
            type: object
            properties:
              projectRef:
                type: string
              name:
                type: string
              description:
                type: string
              hypothesis:
                type: string
              template:
                type: object
                properties:
                  image:
                    type: string
                  command:
                    type: array
                    items:
                      type: string
                  resources:
                    type: object
                    properties:
                      gpus:
                        type: integer
                      memory:
                        type: string
                      cpu:
                        type: string
                  env:
                    type: array
                    items:
                      type: object
                      properties:
                        name:
                          type: string
                        value:
                          type: string
              runs:
                type: array
                items:
                  type: object
                  properties:
                    name:
                      type: string
                    parameters:
                      type: object
                      x-kubernetes-preserve-unknown-fields: true
          status:
            type: object
            properties:
              phase:
                type: string
              activeRuns:
                type: integer
              completedRuns:
                type: integer
              failedRuns:
                type: integer
              bestRun:
                type: object
                properties:
                  name:
                    type: string
                  metric:
                    type: number
    subresources:
      status: {}
  scope: Namespaced
  names:
    plural: experiments
    singular: experiment
    kind: Experiment
    shortNames:
    - exp

---
# 示例实验
apiVersion: ai.platform.io/v1
kind: Experiment
metadata:
  name: lr-sweep-llm-7b
  namespace: ai-experiments
spec:
  projectRef: llm-pretraining
  name: Learning Rate Sweep
  description: "探索不同学习率对 LLM-7B 训练效果的影响"
  hypothesis: "较小的学习率(1e-5)可能比较大的学习率(1e-4)获得更好的验证损失"

  template:
    image: training:v1.0
    command:
    - python
    - train.py
    - --config=/config/training.yaml
    resources:
      gpus: 8
      memory: "128Gi"
      cpu: "32"
    env:
    - name: EXPERIMENT_TRACKING_URI
      value: "http://experiment-server:8080"

  runs:
  - name: lr-1e-5
    parameters:
      learning_rate: 0.00001
      warmup_steps: 1000
  - name: lr-5e-5
    parameters:
      learning_rate: 0.00005
      warmup_steps: 1000
  - name: lr-1e-4
    parameters:
      learning_rate: 0.0001
      warmup_steps: 2000
  - name: lr-5e-4
    parameters:
      learning_rate: 0.0005
      warmup_steps: 2000

实验控制器

// experiment_controller.go
package controller

import (
    "context"
    "fmt"

    "k8s.io/apimachinery/pkg/runtime"
    ctrl "sigs.k8s.io/controller-runtime"
    "sigs.k8s.io/controller-runtime/pkg/client"
    "sigs.k8s.io/controller-runtime/pkg/log"

    aiv1 "ai.platform.io/api/v1"
    batchv1 "k8s.io/api/batch/v1"
    corev1 "k8s.io/api/core/v1"
    metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

// ExperimentReconciler 实验协调器
type ExperimentReconciler struct {
    client.Client
    Scheme *runtime.Scheme
}

// Reconcile 协调实验状态
func (r *ExperimentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
    logger := log.FromContext(ctx)

    // 获取实验
    var experiment aiv1.Experiment
    if err := r.Get(ctx, req.NamespacedName, &experiment); err != nil {
        return ctrl.Result{}, client.IgnoreNotFound(err)
    }

    // 处理每个运行配置
    for _, runConfig := range experiment.Spec.Runs {
        // 检查是否已创建
        jobName := fmt.Sprintf("%s-%s", experiment.Name, runConfig.Name)

        var existingJob batchv1.Job
        err := r.Get(ctx, client.ObjectKey{
            Namespace: experiment.Namespace,
            Name:      jobName,
        }, &existingJob)

        if err == nil {
            // Job 已存在,更新状态
            r.updateRunStatus(ctx, &experiment, runConfig.Name, &existingJob)
            continue
        }

        // 创建新 Job
        job := r.buildJob(&experiment, runConfig)
        if err := r.Create(ctx, job); err != nil {
            logger.Error(err, "Failed to create job", "job", jobName)
            return ctrl.Result{}, err
        }

        logger.Info("Created job", "job", jobName)
    }

    // 更新实验状态
    r.updateExperimentStatus(ctx, &experiment)

    return ctrl.Result{}, nil
}

// buildJob 构建 Job
func (r *ExperimentReconciler) buildJob(experiment *aiv1.Experiment, runConfig aiv1.RunConfig) *batchv1.Job {
    jobName := fmt.Sprintf("%s-%s", experiment.Name, runConfig.Name)

    // 构建环境变量
    env := experiment.Spec.Template.Env
    env = append(env,
        corev1.EnvVar{
            Name:  "RUN_NAME",
            Value: runConfig.Name,
        },
        corev1.EnvVar{
            Name:  "EXPERIMENT_ID",
            Value: experiment.Name,
        },
    )

    // 添加参数作为环境变量
    for key, value := range runConfig.Parameters {
        env = append(env, corev1.EnvVar{
            Name:  fmt.Sprintf("PARAM_%s", key),
            Value: fmt.Sprintf("%v", value),
        })
    }

    job := &batchv1.Job{
        ObjectMeta: metav1.ObjectMeta{
            Name:      jobName,
            Namespace: experiment.Namespace,
            Labels: map[string]string{
                "experiment":    experiment.Name,
                "run":          runConfig.Name,
                "ai.platform.io/type": "experiment-run",
            },
            OwnerReferences: []metav1.OwnerReference{
                *metav1.NewControllerRef(experiment, aiv1.GroupVersion.WithKind("Experiment")),
            },
        },
        Spec: batchv1.JobSpec{
            Template: corev1.PodTemplateSpec{
                Spec: corev1.PodSpec{
                    RestartPolicy: corev1.RestartPolicyNever,
                    Containers: []corev1.Container{
                        {
                            Name:    "training",
                            Image:   experiment.Spec.Template.Image,
                            Command: experiment.Spec.Template.Command,
                            Env:     env,
                            Resources: corev1.ResourceRequirements{
                                Limits: corev1.ResourceList{
                                    "nvidia.com/gpu": *parseQuantity(experiment.Spec.Template.Resources.GPUs),
                                    corev1.ResourceMemory: parseQuantity(experiment.Spec.Template.Resources.Memory),
                                    corev1.ResourceCPU:    parseQuantity(experiment.Spec.Template.Resources.CPU),
                                },
                            },
                        },
                    },
                },
            },
        },
    }

    return job
}

func (r *ExperimentReconciler) updateRunStatus(ctx context.Context, experiment *aiv1.Experiment, runName string, job *batchv1.Job) {
    // 根据 Job 状态更新运行状态
}

func (r *ExperimentReconciler) updateExperimentStatus(ctx context.Context, experiment *aiv1.Experiment) {
    // 统计运行状态并更新实验状态
}

// SetupWithManager 设置控制器
func (r *ExperimentReconciler) SetupWithManager(mgr ctrl.Manager) error {
    return ctrl.NewControllerManagedBy(mgr).
        For(&aiv1.Experiment{}).
        Owns(&batchv1.Job{}).
        Complete(r)
}

最佳实践

实验管理规范

┌─────────────────────────────────────────────────────────────────┐
│                    实验管理最佳实践                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  1. 实验设计原则                                                 │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 明确假设:每个实验应有清晰的假设和预期结果                │   │
│  │ • 控制变量:每次只改变少量变量,便于分析影响               │   │
│  │ • 可复现性:记录所有必要信息以复现实验                      │   │
│  │ • 基线对比:始终与基线模型进行对比                          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  2. 命名规范                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 项目:{team}-{model}-{task}                              │   │
│  │   例如:nlp-llama-pretraining                              │   │
│  │                                                            │   │
│  │ • 实验:{variable}-{sweep/ablation/comparison}             │   │
│  │   例如:lr-sweep, attention-ablation                       │   │
│  │                                                            │   │
│  │ • 运行:{experiment}-{config-summary}                      │   │
│  │   例如:lr-sweep-1e4-warmup1k                              │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  3. 必记参数                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 模型配置:架构、参数量、hidden_size 等                   │   │
│  │ • 训练配置:learning_rate, batch_size, epochs 等          │   │
│  │ • 数据配置:数据集、预处理方式、数据增强                   │   │
│  │ • 环境配置:框架版本、GPU 类型、随机种子                   │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  4. 必记指标                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 训练指标:loss, gradient_norm, learning_rate            │   │
│  │ • 验证指标:eval_loss, perplexity, accuracy               │   │
│  │ • 资源指标:gpu_memory, throughput, step_time             │   │
│  │ • 系统指标:gpu_utilization, memory_usage                 │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  5. 记录频率建议                                                 │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 训练 loss:每 10-100 步                                  │   │
│  │ • 学习率:每步或每 100 步                                  │   │
│  │ • 验证指标:每 epoch 或每 1000 步                          │   │
│  │ • 检查点:每 epoch 或每 5000 步                            │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

小结

本章详细介绍了 AI 训练平台的实验管理系统:

  1. 核心概念:项目、实验、运行的层次结构和数据模型
  2. 追踪系统:Go 和 Python SDK 的实现,支持参数、指标、产物追踪
  3. 服务架构:实验服务、指标存储、产物管理的设计
  4. 比较分析:实验比较、参数重要性分析的实现
  5. 可视化:指标图表、平行坐标图、等高线图的数据生成
  6. Kubernetes 集成:实验 CRD 和控制器实现

下一章我们将探讨 超参数优化,讲解如何自动化地搜索最优超参数配置。

Prev
模型存储与管理
Next
超参数优化