HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

04-数据版本管理

概述

在机器学习项目中,数据版本管理是确保实验可复现性的关键环节。本文深入探讨 DVC、LakeFS、Delta Lake 等数据版本管理工具的原理与实践,以及如何构建端到端的数据版本管理系统。

数据版本管理的挑战

为什么需要数据版本管理

┌─────────────────────────────────────────────────────────────────────┐
│                    ML 项目数据管理挑战                               │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  ┌────────────────┐     ┌────────────────┐     ┌────────────────┐  │
│  │   数据体量大    │     │   版本追踪难    │     │   协作困难      │  │
│  │                │     │                │     │                │  │
│  │ • TB/PB 级数据 │     │ • 数据漂移     │     │ • 多人修改     │  │
│  │ • 存储成本高   │     │ • 血缘不清     │     │ • 冲突解决     │  │
│  │ • 传输耗时     │     │ • 难以回溯     │     │ • 权限控制     │  │
│  └────────────────┘     └────────────────┘     └────────────────┘  │
│           │                     │                     │            │
│           └─────────────────────┼─────────────────────┘            │
│                                 ▼                                   │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                   数据版本管理系统                            │   │
│  │                                                             │   │
│  │  • 增量存储:只存储变化部分,节省空间                         │   │
│  │  • 版本追踪:完整记录数据变更历史                             │   │
│  │  • 分支合并:支持并行开发和实验                               │   │
│  │  • 血缘追踪:数据来源和转换链路                               │   │
│  │  • 访问控制:细粒度权限管理                                   │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

数据版本管理工具对比

特性DVCLakeFSDelta LakePachyderm
架构模式Git扩展Git-like API表格式存储数据管道
存储后端S3/GCS/AzureS3 兼容S3/HDFSS3/GCS
元数据存储文件系统PostgreSQL事务日志etcd
分支支持✓✓✓✓
ACID 事务✗✓✓✓
数据血缘基础✓基础✓
学习曲线低中中高
适用场景ML项目数据湖数据仓库复杂管道

DVC 深度实践

DVC 架构与原理

┌─────────────────────────────────────────────────────────────────────┐
│                         DVC 架构                                    │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      Git Repository                          │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ .dvc     │  │ dvc.yaml │  │ dvc.lock │  │ params   │    │   │
│  │  │ files    │  │ pipeline │  │ state    │  │ .yaml    │    │   │
│  │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────┬─────┘    │   │
│  └───────┼─────────────┼─────────────┼─────────────┼───────────┘   │
│          │             │             │             │                │
│          ▼             ▼             ▼             ▼                │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                     DVC Core                                 │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ Hash     │  │ Pipeline │  │ Cache    │  │ Remote   │    │   │
│  │  │ Manager  │  │ Engine   │  │ Manager  │  │ Storage  │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                   Remote Storage                             │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │   S3     │  │   GCS    │  │  Azure   │  │  HDFS    │    │   │
│  │  │          │  │          │  │  Blob    │  │          │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

DVC 项目配置

# .dvc/config - DVC 配置文件
[core]
    remote = s3remote
    autostage = true

[remote "s3remote"]
    url = s3://ml-data-bucket/dvc-store
    region = us-east-1
    access_key_id = ${AWS_ACCESS_KEY_ID}
    secret_access_key = ${AWS_SECRET_ACCESS_KEY}

[remote "local"]
    url = /data/dvc-cache

# 缓存配置
[cache]
    type = symlink
    local = /data/dvc-local-cache

# 性能配置
[dvc]
    jobs = 8
# dvc.yaml - DVC Pipeline 定义
stages:
  # 数据获取阶段
  fetch_data:
    cmd: python src/data/fetch.py --config configs/data.yaml
    deps:
      - src/data/fetch.py
      - configs/data.yaml
    outs:
      - data/raw/:
          persist: true

  # 数据预处理
  preprocess:
    cmd: python src/data/preprocess.py
    deps:
      - src/data/preprocess.py
      - data/raw/
    params:
      - preprocess.train_ratio
      - preprocess.val_ratio
      - preprocess.seed
    outs:
      - data/processed/train/
      - data/processed/val/
      - data/processed/test/
    plots:
      - reports/data_stats.json:
          x: category
          y: count

  # 特征工程
  feature_engineering:
    cmd: python src/features/build.py
    deps:
      - src/features/build.py
      - data/processed/
    params:
      - features
    outs:
      - data/features/
    metrics:
      - reports/feature_stats.json:
          cache: false

  # 模型训练
  train:
    cmd: python src/models/train.py
    deps:
      - src/models/train.py
      - data/features/
    params:
      - model.type
      - model.hyperparameters
      - train.epochs
      - train.batch_size
    outs:
      - models/
    metrics:
      - reports/metrics.json:
          cache: false
    plots:
      - reports/loss_curve.csv:
          x: epoch
          y: loss

  # 模型评估
  evaluate:
    cmd: python src/models/evaluate.py
    deps:
      - src/models/evaluate.py
      - models/
      - data/processed/test/
    metrics:
      - reports/evaluation.json:
          cache: false
    plots:
      - reports/confusion_matrix.png:
          persist: true
# params.yaml - 参数配置
preprocess:
  train_ratio: 0.7
  val_ratio: 0.15
  seed: 42

features:
  numeric:
    - age
    - income
    - score
  categorical:
    - category
    - region
  text:
    - description
  transformations:
    - name: StandardScaler
      columns: [age, income, score]
    - name: OneHotEncoder
      columns: [category, region]

model:
  type: xgboost
  hyperparameters:
    max_depth: 6
    learning_rate: 0.1
    n_estimators: 100
    objective: binary:logistic

train:
  epochs: 100
  batch_size: 256
  early_stopping: 10

DVC Python API

# dvc_manager.py
"""
DVC 数据版本管理器
提供 Python API 进行数据版本控制
"""

import os
import json
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass, field
from datetime import datetime
import subprocess
import yaml
import shutil


@dataclass
class DVCFile:
    """DVC 文件信息"""
    path: str
    md5: str
    size: int
    nfiles: Optional[int] = None
    remote: Optional[str] = None


@dataclass
class DVCStage:
    """DVC Pipeline 阶段"""
    name: str
    cmd: str
    deps: List[str] = field(default_factory=list)
    outs: List[str] = field(default_factory=list)
    params: List[str] = field(default_factory=list)
    metrics: List[str] = field(default_factory=list)
    plots: List[str] = field(default_factory=list)


class DVCManager:
    """DVC 管理器"""

    def __init__(self, repo_path: str = "."):
        self.repo_path = Path(repo_path).resolve()
        self.dvc_dir = self.repo_path / ".dvc"
        self._ensure_initialized()

    def _ensure_initialized(self):
        """确保 DVC 已初始化"""
        if not self.dvc_dir.exists():
            self._run_dvc("init")

    def _run_dvc(self, *args, capture_output=True) -> subprocess.CompletedProcess:
        """运行 DVC 命令"""
        cmd = ["dvc"] + list(args)
        result = subprocess.run(
            cmd,
            cwd=self.repo_path,
            capture_output=capture_output,
            text=True
        )
        if result.returncode != 0:
            raise RuntimeError(f"DVC command failed: {result.stderr}")
        return result

    def _run_git(self, *args) -> subprocess.CompletedProcess:
        """运行 Git 命令"""
        cmd = ["git"] + list(args)
        return subprocess.run(
            cmd,
            cwd=self.repo_path,
            capture_output=True,
            text=True
        )

    # ==================== Remote 管理 ====================

    def add_remote(
        self,
        name: str,
        url: str,
        default: bool = False,
        **kwargs
    ):
        """添加远程存储"""
        args = ["remote", "add"]
        if default:
            args.append("-d")
        args.extend([name, url])

        self._run_dvc(*args)

        # 设置额外配置
        for key, value in kwargs.items():
            self._run_dvc("remote", "modify", name, key, str(value))

    def list_remotes(self) -> Dict[str, str]:
        """列出所有远程存储"""
        result = self._run_dvc("remote", "list")
        remotes = {}
        for line in result.stdout.strip().split("\n"):
            if line:
                parts = line.split("\t")
                if len(parts) == 2:
                    remotes[parts[0]] = parts[1]
        return remotes

    # ==================== 数据追踪 ====================

    def add(self, path: Union[str, Path], **kwargs) -> DVCFile:
        """添加文件/目录到 DVC 追踪"""
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Path not found: {path}")

        args = ["add", str(path)]

        # 处理选项
        if kwargs.get("external"):
            args.append("--external")
        if kwargs.get("no_commit"):
            args.append("--no-commit")

        self._run_dvc(*args)

        # 读取生成的 .dvc 文件
        dvc_file = path.with_suffix(path.suffix + ".dvc")
        return self._parse_dvc_file(dvc_file)

    def _parse_dvc_file(self, dvc_file: Path) -> DVCFile:
        """解析 .dvc 文件"""
        with open(dvc_file) as f:
            data = yaml.safe_load(f)

        out = data.get("outs", [{}])[0]
        return DVCFile(
            path=out.get("path", ""),
            md5=out.get("md5", ""),
            size=out.get("size", 0),
            nfiles=out.get("nfiles")
        )

    def remove(self, path: Union[str, Path], keep_in_cache: bool = True):
        """从 DVC 追踪中移除文件"""
        args = ["remove", str(path)]
        if not keep_in_cache:
            args.append("--outs")

        self._run_dvc(*args)

    # ==================== Push/Pull ====================

    def push(
        self,
        targets: Optional[List[str]] = None,
        remote: Optional[str] = None,
        jobs: int = 4
    ):
        """推送数据到远程存储"""
        args = ["push", "-j", str(jobs)]

        if remote:
            args.extend(["-r", remote])

        if targets:
            args.extend(targets)

        self._run_dvc(*args)

    def pull(
        self,
        targets: Optional[List[str]] = None,
        remote: Optional[str] = None,
        jobs: int = 4
    ):
        """从远程存储拉取数据"""
        args = ["pull", "-j", str(jobs)]

        if remote:
            args.extend(["-r", remote])

        if targets:
            args.extend(targets)

        self._run_dvc(*args)

    def fetch(
        self,
        targets: Optional[List[str]] = None,
        remote: Optional[str] = None,
        jobs: int = 4
    ):
        """仅获取数据到缓存(不 checkout)"""
        args = ["fetch", "-j", str(jobs)]

        if remote:
            args.extend(["-r", remote])

        if targets:
            args.extend(targets)

        self._run_dvc(*args)

    # ==================== Pipeline 管理 ====================

    def run_pipeline(
        self,
        stages: Optional[List[str]] = None,
        force: bool = False,
        single_item: bool = False
    ):
        """运行 DVC Pipeline"""
        args = ["repro"]

        if force:
            args.append("-f")
        if single_item:
            args.append("-s")

        if stages:
            args.extend(stages)

        self._run_dvc(*args, capture_output=False)

    def get_pipeline_status(self) -> Dict[str, str]:
        """获取 Pipeline 状态"""
        result = self._run_dvc("status")
        # 解析状态输出
        status = {}
        current_stage = None

        for line in result.stdout.split("\n"):
            if line.endswith(":"):
                current_stage = line[:-1]
                status[current_stage] = "modified"
            elif "changed" in line.lower():
                if current_stage:
                    status[current_stage] = "changed"

        return status

    def dag(self, output_format: str = "ascii") -> str:
        """获取 Pipeline DAG"""
        args = ["dag"]
        if output_format == "dot":
            args.append("--dot")
        elif output_format == "mermaid":
            args.append("--mermaid")

        result = self._run_dvc(*args)
        return result.stdout

    # ==================== 指标与参数 ====================

    def get_metrics(self, all_branches: bool = False) -> Dict[str, Any]:
        """获取实验指标"""
        args = ["metrics", "show", "--json"]

        if all_branches:
            args.append("-A")

        result = self._run_dvc(*args)
        return json.loads(result.stdout)

    def diff_metrics(
        self,
        a_rev: str = "HEAD",
        b_rev: Optional[str] = None
    ) -> Dict[str, Any]:
        """比较指标差异"""
        args = ["metrics", "diff", "--json", a_rev]
        if b_rev:
            args.append(b_rev)

        result = self._run_dvc(*args)
        return json.loads(result.stdout)

    def get_params(self) -> Dict[str, Any]:
        """获取当前参数"""
        args = ["params", "show", "--json"]
        result = self._run_dvc(*args)
        return json.loads(result.stdout)

    def diff_params(
        self,
        a_rev: str = "HEAD",
        b_rev: Optional[str] = None
    ) -> Dict[str, Any]:
        """比较参数差异"""
        args = ["params", "diff", "--json", a_rev]
        if b_rev:
            args.append(b_rev)

        result = self._run_dvc(*args)
        return json.loads(result.stdout)

    # ==================== 实验管理 ====================

    def exp_run(
        self,
        params: Optional[Dict[str, Any]] = None,
        name: Optional[str] = None,
        queue: bool = False
    ) -> str:
        """运行实验"""
        args = ["exp", "run"]

        if name:
            args.extend(["-n", name])

        if queue:
            args.append("--queue")

        if params:
            for key, value in params.items():
                args.extend(["-S", f"{key}={value}"])

        result = self._run_dvc(*args, capture_output=False)
        return result.stdout

    def exp_list(self, all_commits: bool = False) -> List[Dict[str, Any]]:
        """列出实验"""
        args = ["exp", "show", "--json"]

        if all_commits:
            args.append("-A")

        result = self._run_dvc(*args)
        return json.loads(result.stdout)

    def exp_apply(self, exp_name: str):
        """应用实验"""
        self._run_dvc("exp", "apply", exp_name)

    def exp_branch(self, exp_name: str, branch_name: str):
        """从实验创建分支"""
        self._run_dvc("exp", "branch", exp_name, branch_name)

    def exp_remove(self, exp_names: List[str]):
        """删除实验"""
        args = ["exp", "remove"] + exp_names
        self._run_dvc(*args)

    # ==================== 数据血缘 ====================

    def get_lineage(self, path: str) -> Dict[str, Any]:
        """获取数据血缘"""
        # 解析 dvc.lock 获取依赖关系
        lock_file = self.repo_path / "dvc.lock"
        if not lock_file.exists():
            return {}

        with open(lock_file) as f:
            lock_data = yaml.safe_load(f)

        lineage = {"target": path, "upstream": [], "downstream": []}

        stages = lock_data.get("stages", {})
        for stage_name, stage_data in stages.items():
            outs = [o.get("path") for o in stage_data.get("outs", [])]
            deps = [d.get("path") for d in stage_data.get("deps", [])]

            if path in outs:
                lineage["upstream"].extend([
                    {"stage": stage_name, "deps": deps}
                ])

            if path in deps:
                lineage["downstream"].extend([
                    {"stage": stage_name, "outs": outs}
                ])

        return lineage


class DataVersionSnapshot:
    """数据版本快照管理"""

    def __init__(self, dvc_manager: DVCManager):
        self.dvc = dvc_manager
        self.snapshots_dir = dvc_manager.repo_path / ".dvc" / "snapshots"
        self.snapshots_dir.mkdir(exist_ok=True)

    def create_snapshot(
        self,
        name: str,
        description: str = "",
        paths: Optional[List[str]] = None
    ) -> str:
        """创建数据快照"""
        snapshot_id = f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        snapshot_dir = self.snapshots_dir / snapshot_id
        snapshot_dir.mkdir()

        # 收集 .dvc 文件
        dvc_files = []
        if paths:
            for path in paths:
                dvc_file = Path(path).with_suffix(Path(path).suffix + ".dvc")
                if dvc_file.exists():
                    dvc_files.append(dvc_file)
        else:
            dvc_files = list(self.dvc.repo_path.rglob("*.dvc"))

        # 复制 .dvc 文件
        for dvc_file in dvc_files:
            rel_path = dvc_file.relative_to(self.dvc.repo_path)
            dest = snapshot_dir / rel_path
            dest.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(dvc_file, dest)

        # 保存元数据
        metadata = {
            "id": snapshot_id,
            "name": name,
            "description": description,
            "created_at": datetime.now().isoformat(),
            "files": [str(f.relative_to(self.dvc.repo_path)) for f in dvc_files],
            "git_commit": self._get_git_commit()
        }

        with open(snapshot_dir / "metadata.json", "w") as f:
            json.dump(metadata, f, indent=2)

        return snapshot_id

    def restore_snapshot(self, snapshot_id: str):
        """恢复数据快照"""
        snapshot_dir = self.snapshots_dir / snapshot_id
        if not snapshot_dir.exists():
            raise ValueError(f"Snapshot not found: {snapshot_id}")

        # 恢复 .dvc 文件
        for dvc_file in snapshot_dir.rglob("*.dvc"):
            rel_path = dvc_file.relative_to(snapshot_dir)
            dest = self.dvc.repo_path / rel_path
            dest.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(dvc_file, dest)

        # 拉取数据
        self.dvc.pull()

    def list_snapshots(self) -> List[Dict[str, Any]]:
        """列出所有快照"""
        snapshots = []
        for snapshot_dir in self.snapshots_dir.iterdir():
            if snapshot_dir.is_dir():
                metadata_file = snapshot_dir / "metadata.json"
                if metadata_file.exists():
                    with open(metadata_file) as f:
                        snapshots.append(json.load(f))
        return sorted(snapshots, key=lambda x: x["created_at"], reverse=True)

    def _get_git_commit(self) -> str:
        """获取当前 Git commit"""
        result = self.dvc._run_git("rev-parse", "HEAD")
        return result.stdout.strip()


# 使用示例
if __name__ == "__main__":
    # 初始化 DVC 管理器
    dvc = DVCManager("./ml-project")

    # 添加远程存储
    dvc.add_remote(
        "s3",
        "s3://ml-bucket/dvc-store",
        default=True,
        region="us-east-1"
    )

    # 添加数据到 DVC
    dvc.add("data/raw/")
    dvc.add("models/")

    # 推送数据
    dvc.push()

    # 运行 Pipeline
    dvc.run_pipeline()

    # 查看指标
    metrics = dvc.get_metrics()
    print(f"Metrics: {json.dumps(metrics, indent=2)}")

    # 创建快照
    snapshot = DataVersionSnapshot(dvc)
    snapshot_id = snapshot.create_snapshot(
        name="baseline",
        description="Initial baseline model"
    )
    print(f"Created snapshot: {snapshot_id}")

    # 运行实验
    dvc.exp_run(
        params={
            "model.learning_rate": 0.05,
            "train.epochs": 200
        },
        name="lr_experiment"
    )

    # 获取数据血缘
    lineage = dvc.get_lineage("models/model.pkl")
    print(f"Lineage: {json.dumps(lineage, indent=2)}")

LakeFS 数据湖版本管理

LakeFS 架构

┌─────────────────────────────────────────────────────────────────────┐
│                        LakeFS 架构                                  │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      应用层                                  │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ Spark    │  │ Presto   │  │ Hive     │  │ Python   │    │   │
│  │  │          │  │          │  │          │  │ Client   │    │   │
│  │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────┬─────┘    │   │
│  └───────┼─────────────┼─────────────┼─────────────┼───────────┘   │
│          │             │             │             │                │
│          └─────────────┴──────┬──────┴─────────────┘                │
│                               ▼                                     │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                     LakeFS Server                            │   │
│  │  ┌──────────────────────────────────────────────────────┐   │   │
│  │  │                    S3 Gateway                         │   │   │
│  │  │  • S3 API 兼容        • 透明代理                      │   │   │
│  │  │  • 版本化访问         • 分支隔离                      │   │   │
│  │  └──────────────────────────────────────────────────────┘   │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ Branch   │  │ Commit   │  │ Merge    │  │ Garbage  │    │   │
│  │  │ Manager  │  │ Engine   │  │ Engine   │  │ Collector│    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                               │                                     │
│                               ▼                                     │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                     存储层                                   │   │
│  │  ┌──────────────────────┐  ┌──────────────────────┐        │   │
│  │  │   Metadata Store     │  │   Object Storage     │        │   │
│  │  │   (PostgreSQL)       │  │   (S3/GCS/Azure)     │        │   │
│  │  │                      │  │                      │        │   │
│  │  │  • 分支信息          │  │  • 实际数据文件      │        │   │
│  │  │  • 提交历史          │  │  • 增量存储          │        │   │
│  │  │  • 对象引用          │  │  • 数据去重          │        │   │
│  │  └──────────────────────┘  └──────────────────────┘        │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

LakeFS Python 客户端

# lakefs_manager.py
"""
LakeFS 数据湖版本管理器
提供分支、提交、合并等 Git-like 操作
"""

import os
from typing import Dict, List, Optional, Any, Iterator, BinaryIO
from dataclasses import dataclass
from datetime import datetime
import lakefs_client
from lakefs_client import Configuration, ApiClient
from lakefs_client.api import (
    repositories_api,
    branches_api,
    commits_api,
    objects_api,
    refs_api,
    actions_api
)
from lakefs_client.model import (
    RepositoryCreation,
    BranchCreation,
    CommitCreation,
    Merge,
    ObjectStatsList,
    RevertCreation
)
import pandas as pd
from io import BytesIO


@dataclass
class LakeFSConfig:
    """LakeFS 配置"""
    endpoint: str
    access_key_id: str
    secret_access_key: str


@dataclass
class ObjectInfo:
    """对象信息"""
    path: str
    size: int
    checksum: str
    mtime: datetime
    content_type: Optional[str] = None


@dataclass
class CommitInfo:
    """提交信息"""
    id: str
    message: str
    committer: str
    creation_date: datetime
    parents: List[str]
    metadata: Dict[str, str]


class LakeFSManager:
    """LakeFS 管理器"""

    def __init__(self, config: LakeFSConfig):
        self.config = config
        self._setup_client()

    def _setup_client(self):
        """设置 LakeFS 客户端"""
        configuration = Configuration(
            host=self.config.endpoint,
            username=self.config.access_key_id,
            password=self.config.secret_access_key
        )

        self.api_client = ApiClient(configuration)
        self.repos_api = repositories_api.RepositoriesApi(self.api_client)
        self.branches_api = branches_api.BranchesApi(self.api_client)
        self.commits_api = commits_api.CommitsApi(self.api_client)
        self.objects_api = objects_api.ObjectsApi(self.api_client)
        self.refs_api = refs_api.RefsApi(self.api_client)
        self.actions_api = actions_api.ActionsApi(self.api_client)

    # ==================== Repository 管理 ====================

    def create_repository(
        self,
        name: str,
        storage_namespace: str,
        default_branch: str = "main"
    ) -> Dict[str, Any]:
        """创建仓库"""
        repo_creation = RepositoryCreation(
            name=name,
            storage_namespace=storage_namespace,
            default_branch=default_branch
        )
        result = self.repos_api.create_repository(repo_creation)
        return {
            "id": result.id,
            "storage_namespace": result.storage_namespace,
            "default_branch": result.default_branch,
            "creation_date": result.creation_date
        }

    def list_repositories(self) -> List[Dict[str, Any]]:
        """列出所有仓库"""
        result = self.repos_api.list_repositories()
        return [
            {
                "id": repo.id,
                "storage_namespace": repo.storage_namespace,
                "default_branch": repo.default_branch,
                "creation_date": repo.creation_date
            }
            for repo in result.results
        ]

    def delete_repository(self, name: str):
        """删除仓库"""
        self.repos_api.delete_repository(name)

    # ==================== Branch 管理 ====================

    def create_branch(
        self,
        repository: str,
        name: str,
        source: str = "main"
    ) -> Dict[str, Any]:
        """创建分支"""
        branch_creation = BranchCreation(name=name, source=source)
        result = self.branches_api.create_branch(repository, branch_creation)
        return {"name": name, "commit_id": result}

    def list_branches(self, repository: str) -> List[Dict[str, Any]]:
        """列出分支"""
        result = self.branches_api.list_branches(repository)
        return [
            {"name": ref.id, "commit_id": ref.commit_id}
            for ref in result.results
        ]

    def delete_branch(self, repository: str, branch: str):
        """删除分支"""
        self.branches_api.delete_branch(repository, branch)

    def get_branch(self, repository: str, branch: str) -> Dict[str, Any]:
        """获取分支信息"""
        result = self.branches_api.get_branch(repository, branch)
        return {"name": branch, "commit_id": result.commit_id}

    # ==================== Commit 管理 ====================

    def commit(
        self,
        repository: str,
        branch: str,
        message: str,
        metadata: Optional[Dict[str, str]] = None
    ) -> CommitInfo:
        """提交更改"""
        commit_creation = CommitCreation(
            message=message,
            metadata=metadata or {}
        )
        result = self.commits_api.commit(repository, branch, commit_creation)
        return CommitInfo(
            id=result.id,
            message=result.message,
            committer=result.committer,
            creation_date=result.creation_date,
            parents=result.parents,
            metadata=result.metadata or {}
        )

    def get_commit(self, repository: str, commit_id: str) -> CommitInfo:
        """获取提交信息"""
        result = self.commits_api.get_commit(repository, commit_id)
        return CommitInfo(
            id=result.id,
            message=result.message,
            committer=result.committer,
            creation_date=result.creation_date,
            parents=result.parents,
            metadata=result.metadata or {}
        )

    def log(
        self,
        repository: str,
        ref: str,
        limit: int = 100
    ) -> List[CommitInfo]:
        """获取提交历史"""
        result = self.refs_api.log_commits(repository, ref, amount=limit)
        return [
            CommitInfo(
                id=commit.id,
                message=commit.message,
                committer=commit.committer,
                creation_date=commit.creation_date,
                parents=commit.parents,
                metadata=commit.metadata or {}
            )
            for commit in result.results
        ]

    # ==================== Merge 操作 ====================

    def merge(
        self,
        repository: str,
        source_ref: str,
        destination_branch: str,
        message: Optional[str] = None,
        strategy: str = "default"
    ) -> Dict[str, Any]:
        """合并分支"""
        merge_request = Merge(
            message=message or f"Merge {source_ref} into {destination_branch}",
            strategy=strategy
        )
        result = self.refs_api.merge_into_branch(
            repository,
            source_ref,
            destination_branch,
            merge=merge_request
        )
        return {"reference": result.reference}

    def diff(
        self,
        repository: str,
        left_ref: str,
        right_ref: str,
        prefix: str = ""
    ) -> List[Dict[str, Any]]:
        """比较两个引用之间的差异"""
        result = self.refs_api.diff_refs(
            repository,
            left_ref,
            right_ref,
            prefix=prefix if prefix else None
        )
        return [
            {
                "path": diff.path,
                "type": diff.type,
                "size_bytes": diff.size_bytes
            }
            for diff in result.results
        ]

    def revert(
        self,
        repository: str,
        branch: str,
        parent_number: int = 1
    ):
        """回退提交"""
        revert = RevertCreation(
            ref=branch,
            parent_number=parent_number
        )
        self.branches_api.revert_branch(repository, branch, revert)

    # ==================== Object 操作 ====================

    def upload_object(
        self,
        repository: str,
        branch: str,
        path: str,
        content: BinaryIO,
        content_type: Optional[str] = None
    ) -> ObjectInfo:
        """上传对象"""
        result = self.objects_api.upload_object(
            repository,
            branch,
            path,
            content=content
        )
        return ObjectInfo(
            path=result.path,
            size=result.size_bytes,
            checksum=result.checksum,
            mtime=result.mtime,
            content_type=result.content_type
        )

    def get_object(
        self,
        repository: str,
        ref: str,
        path: str
    ) -> bytes:
        """获取对象内容"""
        result = self.objects_api.get_object(repository, ref, path)
        return result.read()

    def delete_object(self, repository: str, branch: str, path: str):
        """删除对象"""
        self.objects_api.delete_object(repository, branch, path)

    def list_objects(
        self,
        repository: str,
        ref: str,
        prefix: str = "",
        delimiter: str = "",
        limit: int = 1000
    ) -> List[ObjectInfo]:
        """列出对象"""
        result = self.objects_api.list_objects(
            repository,
            ref,
            prefix=prefix if prefix else None,
            delimiter=delimiter if delimiter else None,
            amount=limit
        )
        return [
            ObjectInfo(
                path=obj.path,
                size=obj.size_bytes,
                checksum=obj.checksum,
                mtime=obj.mtime,
                content_type=obj.content_type
            )
            for obj in result.results
        ]

    def stat_object(
        self,
        repository: str,
        ref: str,
        path: str
    ) -> ObjectInfo:
        """获取对象状态"""
        result = self.objects_api.stat_object(repository, ref, path)
        return ObjectInfo(
            path=result.path,
            size=result.size_bytes,
            checksum=result.checksum,
            mtime=result.mtime,
            content_type=result.content_type
        )

    # ==================== 便捷方法 ====================

    def upload_dataframe(
        self,
        repository: str,
        branch: str,
        path: str,
        df: pd.DataFrame,
        format: str = "parquet"
    ) -> ObjectInfo:
        """上传 DataFrame"""
        buffer = BytesIO()

        if format == "parquet":
            df.to_parquet(buffer, index=False)
            content_type = "application/octet-stream"
        elif format == "csv":
            df.to_csv(buffer, index=False)
            content_type = "text/csv"
        else:
            raise ValueError(f"Unsupported format: {format}")

        buffer.seek(0)
        return self.upload_object(
            repository, branch, path, buffer, content_type
        )

    def read_dataframe(
        self,
        repository: str,
        ref: str,
        path: str,
        format: str = "parquet"
    ) -> pd.DataFrame:
        """读取 DataFrame"""
        content = self.get_object(repository, ref, path)
        buffer = BytesIO(content)

        if format == "parquet":
            return pd.read_parquet(buffer)
        elif format == "csv":
            return pd.read_csv(buffer)
        else:
            raise ValueError(f"Unsupported format: {format}")

    def get_s3_uri(self, repository: str, ref: str, path: str = "") -> str:
        """获取 S3 兼容 URI"""
        return f"s3://{repository}/{ref}/{path}".rstrip("/")


class LakeFSDataPipeline:
    """LakeFS 数据管道"""

    def __init__(self, manager: LakeFSManager, repository: str):
        self.manager = manager
        self.repository = repository

    def create_experiment_branch(
        self,
        experiment_name: str,
        source_branch: str = "main"
    ) -> str:
        """创建实验分支"""
        branch_name = f"experiment/{experiment_name}"
        self.manager.create_branch(
            self.repository,
            branch_name,
            source_branch
        )
        return branch_name

    def ingest_data(
        self,
        branch: str,
        data: pd.DataFrame,
        table_name: str,
        partition_cols: Optional[List[str]] = None
    ) -> List[str]:
        """数据摄入"""
        paths = []

        if partition_cols:
            # 分区写入
            for name, group in data.groupby(partition_cols):
                if isinstance(name, tuple):
                    partition_path = "/".join(
                        f"{col}={val}" for col, val in zip(partition_cols, name)
                    )
                else:
                    partition_path = f"{partition_cols[0]}={name}"

                path = f"data/{table_name}/{partition_path}/data.parquet"
                self.manager.upload_dataframe(
                    self.repository, branch, path, group
                )
                paths.append(path)
        else:
            # 单文件写入
            path = f"data/{table_name}/data.parquet"
            self.manager.upload_dataframe(
                self.repository, branch, path, data
            )
            paths.append(path)

        return paths

    def promote_to_production(
        self,
        source_branch: str,
        validation_passed: bool = True,
        commit_message: str = "Promote to production"
    ):
        """推广到生产"""
        if not validation_passed:
            raise ValueError("Validation must pass before promotion")

        # 先提交源分支
        self.manager.commit(
            self.repository,
            source_branch,
            f"Prepare for production: {commit_message}"
        )

        # 合并到 main
        self.manager.merge(
            self.repository,
            source_branch,
            "main",
            message=commit_message
        )

    def create_snapshot(self, name: str, description: str = "") -> str:
        """创建数据快照(标签)"""
        # 使用分支作为快照(LakeFS 原生不支持 tag,用分支模拟)
        snapshot_branch = f"snapshot/{name}"
        self.manager.create_branch(
            self.repository,
            snapshot_branch,
            "main"
        )

        # 添加快照元数据
        self.manager.commit(
            self.repository,
            snapshot_branch,
            f"Snapshot: {description}",
            metadata={"snapshot_name": name, "description": description}
        )

        return snapshot_branch


# 使用示例
if __name__ == "__main__":
    # 配置 LakeFS
    config = LakeFSConfig(
        endpoint="http://localhost:8000",
        access_key_id="AKIAIOSFODNN7EXAMPLE",
        secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
    )

    # 创建管理器
    lakefs = LakeFSManager(config)

    # 创建仓库
    lakefs.create_repository(
        name="ml-data",
        storage_namespace="s3://my-bucket/lakefs/ml-data"
    )

    # 创建数据管道
    pipeline = LakeFSDataPipeline(lakefs, "ml-data")

    # 创建实验分支
    branch = pipeline.create_experiment_branch("feature-v2")

    # 摄入数据
    df = pd.DataFrame({
        "feature1": [1, 2, 3],
        "feature2": ["a", "b", "c"],
        "label": [0, 1, 0]
    })
    pipeline.ingest_data(branch, df, "training_data")

    # 提交更改
    lakefs.commit(
        "ml-data",
        branch,
        "Add new training data",
        metadata={"version": "2.0", "rows": "3"}
    )

    # 查看差异
    diff = lakefs.diff("ml-data", "main", branch)
    print(f"Changes: {diff}")

    # 推广到生产
    pipeline.promote_to_production(branch, validation_passed=True)

Delta Lake 表格版本管理

Delta Lake 与 ML 工作流集成

# delta_lake_manager.py
"""
Delta Lake 数据版本管理
支持 ACID 事务、时间旅行、Schema 演进
"""

from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
from datetime import datetime, timedelta
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, current_timestamp, lit
from delta import DeltaTable, configure_spark_with_delta_pip
import json


@dataclass
class DeltaTableInfo:
    """Delta 表信息"""
    name: str
    location: str
    num_files: int
    size_bytes: int
    num_records: int
    partitions: List[str]
    schema: Dict[str, str]
    version: int
    created_at: datetime
    last_modified: datetime


class DeltaLakeManager:
    """Delta Lake 管理器"""

    def __init__(
        self,
        spark: Optional[SparkSession] = None,
        warehouse_path: str = "/data/delta-warehouse"
    ):
        self.warehouse_path = warehouse_path
        self.spark = spark or self._create_spark_session()

    def _create_spark_session(self) -> SparkSession:
        """创建 Spark Session"""
        builder = SparkSession.builder \
            .appName("DeltaLakeManager") \
            .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
            .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
            .config("spark.sql.warehouse.dir", self.warehouse_path)

        return configure_spark_with_delta_pip(builder).getOrCreate()

    # ==================== 表管理 ====================

    def create_table(
        self,
        name: str,
        df: DataFrame,
        partition_by: Optional[List[str]] = None,
        mode: str = "errorIfExists"
    ) -> str:
        """创建 Delta 表"""
        path = f"{self.warehouse_path}/{name}"
        writer = df.write.format("delta").mode(mode)

        if partition_by:
            writer = writer.partitionBy(*partition_by)

        writer.save(path)
        return path

    def get_table(self, name: str) -> DeltaTable:
        """获取 Delta 表"""
        path = f"{self.warehouse_path}/{name}"
        return DeltaTable.forPath(self.spark, path)

    def table_exists(self, name: str) -> bool:
        """检查表是否存在"""
        path = f"{self.warehouse_path}/{name}"
        try:
            DeltaTable.forPath(self.spark, path)
            return True
        except:
            return False

    def get_table_info(self, name: str) -> DeltaTableInfo:
        """获取表详细信息"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)

        # 获取详情
        detail = table.detail().collect()[0]
        history = table.history(1).collect()[0]

        return DeltaTableInfo(
            name=name,
            location=detail.location,
            num_files=detail.numFiles,
            size_bytes=detail.sizeInBytes,
            num_records=self.spark.read.format("delta").load(path).count(),
            partitions=detail.partitionColumns,
            schema={f.name: str(f.dataType) for f in self.spark.read.format("delta").load(path).schema.fields},
            version=history.version,
            created_at=detail.createdAt,
            last_modified=detail.lastModified
        )

    def drop_table(self, name: str):
        """删除表"""
        path = f"{self.warehouse_path}/{name}"
        import shutil
        shutil.rmtree(path, ignore_errors=True)

    # ==================== 数据操作 ====================

    def write(
        self,
        name: str,
        df: DataFrame,
        mode: str = "append",
        partition_by: Optional[List[str]] = None,
        merge_schema: bool = False
    ):
        """写入数据"""
        path = f"{self.warehouse_path}/{name}"
        writer = df.write.format("delta").mode(mode)

        if partition_by:
            writer = writer.partitionBy(*partition_by)
        if merge_schema:
            writer = writer.option("mergeSchema", "true")

        writer.save(path)

    def read(
        self,
        name: str,
        version: Optional[int] = None,
        timestamp: Optional[str] = None
    ) -> DataFrame:
        """读取数据(支持时间旅行)"""
        path = f"{self.warehouse_path}/{name}"
        reader = self.spark.read.format("delta")

        if version is not None:
            reader = reader.option("versionAsOf", version)
        elif timestamp:
            reader = reader.option("timestampAsOf", timestamp)

        return reader.load(path)

    def upsert(
        self,
        name: str,
        updates: DataFrame,
        condition: str,
        update_columns: Optional[List[str]] = None
    ):
        """Upsert 操作(合并更新)"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)

        merge_builder = table.alias("target").merge(
            updates.alias("source"),
            condition
        )

        # 匹配时更新
        if update_columns:
            update_set = {col: f"source.{col}" for col in update_columns}
        else:
            update_set = {f.name: f"source.{f.name}" for f in updates.schema.fields}

        merge_builder = merge_builder.whenMatchedUpdate(set=update_set)

        # 不匹配时插入
        merge_builder = merge_builder.whenNotMatchedInsertAll()

        merge_builder.execute()

    def delete(self, name: str, condition: str):
        """删除数据"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)
        table.delete(condition)

    def update(self, name: str, condition: str, set_values: Dict[str, Any]):
        """更新数据"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)
        table.update(condition, set_values)

    # ==================== 版本管理 ====================

    def history(self, name: str, limit: int = 100) -> List[Dict[str, Any]]:
        """获取版本历史"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)
        history_df = table.history(limit)

        return [
            {
                "version": row.version,
                "timestamp": row.timestamp,
                "operation": row.operation,
                "operationParameters": row.operationParameters,
                "operationMetrics": row.operationMetrics,
                "userMetadata": row.userMetadata
            }
            for row in history_df.collect()
        ]

    def restore(self, name: str, version: Optional[int] = None, timestamp: Optional[str] = None):
        """恢复到指定版本"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)

        if version is not None:
            table.restoreToVersion(version)
        elif timestamp:
            table.restoreToTimestamp(timestamp)
        else:
            raise ValueError("Must specify version or timestamp")

    def vacuum(self, name: str, retention_hours: int = 168):
        """清理旧版本文件"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)
        table.vacuum(retention_hours)

    def optimize(self, name: str, z_order_by: Optional[List[str]] = None):
        """优化表"""
        path = f"{self.warehouse_path}/{name}"
        table = DeltaTable.forPath(self.spark, path)

        if z_order_by:
            table.optimize().executeZOrderBy(*z_order_by)
        else:
            table.optimize().executeCompaction()

    # ==================== Schema 演进 ====================

    def add_columns(
        self,
        name: str,
        columns: Dict[str, str],
        default_values: Optional[Dict[str, Any]] = None
    ):
        """添加列"""
        path = f"{self.warehouse_path}/{name}"

        # 读取现有数据
        df = self.spark.read.format("delta").load(path)

        # 添加新列
        for col_name, col_type in columns.items():
            default = default_values.get(col_name) if default_values else None
            df = df.withColumn(col_name, lit(default).cast(col_type))

        # 重写表
        df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(path)

    def rename_column(self, name: str, old_name: str, new_name: str):
        """重命名列"""
        path = f"{self.warehouse_path}/{name}"
        df = self.spark.read.format("delta").load(path)
        df = df.withColumnRenamed(old_name, new_name)
        df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(path)

    # ==================== Change Data Feed ====================

    def enable_cdf(self, name: str):
        """启用 Change Data Feed"""
        path = f"{self.warehouse_path}/{name}"
        self.spark.sql(f"""
            ALTER TABLE delta.`{path}`
            SET TBLPROPERTIES (delta.enableChangeDataFeed = true)
        """)

    def read_changes(
        self,
        name: str,
        start_version: int,
        end_version: Optional[int] = None
    ) -> DataFrame:
        """读取变更数据"""
        path = f"{self.warehouse_path}/{name}"
        reader = self.spark.read.format("delta") \
            .option("readChangeFeed", "true") \
            .option("startingVersion", start_version)

        if end_version:
            reader = reader.option("endingVersion", end_version)

        return reader.load(path)


class MLDataVersionManager:
    """ML 数据版本管理器"""

    def __init__(self, delta_manager: DeltaLakeManager):
        self.delta = delta_manager

    def create_feature_table(
        self,
        name: str,
        df: DataFrame,
        primary_key: str,
        timestamp_col: str,
        partition_by: Optional[List[str]] = None
    ) -> str:
        """创建特征表"""
        # 添加版本列
        df = df.withColumn("_version_timestamp", current_timestamp())

        path = self.delta.create_table(
            name,
            df,
            partition_by=partition_by
        )

        # 启用 CDF
        self.delta.enable_cdf(name)

        return path

    def update_features(
        self,
        name: str,
        updates: DataFrame,
        primary_key: str
    ):
        """更新特征"""
        updates = updates.withColumn("_version_timestamp", current_timestamp())

        self.delta.upsert(
            name,
            updates,
            condition=f"target.{primary_key} = source.{primary_key}"
        )

    def get_features_at_time(
        self,
        name: str,
        entity_ids: List[Any],
        timestamp: str,
        primary_key: str
    ) -> DataFrame:
        """获取指定时间点的特征(Point-in-Time)"""
        df = self.delta.read(name, timestamp=timestamp)
        return df.filter(col(primary_key).isin(entity_ids))

    def create_training_snapshot(
        self,
        name: str,
        version: Optional[int] = None,
        description: str = ""
    ) -> int:
        """创建训练数据快照"""
        if version is None:
            # 获取当前版本
            history = self.delta.history(name, 1)
            version = history[0]["version"]

        # 记录快照元数据
        snapshot_info = {
            "table": name,
            "version": version,
            "description": description,
            "created_at": datetime.now().isoformat()
        }

        # 保存到元数据表
        snapshot_df = self.delta.spark.createDataFrame([snapshot_info])

        if self.delta.table_exists("_training_snapshots"):
            self.delta.write("_training_snapshots", snapshot_df, mode="append")
        else:
            self.delta.create_table("_training_snapshots", snapshot_df)

        return version

    def load_training_snapshot(self, name: str, version: int) -> DataFrame:
        """加载训练数据快照"""
        return self.delta.read(name, version=version)

    def compare_versions(
        self,
        name: str,
        version1: int,
        version2: int
    ) -> Dict[str, Any]:
        """比较两个版本"""
        df1 = self.delta.read(name, version=version1)
        df2 = self.delta.read(name, version=version2)

        count1 = df1.count()
        count2 = df2.count()

        # Schema 差异
        schema1 = set(f.name for f in df1.schema.fields)
        schema2 = set(f.name for f in df2.schema.fields)

        return {
            "version1": {"version": version1, "row_count": count1, "columns": list(schema1)},
            "version2": {"version": version2, "row_count": count2, "columns": list(schema2)},
            "row_count_diff": count2 - count1,
            "added_columns": list(schema2 - schema1),
            "removed_columns": list(schema1 - schema2)
        }


# 使用示例
if __name__ == "__main__":
    # 创建管理器
    delta = DeltaLakeManager(warehouse_path="/data/ml-warehouse")

    # 创建 ML 数据管理器
    ml_data = MLDataVersionManager(delta)

    # 创建示例数据
    data = [
        ("user1", 25, 50000, 0.8),
        ("user2", 30, 75000, 0.6),
        ("user3", 35, 100000, 0.9)
    ]
    df = delta.spark.createDataFrame(data, ["user_id", "age", "income", "score"])

    # 创建特征表
    ml_data.create_feature_table(
        "user_features",
        df,
        primary_key="user_id",
        timestamp_col="timestamp"
    )

    # 创建训练快照
    version = ml_data.create_training_snapshot(
        "user_features",
        description="Baseline training data"
    )
    print(f"Created snapshot at version {version}")

    # 更新特征
    updates = delta.spark.createDataFrame([
        ("user1", 26, 55000, 0.85)
    ], ["user_id", "age", "income", "score"])

    ml_data.update_features("user_features", updates, "user_id")

    # 比较版本
    diff = ml_data.compare_versions("user_features", version, version + 1)
    print(f"Version diff: {json.dumps(diff, indent=2)}")

    # 加载历史快照进行训练
    training_data = ml_data.load_training_snapshot("user_features", version)
    training_data.show()

总结

数据版本管理是 ML 项目可复现性的基石。本文介绍了三种主流方案:

  1. DVC - Git 扩展模式

    • 适合:小团队 ML 项目
    • 优势:学习曲线低,与 Git 深度集成
    • 劣势:大规模数据管理能力有限
  2. LakeFS - 数据湖版本控制

    • 适合:数据湖环境
    • 优势:S3 兼容,零拷贝分支
    • 劣势:需要额外基础设施
  3. Delta Lake - 表格式存储

    • 适合:数据仓库场景
    • 优势:ACID 事务,Schema 演进
    • 劣势:依赖 Spark 生态

选择建议:

  • 小型 ML 项目:DVC
  • 数据湖环境:LakeFS
  • 数据仓库/大规模分析:Delta Lake

下一章节将探讨实验跟踪与模型注册,完成 MLOps 数据管理的闭环。

Prev
03-Argo Workflows 深度实践
Next
05-实验跟踪与模型注册