实验管理
概述
机器学习实验管理是 AI 平台的核心能力之一。一个典型的大模型训练项目可能涉及数百次实验,每次实验都有不同的超参数配置、数据集版本、代码版本。如何系统化地追踪、比较和复现这些实验,是提升研发效率的关键。本章深入讲解实验管理系统的设计与实现。
实验管理核心概念
实验层次结构
┌─────────────────────────────────────────────────────────────────┐
│ 实验管理层次结构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Project │ │
│ │ (项目:如 LLM-7B-Training) │ │
│ │ ┌───────────────────────────────────────────────────┐ │ │
│ │ │ Experiment │ │ │
│ │ │ (实验:如 learning-rate-sweep) │ │ │
│ │ │ ┌─────────────────────────────────────────────┐ │ │ │
│ │ │ │ Run │ │ │ │
│ │ │ │ (运行:如 run-20241201-001) │ │ │ │
│ │ │ │ │ │ │ │
│ │ │ │ • Parameters (超参数) │ │ │ │
│ │ │ │ • Metrics (指标) │ │ │ │
│ │ │ │ • Artifacts (产物) │ │ │ │
│ │ │ │ • Source (代码版本) │ │ │ │
│ │ │ │ • Environment (环境) │ │ │ │
│ │ │ │ │ │ │ │
│ │ │ └─────────────────────────────────────────────┘ │ │ │
│ │ └───────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 关键关系: │
│ • 一个 Project 包含多个 Experiment │
│ • 一个 Experiment 包含多个 Run │
│ • 每个 Run 记录完整的实验上下文 │
│ │
└─────────────────────────────────────────────────────────────────┘
核心数据模型
// experiment_models.go
package experiment
import (
"time"
)
// Project 项目
type Project struct {
ID string `json:"id" db:"id"`
Name string `json:"name" db:"name"`
Description string `json:"description" db:"description"`
Owner string `json:"owner" db:"owner"`
Team string `json:"team" db:"team"`
Tags []string `json:"tags" db:"tags"`
Settings ProjectSettings `json:"settings" db:"settings"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// ProjectSettings 项目设置
type ProjectSettings struct {
// 默认实验配置
DefaultExperimentConfig map[string]interface{} `json:"default_experiment_config"`
// 指标可视化配置
MetricDisplayConfig MetricDisplayConfig `json:"metric_display_config"`
// 归档策略
ArchivePolicy ArchivePolicy `json:"archive_policy"`
}
// MetricDisplayConfig 指标显示配置
type MetricDisplayConfig struct {
PrimaryMetrics []string `json:"primary_metrics"` // 主要指标
HigherIsBetter []string `json:"higher_is_better"` // 越高越好的指标
ChartTypes map[string]string `json:"chart_types"` // 图表类型
}
// ArchivePolicy 归档策略
type ArchivePolicy struct {
AutoArchiveDays int `json:"auto_archive_days"` // 自动归档天数
KeepBestRuns int `json:"keep_best_runs"` // 保留最佳运行数
KeepRecentRuns int `json:"keep_recent_runs"` // 保留最近运行数
}
// Experiment 实验
type Experiment struct {
ID string `json:"id" db:"id"`
ProjectID string `json:"project_id" db:"project_id"`
Name string `json:"name" db:"name"`
Description string `json:"description" db:"description"`
Hypothesis string `json:"hypothesis" db:"hypothesis"` // 实验假设
Status ExperimentStatus `json:"status" db:"status"`
Tags []string `json:"tags" db:"tags"`
CreatedBy string `json:"created_by" db:"created_by"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
ArchivedAt *time.Time `json:"archived_at" db:"archived_at"`
}
// ExperimentStatus 实验状态
type ExperimentStatus string
const (
ExperimentStatusActive ExperimentStatus = "active"
ExperimentStatusCompleted ExperimentStatus = "completed"
ExperimentStatusArchived ExperimentStatus = "archived"
)
// Run 运行
type Run struct {
ID string `json:"id" db:"id"`
ExperimentID string `json:"experiment_id" db:"experiment_id"`
Name string `json:"name" db:"name"`
Status RunStatus `json:"status" db:"status"`
// 来源追踪
Source RunSource `json:"source" db:"source"`
// 配置
Parameters Parameters `json:"parameters" db:"parameters"`
Environment Environment `json:"environment" db:"environment"`
// 运行时信息
StartTime *time.Time `json:"start_time" db:"start_time"`
EndTime *time.Time `json:"end_time" db:"end_time"`
Duration int64 `json:"duration" db:"duration"` // 秒
// 资源使用
Resources ResourceUsage `json:"resources" db:"resources"`
// 元信息
Tags []string `json:"tags" db:"tags"`
Notes string `json:"notes" db:"notes"`
CreatedBy string `json:"created_by" db:"created_by"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// RunStatus 运行状态
type RunStatus string
const (
RunStatusPending RunStatus = "pending"
RunStatusRunning RunStatus = "running"
RunStatusCompleted RunStatus = "completed"
RunStatusFailed RunStatus = "failed"
RunStatusKilled RunStatus = "killed"
)
// RunSource 运行来源
type RunSource struct {
// Git 信息
GitRepoURL string `json:"git_repo_url"`
GitCommit string `json:"git_commit"`
GitBranch string `json:"git_branch"`
GitTag string `json:"git_tag"`
GitDirty bool `json:"git_dirty"` // 是否有未提交更改
// 代码入口
EntryPoint string `json:"entry_point"` // 执行脚本
SourceType string `json:"source_type"` // "local", "git", "notebook"
// 代码快照(小项目)
CodeSnapshot string `json:"code_snapshot"` // 代码归档路径
}
// Parameters 参数
type Parameters struct {
// 模型参数
ModelConfig map[string]interface{} `json:"model_config"`
// 训练参数
TrainingConfig map[string]interface{} `json:"training_config"`
// 数据参数
DataConfig map[string]interface{} `json:"data_config"`
// 原始参数(全部展开)
Flat map[string]interface{} `json:"flat"`
}
// Environment 环境信息
type Environment struct {
// Python 环境
PythonVersion string `json:"python_version"`
PipPackages map[string]string `json:"pip_packages"`
CondaEnvironment string `json:"conda_environment"`
// 硬件信息
GPUType string `json:"gpu_type"`
GPUCount int `json:"gpu_count"`
GPUMemory int64 `json:"gpu_memory"`
CPUCount int `json:"cpu_count"`
MemoryGB int `json:"memory_gb"`
// 系统信息
Platform string `json:"platform"`
Hostname string `json:"hostname"`
DockerImage string `json:"docker_image"`
// 框架版本
FrameworkVersions map[string]string `json:"framework_versions"`
}
// ResourceUsage 资源使用
type ResourceUsage struct {
GPUHours float64 `json:"gpu_hours"`
CPUHours float64 `json:"cpu_hours"`
MemoryGBHours float64 `json:"memory_gb_hours"`
StorageGB float64 `json:"storage_gb"`
Cost float64 `json:"cost"` // 预估成本
}
// Metric 指标
type Metric struct {
RunID string `json:"run_id" db:"run_id"`
Key string `json:"key" db:"key"`
Value float64 `json:"value" db:"value"`
Step int64 `json:"step" db:"step"`
Timestamp time.Time `json:"timestamp" db:"timestamp"`
Context string `json:"context" db:"context"` // train, eval, test
}
// Artifact 产物
type Artifact struct {
ID string `json:"id" db:"id"`
RunID string `json:"run_id" db:"run_id"`
Name string `json:"name" db:"name"`
Type string `json:"type" db:"type"` // model, data, figure, etc.
Path string `json:"path" db:"path"`
Size int64 `json:"size" db:"size"`
Checksum string `json:"checksum" db:"checksum"`
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
实验追踪系统
追踪客户端
// tracking_client.go
package experiment
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"runtime"
"strings"
"sync"
"time"
)
// TrackingClient 追踪客户端
type TrackingClient struct {
serverURL string
httpClient *http.Client
// 当前运行上下文
activeRun *Run
runMu sync.RWMutex
// 批量写入缓冲
metricBuffer []*Metric
bufferMu sync.Mutex
flushTicker *time.Ticker
// 配置
config ClientConfig
}
// ClientConfig 客户端配置
type ClientConfig struct {
BatchSize int // 批量大小
FlushInterval time.Duration // 刷新间隔
AutoLogEnv bool // 自动记录环境
AutoLogGit bool // 自动记录 Git 信息
}
// NewTrackingClient 创建追踪客户端
func NewTrackingClient(serverURL string, config ClientConfig) *TrackingClient {
client := &TrackingClient{
serverURL: serverURL,
httpClient: &http.Client{Timeout: 30 * time.Second},
metricBuffer: make([]*Metric, 0, config.BatchSize),
config: config,
}
// 启动定时刷新
client.flushTicker = time.NewTicker(config.FlushInterval)
go client.flushLoop()
return client
}
// StartRun 开始运行
func (c *TrackingClient) StartRun(
ctx context.Context,
experimentID string,
runName string,
params map[string]interface{},
) (*Run, error) {
// 创建运行
run := &Run{
ExperimentID: experimentID,
Name: runName,
Status: RunStatusRunning,
Parameters: Parameters{
Flat: params,
},
CreatedAt: time.Now(),
}
now := time.Now()
run.StartTime = &now
// 自动收集信息
if c.config.AutoLogEnv {
run.Environment = c.collectEnvironment()
}
if c.config.AutoLogGit {
run.Source = c.collectGitInfo()
}
// 发送到服务器
resp, err := c.post(ctx, "/api/v1/runs", run)
if err != nil {
return nil, fmt.Errorf("create run: %w", err)
}
if err := json.NewDecoder(resp.Body).Decode(run); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
// 设置为活跃运行
c.runMu.Lock()
c.activeRun = run
c.runMu.Unlock()
return run, nil
}
// LogParam 记录参数
func (c *TrackingClient) LogParam(key string, value interface{}) error {
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return fmt.Errorf("no active run")
}
return c.LogParams(map[string]interface{}{key: value})
}
// LogParams 批量记录参数
func (c *TrackingClient) LogParams(params map[string]interface{}) error {
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return fmt.Errorf("no active run")
}
_, err := c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/params", run.ID), params)
return err
}
// LogMetric 记录指标
func (c *TrackingClient) LogMetric(key string, value float64, step int64) error {
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return fmt.Errorf("no active run")
}
metric := &Metric{
RunID: run.ID,
Key: key,
Value: value,
Step: step,
Timestamp: time.Now(),
}
// 添加到缓冲
c.bufferMu.Lock()
c.metricBuffer = append(c.metricBuffer, metric)
// 如果达到批量大小,立即刷新
if len(c.metricBuffer) >= c.config.BatchSize {
go c.flush()
}
c.bufferMu.Unlock()
return nil
}
// LogMetrics 批量记录指标
func (c *TrackingClient) LogMetrics(metrics map[string]float64, step int64) error {
for key, value := range metrics {
if err := c.LogMetric(key, value, step); err != nil {
return err
}
}
return nil
}
// LogArtifact 记录产物
func (c *TrackingClient) LogArtifact(localPath string, artifactPath string) error {
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return fmt.Errorf("no active run")
}
// 读取文件
data, err := os.ReadFile(localPath)
if err != nil {
return fmt.Errorf("read file: %w", err)
}
// 上传
url := fmt.Sprintf("%s/api/v1/runs/%s/artifacts?path=%s", c.serverURL, run.ID, artifactPath)
req, err := http.NewRequest("PUT", url, bytes.NewReader(data))
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("upload failed: %s", resp.Status)
}
return nil
}
// LogModel 记录模型
func (c *TrackingClient) LogModel(modelPath string, modelName string, metadata map[string]interface{}) error {
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return fmt.Errorf("no active run")
}
// 创建模型产物记录
artifact := &Artifact{
RunID: run.ID,
Name: modelName,
Type: "model",
Path: modelPath,
Metadata: metadata,
}
_, err := c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/models", run.ID), artifact)
return err
}
// SetTag 设置标签
func (c *TrackingClient) SetTag(key, value string) error {
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return fmt.Errorf("no active run")
}
_, err := c.post(context.Background(),
fmt.Sprintf("/api/v1/runs/%s/tags", run.ID),
map[string]string{key: value})
return err
}
// EndRun 结束运行
func (c *TrackingClient) EndRun(status RunStatus) error {
c.runMu.Lock()
run := c.activeRun
c.activeRun = nil
c.runMu.Unlock()
if run == nil {
return fmt.Errorf("no active run")
}
// 刷新剩余指标
c.flush()
// 更新运行状态
now := time.Now()
run.EndTime = &now
run.Status = status
run.Duration = int64(now.Sub(*run.StartTime).Seconds())
_, err := c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/end", run.ID), run)
return err
}
// flush 刷新指标缓冲
func (c *TrackingClient) flush() {
c.bufferMu.Lock()
if len(c.metricBuffer) == 0 {
c.bufferMu.Unlock()
return
}
metrics := c.metricBuffer
c.metricBuffer = make([]*Metric, 0, c.config.BatchSize)
c.bufferMu.Unlock()
c.runMu.RLock()
run := c.activeRun
c.runMu.RUnlock()
if run == nil {
return
}
// 批量发送
c.post(context.Background(), fmt.Sprintf("/api/v1/runs/%s/metrics/batch", run.ID), metrics)
}
func (c *TrackingClient) flushLoop() {
for range c.flushTicker.C {
c.flush()
}
}
// collectEnvironment 收集环境信息
func (c *TrackingClient) collectEnvironment() Environment {
env := Environment{
PythonVersion: getPythonVersion(),
PipPackages: getPipPackages(),
Platform: runtime.GOOS,
CPUCount: runtime.NumCPU(),
FrameworkVersions: make(map[string]string),
}
// 获取 GPU 信息
if gpuInfo := getGPUInfo(); gpuInfo != nil {
env.GPUType = gpuInfo.Type
env.GPUCount = gpuInfo.Count
env.GPUMemory = gpuInfo.Memory
}
// 获取主机名
env.Hostname, _ = os.Hostname()
// 获取 Docker 镜像(如果在容器中)
env.DockerImage = os.Getenv("DOCKER_IMAGE")
return env
}
// collectGitInfo 收集 Git 信息
func (c *TrackingClient) collectGitInfo() RunSource {
source := RunSource{
SourceType: "git",
}
// 获取远程仓库 URL
if out, err := exec.Command("git", "remote", "get-url", "origin").Output(); err == nil {
source.GitRepoURL = strings.TrimSpace(string(out))
}
// 获取当前提交
if out, err := exec.Command("git", "rev-parse", "HEAD").Output(); err == nil {
source.GitCommit = strings.TrimSpace(string(out))
}
// 获取当前分支
if out, err := exec.Command("git", "rev-parse", "--abbrev-ref", "HEAD").Output(); err == nil {
source.GitBranch = strings.TrimSpace(string(out))
}
// 检查是否有未提交更改
if out, err := exec.Command("git", "status", "--porcelain").Output(); err == nil {
source.GitDirty = len(out) > 0
}
return source
}
func (c *TrackingClient) post(ctx context.Context, path string, body interface{}) (*http.Response, error) {
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", c.serverURL+path, bytes.NewReader(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return c.httpClient.Do(req)
}
// 辅助函数
func getPythonVersion() string {
out, err := exec.Command("python", "--version").Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
func getPipPackages() map[string]string {
packages := make(map[string]string)
out, err := exec.Command("pip", "freeze").Output()
if err != nil {
return packages
}
for _, line := range strings.Split(string(out), "\n") {
parts := strings.Split(line, "==")
if len(parts) == 2 {
packages[parts[0]] = parts[1]
}
}
return packages
}
type GPUInfo struct {
Type string
Count int
Memory int64
}
func getGPUInfo() *GPUInfo {
// 使用 nvidia-smi 获取 GPU 信息
out, err := exec.Command("nvidia-smi", "--query-gpu=name,count,memory.total", "--format=csv,noheader,nounits").Output()
if err != nil {
return nil
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) == 0 {
return nil
}
// 解析第一行
parts := strings.Split(lines[0], ", ")
if len(parts) < 3 {
return nil
}
info := &GPUInfo{
Type: parts[0],
Count: len(lines),
}
fmt.Sscanf(parts[2], "%d", &info.Memory)
return info
}
Python SDK
# experiment_sdk.py
import os
import json
import time
import hashlib
import threading
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
import requests
@dataclass
class RunConfig:
"""运行配置"""
experiment_id: str
run_name: Optional[str] = None
tags: Dict[str, str] = field(default_factory=dict)
description: str = ""
class ExperimentTracker:
"""实验追踪器"""
def __init__(
self,
tracking_uri: str,
batch_size: int = 100,
flush_interval: float = 5.0,
auto_log_env: bool = True,
auto_log_git: bool = True,
):
self.tracking_uri = tracking_uri
self.batch_size = batch_size
self.flush_interval = flush_interval
self.auto_log_env = auto_log_env
self.auto_log_git = auto_log_git
self._active_run: Optional[Dict] = None
self._metric_buffer: List[Dict] = []
self._buffer_lock = threading.Lock()
self._flush_thread: Optional[threading.Thread] = None
self._stop_flush = threading.Event()
def start_run(
self,
config: RunConfig,
params: Optional[Dict[str, Any]] = None,
) -> Dict:
"""开始运行"""
run_data = {
"experiment_id": config.experiment_id,
"name": config.run_name or f"run-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
"status": "running",
"tags": config.tags,
"description": config.description,
"start_time": datetime.now().isoformat(),
}
# 自动收集环境信息
if self.auto_log_env:
run_data["environment"] = self._collect_environment()
# 自动收集 Git 信息
if self.auto_log_git:
run_data["source"] = self._collect_git_info()
# 创建运行
response = requests.post(
f"{self.tracking_uri}/api/v1/runs",
json=run_data,
)
response.raise_for_status()
self._active_run = response.json()
# 记录参数
if params:
self.log_params(params)
# 启动后台刷新线程
self._start_flush_thread()
return self._active_run
def log_param(self, key: str, value: Any) -> None:
"""记录单个参数"""
self.log_params({key: value})
def log_params(self, params: Dict[str, Any]) -> None:
"""记录多个参数"""
if not self._active_run:
raise RuntimeError("No active run")
# 扁平化嵌套参数
flat_params = self._flatten_params(params)
requests.post(
f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/params",
json=flat_params,
)
def log_metric(
self,
key: str,
value: float,
step: Optional[int] = None,
context: str = "train",
) -> None:
"""记录指标"""
if not self._active_run:
raise RuntimeError("No active run")
metric = {
"run_id": self._active_run["id"],
"key": key,
"value": float(value),
"step": step or 0,
"timestamp": datetime.now().isoformat(),
"context": context,
}
with self._buffer_lock:
self._metric_buffer.append(metric)
if len(self._metric_buffer) >= self.batch_size:
self._flush_metrics()
def log_metrics(
self,
metrics: Dict[str, float],
step: Optional[int] = None,
context: str = "train",
) -> None:
"""记录多个指标"""
for key, value in metrics.items():
self.log_metric(key, value, step, context)
def log_artifact(
self,
local_path: str,
artifact_path: Optional[str] = None,
artifact_type: str = "file",
) -> None:
"""记录产物"""
if not self._active_run:
raise RuntimeError("No active run")
artifact_path = artifact_path or os.path.basename(local_path)
with open(local_path, "rb") as f:
files = {"file": (artifact_path, f)}
requests.post(
f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/artifacts",
files=files,
data={"type": artifact_type},
)
def log_model(
self,
model: Any,
model_name: str,
signature: Optional[Dict] = None,
metadata: Optional[Dict] = None,
) -> None:
"""记录模型"""
if not self._active_run:
raise RuntimeError("No active run")
# 保存模型到临时文件
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
model_path = os.path.join(tmpdir, model_name)
# 根据模型类型保存
if hasattr(model, "save_pretrained"):
# HuggingFace 模型
model.save_pretrained(model_path)
elif hasattr(model, "state_dict"):
# PyTorch 模型
import torch
torch.save(model.state_dict(), model_path)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
# 上传模型
self._upload_directory(model_path, f"models/{model_name}")
# 记录模型元数据
model_info = {
"name": model_name,
"signature": signature,
"metadata": metadata or {},
"framework": self._detect_framework(model),
}
requests.post(
f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/models",
json=model_info,
)
def log_figure(
self,
figure: Any,
name: str,
format: str = "png",
) -> None:
"""记录图表"""
import tempfile
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as f:
if hasattr(figure, "savefig"):
# Matplotlib figure
figure.savefig(f.name, format=format, bbox_inches="tight")
else:
raise ValueError(f"Unsupported figure type: {type(figure)}")
self.log_artifact(f.name, f"figures/{name}.{format}", "figure")
os.unlink(f.name)
def set_tag(self, key: str, value: str) -> None:
"""设置标签"""
if not self._active_run:
raise RuntimeError("No active run")
requests.post(
f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/tags",
json={key: value},
)
def end_run(self, status: str = "completed") -> None:
"""结束运行"""
if not self._active_run:
return
# 停止刷新线程
self._stop_flush.set()
if self._flush_thread:
self._flush_thread.join()
# 刷新剩余指标
self._flush_metrics()
# 更新运行状态
requests.post(
f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/end",
json={
"status": status,
"end_time": datetime.now().isoformat(),
},
)
self._active_run = None
def _flush_metrics(self) -> None:
"""刷新指标缓冲"""
with self._buffer_lock:
if not self._metric_buffer:
return
metrics = self._metric_buffer
self._metric_buffer = []
if not self._active_run:
return
try:
requests.post(
f"{self.tracking_uri}/api/v1/runs/{self._active_run['id']}/metrics/batch",
json=metrics,
)
except Exception as e:
print(f"Failed to flush metrics: {e}")
def _start_flush_thread(self) -> None:
"""启动后台刷新线程"""
self._stop_flush.clear()
def flush_loop():
while not self._stop_flush.wait(self.flush_interval):
self._flush_metrics()
self._flush_thread = threading.Thread(target=flush_loop, daemon=True)
self._flush_thread.start()
def _flatten_params(
self,
params: Dict[str, Any],
prefix: str = "",
) -> Dict[str, Any]:
"""扁平化嵌套参数"""
flat = {}
for key, value in params.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, dict):
flat.update(self._flatten_params(value, full_key))
else:
flat[full_key] = value
return flat
def _collect_environment(self) -> Dict:
"""收集环境信息"""
import sys
import platform
env = {
"python_version": sys.version,
"platform": platform.platform(),
"hostname": platform.node(),
}
# 收集 pip 包
try:
import pkg_resources
env["pip_packages"] = {
pkg.key: pkg.version
for pkg in pkg_resources.working_set
}
except:
pass
# 收集 GPU 信息
try:
import torch
if torch.cuda.is_available():
env["gpu_count"] = torch.cuda.device_count()
env["gpu_type"] = torch.cuda.get_device_name(0)
env["cuda_version"] = torch.version.cuda
except:
pass
return env
def _collect_git_info(self) -> Dict:
"""收集 Git 信息"""
import subprocess
info = {"source_type": "git"}
try:
info["git_repo_url"] = subprocess.check_output(
["git", "remote", "get-url", "origin"],
stderr=subprocess.DEVNULL,
).decode().strip()
except:
pass
try:
info["git_commit"] = subprocess.check_output(
["git", "rev-parse", "HEAD"],
stderr=subprocess.DEVNULL,
).decode().strip()
except:
pass
try:
info["git_branch"] = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
stderr=subprocess.DEVNULL,
).decode().strip()
except:
pass
try:
status = subprocess.check_output(
["git", "status", "--porcelain"],
stderr=subprocess.DEVNULL,
).decode().strip()
info["git_dirty"] = len(status) > 0
except:
pass
return info
def _detect_framework(self, model: Any) -> str:
"""检测模型框架"""
module = type(model).__module__
if "torch" in module:
return "pytorch"
elif "tensorflow" in module or "keras" in module:
return "tensorflow"
elif "transformers" in module:
return "transformers"
return "unknown"
def _upload_directory(self, local_dir: str, artifact_dir: str) -> None:
"""上传目录"""
for root, dirs, files in os.walk(local_dir):
for file in files:
local_path = os.path.join(root, file)
rel_path = os.path.relpath(local_path, local_dir)
artifact_path = os.path.join(artifact_dir, rel_path)
self.log_artifact(local_path, artifact_path)
# 上下文管理器
class ExperimentRun:
"""实验运行上下文管理器"""
def __init__(
self,
tracker: ExperimentTracker,
config: RunConfig,
params: Optional[Dict[str, Any]] = None,
):
self.tracker = tracker
self.config = config
self.params = params
self.run: Optional[Dict] = None
def __enter__(self) -> "ExperimentRun":
self.run = self.tracker.start_run(self.config, self.params)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
status = "failed" if exc_type else "completed"
self.tracker.end_run(status)
def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None:
self.tracker.log_metric(key, value, step)
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
self.tracker.log_metrics(metrics, step)
def log_param(self, key: str, value: Any) -> None:
self.tracker.log_param(key, value)
def log_params(self, params: Dict[str, Any]) -> None:
self.tracker.log_params(params)
# 全局 tracker 实例
_global_tracker: Optional[ExperimentTracker] = None
def init(tracking_uri: str, **kwargs) -> ExperimentTracker:
"""初始化全局 tracker"""
global _global_tracker
_global_tracker = ExperimentTracker(tracking_uri, **kwargs)
return _global_tracker
def start_run(
experiment_id: str,
run_name: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict:
"""开始运行"""
if not _global_tracker:
raise RuntimeError("Tracker not initialized. Call init() first.")
config = RunConfig(experiment_id=experiment_id, run_name=run_name, **kwargs)
return _global_tracker.start_run(config, params)
def log_param(key: str, value: Any) -> None:
"""记录参数"""
if _global_tracker:
_global_tracker.log_param(key, value)
def log_params(params: Dict[str, Any]) -> None:
"""记录多个参数"""
if _global_tracker:
_global_tracker.log_params(params)
def log_metric(key: str, value: float, step: Optional[int] = None) -> None:
"""记录指标"""
if _global_tracker:
_global_tracker.log_metric(key, value, step)
def log_metrics(metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""记录多个指标"""
if _global_tracker:
_global_tracker.log_metrics(metrics, step)
def log_artifact(local_path: str, artifact_path: Optional[str] = None) -> None:
"""记录产物"""
if _global_tracker:
_global_tracker.log_artifact(local_path, artifact_path)
def end_run(status: str = "completed") -> None:
"""结束运行"""
if _global_tracker:
_global_tracker.end_run(status)
使用示例
# training_with_tracking.py
import experiment_sdk as exp
# 初始化
exp.init("http://experiment-server:8080")
# 定义训练参数
params = {
"model": {
"name": "llama-7b",
"hidden_size": 4096,
"num_layers": 32,
"num_heads": 32,
},
"training": {
"learning_rate": 1e-4,
"batch_size": 32,
"epochs": 3,
"warmup_steps": 1000,
"weight_decay": 0.01,
},
"data": {
"dataset": "openwebtext",
"max_length": 2048,
},
}
# 开始运行
run = exp.start_run(
experiment_id="llm-pretraining",
run_name="lr-1e4-bs32",
params=params,
)
try:
# 训练循环
for epoch in range(params["training"]["epochs"]):
for step, batch in enumerate(train_dataloader):
# 训练步骤
loss = train_step(model, batch)
# 记录指标
if step % 100 == 0:
exp.log_metrics({
"train/loss": loss.item(),
"train/learning_rate": scheduler.get_last_lr()[0],
"train/epoch": epoch,
}, step=global_step)
# 评估
eval_metrics = evaluate(model, eval_dataloader)
exp.log_metrics({
"eval/loss": eval_metrics["loss"],
"eval/perplexity": eval_metrics["perplexity"],
}, step=global_step)
# 保存检查点
checkpoint_path = f"checkpoints/epoch_{epoch}"
save_checkpoint(model, checkpoint_path)
exp.log_artifact(checkpoint_path, f"checkpoints/epoch_{epoch}")
# 保存最终模型
exp.log_model(model, "final_model", metadata={"epochs": params["training"]["epochs"]})
exp.end_run("completed")
except Exception as e:
exp.log_param("error", str(e))
exp.end_run("failed")
raise
实验服务端
服务架构
┌─────────────────────────────────────────────────────────────────┐
│ 实验管理服务架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ API Gateway │ │
│ │ (REST + gRPC) │ │
│ └─────────────────────────┬───────────────────────────────┘ │
│ │ │
│ ┌─────────────┬───────────┴───────────┬─────────────┐ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Run │ │ Metric │ │Artifact │ │ Query │ │
│ │ Svc │ │ Service │ │ Service │ │ Service │ │
│ └──┬──┘ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │ │
│ ┌──▼───────────▼─────────────────────▼─────────────▼──┐ │
│ │ Data Layer │ │
│ ├──────────────────────────────────────────────────────┤ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │PostgreSQL│ │TimescaleDB│ │ S3 │ │ │
│ │ │(Metadata)│ │ (Metrics) │ │(Artifact)│ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
服务实现
// experiment_service.go
package service
import (
"context"
"fmt"
"time"
)
// ExperimentService 实验服务
type ExperimentService struct {
// 存储
projectRepo ProjectRepository
experimentRepo ExperimentRepository
runRepo RunRepository
metricRepo MetricRepository
artifactRepo ArtifactRepository
// 存储后端
objectStorage ObjectStorage
// 缓存
cache Cache
// 事件发布
eventPublisher EventPublisher
}
// NewExperimentService 创建服务
func NewExperimentService(
projectRepo ProjectRepository,
experimentRepo ExperimentRepository,
runRepo RunRepository,
metricRepo MetricRepository,
artifactRepo ArtifactRepository,
objectStorage ObjectStorage,
cache Cache,
eventPublisher EventPublisher,
) *ExperimentService {
return &ExperimentService{
projectRepo: projectRepo,
experimentRepo: experimentRepo,
runRepo: runRepo,
metricRepo: metricRepo,
artifactRepo: artifactRepo,
objectStorage: objectStorage,
cache: cache,
eventPublisher: eventPublisher,
}
}
// CreateProject 创建项目
func (s *ExperimentService) CreateProject(ctx context.Context, project *Project) error {
project.ID = generateID("proj")
project.CreatedAt = time.Now()
project.UpdatedAt = time.Now()
if err := s.projectRepo.Create(ctx, project); err != nil {
return fmt.Errorf("create project: %w", err)
}
s.eventPublisher.Publish(ctx, "project.created", project)
return nil
}
// CreateExperiment 创建实验
func (s *ExperimentService) CreateExperiment(ctx context.Context, experiment *Experiment) error {
// 验证项目存在
if _, err := s.projectRepo.Get(ctx, experiment.ProjectID); err != nil {
return fmt.Errorf("project not found: %w", err)
}
experiment.ID = generateID("exp")
experiment.Status = ExperimentStatusActive
experiment.CreatedAt = time.Now()
experiment.UpdatedAt = time.Now()
if err := s.experimentRepo.Create(ctx, experiment); err != nil {
return fmt.Errorf("create experiment: %w", err)
}
s.eventPublisher.Publish(ctx, "experiment.created", experiment)
return nil
}
// CreateRun 创建运行
func (s *ExperimentService) CreateRun(ctx context.Context, run *Run) error {
// 验证实验存在
if _, err := s.experimentRepo.Get(ctx, run.ExperimentID); err != nil {
return fmt.Errorf("experiment not found: %w", err)
}
run.ID = generateID("run")
run.Status = RunStatusRunning
run.CreatedAt = time.Now()
if err := s.runRepo.Create(ctx, run); err != nil {
return fmt.Errorf("create run: %w", err)
}
s.eventPublisher.Publish(ctx, "run.started", run)
return nil
}
// LogParams 记录参数
func (s *ExperimentService) LogParams(ctx context.Context, runID string, params map[string]interface{}) error {
run, err := s.runRepo.Get(ctx, runID)
if err != nil {
return fmt.Errorf("run not found: %w", err)
}
// 合并参数
if run.Parameters.Flat == nil {
run.Parameters.Flat = make(map[string]interface{})
}
for k, v := range params {
run.Parameters.Flat[k] = v
}
return s.runRepo.Update(ctx, run)
}
// LogMetrics 记录指标
func (s *ExperimentService) LogMetrics(ctx context.Context, metrics []*Metric) error {
if len(metrics) == 0 {
return nil
}
// 批量写入
if err := s.metricRepo.BatchCreate(ctx, metrics); err != nil {
return fmt.Errorf("batch create metrics: %w", err)
}
// 更新缓存中的最新指标
runID := metrics[0].RunID
latestMetrics := make(map[string]*Metric)
for _, m := range metrics {
key := fmt.Sprintf("%s:%s", m.Key, m.Context)
if existing, ok := latestMetrics[key]; !ok || m.Step > existing.Step {
latestMetrics[key] = m
}
}
for _, m := range latestMetrics {
cacheKey := fmt.Sprintf("run:%s:metric:%s:%s:latest", runID, m.Key, m.Context)
s.cache.Set(ctx, cacheKey, m, time.Hour)
}
return nil
}
// LogArtifact 记录产物
func (s *ExperimentService) LogArtifact(ctx context.Context, runID string, name string, data []byte, artifactType string) error {
run, err := s.runRepo.Get(ctx, runID)
if err != nil {
return fmt.Errorf("run not found: %w", err)
}
// 上传到对象存储
storagePath := fmt.Sprintf("runs/%s/artifacts/%s", runID, name)
if err := s.objectStorage.Upload(ctx, storagePath, data); err != nil {
return fmt.Errorf("upload artifact: %w", err)
}
// 计算校验和
checksum := calculateChecksum(data)
// 创建产物记录
artifact := &Artifact{
ID: generateID("art"),
RunID: run.ID,
Name: name,
Type: artifactType,
Path: storagePath,
Size: int64(len(data)),
Checksum: checksum,
CreatedAt: time.Now(),
}
if err := s.artifactRepo.Create(ctx, artifact); err != nil {
return fmt.Errorf("create artifact record: %w", err)
}
return nil
}
// EndRun 结束运行
func (s *ExperimentService) EndRun(ctx context.Context, runID string, status RunStatus) error {
run, err := s.runRepo.Get(ctx, runID)
if err != nil {
return fmt.Errorf("run not found: %w", err)
}
now := time.Now()
run.Status = status
run.EndTime = &now
if run.StartTime != nil {
run.Duration = int64(now.Sub(*run.StartTime).Seconds())
}
// 计算资源使用
run.Resources = s.calculateResourceUsage(ctx, run)
if err := s.runRepo.Update(ctx, run); err != nil {
return fmt.Errorf("update run: %w", err)
}
s.eventPublisher.Publish(ctx, "run.ended", run)
return nil
}
// GetRunMetrics 获取运行指标
func (s *ExperimentService) GetRunMetrics(
ctx context.Context,
runID string,
keys []string,
stepRange *StepRange,
) ([]*Metric, error) {
return s.metricRepo.Query(ctx, MetricQuery{
RunID: runID,
Keys: keys,
StepRange: stepRange,
})
}
// CompareRuns 比较运行
func (s *ExperimentService) CompareRuns(
ctx context.Context,
runIDs []string,
metricKeys []string,
) (*RunComparison, error) {
comparison := &RunComparison{
Runs: make([]*Run, 0, len(runIDs)),
Metrics: make(map[string]map[string]float64),
Params: make(map[string]map[string]interface{}),
}
for _, runID := range runIDs {
run, err := s.runRepo.Get(ctx, runID)
if err != nil {
return nil, err
}
comparison.Runs = append(comparison.Runs, run)
comparison.Params[runID] = run.Parameters.Flat
// 获取最终指标
for _, key := range metricKeys {
cacheKey := fmt.Sprintf("run:%s:metric:%s:train:latest", runID, key)
var metric *Metric
if err := s.cache.Get(ctx, cacheKey, &metric); err == nil && metric != nil {
if comparison.Metrics[key] == nil {
comparison.Metrics[key] = make(map[string]float64)
}
comparison.Metrics[key][runID] = metric.Value
}
}
}
return comparison, nil
}
// SearchRuns 搜索运行
func (s *ExperimentService) SearchRuns(
ctx context.Context,
query RunSearchQuery,
) ([]*Run, int64, error) {
return s.runRepo.Search(ctx, query)
}
// calculateResourceUsage 计算资源使用
func (s *ExperimentService) calculateResourceUsage(ctx context.Context, run *Run) ResourceUsage {
usage := ResourceUsage{}
if run.StartTime == nil || run.EndTime == nil {
return usage
}
duration := run.EndTime.Sub(*run.StartTime)
hours := duration.Hours()
// GPU 小时
usage.GPUHours = float64(run.Environment.GPUCount) * hours
// CPU 小时
usage.CPUHours = float64(run.Environment.CPUCount) * hours
// 内存 GB 小时
usage.MemoryGBHours = float64(run.Environment.MemoryGB) * hours
// 预估成本(示例价格)
gpuPrice := 2.0 // $/GPU/hour
cpuPrice := 0.05 // $/CPU/hour
memPrice := 0.01 // $/GB/hour
usage.Cost = usage.GPUHours*gpuPrice + usage.CPUHours*cpuPrice + usage.MemoryGBHours*memPrice
return usage
}
// RunSearchQuery 运行搜索查询
type RunSearchQuery struct {
ExperimentID string
Status []RunStatus
Tags map[string]string
ParamFilter map[string]interface{}
MetricFilter *MetricFilter
TimeRange *TimeRange
OrderBy string
Limit int
Offset int
}
// MetricFilter 指标过滤
type MetricFilter struct {
Key string
Operator string // ">", "<", ">=", "<=", "="
Value float64
}
// TimeRange 时间范围
type TimeRange struct {
Start time.Time
End time.Time
}
// StepRange 步数范围
type StepRange struct {
Start int64
End int64
}
// RunComparison 运行比较结果
type RunComparison struct {
Runs []*Run
Metrics map[string]map[string]float64 // metric -> runID -> value
Params map[string]map[string]interface{} // runID -> params
}
func generateID(prefix string) string {
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
func calculateChecksum(data []byte) string {
// MD5 或 SHA256
return ""
}
指标存储优化
// metric_repository.go
package repository
import (
"context"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// TimescaleMetricRepository 使用 TimescaleDB 的指标存储
type TimescaleMetricRepository struct {
pool *pgxpool.Pool
}
func NewTimescaleMetricRepository(pool *pgxpool.Pool) *TimescaleMetricRepository {
return &TimescaleMetricRepository{pool: pool}
}
// 初始化表结构
func (r *TimescaleMetricRepository) Init(ctx context.Context) error {
// 创建 hypertable
queries := []string{
`CREATE TABLE IF NOT EXISTS metrics (
run_id TEXT NOT NULL,
key TEXT NOT NULL,
value DOUBLE PRECISION NOT NULL,
step BIGINT NOT NULL,
context TEXT DEFAULT 'train',
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
)`,
// 转换为 hypertable(TimescaleDB 特性)
`SELECT create_hypertable('metrics', 'timestamp',
if_not_exists => TRUE,
chunk_time_interval => INTERVAL '1 day'
)`,
// 创建索引
`CREATE INDEX IF NOT EXISTS idx_metrics_run_id ON metrics (run_id)`,
`CREATE INDEX IF NOT EXISTS idx_metrics_run_key ON metrics (run_id, key)`,
`CREATE INDEX IF NOT EXISTS idx_metrics_run_key_step ON metrics (run_id, key, step DESC)`,
// 启用压缩(旧数据压缩)
`ALTER TABLE metrics SET (
timescaledb.compress,
timescaledb.compress_segmentby = 'run_id,key'
)`,
// 自动压缩策略(7天后压缩)
`SELECT add_compression_policy('metrics', INTERVAL '7 days', if_not_exists => TRUE)`,
// 数据保留策略(90天后删除)
`SELECT add_retention_policy('metrics', INTERVAL '90 days', if_not_exists => TRUE)`,
}
for _, q := range queries {
if _, err := r.pool.Exec(ctx, q); err != nil {
// 忽略某些预期错误
if !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("init query failed: %w", err)
}
}
}
return nil
}
// BatchCreate 批量创建指标
func (r *TimescaleMetricRepository) BatchCreate(ctx context.Context, metrics []*Metric) error {
if len(metrics) == 0 {
return nil
}
// 使用 COPY 协议进行高效批量插入
query := `INSERT INTO metrics (run_id, key, value, step, context, timestamp) VALUES `
values := make([]interface{}, 0, len(metrics)*6)
placeholders := make([]string, 0, len(metrics))
for i, m := range metrics {
base := i * 6
placeholders = append(placeholders,
fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d)",
base+1, base+2, base+3, base+4, base+5, base+6))
values = append(values, m.RunID, m.Key, m.Value, m.Step, m.Context, m.Timestamp)
}
query += strings.Join(placeholders, ", ")
_, err := r.pool.Exec(ctx, query, values...)
return err
}
// Query 查询指标
func (r *TimescaleMetricRepository) Query(ctx context.Context, query MetricQuery) ([]*Metric, error) {
var conditions []string
var args []interface{}
argIdx := 1
conditions = append(conditions, fmt.Sprintf("run_id = $%d", argIdx))
args = append(args, query.RunID)
argIdx++
if len(query.Keys) > 0 {
placeholders := make([]string, len(query.Keys))
for i, key := range query.Keys {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, key)
argIdx++
}
conditions = append(conditions, fmt.Sprintf("key IN (%s)", strings.Join(placeholders, ",")))
}
if query.StepRange != nil {
conditions = append(conditions, fmt.Sprintf("step >= $%d AND step <= $%d", argIdx, argIdx+1))
args = append(args, query.StepRange.Start, query.StepRange.End)
argIdx += 2
}
sql := fmt.Sprintf(`
SELECT run_id, key, value, step, context, timestamp
FROM metrics
WHERE %s
ORDER BY step ASC
`, strings.Join(conditions, " AND "))
rows, err := r.pool.Query(ctx, sql, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var metrics []*Metric
for rows.Next() {
m := &Metric{}
if err := rows.Scan(&m.RunID, &m.Key, &m.Value, &m.Step, &m.Context, &m.Timestamp); err != nil {
return nil, err
}
metrics = append(metrics, m)
}
return metrics, nil
}
// GetAggregated 获取聚合指标
func (r *TimescaleMetricRepository) GetAggregated(
ctx context.Context,
runID string,
key string,
aggregation string, // min, max, avg, last
bucketSize int64, // 步数间隔
) ([]*AggregatedMetric, error) {
var aggFunc string
switch aggregation {
case "min":
aggFunc = "MIN(value)"
case "max":
aggFunc = "MAX(value)"
case "avg":
aggFunc = "AVG(value)"
case "last":
aggFunc = "last(value, step)"
default:
aggFunc = "AVG(value)"
}
sql := fmt.Sprintf(`
SELECT
(step / $3) * $3 as bucket_step,
%s as value,
COUNT(*) as count
FROM metrics
WHERE run_id = $1 AND key = $2
GROUP BY bucket_step
ORDER BY bucket_step ASC
`, aggFunc)
rows, err := r.pool.Query(ctx, sql, runID, key, bucketSize)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*AggregatedMetric
for rows.Next() {
m := &AggregatedMetric{}
if err := rows.Scan(&m.Step, &m.Value, &m.Count); err != nil {
return nil, err
}
results = append(results, m)
}
return results, nil
}
// AggregatedMetric 聚合指标
type AggregatedMetric struct {
Step int64
Value float64
Count int64
}
// MetricQuery 指标查询
type MetricQuery struct {
RunID string
Keys []string
StepRange *StepRange
Context string
}
实验比较与分析
比较分析服务
// analysis_service.go
package service
import (
"context"
"math"
"sort"
)
// AnalysisService 分析服务
type AnalysisService struct {
experimentService *ExperimentService
metricRepo MetricRepository
}
// CompareExperiments 比较实验
func (s *AnalysisService) CompareExperiments(
ctx context.Context,
experimentIDs []string,
metricKeys []string,
) (*ExperimentComparison, error) {
comparison := &ExperimentComparison{
Experiments: make([]*ExperimentSummary, 0, len(experimentIDs)),
}
for _, expID := range experimentIDs {
summary, err := s.getExperimentSummary(ctx, expID, metricKeys)
if err != nil {
return nil, err
}
comparison.Experiments = append(comparison.Experiments, summary)
}
// 计算统计信息
comparison.Statistics = s.calculateComparisonStatistics(comparison.Experiments, metricKeys)
return comparison, nil
}
// ExperimentComparison 实验比较结果
type ExperimentComparison struct {
Experiments []*ExperimentSummary
Statistics map[string]*MetricStatistics
}
// ExperimentSummary 实验摘要
type ExperimentSummary struct {
Experiment *Experiment
BestRun *Run
RunCount int
Metrics map[string]*MetricSummary
}
// MetricSummary 指标摘要
type MetricSummary struct {
Best float64
Worst float64
Mean float64
Std float64
Median float64
BestRun string
}
// MetricStatistics 指标统计
type MetricStatistics struct {
BestExperiment string
BestValue float64
Improvement float64 // 相对于基准的提升
Significance float64 // 统计显著性
}
func (s *AnalysisService) getExperimentSummary(
ctx context.Context,
experimentID string,
metricKeys []string,
) (*ExperimentSummary, error) {
// 获取实验
exp, err := s.experimentService.experimentRepo.Get(ctx, experimentID)
if err != nil {
return nil, err
}
// 获取所有运行
runs, _, err := s.experimentService.SearchRuns(ctx, RunSearchQuery{
ExperimentID: experimentID,
Status: []RunStatus{RunStatusCompleted},
})
if err != nil {
return nil, err
}
summary := &ExperimentSummary{
Experiment: exp,
RunCount: len(runs),
Metrics: make(map[string]*MetricSummary),
}
// 计算每个指标的统计信息
for _, key := range metricKeys {
values := make([]float64, 0, len(runs))
var bestRun *Run
var bestValue float64 = math.Inf(-1)
for _, run := range runs {
metrics, err := s.metricRepo.Query(ctx, MetricQuery{
RunID: run.ID,
Keys: []string{key},
})
if err != nil || len(metrics) == 0 {
continue
}
// 取最后一个值
lastMetric := metrics[len(metrics)-1]
values = append(values, lastMetric.Value)
if lastMetric.Value > bestValue {
bestValue = lastMetric.Value
bestRun = run
}
}
if len(values) > 0 {
summary.Metrics[key] = &MetricSummary{
Best: maxFloat(values),
Worst: minFloat(values),
Mean: meanFloat(values),
Std: stdFloat(values),
Median: medianFloat(values),
BestRun: bestRun.ID,
}
if summary.BestRun == nil || bestValue > summary.Metrics[metricKeys[0]].Best {
summary.BestRun = bestRun
}
}
}
return summary, nil
}
func (s *AnalysisService) calculateComparisonStatistics(
experiments []*ExperimentSummary,
metricKeys []string,
) map[string]*MetricStatistics {
stats := make(map[string]*MetricStatistics)
for _, key := range metricKeys {
var bestExp *ExperimentSummary
var bestValue float64 = math.Inf(-1)
for _, exp := range experiments {
if m, ok := exp.Metrics[key]; ok && m.Best > bestValue {
bestValue = m.Best
bestExp = exp
}
}
if bestExp != nil {
stats[key] = &MetricStatistics{
BestExperiment: bestExp.Experiment.ID,
BestValue: bestValue,
}
}
}
return stats
}
// ParameterImportance 参数重要性分析
func (s *AnalysisService) ParameterImportance(
ctx context.Context,
experimentID string,
targetMetric string,
) (*ParameterImportanceResult, error) {
// 获取所有完成的运行
runs, _, err := s.experimentService.SearchRuns(ctx, RunSearchQuery{
ExperimentID: experimentID,
Status: []RunStatus{RunStatusCompleted},
})
if err != nil {
return nil, err
}
// 收集参数和指标
type dataPoint struct {
params map[string]interface{}
metric float64
}
var data []dataPoint
allParams := make(map[string]bool)
for _, run := range runs {
metrics, err := s.metricRepo.Query(ctx, MetricQuery{
RunID: run.ID,
Keys: []string{targetMetric},
})
if err != nil || len(metrics) == 0 {
continue
}
dp := dataPoint{
params: run.Parameters.Flat,
metric: metrics[len(metrics)-1].Value,
}
data = append(data, dp)
for k := range run.Parameters.Flat {
allParams[k] = true
}
}
// 计算每个参数的重要性(使用简单的相关性分析)
result := &ParameterImportanceResult{
Importance: make(map[string]float64),
}
for param := range allParams {
importance := s.calculateParameterCorrelation(data, param)
result.Importance[param] = importance
}
// 排序
type paramScore struct {
name string
score float64
}
var scores []paramScore
for name, score := range result.Importance {
scores = append(scores, paramScore{name, score})
}
sort.Slice(scores, func(i, j int) bool {
return math.Abs(scores[i].score) > math.Abs(scores[j].score)
})
result.Ranking = make([]string, len(scores))
for i, s := range scores {
result.Ranking[i] = s.name
}
return result, nil
}
// ParameterImportanceResult 参数重要性结果
type ParameterImportanceResult struct {
Importance map[string]float64
Ranking []string
}
func (s *AnalysisService) calculateParameterCorrelation(
data []dataPoint,
param string,
) float64 {
// 简化的相关性计算
// 实际应使用更复杂的方法(如 Spearman 相关系数)
var paramValues, metricValues []float64
for _, dp := range data {
if v, ok := dp.params[param]; ok {
// 尝试转换为数值
switch val := v.(type) {
case float64:
paramValues = append(paramValues, val)
metricValues = append(metricValues, dp.metric)
case int:
paramValues = append(paramValues, float64(val))
metricValues = append(metricValues, dp.metric)
}
}
}
if len(paramValues) < 3 {
return 0
}
return pearsonCorrelation(paramValues, metricValues)
}
// 辅助统计函数
func maxFloat(values []float64) float64 {
if len(values) == 0 {
return 0
}
max := values[0]
for _, v := range values[1:] {
if v > max {
max = v
}
}
return max
}
func minFloat(values []float64) float64 {
if len(values) == 0 {
return 0
}
min := values[0]
for _, v := range values[1:] {
if v < min {
min = v
}
}
return min
}
func meanFloat(values []float64) float64 {
if len(values) == 0 {
return 0
}
sum := 0.0
for _, v := range values {
sum += v
}
return sum / float64(len(values))
}
func stdFloat(values []float64) float64 {
if len(values) < 2 {
return 0
}
mean := meanFloat(values)
sumSq := 0.0
for _, v := range values {
sumSq += (v - mean) * (v - mean)
}
return math.Sqrt(sumSq / float64(len(values)-1))
}
func medianFloat(values []float64) float64 {
if len(values) == 0 {
return 0
}
sorted := make([]float64, len(values))
copy(sorted, values)
sort.Float64s(sorted)
mid := len(sorted) / 2
if len(sorted)%2 == 0 {
return (sorted[mid-1] + sorted[mid]) / 2
}
return sorted[mid]
}
func pearsonCorrelation(x, y []float64) float64 {
if len(x) != len(y) || len(x) == 0 {
return 0
}
n := float64(len(x))
sumX, sumY, sumXY, sumX2, sumY2 := 0.0, 0.0, 0.0, 0.0, 0.0
for i := range x {
sumX += x[i]
sumY += y[i]
sumXY += x[i] * y[i]
sumX2 += x[i] * x[i]
sumY2 += y[i] * y[i]
}
numerator := n*sumXY - sumX*sumY
denominator := math.Sqrt((n*sumX2 - sumX*sumX) * (n*sumY2 - sumY*sumY))
if denominator == 0 {
return 0
}
return numerator / denominator
}
可视化组件
指标图表服务
// visualization_service.go
package service
import (
"context"
"encoding/json"
)
// VisualizationService 可视化服务
type VisualizationService struct {
metricRepo MetricRepository
runRepo RunRepository
}
// ChartData 图表数据
type ChartData struct {
Type string `json:"type"` // line, scatter, bar, heatmap
Title string `json:"title"`
XLabel string `json:"x_label"`
YLabel string `json:"y_label"`
Series []ChartSeries `json:"series"`
Options map[string]interface{} `json:"options"`
}
// ChartSeries 图表系列
type ChartSeries struct {
Name string `json:"name"`
Data [][]float64 `json:"data"` // [[x, y], ...]
Color string `json:"color,omitempty"`
Style string `json:"style,omitempty"` // solid, dashed
}
// GetMetricChart 获取指标图表
func (s *VisualizationService) GetMetricChart(
ctx context.Context,
runIDs []string,
metricKey string,
chartType string,
) (*ChartData, error) {
chart := &ChartData{
Type: chartType,
Title: metricKey,
XLabel: "Step",
YLabel: metricKey,
Series: make([]ChartSeries, 0, len(runIDs)),
}
colors := []string{"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"}
for i, runID := range runIDs {
run, err := s.runRepo.Get(ctx, runID)
if err != nil {
continue
}
metrics, err := s.metricRepo.Query(ctx, MetricQuery{
RunID: runID,
Keys: []string{metricKey},
})
if err != nil {
continue
}
series := ChartSeries{
Name: run.Name,
Data: make([][]float64, 0, len(metrics)),
Color: colors[i%len(colors)],
}
for _, m := range metrics {
series.Data = append(series.Data, []float64{float64(m.Step), m.Value})
}
chart.Series = append(chart.Series, series)
}
return chart, nil
}
// GetParallelCoordinatesData 获取平行坐标图数据
func (s *VisualizationService) GetParallelCoordinatesData(
ctx context.Context,
experimentID string,
paramKeys []string,
metricKey string,
) (*ParallelCoordinatesData, error) {
runs, _, err := s.runRepo.Search(ctx, RunSearchQuery{
ExperimentID: experimentID,
Status: []RunStatus{RunStatusCompleted},
})
if err != nil {
return nil, err
}
data := &ParallelCoordinatesData{
Dimensions: make([]Dimension, 0, len(paramKeys)+1),
Data: make([][]interface{}, 0, len(runs)),
}
// 添加参数维度
for _, key := range paramKeys {
data.Dimensions = append(data.Dimensions, Dimension{
Name: key,
Type: "param",
Range: s.calculateRange(runs, key),
})
}
// 添加指标维度
data.Dimensions = append(data.Dimensions, Dimension{
Name: metricKey,
Type: "metric",
})
// 填充数据
for _, run := range runs {
row := make([]interface{}, 0, len(paramKeys)+1)
for _, key := range paramKeys {
if v, ok := run.Parameters.Flat[key]; ok {
row = append(row, v)
} else {
row = append(row, nil)
}
}
// 获取指标值
metrics, _ := s.metricRepo.Query(ctx, MetricQuery{
RunID: run.ID,
Keys: []string{metricKey},
})
if len(metrics) > 0 {
row = append(row, metrics[len(metrics)-1].Value)
} else {
row = append(row, nil)
}
data.Data = append(data.Data, row)
}
return data, nil
}
// ParallelCoordinatesData 平行坐标图数据
type ParallelCoordinatesData struct {
Dimensions []Dimension `json:"dimensions"`
Data [][]interface{} `json:"data"`
}
// Dimension 维度定义
type Dimension struct {
Name string `json:"name"`
Type string `json:"type"` // param, metric
Range []float64 `json:"range,omitempty"`
}
func (s *VisualizationService) calculateRange(runs []*Run, paramKey string) []float64 {
var min, max float64 = 1e10, -1e10
for _, run := range runs {
if v, ok := run.Parameters.Flat[paramKey]; ok {
switch val := v.(type) {
case float64:
if val < min {
min = val
}
if val > max {
max = val
}
}
}
}
return []float64{min, max}
}
// GetContourPlotData 获取等高线图数据(超参数与指标关系)
func (s *VisualizationService) GetContourPlotData(
ctx context.Context,
experimentID string,
paramX, paramY string,
metricKey string,
) (*ContourPlotData, error) {
runs, _, err := s.runRepo.Search(ctx, RunSearchQuery{
ExperimentID: experimentID,
Status: []RunStatus{RunStatusCompleted},
})
if err != nil {
return nil, err
}
data := &ContourPlotData{
XLabel: paramX,
YLabel: paramY,
ZLabel: metricKey,
Points: make([]ContourPoint, 0, len(runs)),
}
for _, run := range runs {
xVal, xOK := run.Parameters.Flat[paramX]
yVal, yOK := run.Parameters.Flat[paramY]
if !xOK || !yOK {
continue
}
metrics, _ := s.metricRepo.Query(ctx, MetricQuery{
RunID: run.ID,
Keys: []string{metricKey},
})
if len(metrics) == 0 {
continue
}
x, _ := toFloat64(xVal)
y, _ := toFloat64(yVal)
z := metrics[len(metrics)-1].Value
data.Points = append(data.Points, ContourPoint{
X: x,
Y: y,
Z: z,
RunID: run.ID,
})
}
return data, nil
}
// ContourPlotData 等高线图数据
type ContourPlotData struct {
XLabel string `json:"x_label"`
YLabel string `json:"y_label"`
ZLabel string `json:"z_label"`
Points []ContourPoint `json:"points"`
}
// ContourPoint 数据点
type ContourPoint struct {
X float64 `json:"x"`
Y float64 `json:"y"`
Z float64 `json:"z"`
RunID string `json:"run_id"`
}
func toFloat64(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
Kubernetes 集成
实验 CRD
# experiment-crd.yaml
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
name: experiments.ai.platform.io
spec:
group: ai.platform.io
versions:
- name: v1
served: true
storage: true
schema:
openAPIV3Schema:
type: object
properties:
spec:
type: object
properties:
projectRef:
type: string
name:
type: string
description:
type: string
hypothesis:
type: string
template:
type: object
properties:
image:
type: string
command:
type: array
items:
type: string
resources:
type: object
properties:
gpus:
type: integer
memory:
type: string
cpu:
type: string
env:
type: array
items:
type: object
properties:
name:
type: string
value:
type: string
runs:
type: array
items:
type: object
properties:
name:
type: string
parameters:
type: object
x-kubernetes-preserve-unknown-fields: true
status:
type: object
properties:
phase:
type: string
activeRuns:
type: integer
completedRuns:
type: integer
failedRuns:
type: integer
bestRun:
type: object
properties:
name:
type: string
metric:
type: number
subresources:
status: {}
scope: Namespaced
names:
plural: experiments
singular: experiment
kind: Experiment
shortNames:
- exp
---
# 示例实验
apiVersion: ai.platform.io/v1
kind: Experiment
metadata:
name: lr-sweep-llm-7b
namespace: ai-experiments
spec:
projectRef: llm-pretraining
name: Learning Rate Sweep
description: "探索不同学习率对 LLM-7B 训练效果的影响"
hypothesis: "较小的学习率(1e-5)可能比较大的学习率(1e-4)获得更好的验证损失"
template:
image: training:v1.0
command:
- python
- train.py
- --config=/config/training.yaml
resources:
gpus: 8
memory: "128Gi"
cpu: "32"
env:
- name: EXPERIMENT_TRACKING_URI
value: "http://experiment-server:8080"
runs:
- name: lr-1e-5
parameters:
learning_rate: 0.00001
warmup_steps: 1000
- name: lr-5e-5
parameters:
learning_rate: 0.00005
warmup_steps: 1000
- name: lr-1e-4
parameters:
learning_rate: 0.0001
warmup_steps: 2000
- name: lr-5e-4
parameters:
learning_rate: 0.0005
warmup_steps: 2000
实验控制器
// experiment_controller.go
package controller
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/runtime"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
aiv1 "ai.platform.io/api/v1"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// ExperimentReconciler 实验协调器
type ExperimentReconciler struct {
client.Client
Scheme *runtime.Scheme
}
// Reconcile 协调实验状态
func (r *ExperimentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
logger := log.FromContext(ctx)
// 获取实验
var experiment aiv1.Experiment
if err := r.Get(ctx, req.NamespacedName, &experiment); err != nil {
return ctrl.Result{}, client.IgnoreNotFound(err)
}
// 处理每个运行配置
for _, runConfig := range experiment.Spec.Runs {
// 检查是否已创建
jobName := fmt.Sprintf("%s-%s", experiment.Name, runConfig.Name)
var existingJob batchv1.Job
err := r.Get(ctx, client.ObjectKey{
Namespace: experiment.Namespace,
Name: jobName,
}, &existingJob)
if err == nil {
// Job 已存在,更新状态
r.updateRunStatus(ctx, &experiment, runConfig.Name, &existingJob)
continue
}
// 创建新 Job
job := r.buildJob(&experiment, runConfig)
if err := r.Create(ctx, job); err != nil {
logger.Error(err, "Failed to create job", "job", jobName)
return ctrl.Result{}, err
}
logger.Info("Created job", "job", jobName)
}
// 更新实验状态
r.updateExperimentStatus(ctx, &experiment)
return ctrl.Result{}, nil
}
// buildJob 构建 Job
func (r *ExperimentReconciler) buildJob(experiment *aiv1.Experiment, runConfig aiv1.RunConfig) *batchv1.Job {
jobName := fmt.Sprintf("%s-%s", experiment.Name, runConfig.Name)
// 构建环境变量
env := experiment.Spec.Template.Env
env = append(env,
corev1.EnvVar{
Name: "RUN_NAME",
Value: runConfig.Name,
},
corev1.EnvVar{
Name: "EXPERIMENT_ID",
Value: experiment.Name,
},
)
// 添加参数作为环境变量
for key, value := range runConfig.Parameters {
env = append(env, corev1.EnvVar{
Name: fmt.Sprintf("PARAM_%s", key),
Value: fmt.Sprintf("%v", value),
})
}
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: jobName,
Namespace: experiment.Namespace,
Labels: map[string]string{
"experiment": experiment.Name,
"run": runConfig.Name,
"ai.platform.io/type": "experiment-run",
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(experiment, aiv1.GroupVersion.WithKind("Experiment")),
},
},
Spec: batchv1.JobSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicyNever,
Containers: []corev1.Container{
{
Name: "training",
Image: experiment.Spec.Template.Image,
Command: experiment.Spec.Template.Command,
Env: env,
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
"nvidia.com/gpu": *parseQuantity(experiment.Spec.Template.Resources.GPUs),
corev1.ResourceMemory: parseQuantity(experiment.Spec.Template.Resources.Memory),
corev1.ResourceCPU: parseQuantity(experiment.Spec.Template.Resources.CPU),
},
},
},
},
},
},
},
}
return job
}
func (r *ExperimentReconciler) updateRunStatus(ctx context.Context, experiment *aiv1.Experiment, runName string, job *batchv1.Job) {
// 根据 Job 状态更新运行状态
}
func (r *ExperimentReconciler) updateExperimentStatus(ctx context.Context, experiment *aiv1.Experiment) {
// 统计运行状态并更新实验状态
}
// SetupWithManager 设置控制器
func (r *ExperimentReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&aiv1.Experiment{}).
Owns(&batchv1.Job{}).
Complete(r)
}
最佳实践
实验管理规范
┌─────────────────────────────────────────────────────────────────┐
│ 实验管理最佳实践 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 实验设计原则 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 明确假设:每个实验应有清晰的假设和预期结果 │ │
│ │ • 控制变量:每次只改变少量变量,便于分析影响 │ │
│ │ • 可复现性:记录所有必要信息以复现实验 │ │
│ │ • 基线对比:始终与基线模型进行对比 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 2. 命名规范 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 项目:{team}-{model}-{task} │ │
│ │ 例如:nlp-llama-pretraining │ │
│ │ │ │
│ │ • 实验:{variable}-{sweep/ablation/comparison} │ │
│ │ 例如:lr-sweep, attention-ablation │ │
│ │ │ │
│ │ • 运行:{experiment}-{config-summary} │ │
│ │ 例如:lr-sweep-1e4-warmup1k │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 3. 必记参数 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 模型配置:架构、参数量、hidden_size 等 │ │
│ │ • 训练配置:learning_rate, batch_size, epochs 等 │ │
│ │ • 数据配置:数据集、预处理方式、数据增强 │ │
│ │ • 环境配置:框架版本、GPU 类型、随机种子 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 4. 必记指标 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 训练指标:loss, gradient_norm, learning_rate │ │
│ │ • 验证指标:eval_loss, perplexity, accuracy │ │
│ │ • 资源指标:gpu_memory, throughput, step_time │ │
│ │ • 系统指标:gpu_utilization, memory_usage │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 5. 记录频率建议 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 训练 loss:每 10-100 步 │ │
│ │ • 学习率:每步或每 100 步 │ │
│ │ • 验证指标:每 epoch 或每 1000 步 │ │
│ │ • 检查点:每 epoch 或每 5000 步 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
小结
本章详细介绍了 AI 训练平台的实验管理系统:
- 核心概念:项目、实验、运行的层次结构和数据模型
- 追踪系统:Go 和 Python SDK 的实现,支持参数、指标、产物追踪
- 服务架构:实验服务、指标存储、产物管理的设计
- 比较分析:实验比较、参数重要性分析的实现
- 可视化:指标图表、平行坐标图、等高线图的数据生成
- Kubernetes 集成:实验 CRD 和控制器实现
下一章我们将探讨 超参数优化,讲解如何自动化地搜索最优超参数配置。