05-实验跟踪与模型注册
概述
实验跟踪和模型注册是 MLOps 的核心环节,确保模型开发过程的可追溯性和生产部署的可管理性。本文深入探讨 MLflow、Weights & Biases、自建模型注册中心等方案的原理与实践。
实验跟踪系统架构
为什么需要实验跟踪
┌─────────────────────────────────────────────────────────────────────┐
│ ML 实验管理挑战 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 典型 ML 开发流程 │ │
│ │ │ │
│ │ 调参 ──► 训练 ──► 评估 ──► 调参 ──► 训练 ──► ... │ │
│ │ │ │ │ │ │ │ │
│ │ ▼ ▼ ▼ ▼ ▼ │ │
│ │ 参数1 模型1 指标1 参数N 模型N │ │
│ │ │ │
│ │ 问题:哪个参数组合产生了最好的结果? │ │
│ │ 问题:如何复现两周前的实验? │ │
│ │ 问题:团队成员的实验如何共享? │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 实验跟踪系统 │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 参数记录 │ │ 指标追踪 │ │ 模型版本 │ │ 可视化 │ │ │
│ │ │ │ │ │ │ │ │ 对比 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ │ 解决:完整实验记录 + 版本管理 + 团队协作 + 可视化分析 │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
实验跟踪工具对比
| 特性 | MLflow | W&B | Neptune | 自建 |
|---|---|---|---|---|
| 开源 | ✓ | 部分 | 部分 | ✓ |
| 托管服务 | ✓ | ✓ | ✓ | ✗ |
| 实验跟踪 | ✓ | ✓ | ✓ | ✓ |
| 模型注册 | ✓ | ✗ | ✓ | ✓ |
| 可视化 | 基础 | 丰富 | 丰富 | 自定义 |
| 协作功能 | 基础 | 强大 | 强大 | 自定义 |
| K8s 集成 | ✓ | ✓ | ✓ | ✓ |
| 成本 | 低 | 中-高 | 中 | 中 |
MLflow 深度实践
MLflow 架构
┌─────────────────────────────────────────────────────────────────────┐
│ MLflow 架构 │
├─────────────────────────────────────────────────────────────────────┤
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ MLflow Components │ │
│ │ │ │
│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ Tracking │ │ Projects │ │ Models │ │ │
│ │ │ │ │ │ │ │ │ │
│ │ │ • 参数记录 │ │ • 代码打包 │ │ • 模型格式 │ │ │
│ │ │ • 指标追踪 │ │ • 环境定义 │ │ • 部署接口 │ │ │
│ │ │ • 产物存储 │ │ • 可复现运行 │ │ • 服务化 │ │ │
│ │ └──────────────┘ └──────────────┘ └──────────────┘ │ │
│ │ │ │
│ │ ┌──────────────────────────────────────────────────────┐ │ │
│ │ │ Model Registry │ │ │
│ │ │ │ │ │
│ │ │ • 模型版本管理 • 阶段转换 • 模型血缘追踪 │ │ │
│ │ └──────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Backend Storage │ │
│ │ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ Tracking Store│ │ Artifact Store│ │ │
│ │ │ │ │ │ │ │
│ │ │ • SQLite │ │ • 本地文件 │ │ │
│ │ │ • MySQL │ │ • S3/GCS │ │ │
│ │ │ • PostgreSQL │ │ • HDFS │ │ │
│ │ └──────────────┘ └──────────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
MLflow Server 部署
# mlflow-server.yaml - Kubernetes 部署
apiVersion: v1
kind: Namespace
metadata:
name: mlflow
---
# PostgreSQL 数据库
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: mlflow-postgres
namespace: mlflow
spec:
serviceName: mlflow-postgres
replicas: 1
selector:
matchLabels:
app: mlflow-postgres
template:
metadata:
labels:
app: mlflow-postgres
spec:
containers:
- name: postgres
image: postgres:14
ports:
- containerPort: 5432
env:
- name: POSTGRES_DB
value: mlflow
- name: POSTGRES_USER
valueFrom:
secretKeyRef:
name: mlflow-postgres-secret
key: username
- name: POSTGRES_PASSWORD
valueFrom:
secretKeyRef:
name: mlflow-postgres-secret
key: password
volumeMounts:
- name: postgres-data
mountPath: /var/lib/postgresql/data
resources:
requests:
cpu: 500m
memory: 1Gi
limits:
cpu: 2
memory: 4Gi
volumeClaimTemplates:
- metadata:
name: postgres-data
spec:
accessModes: ["ReadWriteOnce"]
resources:
requests:
storage: 50Gi
---
apiVersion: v1
kind: Service
metadata:
name: mlflow-postgres
namespace: mlflow
spec:
ports:
- port: 5432
selector:
app: mlflow-postgres
clusterIP: None
---
# MinIO 存储
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: mlflow-minio
namespace: mlflow
spec:
serviceName: mlflow-minio
replicas: 1
selector:
matchLabels:
app: mlflow-minio
template:
metadata:
labels:
app: mlflow-minio
spec:
containers:
- name: minio
image: minio/minio:latest
args:
- server
- /data
- --console-address
- ":9001"
ports:
- containerPort: 9000
name: api
- containerPort: 9001
name: console
env:
- name: MINIO_ROOT_USER
valueFrom:
secretKeyRef:
name: mlflow-minio-secret
key: access-key
- name: MINIO_ROOT_PASSWORD
valueFrom:
secretKeyRef:
name: mlflow-minio-secret
key: secret-key
volumeMounts:
- name: minio-data
mountPath: /data
resources:
requests:
cpu: 500m
memory: 1Gi
limits:
cpu: 2
memory: 4Gi
volumeClaimTemplates:
- metadata:
name: minio-data
spec:
accessModes: ["ReadWriteOnce"]
resources:
requests:
storage: 100Gi
---
apiVersion: v1
kind: Service
metadata:
name: mlflow-minio
namespace: mlflow
spec:
ports:
- port: 9000
name: api
- port: 9001
name: console
selector:
app: mlflow-minio
---
# MLflow Server
apiVersion: apps/v1
kind: Deployment
metadata:
name: mlflow-server
namespace: mlflow
spec:
replicas: 2
selector:
matchLabels:
app: mlflow-server
template:
metadata:
labels:
app: mlflow-server
spec:
containers:
- name: mlflow
image: ghcr.io/mlflow/mlflow:v2.9.0
command:
- mlflow
- server
- --host=0.0.0.0
- --port=5000
- --backend-store-uri=postgresql://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@mlflow-postgres:5432/mlflow
- --default-artifact-root=s3://mlflow-artifacts/
- --serve-artifacts
ports:
- containerPort: 5000
env:
- name: POSTGRES_USER
valueFrom:
secretKeyRef:
name: mlflow-postgres-secret
key: username
- name: POSTGRES_PASSWORD
valueFrom:
secretKeyRef:
name: mlflow-postgres-secret
key: password
- name: MLFLOW_S3_ENDPOINT_URL
value: http://mlflow-minio:9000
- name: AWS_ACCESS_KEY_ID
valueFrom:
secretKeyRef:
name: mlflow-minio-secret
key: access-key
- name: AWS_SECRET_ACCESS_KEY
valueFrom:
secretKeyRef:
name: mlflow-minio-secret
key: secret-key
resources:
requests:
cpu: 500m
memory: 1Gi
limits:
cpu: 2
memory: 4Gi
readinessProbe:
httpGet:
path: /health
port: 5000
initialDelaySeconds: 10
periodSeconds: 10
livenessProbe:
httpGet:
path: /health
port: 5000
initialDelaySeconds: 30
periodSeconds: 30
---
apiVersion: v1
kind: Service
metadata:
name: mlflow-server
namespace: mlflow
spec:
ports:
- port: 5000
targetPort: 5000
selector:
app: mlflow-server
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: mlflow-ingress
namespace: mlflow
annotations:
nginx.ingress.kubernetes.io/proxy-body-size: "0"
spec:
ingressClassName: nginx
rules:
- host: mlflow.example.com
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: mlflow-server
port:
number: 5000
MLflow Python SDK 封装
# mlflow_manager.py
"""
MLflow 实验跟踪与模型注册管理器
提供统一的实验管理接口
"""
import os
import json
import tempfile
from typing import Dict, List, Optional, Any, Union, Callable
from dataclasses import dataclass, field
from datetime import datetime
from contextlib import contextmanager
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import ViewType
from mlflow.models.signature import ModelSignature, infer_signature
from mlflow.types.schema import Schema, ColSpec
import pandas as pd
import numpy as np
@dataclass
class ExperimentConfig:
"""实验配置"""
name: str
tracking_uri: str
artifact_location: Optional[str] = None
tags: Dict[str, str] = field(default_factory=dict)
@dataclass
class RunInfo:
"""运行信息"""
run_id: str
experiment_id: str
status: str
start_time: datetime
end_time: Optional[datetime]
artifact_uri: str
metrics: Dict[str, float]
params: Dict[str, str]
tags: Dict[str, str]
@dataclass
class ModelVersion:
"""模型版本信息"""
name: str
version: str
stage: str
source: str
run_id: str
status: str
creation_timestamp: datetime
description: Optional[str] = None
class MLflowExperimentTracker:
"""MLflow 实验跟踪器"""
def __init__(
self,
tracking_uri: str,
default_experiment: Optional[str] = None
):
self.tracking_uri = tracking_uri
mlflow.set_tracking_uri(tracking_uri)
self.client = MlflowClient(tracking_uri)
if default_experiment:
self.set_experiment(default_experiment)
# ==================== 实验管理 ====================
def set_experiment(
self,
name: str,
artifact_location: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> str:
"""设置当前实验"""
experiment = mlflow.set_experiment(
name,
artifact_location=artifact_location,
tags=tags
)
return experiment.experiment_id
def get_experiment(self, name: str) -> Optional[Dict[str, Any]]:
"""获取实验信息"""
experiment = self.client.get_experiment_by_name(name)
if experiment:
return {
"experiment_id": experiment.experiment_id,
"name": experiment.name,
"artifact_location": experiment.artifact_location,
"lifecycle_stage": experiment.lifecycle_stage,
"tags": experiment.tags
}
return None
def list_experiments(
self,
view_type: str = "ACTIVE_ONLY"
) -> List[Dict[str, Any]]:
"""列出实验"""
view = getattr(ViewType, view_type)
experiments = self.client.search_experiments(view_type=view)
return [
{
"experiment_id": exp.experiment_id,
"name": exp.name,
"artifact_location": exp.artifact_location,
"lifecycle_stage": exp.lifecycle_stage
}
for exp in experiments
]
def delete_experiment(self, name: str):
"""删除实验"""
experiment = self.client.get_experiment_by_name(name)
if experiment:
self.client.delete_experiment(experiment.experiment_id)
# ==================== 运行管理 ====================
@contextmanager
def start_run(
self,
run_name: Optional[str] = None,
experiment_name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
nested: bool = False
):
"""开始运行(上下文管理器)"""
if experiment_name:
self.set_experiment(experiment_name)
with mlflow.start_run(
run_name=run_name,
tags=tags,
nested=nested
) as run:
yield RunContext(run, self)
def get_run(self, run_id: str) -> RunInfo:
"""获取运行信息"""
run = self.client.get_run(run_id)
return RunInfo(
run_id=run.info.run_id,
experiment_id=run.info.experiment_id,
status=run.info.status,
start_time=datetime.fromtimestamp(run.info.start_time / 1000),
end_time=datetime.fromtimestamp(run.info.end_time / 1000) if run.info.end_time else None,
artifact_uri=run.info.artifact_uri,
metrics=run.data.metrics,
params=run.data.params,
tags=run.data.tags
)
def search_runs(
self,
experiment_names: Optional[List[str]] = None,
filter_string: str = "",
max_results: int = 100,
order_by: Optional[List[str]] = None
) -> List[RunInfo]:
"""搜索运行"""
experiment_ids = None
if experiment_names:
experiment_ids = [
self.client.get_experiment_by_name(name).experiment_id
for name in experiment_names
if self.client.get_experiment_by_name(name)
]
runs = self.client.search_runs(
experiment_ids=experiment_ids,
filter_string=filter_string,
max_results=max_results,
order_by=order_by
)
return [
RunInfo(
run_id=run.info.run_id,
experiment_id=run.info.experiment_id,
status=run.info.status,
start_time=datetime.fromtimestamp(run.info.start_time / 1000),
end_time=datetime.fromtimestamp(run.info.end_time / 1000) if run.info.end_time else None,
artifact_uri=run.info.artifact_uri,
metrics=run.data.metrics,
params=run.data.params,
tags=run.data.tags
)
for run in runs
]
def delete_run(self, run_id: str):
"""删除运行"""
self.client.delete_run(run_id)
# ==================== 指标与参数 ====================
def log_param(self, key: str, value: Any):
"""记录参数"""
mlflow.log_param(key, value)
def log_params(self, params: Dict[str, Any]):
"""批量记录参数"""
mlflow.log_params(params)
def log_metric(self, key: str, value: float, step: Optional[int] = None):
"""记录指标"""
mlflow.log_metric(key, value, step=step)
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""批量记录指标"""
mlflow.log_metrics(metrics, step=step)
# ==================== Artifact 管理 ====================
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
"""记录 artifact"""
mlflow.log_artifact(local_path, artifact_path)
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None):
"""记录目录"""
mlflow.log_artifacts(local_dir, artifact_path)
def log_figure(self, figure: Any, artifact_file: str):
"""记录图表"""
mlflow.log_figure(figure, artifact_file)
def log_dict(self, dictionary: Dict[str, Any], artifact_file: str):
"""记录字典为 JSON"""
mlflow.log_dict(dictionary, artifact_file)
def log_dataframe(
self,
df: pd.DataFrame,
artifact_file: str,
format: str = "parquet"
):
"""记录 DataFrame"""
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as f:
if format == "parquet":
df.to_parquet(f.name, index=False)
elif format == "csv":
df.to_csv(f.name, index=False)
mlflow.log_artifact(f.name, "data")
os.unlink(f.name)
def download_artifacts(self, run_id: str, path: str, dst_path: str) -> str:
"""下载 artifacts"""
return self.client.download_artifacts(run_id, path, dst_path)
class RunContext:
"""运行上下文"""
def __init__(self, run: mlflow.ActiveRun, tracker: MLflowExperimentTracker):
self.run = run
self.tracker = tracker
self.run_id = run.info.run_id
def log_param(self, key: str, value: Any):
self.tracker.log_param(key, value)
def log_params(self, params: Dict[str, Any]):
self.tracker.log_params(params)
def log_metric(self, key: str, value: float, step: Optional[int] = None):
self.tracker.log_metric(key, value, step=step)
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
self.tracker.log_metrics(metrics, step=step)
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
self.tracker.log_artifact(local_path, artifact_path)
def log_model(
self,
model: Any,
artifact_path: str,
flavor: str = "sklearn",
signature: Optional[ModelSignature] = None,
input_example: Optional[Any] = None,
registered_model_name: Optional[str] = None
):
"""记录模型"""
log_func = getattr(mlflow, flavor)
log_func.log_model(
model,
artifact_path,
signature=signature,
input_example=input_example,
registered_model_name=registered_model_name
)
def set_tag(self, key: str, value: str):
"""设置标签"""
mlflow.set_tag(key, value)
def set_tags(self, tags: Dict[str, str]):
"""批量设置标签"""
mlflow.set_tags(tags)
class MLflowModelRegistry:
"""MLflow 模型注册中心"""
def __init__(self, tracking_uri: str):
self.tracking_uri = tracking_uri
mlflow.set_tracking_uri(tracking_uri)
self.client = MlflowClient(tracking_uri)
# ==================== 模型管理 ====================
def create_registered_model(
self,
name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
):
"""创建注册模型"""
self.client.create_registered_model(
name,
description=description,
tags=[{"key": k, "value": v} for k, v in (tags or {}).items()]
)
def list_registered_models(
self,
filter_string: str = "",
max_results: int = 100
) -> List[Dict[str, Any]]:
"""列出注册模型"""
models = self.client.search_registered_models(
filter_string=filter_string,
max_results=max_results
)
return [
{
"name": model.name,
"description": model.description,
"creation_timestamp": model.creation_timestamp,
"last_updated_timestamp": model.last_updated_timestamp,
"latest_versions": [
{
"version": v.version,
"stage": v.current_stage,
"status": v.status
}
for v in model.latest_versions
]
}
for model in models
]
def get_registered_model(self, name: str) -> Dict[str, Any]:
"""获取注册模型详情"""
model = self.client.get_registered_model(name)
return {
"name": model.name,
"description": model.description,
"creation_timestamp": model.creation_timestamp,
"last_updated_timestamp": model.last_updated_timestamp,
"tags": model.tags,
"latest_versions": [
{
"version": v.version,
"stage": v.current_stage,
"status": v.status,
"run_id": v.run_id
}
for v in model.latest_versions
]
}
def update_registered_model(
self,
name: str,
description: Optional[str] = None
):
"""更新注册模型"""
self.client.update_registered_model(name, description=description)
def delete_registered_model(self, name: str):
"""删除注册模型"""
self.client.delete_registered_model(name)
# ==================== 版本管理 ====================
def create_model_version(
self,
name: str,
source: str,
run_id: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> ModelVersion:
"""创建模型版本"""
version = self.client.create_model_version(
name=name,
source=source,
run_id=run_id,
description=description,
tags=[{"key": k, "value": v} for k, v in (tags or {}).items()]
)
return ModelVersion(
name=version.name,
version=version.version,
stage=version.current_stage,
source=version.source,
run_id=version.run_id,
status=version.status,
creation_timestamp=datetime.fromtimestamp(version.creation_timestamp / 1000),
description=version.description
)
def get_model_version(self, name: str, version: str) -> ModelVersion:
"""获取模型版本"""
v = self.client.get_model_version(name, version)
return ModelVersion(
name=v.name,
version=v.version,
stage=v.current_stage,
source=v.source,
run_id=v.run_id,
status=v.status,
creation_timestamp=datetime.fromtimestamp(v.creation_timestamp / 1000),
description=v.description
)
def search_model_versions(
self,
filter_string: str = "",
max_results: int = 100
) -> List[ModelVersion]:
"""搜索模型版本"""
versions = self.client.search_model_versions(
filter_string=filter_string,
max_results=max_results
)
return [
ModelVersion(
name=v.name,
version=v.version,
stage=v.current_stage,
source=v.source,
run_id=v.run_id,
status=v.status,
creation_timestamp=datetime.fromtimestamp(v.creation_timestamp / 1000),
description=v.description
)
for v in versions
]
def update_model_version(
self,
name: str,
version: str,
description: Optional[str] = None
):
"""更新模型版本"""
self.client.update_model_version(name, version, description=description)
def delete_model_version(self, name: str, version: str):
"""删除模型版本"""
self.client.delete_model_version(name, version)
# ==================== 阶段转换 ====================
def transition_model_version_stage(
self,
name: str,
version: str,
stage: str,
archive_existing_versions: bool = True
):
"""转换模型版本阶段"""
self.client.transition_model_version_stage(
name=name,
version=version,
stage=stage,
archive_existing_versions=archive_existing_versions
)
def get_latest_versions(
self,
name: str,
stages: Optional[List[str]] = None
) -> List[ModelVersion]:
"""获取最新版本"""
versions = self.client.get_latest_versions(name, stages=stages)
return [
ModelVersion(
name=v.name,
version=v.version,
stage=v.current_stage,
source=v.source,
run_id=v.run_id,
status=v.status,
creation_timestamp=datetime.fromtimestamp(v.creation_timestamp / 1000),
description=v.description
)
for v in versions
]
# ==================== 模型加载 ====================
def load_model(
self,
name: str,
version: Optional[str] = None,
stage: Optional[str] = None
) -> Any:
"""加载模型"""
if version:
model_uri = f"models:/{name}/{version}"
elif stage:
model_uri = f"models:/{name}/{stage}"
else:
model_uri = f"models:/{name}/latest"
return mlflow.pyfunc.load_model(model_uri)
class MLflowAutolog:
"""MLflow 自动日志"""
@staticmethod
def enable_sklearn():
"""启用 sklearn 自动日志"""
mlflow.sklearn.autolog()
@staticmethod
def enable_pytorch():
"""启用 PyTorch 自动日志"""
mlflow.pytorch.autolog()
@staticmethod
def enable_tensorflow():
"""启用 TensorFlow 自动日志"""
mlflow.tensorflow.autolog()
@staticmethod
def enable_xgboost():
"""启用 XGBoost 自动日志"""
mlflow.xgboost.autolog()
@staticmethod
def enable_lightgbm():
"""启用 LightGBM 自动日志"""
mlflow.lightgbm.autolog()
@staticmethod
def enable_all():
"""启用所有自动日志"""
mlflow.autolog()
# ==================== 使用示例 ====================
def training_example():
"""训练示例"""
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
# 初始化跟踪器
tracker = MLflowExperimentTracker(
tracking_uri="http://mlflow-server:5000",
default_experiment="iris-classification"
)
# 加载数据
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
# 训练参数
params = {
"n_estimators": 100,
"max_depth": 5,
"random_state": 42
}
# 启动运行
with tracker.start_run(run_name="rf-baseline") as run:
# 记录参数
run.log_params(params)
run.set_tag("model_type", "random_forest")
# 训练模型
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# 预测和评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="weighted")
# 记录指标
run.log_metrics({
"accuracy": accuracy,
"f1_score": f1
})
# 推断签名
signature = infer_signature(X_train, model.predict(X_train))
# 记录模型
run.log_model(
model,
"model",
flavor="sklearn",
signature=signature,
input_example=X_train[:5],
registered_model_name="iris-classifier"
)
print(f"Run ID: {run.run_id}")
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
def model_registry_example():
"""模型注册示例"""
registry = MLflowModelRegistry("http://mlflow-server:5000")
# 获取模型信息
model_info = registry.get_registered_model("iris-classifier")
print(f"Model: {model_info['name']}")
print(f"Latest versions: {model_info['latest_versions']}")
# 获取最新 Staging 版本
staging_versions = registry.get_latest_versions(
"iris-classifier",
stages=["Staging"]
)
if staging_versions:
version = staging_versions[0]
print(f"Staging version: {version.version}")
# 推广到 Production
registry.transition_model_version_stage(
"iris-classifier",
version.version,
"Production",
archive_existing_versions=True
)
print(f"Promoted version {version.version} to Production")
# 加载 Production 模型
model = registry.load_model("iris-classifier", stage="Production")
print(f"Loaded production model")
if __name__ == "__main__":
training_example()
model_registry_example()
自建模型注册中心
模型注册中心架构
┌─────────────────────────────────────────────────────────────────────┐
│ 自建模型注册中心架构 │
├─────────────────────────────────────────────────────────────────────┤
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ API Gateway │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 认证 │ │ 限流 │ │ 路由 │ │ 日志 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Model Registry Service │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 模型管理 │ │ 版本管理 │ │ 阶段管理 │ │ 权限管理 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 元数据 │ │ 血缘追踪 │ │ 审批流程 │ │ 通知服务 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────┐ ┌───────────────────┐ ┌──────────────┐ │
│ │ Metadata DB │ │ Artifact Store │ │ Event Bus │ │
│ │ (PostgreSQL) │ │ (S3/MinIO) │ │ (Kafka) │ │
│ └───────────────────┘ └───────────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
模型注册中心实现
# model_registry.py
"""
自建模型注册中心
支持模型版本管理、审批流程、血缘追踪
"""
import os
import hashlib
import json
from typing import Dict, List, Optional, Any, BinaryIO
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import uuid
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Text, ForeignKey, Enum as SQLEnum
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.dialects.postgresql import JSONB
import boto3
from botocore.exceptions import ClientError
Base = declarative_base()
class ModelStage(str, Enum):
"""模型阶段"""
NONE = "None"
STAGING = "Staging"
PRODUCTION = "Production"
ARCHIVED = "Archived"
class ApprovalStatus(str, Enum):
"""审批状态"""
PENDING = "Pending"
APPROVED = "Approved"
REJECTED = "Rejected"
# ==================== 数据模型 ====================
class RegisteredModel(Base):
"""注册模型表"""
__tablename__ = "registered_models"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(255), unique=True, nullable=False)
description = Column(Text)
owner = Column(String(255))
tags = Column(JSONB, default={})
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
versions = relationship("ModelVersion", back_populates="model", cascade="all, delete-orphan")
class ModelVersion(Base):
"""模型版本表"""
__tablename__ = "model_versions"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
model_id = Column(String(36), ForeignKey("registered_models.id"), nullable=False)
version = Column(Integer, nullable=False)
stage = Column(SQLEnum(ModelStage), default=ModelStage.NONE)
description = Column(Text)
source_uri = Column(String(1024), nullable=False)
artifact_uri = Column(String(1024))
run_id = Column(String(255))
experiment_id = Column(String(255))
framework = Column(String(100))
framework_version = Column(String(50))
signature = Column(JSONB)
input_example = Column(JSONB)
metrics = Column(JSONB, default={})
params = Column(JSONB, default={})
tags = Column(JSONB, default={})
created_at = Column(DateTime, default=datetime.utcnow)
created_by = Column(String(255))
model = relationship("RegisteredModel", back_populates="versions")
approvals = relationship("ModelApproval", back_populates="model_version", cascade="all, delete-orphan")
lineage = relationship("ModelLineage", back_populates="model_version", cascade="all, delete-orphan")
class ModelApproval(Base):
"""模型审批表"""
__tablename__ = "model_approvals"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
version_id = Column(String(36), ForeignKey("model_versions.id"), nullable=False)
from_stage = Column(SQLEnum(ModelStage), nullable=False)
to_stage = Column(SQLEnum(ModelStage), nullable=False)
status = Column(SQLEnum(ApprovalStatus), default=ApprovalStatus.PENDING)
requester = Column(String(255), nullable=False)
approver = Column(String(255))
comment = Column(Text)
requested_at = Column(DateTime, default=datetime.utcnow)
resolved_at = Column(DateTime)
model_version = relationship("ModelVersion", back_populates="approvals")
class ModelLineage(Base):
"""模型血缘表"""
__tablename__ = "model_lineage"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
version_id = Column(String(36), ForeignKey("model_versions.id"), nullable=False)
lineage_type = Column(String(50)) # dataset, parent_model, code
source_uri = Column(String(1024))
source_version = Column(String(255))
metadata = Column(JSONB, default={})
created_at = Column(DateTime, default=datetime.utcnow)
model_version = relationship("ModelVersion", back_populates="lineage")
# ==================== 存储服务 ====================
class ArtifactStore:
"""Artifact 存储服务"""
def __init__(
self,
endpoint_url: str,
bucket: str,
access_key: str,
secret_key: str
):
self.bucket = bucket
self.client = boto3.client(
"s3",
endpoint_url=endpoint_url,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
self._ensure_bucket()
def _ensure_bucket(self):
"""确保 bucket 存在"""
try:
self.client.head_bucket(Bucket=self.bucket)
except ClientError:
self.client.create_bucket(Bucket=self.bucket)
def upload_model(
self,
model_name: str,
version: int,
file_obj: BinaryIO,
filename: str
) -> str:
"""上传模型文件"""
key = f"models/{model_name}/{version}/{filename}"
self.client.upload_fileobj(file_obj, self.bucket, key)
return f"s3://{self.bucket}/{key}"
def download_model(self, artifact_uri: str, local_path: str):
"""下载模型文件"""
# 解析 S3 URI
parts = artifact_uri.replace("s3://", "").split("/", 1)
bucket = parts[0]
key = parts[1]
os.makedirs(os.path.dirname(local_path), exist_ok=True)
self.client.download_file(bucket, key, local_path)
def delete_model(self, artifact_uri: str):
"""删除模型文件"""
parts = artifact_uri.replace("s3://", "").split("/", 1)
bucket = parts[0]
key = parts[1]
# 删除前缀下所有对象
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=key):
for obj in page.get("Contents", []):
self.client.delete_object(Bucket=bucket, Key=obj["Key"])
# ==================== 注册中心服务 ====================
@dataclass
class ModelVersionInfo:
"""模型版本信息"""
id: str
model_name: str
version: int
stage: ModelStage
description: Optional[str]
artifact_uri: str
framework: Optional[str]
metrics: Dict[str, float]
params: Dict[str, Any]
tags: Dict[str, str]
created_at: datetime
created_by: str
class ModelRegistryService:
"""模型注册中心服务"""
def __init__(
self,
database_url: str,
artifact_store: ArtifactStore
):
self.engine = create_engine(database_url)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
self.artifact_store = artifact_store
# ==================== 模型管理 ====================
def create_model(
self,
name: str,
description: Optional[str] = None,
owner: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> str:
"""创建注册模型"""
session = self.Session()
try:
model = RegisteredModel(
name=name,
description=description,
owner=owner,
tags=tags or {}
)
session.add(model)
session.commit()
return model.id
finally:
session.close()
def get_model(self, name: str) -> Optional[Dict[str, Any]]:
"""获取模型信息"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=name).first()
if not model:
return None
return {
"id": model.id,
"name": model.name,
"description": model.description,
"owner": model.owner,
"tags": model.tags,
"created_at": model.created_at,
"updated_at": model.updated_at,
"version_count": len(model.versions)
}
finally:
session.close()
def list_models(
self,
filter_tags: Optional[Dict[str, str]] = None,
limit: int = 100,
offset: int = 0
) -> List[Dict[str, Any]]:
"""列出模型"""
session = self.Session()
try:
query = session.query(RegisteredModel)
if filter_tags:
for key, value in filter_tags.items():
query = query.filter(
RegisteredModel.tags[key].astext == value
)
models = query.offset(offset).limit(limit).all()
return [
{
"id": m.id,
"name": m.name,
"description": m.description,
"owner": m.owner,
"tags": m.tags,
"created_at": m.created_at,
"version_count": len(m.versions)
}
for m in models
]
finally:
session.close()
def update_model(
self,
name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
):
"""更新模型"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=name).first()
if not model:
raise ValueError(f"Model not found: {name}")
if description is not None:
model.description = description
if tags is not None:
model.tags = {**model.tags, **tags}
session.commit()
finally:
session.close()
def delete_model(self, name: str):
"""删除模型"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=name).first()
if not model:
raise ValueError(f"Model not found: {name}")
# 删除所有版本的 artifacts
for version in model.versions:
if version.artifact_uri:
self.artifact_store.delete_model(version.artifact_uri)
session.delete(model)
session.commit()
finally:
session.close()
# ==================== 版本管理 ====================
def create_version(
self,
model_name: str,
source_uri: str,
model_file: Optional[BinaryIO] = None,
filename: str = "model.pkl",
description: Optional[str] = None,
run_id: Optional[str] = None,
experiment_id: Optional[str] = None,
framework: Optional[str] = None,
framework_version: Optional[str] = None,
signature: Optional[Dict[str, Any]] = None,
metrics: Optional[Dict[str, float]] = None,
params: Optional[Dict[str, Any]] = None,
tags: Optional[Dict[str, str]] = None,
created_by: str = "system"
) -> ModelVersionInfo:
"""创建模型版本"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
raise ValueError(f"Model not found: {model_name}")
# 获取下一个版本号
max_version = session.query(ModelVersion).filter_by(
model_id=model.id
).count()
new_version = max_version + 1
# 上传模型文件
artifact_uri = None
if model_file:
artifact_uri = self.artifact_store.upload_model(
model_name, new_version, model_file, filename
)
# 创建版本记录
version = ModelVersion(
model_id=model.id,
version=new_version,
stage=ModelStage.NONE,
description=description,
source_uri=source_uri,
artifact_uri=artifact_uri,
run_id=run_id,
experiment_id=experiment_id,
framework=framework,
framework_version=framework_version,
signature=signature,
metrics=metrics or {},
params=params or {},
tags=tags or {},
created_by=created_by
)
session.add(version)
session.commit()
return ModelVersionInfo(
id=version.id,
model_name=model_name,
version=version.version,
stage=version.stage,
description=version.description,
artifact_uri=version.artifact_uri or "",
framework=version.framework,
metrics=version.metrics,
params=version.params,
tags=version.tags,
created_at=version.created_at,
created_by=version.created_by
)
finally:
session.close()
def get_version(self, model_name: str, version: int) -> Optional[ModelVersionInfo]:
"""获取模型版本"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
return None
v = session.query(ModelVersion).filter_by(
model_id=model.id,
version=version
).first()
if not v:
return None
return ModelVersionInfo(
id=v.id,
model_name=model_name,
version=v.version,
stage=v.stage,
description=v.description,
artifact_uri=v.artifact_uri or "",
framework=v.framework,
metrics=v.metrics,
params=v.params,
tags=v.tags,
created_at=v.created_at,
created_by=v.created_by
)
finally:
session.close()
def get_latest_version(
self,
model_name: str,
stage: Optional[ModelStage] = None
) -> Optional[ModelVersionInfo]:
"""获取最新版本"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
return None
query = session.query(ModelVersion).filter_by(model_id=model.id)
if stage:
query = query.filter_by(stage=stage)
v = query.order_by(ModelVersion.version.desc()).first()
if not v:
return None
return ModelVersionInfo(
id=v.id,
model_name=model_name,
version=v.version,
stage=v.stage,
description=v.description,
artifact_uri=v.artifact_uri or "",
framework=v.framework,
metrics=v.metrics,
params=v.params,
tags=v.tags,
created_at=v.created_at,
created_by=v.created_by
)
finally:
session.close()
def list_versions(
self,
model_name: str,
stage: Optional[ModelStage] = None
) -> List[ModelVersionInfo]:
"""列出模型版本"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
return []
query = session.query(ModelVersion).filter_by(model_id=model.id)
if stage:
query = query.filter_by(stage=stage)
versions = query.order_by(ModelVersion.version.desc()).all()
return [
ModelVersionInfo(
id=v.id,
model_name=model_name,
version=v.version,
stage=v.stage,
description=v.description,
artifact_uri=v.artifact_uri or "",
framework=v.framework,
metrics=v.metrics,
params=v.params,
tags=v.tags,
created_at=v.created_at,
created_by=v.created_by
)
for v in versions
]
finally:
session.close()
# ==================== 阶段转换 ====================
def transition_stage(
self,
model_name: str,
version: int,
to_stage: ModelStage,
requester: str,
require_approval: bool = True
) -> Optional[str]:
"""转换模型阶段"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
raise ValueError(f"Model not found: {model_name}")
v = session.query(ModelVersion).filter_by(
model_id=model.id,
version=version
).first()
if not v:
raise ValueError(f"Version not found: {model_name}/{version}")
from_stage = v.stage
if require_approval and to_stage == ModelStage.PRODUCTION:
# 需要审批
approval = ModelApproval(
version_id=v.id,
from_stage=from_stage,
to_stage=to_stage,
requester=requester
)
session.add(approval)
session.commit()
return approval.id
else:
# 直接转换
v.stage = to_stage
# 归档旧的 Production 版本
if to_stage == ModelStage.PRODUCTION:
old_prod = session.query(ModelVersion).filter_by(
model_id=model.id,
stage=ModelStage.PRODUCTION
).filter(ModelVersion.id != v.id).all()
for old_v in old_prod:
old_v.stage = ModelStage.ARCHIVED
session.commit()
return None
finally:
session.close()
def approve_transition(
self,
approval_id: str,
approver: str,
approved: bool,
comment: Optional[str] = None
):
"""审批阶段转换"""
session = self.Session()
try:
approval = session.query(ModelApproval).filter_by(id=approval_id).first()
if not approval:
raise ValueError(f"Approval not found: {approval_id}")
if approval.status != ApprovalStatus.PENDING:
raise ValueError(f"Approval already resolved")
approval.status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED
approval.approver = approver
approval.comment = comment
approval.resolved_at = datetime.utcnow()
if approved:
# 执行阶段转换
version = approval.model_version
version.stage = approval.to_stage
# 归档旧的 Production 版本
if approval.to_stage == ModelStage.PRODUCTION:
old_prod = session.query(ModelVersion).filter_by(
model_id=version.model_id,
stage=ModelStage.PRODUCTION
).filter(ModelVersion.id != version.id).all()
for old_v in old_prod:
old_v.stage = ModelStage.ARCHIVED
session.commit()
finally:
session.close()
# ==================== 血缘追踪 ====================
def add_lineage(
self,
model_name: str,
version: int,
lineage_type: str,
source_uri: str,
source_version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""添加血缘信息"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
raise ValueError(f"Model not found: {model_name}")
v = session.query(ModelVersion).filter_by(
model_id=model.id,
version=version
).first()
if not v:
raise ValueError(f"Version not found: {model_name}/{version}")
lineage = ModelLineage(
version_id=v.id,
lineage_type=lineage_type,
source_uri=source_uri,
source_version=source_version,
metadata=metadata or {}
)
session.add(lineage)
session.commit()
finally:
session.close()
def get_lineage(
self,
model_name: str,
version: int
) -> List[Dict[str, Any]]:
"""获取血缘信息"""
session = self.Session()
try:
model = session.query(RegisteredModel).filter_by(name=model_name).first()
if not model:
return []
v = session.query(ModelVersion).filter_by(
model_id=model.id,
version=version
).first()
if not v:
return []
return [
{
"type": l.lineage_type,
"source_uri": l.source_uri,
"source_version": l.source_version,
"metadata": l.metadata,
"created_at": l.created_at
}
for l in v.lineage
]
finally:
session.close()
# ==================== 使用示例 ====================
if __name__ == "__main__":
# 初始化存储
artifact_store = ArtifactStore(
endpoint_url="http://minio:9000",
bucket="model-registry",
access_key="admin",
secret_key="password"
)
# 初始化服务
registry = ModelRegistryService(
database_url="postgresql://user:pass@postgres:5432/model_registry",
artifact_store=artifact_store
)
# 创建模型
registry.create_model(
name="fraud-detector",
description="Fraud detection model",
owner="ml-team",
tags={"domain": "finance", "type": "classification"}
)
# 创建版本
version = registry.create_version(
model_name="fraud-detector",
source_uri="mlflow://experiment/run123",
description="Baseline model with XGBoost",
framework="xgboost",
framework_version="1.7.0",
metrics={"auc": 0.95, "precision": 0.88},
params={"max_depth": 6, "learning_rate": 0.1},
created_by="alice"
)
print(f"Created version: {version.version}")
# 添加血缘
registry.add_lineage(
"fraud-detector",
version.version,
lineage_type="dataset",
source_uri="s3://data/fraud-dataset-v2",
source_version="2024-01-01",
metadata={"rows": 1000000, "features": 50}
)
# 请求推广到 Production
approval_id = registry.transition_stage(
"fraud-detector",
version.version,
ModelStage.PRODUCTION,
requester="alice",
require_approval=True
)
print(f"Approval requested: {approval_id}")
# 审批
if approval_id:
registry.approve_transition(
approval_id,
approver="bob",
approved=True,
comment="Metrics look good"
)
print("Approved and promoted to Production")
# 获取 Production 版本
prod_version = registry.get_latest_version("fraud-detector", ModelStage.PRODUCTION)
print(f"Production version: {prod_version.version}")
总结
实验跟踪和模型注册是 MLOps 的核心组件:
实验跟踪
- 参数、指标、Artifact 完整记录
- 支持实验对比和可视化
- 确保实验可复现性
模型注册
- 统一的模型版本管理
- 阶段转换(Staging → Production)
- 审批流程和血缘追踪
工具选择
- MLflow:开源、功能全面、社区活跃
- W&B:可视化强大、协作功能丰富
- 自建:定制化需求、企业合规要求
最佳实践
- 自动化实验记录(autolog)
- 标准化模型打包格式
- 完善的审批和治理流程
- 完整的血缘追踪
下一章节将进入 MLOps 实践,探讨如何将这些组件整合成完整的机器学习平台。