HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

05-实验跟踪与模型注册

概述

实验跟踪和模型注册是 MLOps 的核心环节,确保模型开发过程的可追溯性和生产部署的可管理性。本文深入探讨 MLflow、Weights & Biases、自建模型注册中心等方案的原理与实践。

实验跟踪系统架构

为什么需要实验跟踪

┌─────────────────────────────────────────────────────────────────────┐
│                    ML 实验管理挑战                                   │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                    典型 ML 开发流程                          │   │
│  │                                                             │   │
│  │    调参 ──► 训练 ──► 评估 ──► 调参 ──► 训练 ──► ...        │   │
│  │     │        │        │        │        │                   │   │
│  │     ▼        ▼        ▼        ▼        ▼                   │   │
│  │   参数1    模型1    指标1    参数N    模型N                  │   │
│  │                                                             │   │
│  │    问题:哪个参数组合产生了最好的结果?                      │   │
│  │    问题:如何复现两周前的实验?                              │   │
│  │    问题:团队成员的实验如何共享?                            │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                   实验跟踪系统                                │   │
│  │                                                             │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ 参数记录  │  │ 指标追踪  │  │ 模型版本  │  │ 可视化    │    │   │
│  │  │          │  │          │  │          │  │ 对比      │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  │                                                             │   │
│  │  解决:完整实验记录 + 版本管理 + 团队协作 + 可视化分析       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

实验跟踪工具对比

特性MLflowW&BNeptune自建
开源✓部分部分✓
托管服务✓✓✓✗
实验跟踪✓✓✓✓
模型注册✓✗✓✓
可视化基础丰富丰富自定义
协作功能基础强大强大自定义
K8s 集成✓✓✓✓
成本低中-高中中

MLflow 深度实践

MLflow 架构

┌─────────────────────────────────────────────────────────────────────┐
│                        MLflow 架构                                  │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                     MLflow Components                        │   │
│  │                                                             │   │
│  │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐      │   │
│  │  │   Tracking   │  │   Projects   │  │    Models    │      │   │
│  │  │              │  │              │  │              │      │   │
│  │  │ • 参数记录    │  │ • 代码打包    │  │ • 模型格式    │      │   │
│  │  │ • 指标追踪    │  │ • 环境定义    │  │ • 部署接口    │      │   │
│  │  │ • 产物存储    │  │ • 可复现运行  │  │ • 服务化      │      │   │
│  │  └──────────────┘  └──────────────┘  └──────────────┘      │   │
│  │                                                             │   │
│  │  ┌──────────────────────────────────────────────────────┐  │   │
│  │  │                  Model Registry                       │  │   │
│  │  │                                                       │  │   │
│  │  │  • 模型版本管理    • 阶段转换    • 模型血缘追踪       │  │   │
│  │  └──────────────────────────────────────────────────────┘  │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                     Backend Storage                          │   │
│  │  ┌──────────────┐              ┌──────────────┐             │   │
│  │  │ Tracking Store│              │ Artifact Store│            │   │
│  │  │              │              │              │             │   │
│  │  │ • SQLite     │              │ • 本地文件    │             │   │
│  │  │ • MySQL      │              │ • S3/GCS     │             │   │
│  │  │ • PostgreSQL │              │ • HDFS       │             │   │
│  │  └──────────────┘              └──────────────┘             │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

MLflow Server 部署

# mlflow-server.yaml - Kubernetes 部署
apiVersion: v1
kind: Namespace
metadata:
  name: mlflow
---
# PostgreSQL 数据库
apiVersion: apps/v1
kind: StatefulSet
metadata:
  name: mlflow-postgres
  namespace: mlflow
spec:
  serviceName: mlflow-postgres
  replicas: 1
  selector:
    matchLabels:
      app: mlflow-postgres
  template:
    metadata:
      labels:
        app: mlflow-postgres
    spec:
      containers:
      - name: postgres
        image: postgres:14
        ports:
        - containerPort: 5432
        env:
        - name: POSTGRES_DB
          value: mlflow
        - name: POSTGRES_USER
          valueFrom:
            secretKeyRef:
              name: mlflow-postgres-secret
              key: username
        - name: POSTGRES_PASSWORD
          valueFrom:
            secretKeyRef:
              name: mlflow-postgres-secret
              key: password
        volumeMounts:
        - name: postgres-data
          mountPath: /var/lib/postgresql/data
        resources:
          requests:
            cpu: 500m
            memory: 1Gi
          limits:
            cpu: 2
            memory: 4Gi
  volumeClaimTemplates:
  - metadata:
      name: postgres-data
    spec:
      accessModes: ["ReadWriteOnce"]
      resources:
        requests:
          storage: 50Gi
---
apiVersion: v1
kind: Service
metadata:
  name: mlflow-postgres
  namespace: mlflow
spec:
  ports:
  - port: 5432
  selector:
    app: mlflow-postgres
  clusterIP: None
---
# MinIO 存储
apiVersion: apps/v1
kind: StatefulSet
metadata:
  name: mlflow-minio
  namespace: mlflow
spec:
  serviceName: mlflow-minio
  replicas: 1
  selector:
    matchLabels:
      app: mlflow-minio
  template:
    metadata:
      labels:
        app: mlflow-minio
    spec:
      containers:
      - name: minio
        image: minio/minio:latest
        args:
        - server
        - /data
        - --console-address
        - ":9001"
        ports:
        - containerPort: 9000
          name: api
        - containerPort: 9001
          name: console
        env:
        - name: MINIO_ROOT_USER
          valueFrom:
            secretKeyRef:
              name: mlflow-minio-secret
              key: access-key
        - name: MINIO_ROOT_PASSWORD
          valueFrom:
            secretKeyRef:
              name: mlflow-minio-secret
              key: secret-key
        volumeMounts:
        - name: minio-data
          mountPath: /data
        resources:
          requests:
            cpu: 500m
            memory: 1Gi
          limits:
            cpu: 2
            memory: 4Gi
  volumeClaimTemplates:
  - metadata:
      name: minio-data
    spec:
      accessModes: ["ReadWriteOnce"]
      resources:
        requests:
          storage: 100Gi
---
apiVersion: v1
kind: Service
metadata:
  name: mlflow-minio
  namespace: mlflow
spec:
  ports:
  - port: 9000
    name: api
  - port: 9001
    name: console
  selector:
    app: mlflow-minio
---
# MLflow Server
apiVersion: apps/v1
kind: Deployment
metadata:
  name: mlflow-server
  namespace: mlflow
spec:
  replicas: 2
  selector:
    matchLabels:
      app: mlflow-server
  template:
    metadata:
      labels:
        app: mlflow-server
    spec:
      containers:
      - name: mlflow
        image: ghcr.io/mlflow/mlflow:v2.9.0
        command:
        - mlflow
        - server
        - --host=0.0.0.0
        - --port=5000
        - --backend-store-uri=postgresql://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@mlflow-postgres:5432/mlflow
        - --default-artifact-root=s3://mlflow-artifacts/
        - --serve-artifacts
        ports:
        - containerPort: 5000
        env:
        - name: POSTGRES_USER
          valueFrom:
            secretKeyRef:
              name: mlflow-postgres-secret
              key: username
        - name: POSTGRES_PASSWORD
          valueFrom:
            secretKeyRef:
              name: mlflow-postgres-secret
              key: password
        - name: MLFLOW_S3_ENDPOINT_URL
          value: http://mlflow-minio:9000
        - name: AWS_ACCESS_KEY_ID
          valueFrom:
            secretKeyRef:
              name: mlflow-minio-secret
              key: access-key
        - name: AWS_SECRET_ACCESS_KEY
          valueFrom:
            secretKeyRef:
              name: mlflow-minio-secret
              key: secret-key
        resources:
          requests:
            cpu: 500m
            memory: 1Gi
          limits:
            cpu: 2
            memory: 4Gi
        readinessProbe:
          httpGet:
            path: /health
            port: 5000
          initialDelaySeconds: 10
          periodSeconds: 10
        livenessProbe:
          httpGet:
            path: /health
            port: 5000
          initialDelaySeconds: 30
          periodSeconds: 30
---
apiVersion: v1
kind: Service
metadata:
  name: mlflow-server
  namespace: mlflow
spec:
  ports:
  - port: 5000
    targetPort: 5000
  selector:
    app: mlflow-server
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
  name: mlflow-ingress
  namespace: mlflow
  annotations:
    nginx.ingress.kubernetes.io/proxy-body-size: "0"
spec:
  ingressClassName: nginx
  rules:
  - host: mlflow.example.com
    http:
      paths:
      - path: /
        pathType: Prefix
        backend:
          service:
            name: mlflow-server
            port:
              number: 5000

MLflow Python SDK 封装

# mlflow_manager.py
"""
MLflow 实验跟踪与模型注册管理器
提供统一的实验管理接口
"""

import os
import json
import tempfile
from typing import Dict, List, Optional, Any, Union, Callable
from dataclasses import dataclass, field
from datetime import datetime
from contextlib import contextmanager
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import ViewType
from mlflow.models.signature import ModelSignature, infer_signature
from mlflow.types.schema import Schema, ColSpec
import pandas as pd
import numpy as np


@dataclass
class ExperimentConfig:
    """实验配置"""
    name: str
    tracking_uri: str
    artifact_location: Optional[str] = None
    tags: Dict[str, str] = field(default_factory=dict)


@dataclass
class RunInfo:
    """运行信息"""
    run_id: str
    experiment_id: str
    status: str
    start_time: datetime
    end_time: Optional[datetime]
    artifact_uri: str
    metrics: Dict[str, float]
    params: Dict[str, str]
    tags: Dict[str, str]


@dataclass
class ModelVersion:
    """模型版本信息"""
    name: str
    version: str
    stage: str
    source: str
    run_id: str
    status: str
    creation_timestamp: datetime
    description: Optional[str] = None


class MLflowExperimentTracker:
    """MLflow 实验跟踪器"""

    def __init__(
        self,
        tracking_uri: str,
        default_experiment: Optional[str] = None
    ):
        self.tracking_uri = tracking_uri
        mlflow.set_tracking_uri(tracking_uri)
        self.client = MlflowClient(tracking_uri)

        if default_experiment:
            self.set_experiment(default_experiment)

    # ==================== 实验管理 ====================

    def set_experiment(
        self,
        name: str,
        artifact_location: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ) -> str:
        """设置当前实验"""
        experiment = mlflow.set_experiment(
            name,
            artifact_location=artifact_location,
            tags=tags
        )
        return experiment.experiment_id

    def get_experiment(self, name: str) -> Optional[Dict[str, Any]]:
        """获取实验信息"""
        experiment = self.client.get_experiment_by_name(name)
        if experiment:
            return {
                "experiment_id": experiment.experiment_id,
                "name": experiment.name,
                "artifact_location": experiment.artifact_location,
                "lifecycle_stage": experiment.lifecycle_stage,
                "tags": experiment.tags
            }
        return None

    def list_experiments(
        self,
        view_type: str = "ACTIVE_ONLY"
    ) -> List[Dict[str, Any]]:
        """列出实验"""
        view = getattr(ViewType, view_type)
        experiments = self.client.search_experiments(view_type=view)
        return [
            {
                "experiment_id": exp.experiment_id,
                "name": exp.name,
                "artifact_location": exp.artifact_location,
                "lifecycle_stage": exp.lifecycle_stage
            }
            for exp in experiments
        ]

    def delete_experiment(self, name: str):
        """删除实验"""
        experiment = self.client.get_experiment_by_name(name)
        if experiment:
            self.client.delete_experiment(experiment.experiment_id)

    # ==================== 运行管理 ====================

    @contextmanager
    def start_run(
        self,
        run_name: Optional[str] = None,
        experiment_name: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None,
        nested: bool = False
    ):
        """开始运行(上下文管理器)"""
        if experiment_name:
            self.set_experiment(experiment_name)

        with mlflow.start_run(
            run_name=run_name,
            tags=tags,
            nested=nested
        ) as run:
            yield RunContext(run, self)

    def get_run(self, run_id: str) -> RunInfo:
        """获取运行信息"""
        run = self.client.get_run(run_id)
        return RunInfo(
            run_id=run.info.run_id,
            experiment_id=run.info.experiment_id,
            status=run.info.status,
            start_time=datetime.fromtimestamp(run.info.start_time / 1000),
            end_time=datetime.fromtimestamp(run.info.end_time / 1000) if run.info.end_time else None,
            artifact_uri=run.info.artifact_uri,
            metrics=run.data.metrics,
            params=run.data.params,
            tags=run.data.tags
        )

    def search_runs(
        self,
        experiment_names: Optional[List[str]] = None,
        filter_string: str = "",
        max_results: int = 100,
        order_by: Optional[List[str]] = None
    ) -> List[RunInfo]:
        """搜索运行"""
        experiment_ids = None
        if experiment_names:
            experiment_ids = [
                self.client.get_experiment_by_name(name).experiment_id
                for name in experiment_names
                if self.client.get_experiment_by_name(name)
            ]

        runs = self.client.search_runs(
            experiment_ids=experiment_ids,
            filter_string=filter_string,
            max_results=max_results,
            order_by=order_by
        )

        return [
            RunInfo(
                run_id=run.info.run_id,
                experiment_id=run.info.experiment_id,
                status=run.info.status,
                start_time=datetime.fromtimestamp(run.info.start_time / 1000),
                end_time=datetime.fromtimestamp(run.info.end_time / 1000) if run.info.end_time else None,
                artifact_uri=run.info.artifact_uri,
                metrics=run.data.metrics,
                params=run.data.params,
                tags=run.data.tags
            )
            for run in runs
        ]

    def delete_run(self, run_id: str):
        """删除运行"""
        self.client.delete_run(run_id)

    # ==================== 指标与参数 ====================

    def log_param(self, key: str, value: Any):
        """记录参数"""
        mlflow.log_param(key, value)

    def log_params(self, params: Dict[str, Any]):
        """批量记录参数"""
        mlflow.log_params(params)

    def log_metric(self, key: str, value: float, step: Optional[int] = None):
        """记录指标"""
        mlflow.log_metric(key, value, step=step)

    def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
        """批量记录指标"""
        mlflow.log_metrics(metrics, step=step)

    # ==================== Artifact 管理 ====================

    def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
        """记录 artifact"""
        mlflow.log_artifact(local_path, artifact_path)

    def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None):
        """记录目录"""
        mlflow.log_artifacts(local_dir, artifact_path)

    def log_figure(self, figure: Any, artifact_file: str):
        """记录图表"""
        mlflow.log_figure(figure, artifact_file)

    def log_dict(self, dictionary: Dict[str, Any], artifact_file: str):
        """记录字典为 JSON"""
        mlflow.log_dict(dictionary, artifact_file)

    def log_dataframe(
        self,
        df: pd.DataFrame,
        artifact_file: str,
        format: str = "parquet"
    ):
        """记录 DataFrame"""
        with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as f:
            if format == "parquet":
                df.to_parquet(f.name, index=False)
            elif format == "csv":
                df.to_csv(f.name, index=False)
            mlflow.log_artifact(f.name, "data")
            os.unlink(f.name)

    def download_artifacts(self, run_id: str, path: str, dst_path: str) -> str:
        """下载 artifacts"""
        return self.client.download_artifacts(run_id, path, dst_path)


class RunContext:
    """运行上下文"""

    def __init__(self, run: mlflow.ActiveRun, tracker: MLflowExperimentTracker):
        self.run = run
        self.tracker = tracker
        self.run_id = run.info.run_id

    def log_param(self, key: str, value: Any):
        self.tracker.log_param(key, value)

    def log_params(self, params: Dict[str, Any]):
        self.tracker.log_params(params)

    def log_metric(self, key: str, value: float, step: Optional[int] = None):
        self.tracker.log_metric(key, value, step=step)

    def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
        self.tracker.log_metrics(metrics, step=step)

    def log_artifact(self, local_path: str, artifact_path: Optional[str] = None):
        self.tracker.log_artifact(local_path, artifact_path)

    def log_model(
        self,
        model: Any,
        artifact_path: str,
        flavor: str = "sklearn",
        signature: Optional[ModelSignature] = None,
        input_example: Optional[Any] = None,
        registered_model_name: Optional[str] = None
    ):
        """记录模型"""
        log_func = getattr(mlflow, flavor)
        log_func.log_model(
            model,
            artifact_path,
            signature=signature,
            input_example=input_example,
            registered_model_name=registered_model_name
        )

    def set_tag(self, key: str, value: str):
        """设置标签"""
        mlflow.set_tag(key, value)

    def set_tags(self, tags: Dict[str, str]):
        """批量设置标签"""
        mlflow.set_tags(tags)


class MLflowModelRegistry:
    """MLflow 模型注册中心"""

    def __init__(self, tracking_uri: str):
        self.tracking_uri = tracking_uri
        mlflow.set_tracking_uri(tracking_uri)
        self.client = MlflowClient(tracking_uri)

    # ==================== 模型管理 ====================

    def create_registered_model(
        self,
        name: str,
        description: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ):
        """创建注册模型"""
        self.client.create_registered_model(
            name,
            description=description,
            tags=[{"key": k, "value": v} for k, v in (tags or {}).items()]
        )

    def list_registered_models(
        self,
        filter_string: str = "",
        max_results: int = 100
    ) -> List[Dict[str, Any]]:
        """列出注册模型"""
        models = self.client.search_registered_models(
            filter_string=filter_string,
            max_results=max_results
        )
        return [
            {
                "name": model.name,
                "description": model.description,
                "creation_timestamp": model.creation_timestamp,
                "last_updated_timestamp": model.last_updated_timestamp,
                "latest_versions": [
                    {
                        "version": v.version,
                        "stage": v.current_stage,
                        "status": v.status
                    }
                    for v in model.latest_versions
                ]
            }
            for model in models
        ]

    def get_registered_model(self, name: str) -> Dict[str, Any]:
        """获取注册模型详情"""
        model = self.client.get_registered_model(name)
        return {
            "name": model.name,
            "description": model.description,
            "creation_timestamp": model.creation_timestamp,
            "last_updated_timestamp": model.last_updated_timestamp,
            "tags": model.tags,
            "latest_versions": [
                {
                    "version": v.version,
                    "stage": v.current_stage,
                    "status": v.status,
                    "run_id": v.run_id
                }
                for v in model.latest_versions
            ]
        }

    def update_registered_model(
        self,
        name: str,
        description: Optional[str] = None
    ):
        """更新注册模型"""
        self.client.update_registered_model(name, description=description)

    def delete_registered_model(self, name: str):
        """删除注册模型"""
        self.client.delete_registered_model(name)

    # ==================== 版本管理 ====================

    def create_model_version(
        self,
        name: str,
        source: str,
        run_id: Optional[str] = None,
        description: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ) -> ModelVersion:
        """创建模型版本"""
        version = self.client.create_model_version(
            name=name,
            source=source,
            run_id=run_id,
            description=description,
            tags=[{"key": k, "value": v} for k, v in (tags or {}).items()]
        )
        return ModelVersion(
            name=version.name,
            version=version.version,
            stage=version.current_stage,
            source=version.source,
            run_id=version.run_id,
            status=version.status,
            creation_timestamp=datetime.fromtimestamp(version.creation_timestamp / 1000),
            description=version.description
        )

    def get_model_version(self, name: str, version: str) -> ModelVersion:
        """获取模型版本"""
        v = self.client.get_model_version(name, version)
        return ModelVersion(
            name=v.name,
            version=v.version,
            stage=v.current_stage,
            source=v.source,
            run_id=v.run_id,
            status=v.status,
            creation_timestamp=datetime.fromtimestamp(v.creation_timestamp / 1000),
            description=v.description
        )

    def search_model_versions(
        self,
        filter_string: str = "",
        max_results: int = 100
    ) -> List[ModelVersion]:
        """搜索模型版本"""
        versions = self.client.search_model_versions(
            filter_string=filter_string,
            max_results=max_results
        )
        return [
            ModelVersion(
                name=v.name,
                version=v.version,
                stage=v.current_stage,
                source=v.source,
                run_id=v.run_id,
                status=v.status,
                creation_timestamp=datetime.fromtimestamp(v.creation_timestamp / 1000),
                description=v.description
            )
            for v in versions
        ]

    def update_model_version(
        self,
        name: str,
        version: str,
        description: Optional[str] = None
    ):
        """更新模型版本"""
        self.client.update_model_version(name, version, description=description)

    def delete_model_version(self, name: str, version: str):
        """删除模型版本"""
        self.client.delete_model_version(name, version)

    # ==================== 阶段转换 ====================

    def transition_model_version_stage(
        self,
        name: str,
        version: str,
        stage: str,
        archive_existing_versions: bool = True
    ):
        """转换模型版本阶段"""
        self.client.transition_model_version_stage(
            name=name,
            version=version,
            stage=stage,
            archive_existing_versions=archive_existing_versions
        )

    def get_latest_versions(
        self,
        name: str,
        stages: Optional[List[str]] = None
    ) -> List[ModelVersion]:
        """获取最新版本"""
        versions = self.client.get_latest_versions(name, stages=stages)
        return [
            ModelVersion(
                name=v.name,
                version=v.version,
                stage=v.current_stage,
                source=v.source,
                run_id=v.run_id,
                status=v.status,
                creation_timestamp=datetime.fromtimestamp(v.creation_timestamp / 1000),
                description=v.description
            )
            for v in versions
        ]

    # ==================== 模型加载 ====================

    def load_model(
        self,
        name: str,
        version: Optional[str] = None,
        stage: Optional[str] = None
    ) -> Any:
        """加载模型"""
        if version:
            model_uri = f"models:/{name}/{version}"
        elif stage:
            model_uri = f"models:/{name}/{stage}"
        else:
            model_uri = f"models:/{name}/latest"

        return mlflow.pyfunc.load_model(model_uri)


class MLflowAutolog:
    """MLflow 自动日志"""

    @staticmethod
    def enable_sklearn():
        """启用 sklearn 自动日志"""
        mlflow.sklearn.autolog()

    @staticmethod
    def enable_pytorch():
        """启用 PyTorch 自动日志"""
        mlflow.pytorch.autolog()

    @staticmethod
    def enable_tensorflow():
        """启用 TensorFlow 自动日志"""
        mlflow.tensorflow.autolog()

    @staticmethod
    def enable_xgboost():
        """启用 XGBoost 自动日志"""
        mlflow.xgboost.autolog()

    @staticmethod
    def enable_lightgbm():
        """启用 LightGBM 自动日志"""
        mlflow.lightgbm.autolog()

    @staticmethod
    def enable_all():
        """启用所有自动日志"""
        mlflow.autolog()


# ==================== 使用示例 ====================

def training_example():
    """训练示例"""
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.metrics import accuracy_score, f1_score

    # 初始化跟踪器
    tracker = MLflowExperimentTracker(
        tracking_uri="http://mlflow-server:5000",
        default_experiment="iris-classification"
    )

    # 加载数据
    iris = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(
        iris.data, iris.target, test_size=0.2, random_state=42
    )

    # 训练参数
    params = {
        "n_estimators": 100,
        "max_depth": 5,
        "random_state": 42
    }

    # 启动运行
    with tracker.start_run(run_name="rf-baseline") as run:
        # 记录参数
        run.log_params(params)
        run.set_tag("model_type", "random_forest")

        # 训练模型
        model = RandomForestClassifier(**params)
        model.fit(X_train, y_train)

        # 预测和评估
        y_pred = model.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred, average="weighted")

        # 记录指标
        run.log_metrics({
            "accuracy": accuracy,
            "f1_score": f1
        })

        # 推断签名
        signature = infer_signature(X_train, model.predict(X_train))

        # 记录模型
        run.log_model(
            model,
            "model",
            flavor="sklearn",
            signature=signature,
            input_example=X_train[:5],
            registered_model_name="iris-classifier"
        )

        print(f"Run ID: {run.run_id}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"F1 Score: {f1:.4f}")


def model_registry_example():
    """模型注册示例"""
    registry = MLflowModelRegistry("http://mlflow-server:5000")

    # 获取模型信息
    model_info = registry.get_registered_model("iris-classifier")
    print(f"Model: {model_info['name']}")
    print(f"Latest versions: {model_info['latest_versions']}")

    # 获取最新 Staging 版本
    staging_versions = registry.get_latest_versions(
        "iris-classifier",
        stages=["Staging"]
    )

    if staging_versions:
        version = staging_versions[0]
        print(f"Staging version: {version.version}")

        # 推广到 Production
        registry.transition_model_version_stage(
            "iris-classifier",
            version.version,
            "Production",
            archive_existing_versions=True
        )
        print(f"Promoted version {version.version} to Production")

    # 加载 Production 模型
    model = registry.load_model("iris-classifier", stage="Production")
    print(f"Loaded production model")


if __name__ == "__main__":
    training_example()
    model_registry_example()

自建模型注册中心

模型注册中心架构

┌─────────────────────────────────────────────────────────────────────┐
│                    自建模型注册中心架构                              │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                       API Gateway                            │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ 认证     │  │ 限流     │  │ 路由     │  │ 日志     │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                    Model Registry Service                    │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ 模型管理  │  │ 版本管理  │  │ 阶段管理  │  │ 权限管理  │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ 元数据    │  │ 血缘追踪  │  │ 审批流程  │  │ 通知服务  │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌───────────────────┐  ┌───────────────────┐  ┌──────────────┐   │
│  │   Metadata DB     │  │  Artifact Store   │  │  Event Bus   │   │
│  │   (PostgreSQL)    │  │  (S3/MinIO)       │  │  (Kafka)     │   │
│  └───────────────────┘  └───────────────────┘  └──────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

模型注册中心实现

# model_registry.py
"""
自建模型注册中心
支持模型版本管理、审批流程、血缘追踪
"""

import os
import hashlib
import json
from typing import Dict, List, Optional, Any, BinaryIO
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import uuid
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Text, ForeignKey, Enum as SQLEnum
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.dialects.postgresql import JSONB
import boto3
from botocore.exceptions import ClientError


Base = declarative_base()


class ModelStage(str, Enum):
    """模型阶段"""
    NONE = "None"
    STAGING = "Staging"
    PRODUCTION = "Production"
    ARCHIVED = "Archived"


class ApprovalStatus(str, Enum):
    """审批状态"""
    PENDING = "Pending"
    APPROVED = "Approved"
    REJECTED = "Rejected"


# ==================== 数据模型 ====================

class RegisteredModel(Base):
    """注册模型表"""
    __tablename__ = "registered_models"

    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    name = Column(String(255), unique=True, nullable=False)
    description = Column(Text)
    owner = Column(String(255))
    tags = Column(JSONB, default={})
    created_at = Column(DateTime, default=datetime.utcnow)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    versions = relationship("ModelVersion", back_populates="model", cascade="all, delete-orphan")


class ModelVersion(Base):
    """模型版本表"""
    __tablename__ = "model_versions"

    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    model_id = Column(String(36), ForeignKey("registered_models.id"), nullable=False)
    version = Column(Integer, nullable=False)
    stage = Column(SQLEnum(ModelStage), default=ModelStage.NONE)
    description = Column(Text)
    source_uri = Column(String(1024), nullable=False)
    artifact_uri = Column(String(1024))
    run_id = Column(String(255))
    experiment_id = Column(String(255))
    framework = Column(String(100))
    framework_version = Column(String(50))
    signature = Column(JSONB)
    input_example = Column(JSONB)
    metrics = Column(JSONB, default={})
    params = Column(JSONB, default={})
    tags = Column(JSONB, default={})
    created_at = Column(DateTime, default=datetime.utcnow)
    created_by = Column(String(255))

    model = relationship("RegisteredModel", back_populates="versions")
    approvals = relationship("ModelApproval", back_populates="model_version", cascade="all, delete-orphan")
    lineage = relationship("ModelLineage", back_populates="model_version", cascade="all, delete-orphan")


class ModelApproval(Base):
    """模型审批表"""
    __tablename__ = "model_approvals"

    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    version_id = Column(String(36), ForeignKey("model_versions.id"), nullable=False)
    from_stage = Column(SQLEnum(ModelStage), nullable=False)
    to_stage = Column(SQLEnum(ModelStage), nullable=False)
    status = Column(SQLEnum(ApprovalStatus), default=ApprovalStatus.PENDING)
    requester = Column(String(255), nullable=False)
    approver = Column(String(255))
    comment = Column(Text)
    requested_at = Column(DateTime, default=datetime.utcnow)
    resolved_at = Column(DateTime)

    model_version = relationship("ModelVersion", back_populates="approvals")


class ModelLineage(Base):
    """模型血缘表"""
    __tablename__ = "model_lineage"

    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    version_id = Column(String(36), ForeignKey("model_versions.id"), nullable=False)
    lineage_type = Column(String(50))  # dataset, parent_model, code
    source_uri = Column(String(1024))
    source_version = Column(String(255))
    metadata = Column(JSONB, default={})
    created_at = Column(DateTime, default=datetime.utcnow)

    model_version = relationship("ModelVersion", back_populates="lineage")


# ==================== 存储服务 ====================

class ArtifactStore:
    """Artifact 存储服务"""

    def __init__(
        self,
        endpoint_url: str,
        bucket: str,
        access_key: str,
        secret_key: str
    ):
        self.bucket = bucket
        self.client = boto3.client(
            "s3",
            endpoint_url=endpoint_url,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key
        )
        self._ensure_bucket()

    def _ensure_bucket(self):
        """确保 bucket 存在"""
        try:
            self.client.head_bucket(Bucket=self.bucket)
        except ClientError:
            self.client.create_bucket(Bucket=self.bucket)

    def upload_model(
        self,
        model_name: str,
        version: int,
        file_obj: BinaryIO,
        filename: str
    ) -> str:
        """上传模型文件"""
        key = f"models/{model_name}/{version}/{filename}"
        self.client.upload_fileobj(file_obj, self.bucket, key)
        return f"s3://{self.bucket}/{key}"

    def download_model(self, artifact_uri: str, local_path: str):
        """下载模型文件"""
        # 解析 S3 URI
        parts = artifact_uri.replace("s3://", "").split("/", 1)
        bucket = parts[0]
        key = parts[1]

        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        self.client.download_file(bucket, key, local_path)

    def delete_model(self, artifact_uri: str):
        """删除模型文件"""
        parts = artifact_uri.replace("s3://", "").split("/", 1)
        bucket = parts[0]
        key = parts[1]

        # 删除前缀下所有对象
        paginator = self.client.get_paginator("list_objects_v2")
        for page in paginator.paginate(Bucket=bucket, Prefix=key):
            for obj in page.get("Contents", []):
                self.client.delete_object(Bucket=bucket, Key=obj["Key"])


# ==================== 注册中心服务 ====================

@dataclass
class ModelVersionInfo:
    """模型版本信息"""
    id: str
    model_name: str
    version: int
    stage: ModelStage
    description: Optional[str]
    artifact_uri: str
    framework: Optional[str]
    metrics: Dict[str, float]
    params: Dict[str, Any]
    tags: Dict[str, str]
    created_at: datetime
    created_by: str


class ModelRegistryService:
    """模型注册中心服务"""

    def __init__(
        self,
        database_url: str,
        artifact_store: ArtifactStore
    ):
        self.engine = create_engine(database_url)
        Base.metadata.create_all(self.engine)
        self.Session = sessionmaker(bind=self.engine)
        self.artifact_store = artifact_store

    # ==================== 模型管理 ====================

    def create_model(
        self,
        name: str,
        description: Optional[str] = None,
        owner: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ) -> str:
        """创建注册模型"""
        session = self.Session()
        try:
            model = RegisteredModel(
                name=name,
                description=description,
                owner=owner,
                tags=tags or {}
            )
            session.add(model)
            session.commit()
            return model.id
        finally:
            session.close()

    def get_model(self, name: str) -> Optional[Dict[str, Any]]:
        """获取模型信息"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=name).first()
            if not model:
                return None
            return {
                "id": model.id,
                "name": model.name,
                "description": model.description,
                "owner": model.owner,
                "tags": model.tags,
                "created_at": model.created_at,
                "updated_at": model.updated_at,
                "version_count": len(model.versions)
            }
        finally:
            session.close()

    def list_models(
        self,
        filter_tags: Optional[Dict[str, str]] = None,
        limit: int = 100,
        offset: int = 0
    ) -> List[Dict[str, Any]]:
        """列出模型"""
        session = self.Session()
        try:
            query = session.query(RegisteredModel)

            if filter_tags:
                for key, value in filter_tags.items():
                    query = query.filter(
                        RegisteredModel.tags[key].astext == value
                    )

            models = query.offset(offset).limit(limit).all()
            return [
                {
                    "id": m.id,
                    "name": m.name,
                    "description": m.description,
                    "owner": m.owner,
                    "tags": m.tags,
                    "created_at": m.created_at,
                    "version_count": len(m.versions)
                }
                for m in models
            ]
        finally:
            session.close()

    def update_model(
        self,
        name: str,
        description: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ):
        """更新模型"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=name).first()
            if not model:
                raise ValueError(f"Model not found: {name}")

            if description is not None:
                model.description = description
            if tags is not None:
                model.tags = {**model.tags, **tags}

            session.commit()
        finally:
            session.close()

    def delete_model(self, name: str):
        """删除模型"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=name).first()
            if not model:
                raise ValueError(f"Model not found: {name}")

            # 删除所有版本的 artifacts
            for version in model.versions:
                if version.artifact_uri:
                    self.artifact_store.delete_model(version.artifact_uri)

            session.delete(model)
            session.commit()
        finally:
            session.close()

    # ==================== 版本管理 ====================

    def create_version(
        self,
        model_name: str,
        source_uri: str,
        model_file: Optional[BinaryIO] = None,
        filename: str = "model.pkl",
        description: Optional[str] = None,
        run_id: Optional[str] = None,
        experiment_id: Optional[str] = None,
        framework: Optional[str] = None,
        framework_version: Optional[str] = None,
        signature: Optional[Dict[str, Any]] = None,
        metrics: Optional[Dict[str, float]] = None,
        params: Optional[Dict[str, Any]] = None,
        tags: Optional[Dict[str, str]] = None,
        created_by: str = "system"
    ) -> ModelVersionInfo:
        """创建模型版本"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                raise ValueError(f"Model not found: {model_name}")

            # 获取下一个版本号
            max_version = session.query(ModelVersion).filter_by(
                model_id=model.id
            ).count()
            new_version = max_version + 1

            # 上传模型文件
            artifact_uri = None
            if model_file:
                artifact_uri = self.artifact_store.upload_model(
                    model_name, new_version, model_file, filename
                )

            # 创建版本记录
            version = ModelVersion(
                model_id=model.id,
                version=new_version,
                stage=ModelStage.NONE,
                description=description,
                source_uri=source_uri,
                artifact_uri=artifact_uri,
                run_id=run_id,
                experiment_id=experiment_id,
                framework=framework,
                framework_version=framework_version,
                signature=signature,
                metrics=metrics or {},
                params=params or {},
                tags=tags or {},
                created_by=created_by
            )
            session.add(version)
            session.commit()

            return ModelVersionInfo(
                id=version.id,
                model_name=model_name,
                version=version.version,
                stage=version.stage,
                description=version.description,
                artifact_uri=version.artifact_uri or "",
                framework=version.framework,
                metrics=version.metrics,
                params=version.params,
                tags=version.tags,
                created_at=version.created_at,
                created_by=version.created_by
            )
        finally:
            session.close()

    def get_version(self, model_name: str, version: int) -> Optional[ModelVersionInfo]:
        """获取模型版本"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                return None

            v = session.query(ModelVersion).filter_by(
                model_id=model.id,
                version=version
            ).first()

            if not v:
                return None

            return ModelVersionInfo(
                id=v.id,
                model_name=model_name,
                version=v.version,
                stage=v.stage,
                description=v.description,
                artifact_uri=v.artifact_uri or "",
                framework=v.framework,
                metrics=v.metrics,
                params=v.params,
                tags=v.tags,
                created_at=v.created_at,
                created_by=v.created_by
            )
        finally:
            session.close()

    def get_latest_version(
        self,
        model_name: str,
        stage: Optional[ModelStage] = None
    ) -> Optional[ModelVersionInfo]:
        """获取最新版本"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                return None

            query = session.query(ModelVersion).filter_by(model_id=model.id)
            if stage:
                query = query.filter_by(stage=stage)

            v = query.order_by(ModelVersion.version.desc()).first()
            if not v:
                return None

            return ModelVersionInfo(
                id=v.id,
                model_name=model_name,
                version=v.version,
                stage=v.stage,
                description=v.description,
                artifact_uri=v.artifact_uri or "",
                framework=v.framework,
                metrics=v.metrics,
                params=v.params,
                tags=v.tags,
                created_at=v.created_at,
                created_by=v.created_by
            )
        finally:
            session.close()

    def list_versions(
        self,
        model_name: str,
        stage: Optional[ModelStage] = None
    ) -> List[ModelVersionInfo]:
        """列出模型版本"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                return []

            query = session.query(ModelVersion).filter_by(model_id=model.id)
            if stage:
                query = query.filter_by(stage=stage)

            versions = query.order_by(ModelVersion.version.desc()).all()
            return [
                ModelVersionInfo(
                    id=v.id,
                    model_name=model_name,
                    version=v.version,
                    stage=v.stage,
                    description=v.description,
                    artifact_uri=v.artifact_uri or "",
                    framework=v.framework,
                    metrics=v.metrics,
                    params=v.params,
                    tags=v.tags,
                    created_at=v.created_at,
                    created_by=v.created_by
                )
                for v in versions
            ]
        finally:
            session.close()

    # ==================== 阶段转换 ====================

    def transition_stage(
        self,
        model_name: str,
        version: int,
        to_stage: ModelStage,
        requester: str,
        require_approval: bool = True
    ) -> Optional[str]:
        """转换模型阶段"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                raise ValueError(f"Model not found: {model_name}")

            v = session.query(ModelVersion).filter_by(
                model_id=model.id,
                version=version
            ).first()

            if not v:
                raise ValueError(f"Version not found: {model_name}/{version}")

            from_stage = v.stage

            if require_approval and to_stage == ModelStage.PRODUCTION:
                # 需要审批
                approval = ModelApproval(
                    version_id=v.id,
                    from_stage=from_stage,
                    to_stage=to_stage,
                    requester=requester
                )
                session.add(approval)
                session.commit()
                return approval.id
            else:
                # 直接转换
                v.stage = to_stage

                # 归档旧的 Production 版本
                if to_stage == ModelStage.PRODUCTION:
                    old_prod = session.query(ModelVersion).filter_by(
                        model_id=model.id,
                        stage=ModelStage.PRODUCTION
                    ).filter(ModelVersion.id != v.id).all()

                    for old_v in old_prod:
                        old_v.stage = ModelStage.ARCHIVED

                session.commit()
                return None
        finally:
            session.close()

    def approve_transition(
        self,
        approval_id: str,
        approver: str,
        approved: bool,
        comment: Optional[str] = None
    ):
        """审批阶段转换"""
        session = self.Session()
        try:
            approval = session.query(ModelApproval).filter_by(id=approval_id).first()
            if not approval:
                raise ValueError(f"Approval not found: {approval_id}")

            if approval.status != ApprovalStatus.PENDING:
                raise ValueError(f"Approval already resolved")

            approval.status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED
            approval.approver = approver
            approval.comment = comment
            approval.resolved_at = datetime.utcnow()

            if approved:
                # 执行阶段转换
                version = approval.model_version
                version.stage = approval.to_stage

                # 归档旧的 Production 版本
                if approval.to_stage == ModelStage.PRODUCTION:
                    old_prod = session.query(ModelVersion).filter_by(
                        model_id=version.model_id,
                        stage=ModelStage.PRODUCTION
                    ).filter(ModelVersion.id != version.id).all()

                    for old_v in old_prod:
                        old_v.stage = ModelStage.ARCHIVED

            session.commit()
        finally:
            session.close()

    # ==================== 血缘追踪 ====================

    def add_lineage(
        self,
        model_name: str,
        version: int,
        lineage_type: str,
        source_uri: str,
        source_version: Optional[str] = None,
        metadata: Optional[Dict[str, Any]] = None
    ):
        """添加血缘信息"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                raise ValueError(f"Model not found: {model_name}")

            v = session.query(ModelVersion).filter_by(
                model_id=model.id,
                version=version
            ).first()

            if not v:
                raise ValueError(f"Version not found: {model_name}/{version}")

            lineage = ModelLineage(
                version_id=v.id,
                lineage_type=lineage_type,
                source_uri=source_uri,
                source_version=source_version,
                metadata=metadata or {}
            )
            session.add(lineage)
            session.commit()
        finally:
            session.close()

    def get_lineage(
        self,
        model_name: str,
        version: int
    ) -> List[Dict[str, Any]]:
        """获取血缘信息"""
        session = self.Session()
        try:
            model = session.query(RegisteredModel).filter_by(name=model_name).first()
            if not model:
                return []

            v = session.query(ModelVersion).filter_by(
                model_id=model.id,
                version=version
            ).first()

            if not v:
                return []

            return [
                {
                    "type": l.lineage_type,
                    "source_uri": l.source_uri,
                    "source_version": l.source_version,
                    "metadata": l.metadata,
                    "created_at": l.created_at
                }
                for l in v.lineage
            ]
        finally:
            session.close()


# ==================== 使用示例 ====================

if __name__ == "__main__":
    # 初始化存储
    artifact_store = ArtifactStore(
        endpoint_url="http://minio:9000",
        bucket="model-registry",
        access_key="admin",
        secret_key="password"
    )

    # 初始化服务
    registry = ModelRegistryService(
        database_url="postgresql://user:pass@postgres:5432/model_registry",
        artifact_store=artifact_store
    )

    # 创建模型
    registry.create_model(
        name="fraud-detector",
        description="Fraud detection model",
        owner="ml-team",
        tags={"domain": "finance", "type": "classification"}
    )

    # 创建版本
    version = registry.create_version(
        model_name="fraud-detector",
        source_uri="mlflow://experiment/run123",
        description="Baseline model with XGBoost",
        framework="xgboost",
        framework_version="1.7.0",
        metrics={"auc": 0.95, "precision": 0.88},
        params={"max_depth": 6, "learning_rate": 0.1},
        created_by="alice"
    )
    print(f"Created version: {version.version}")

    # 添加血缘
    registry.add_lineage(
        "fraud-detector",
        version.version,
        lineage_type="dataset",
        source_uri="s3://data/fraud-dataset-v2",
        source_version="2024-01-01",
        metadata={"rows": 1000000, "features": 50}
    )

    # 请求推广到 Production
    approval_id = registry.transition_stage(
        "fraud-detector",
        version.version,
        ModelStage.PRODUCTION,
        requester="alice",
        require_approval=True
    )
    print(f"Approval requested: {approval_id}")

    # 审批
    if approval_id:
        registry.approve_transition(
            approval_id,
            approver="bob",
            approved=True,
            comment="Metrics look good"
        )
        print("Approved and promoted to Production")

    # 获取 Production 版本
    prod_version = registry.get_latest_version("fraud-detector", ModelStage.PRODUCTION)
    print(f"Production version: {prod_version.version}")

总结

实验跟踪和模型注册是 MLOps 的核心组件:

  1. 实验跟踪

    • 参数、指标、Artifact 完整记录
    • 支持实验对比和可视化
    • 确保实验可复现性
  2. 模型注册

    • 统一的模型版本管理
    • 阶段转换(Staging → Production)
    • 审批流程和血缘追踪
  3. 工具选择

    • MLflow:开源、功能全面、社区活跃
    • W&B:可视化强大、协作功能丰富
    • 自建:定制化需求、企业合规要求
  4. 最佳实践

    • 自动化实验记录(autolog)
    • 标准化模型打包格式
    • 完善的审批和治理流程
    • 完整的血缘追踪

下一章节将进入 MLOps 实践,探讨如何将这些组件整合成完整的机器学习平台。

Prev
04-数据版本管理