HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

模型存储与管理

概述

大规模 AI 训练过程中会产生海量的模型文件,包括训练检查点、最终模型、中间产物等。如何高效、可靠地存储和管理这些模型文件,是 AI 平台的核心能力之一。本章深入讲解模型存储架构、版本管理、检查点优化等关键技术。

模型文件特征分析

模型文件类型

┌─────────────────────────────────────────────────────────────────┐
│                      模型文件类型                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐              │
│  │  检查点文件   │  │  最终模型    │  │  中间产物    │              │
│  │ Checkpoint  │  │ Final Model │  │ Artifacts   │              │
│  ├─────────────┤  ├─────────────┤  ├─────────────┤              │
│  │ 权重参数     │  │ 权重参数     │  │ 优化器状态   │              │
│  │ 优化器状态   │  │ 模型配置     │  │ 梯度累积     │              │
│  │ 训练状态     │  │ Tokenizer   │  │ 激活值缓存   │              │
│  │ 随机种子     │  │ 推理配置     │  │ 通信缓冲     │              │
│  └─────────────┘  └─────────────┘  └─────────────┘              │
│                                                                  │
│  大小特征:                                                        │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ 模型规模        │ 检查点大小      │ 存储需求/天            │   │
│  ├──────────────────────────────────────────────────────────┤   │
│  │ 7B (LLaMA)     │ 14-28 GB       │ 200-500 GB            │   │
│  │ 13B            │ 26-52 GB       │ 400-800 GB            │   │
│  │ 65B            │ 130-260 GB     │ 2-4 TB                │   │
│  │ 175B (GPT-3)   │ 350-700 GB     │ 5-10 TB               │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

访问模式分析

// model_access_pattern.go
package storage

import (
    "time"
)

// AccessPattern 模型文件访问模式
type AccessPattern struct {
    // 检查点写入模式
    CheckpointWrite CheckpointWritePattern
    // 检查点读取模式(恢复训练)
    CheckpointRead CheckpointReadPattern
    // 模型导出模式
    ModelExport ModelExportPattern
    // 模型加载模式(推理)
    ModelLoad ModelLoadPattern
}

// CheckpointWritePattern 检查点写入特征
type CheckpointWritePattern struct {
    // 写入频率:每 N 步保存一次
    SaveInterval int
    // 典型间隔:1000-10000 步
    TypicalSteps int
    // 写入大小:与模型规模成正比
    WriteSize int64
    // 并发写入:分布式训练时所有 rank 同时写
    ConcurrentWriters int
    // 写入带宽需求:高峰期需要高吞吐
    RequiredBandwidth int64 // bytes/s
    // 写入模式:顺序写为主
    Sequential bool
}

// CheckpointReadPattern 检查点读取特征
type CheckpointReadPattern struct {
    // 读取频率:训练恢复时
    Frequency string // "rare" - 仅故障恢复或继续训练
    // 读取大小:完整检查点
    ReadSize int64
    // 并发读取:所有 rank 同时读
    ConcurrentReaders int
    // 读取带宽需求
    RequiredBandwidth int64
    // 读取模式:顺序读
    Sequential bool
}

// ModelExportPattern 模型导出特征
type ModelExportPattern struct {
    // 导出格式
    Formats []string // ["pytorch", "safetensors", "onnx", "tensorrt"]
    // 导出频率:训练结束或里程碑
    Frequency string // "milestone"
    // 处理类型:可能需要格式转换
    RequiresConversion bool
}

// ModelLoadPattern 模型加载特征(推理场景)
type ModelLoadPattern struct {
    // 加载频率:服务启动、扩容
    Frequency string // "on_demand"
    // 加载速度要求:影响服务启动时间
    LatencyRequirement time.Duration
    // 并发加载:多实例同时加载
    ConcurrentLoaders int
    // 部分加载:分层加载支持
    SupportsPartialLoad bool
}

// CalculateStorageRequirements 计算存储需求
func CalculateStorageRequirements(
    modelParams int64, // 模型参数量
    trainingDays int,
    checkpointInterval int, // 小时
    keepCheckpoints int, // 保留数量
) StorageRequirements {

    // 每个参数占用字节(FP16 = 2, FP32 = 4, 混合精度约 6)
    bytesPerParam := int64(6)

    // 检查点大小 = 参数 × 每参数字节 × 2(参数+优化器状态)
    checkpointSize := modelParams * bytesPerParam * 2

    // 每天检查点数
    checkpointsPerDay := 24 / checkpointInterval

    // 总存储需求
    totalCheckpointStorage := checkpointSize * int64(keepCheckpoints)

    // 峰值写入带宽(假设检查点写入时间 < 10分钟)
    peakWriteBandwidth := checkpointSize / (10 * 60) // bytes/s

    return StorageRequirements{
        CheckpointSize:       checkpointSize,
        TotalStorage:         totalCheckpointStorage,
        PeakWriteBandwidth:   peakWriteBandwidth,
        CheckpointsPerDay:    checkpointsPerDay,
        RecommendedReplicas:  3, // 三副本保证可靠性
    }
}

type StorageRequirements struct {
    CheckpointSize      int64
    TotalStorage        int64
    PeakWriteBandwidth  int64
    CheckpointsPerDay   int
    RecommendedReplicas int
}

分布式存储架构

存储系统选型

┌─────────────────────────────────────────────────────────────────┐
│                    AI 训练存储架构                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    存储层次架构                           │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
│         ┌─────────────┐                                         │
│         │  本地 NVMe   │  ← 最快,容量有限                       │
│         │  Cache      │    用于:激活值缓存、临时文件             │
│         └──────┬──────┘                                         │
│                │                                                 │
│         ┌──────▼──────┐                                         │
│         │  共享 NVMe   │  ← 高性能,跨节点共享                   │
│         │  over Fabric │    用于:检查点缓存、模型加载            │
│         └──────┬──────┘                                         │
│                │                                                 │
│         ┌──────▼──────┐                                         │
│         │  并行文件    │  ← 高吞吐,大容量                       │
│         │  系统        │    用于:检查点持久化、数据集             │
│         │  (Lustre/    │                                        │
│         │   GPFS)      │                                        │
│         └──────┬──────┘                                         │
│                │                                                 │
│         ┌──────▼──────┐                                         │
│         │  对象存储    │  ← 最大容量,成本最低                    │
│         │  (S3/OSS)   │    用于:模型归档、长期保存               │
│         └─────────────┘                                         │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

并行文件系统集成

// parallel_fs.go
package storage

import (
    "context"
    "fmt"
    "io"
    "os"
    "path/filepath"
    "sync"
    "syscall"
)

// ParallelFileSystem 并行文件系统接口
type ParallelFileSystem interface {
    // 并行写入
    ParallelWrite(ctx context.Context, path string, data []byte, opts WriteOptions) error
    // 并行读取
    ParallelRead(ctx context.Context, path string, opts ReadOptions) ([]byte, error)
    // 条带化配置
    SetStriping(path string, stripeCount int, stripeSize int64) error
    // 获取文件布局
    GetLayout(path string) (*FileLayout, error)
}

// LustreFS Lustre 文件系统实现
type LustreFS struct {
    mountPoint   string
    defaultStripeCount int
    defaultStripeSize  int64
    maxStripeCount     int
}

// WriteOptions 写入选项
type WriteOptions struct {
    StripeCount   int   // 条带数量
    StripeSize    int64 // 条带大小
    Parallel      bool  // 是否并行写
    DirectIO      bool  // 是否使用 Direct I/O
    Compression   string // 压缩算法
}

// ReadOptions 读取选项
type ReadOptions struct {
    Parallel    bool
    DirectIO    bool
    ReadAhead   int64 // 预读大小
}

// FileLayout 文件布局信息
type FileLayout struct {
    StripeCount  int
    StripeSize   int64
    StripeOffset int64
    OSTs         []int // Object Storage Targets
}

func NewLustreFS(mountPoint string) *LustreFS {
    return &LustreFS{
        mountPoint:         mountPoint,
        defaultStripeCount: 4,
        defaultStripeSize:  1 << 20, // 1MB
        maxStripeCount:     128,
    }
}

// SetStriping 设置文件条带化
func (l *LustreFS) SetStriping(path string, stripeCount int, stripeSize int64) error {
    fullPath := filepath.Join(l.mountPoint, path)

    // 确保目录存在
    dir := filepath.Dir(fullPath)
    if err := os.MkdirAll(dir, 0755); err != nil {
        return fmt.Errorf("create directory: %w", err)
    }

    // 使用 lfs setstripe 命令设置条带
    // 实际实现中使用 Lustre API
    return l.setStripeViaAPI(fullPath, stripeCount, stripeSize)
}

func (l *LustreFS) setStripeViaAPI(path string, stripeCount int, stripeSize int64) error {
    // Lustre ioctl 调用设置条带
    // LOV_USER_MAGIC_V3 结构
    type lovUserMd struct {
        Magic       uint32
        Pattern     uint32
        ObjectID    uint64
        ObjectSeq   uint64
        StripeSize  uint32
        StripeCount uint16
        StripeOffset uint16
    }

    // 这里简化处理,实际需要调用 Lustre API
    return nil
}

// ParallelWrite 并行写入大文件
func (l *LustreFS) ParallelWrite(ctx context.Context, path string, data []byte, opts WriteOptions) error {
    fullPath := filepath.Join(l.mountPoint, path)

    // 设置条带化
    if opts.StripeCount > 0 {
        if err := l.SetStriping(path, opts.StripeCount, opts.StripeSize); err != nil {
            return err
        }
    }

    // 打开文件
    flags := os.O_CREATE | os.O_WRONLY | os.O_TRUNC
    if opts.DirectIO {
        flags |= syscall.O_DIRECT
    }

    file, err := os.OpenFile(fullPath, flags, 0644)
    if err != nil {
        return fmt.Errorf("open file: %w", err)
    }
    defer file.Close()

    if !opts.Parallel {
        // 顺序写入
        _, err = file.Write(data)
        return err
    }

    // 并行写入
    return l.parallelWriteChunks(ctx, file, data, opts)
}

func (l *LustreFS) parallelWriteChunks(ctx context.Context, file *os.File, data []byte, opts WriteOptions) error {
    chunkSize := opts.StripeSize
    if chunkSize == 0 {
        chunkSize = l.defaultStripeSize
    }

    numChunks := (int64(len(data)) + chunkSize - 1) / chunkSize

    var wg sync.WaitGroup
    errCh := make(chan error, numChunks)

    for i := int64(0); i < numChunks; i++ {
        start := i * chunkSize
        end := start + chunkSize
        if end > int64(len(data)) {
            end = int64(len(data))
        }

        wg.Add(1)
        go func(offset int64, chunk []byte) {
            defer wg.Done()

            select {
            case <-ctx.Done():
                errCh <- ctx.Err()
                return
            default:
            }

            _, err := file.WriteAt(chunk, offset)
            if err != nil {
                errCh <- fmt.Errorf("write at %d: %w", offset, err)
            }
        }(start, data[start:end])
    }

    wg.Wait()
    close(errCh)

    for err := range errCh {
        if err != nil {
            return err
        }
    }

    return file.Sync()
}

// ParallelRead 并行读取大文件
func (l *LustreFS) ParallelRead(ctx context.Context, path string, opts ReadOptions) ([]byte, error) {
    fullPath := filepath.Join(l.mountPoint, path)

    // 获取文件信息
    info, err := os.Stat(fullPath)
    if err != nil {
        return nil, fmt.Errorf("stat file: %w", err)
    }

    // 打开文件
    flags := os.O_RDONLY
    if opts.DirectIO {
        flags |= syscall.O_DIRECT
    }

    file, err := os.OpenFile(fullPath, flags, 0)
    if err != nil {
        return nil, fmt.Errorf("open file: %w", err)
    }
    defer file.Close()

    // 获取文件布局以优化读取
    layout, _ := l.GetLayout(path)

    data := make([]byte, info.Size())

    if !opts.Parallel {
        _, err = io.ReadFull(file, data)
        return data, err
    }

    // 根据条带布局并行读取
    return l.parallelReadChunks(ctx, file, data, layout, opts)
}

func (l *LustreFS) parallelReadChunks(
    ctx context.Context,
    file *os.File,
    data []byte,
    layout *FileLayout,
    opts ReadOptions,
) ([]byte, error) {
    chunkSize := int64(1 << 20) // 1MB default
    if layout != nil && layout.StripeSize > 0 {
        chunkSize = layout.StripeSize
    }

    numChunks := (int64(len(data)) + chunkSize - 1) / chunkSize

    var wg sync.WaitGroup
    errCh := make(chan error, numChunks)

    // 控制并发度
    sem := make(chan struct{}, 16)

    for i := int64(0); i < numChunks; i++ {
        start := i * chunkSize
        end := start + chunkSize
        if end > int64(len(data)) {
            end = int64(len(data))
        }

        wg.Add(1)
        go func(offset, length int64) {
            defer wg.Done()

            sem <- struct{}{}
            defer func() { <-sem }()

            select {
            case <-ctx.Done():
                errCh <- ctx.Err()
                return
            default:
            }

            _, err := file.ReadAt(data[offset:offset+length], offset)
            if err != nil && err != io.EOF {
                errCh <- fmt.Errorf("read at %d: %w", offset, err)
            }
        }(start, end-start)
    }

    wg.Wait()
    close(errCh)

    for err := range errCh {
        if err != nil {
            return nil, err
        }
    }

    return data, nil
}

// GetLayout 获取文件条带布局
func (l *LustreFS) GetLayout(path string) (*FileLayout, error) {
    // 通过 Lustre API 获取文件布局
    // 简化实现
    return &FileLayout{
        StripeCount:  l.defaultStripeCount,
        StripeSize:   l.defaultStripeSize,
        StripeOffset: 0,
        OSTs:         []int{0, 1, 2, 3},
    }, nil
}

// OptimalStripeConfig 计算最优条带配置
func (l *LustreFS) OptimalStripeConfig(fileSize int64, accessPattern string) (stripeCount int, stripeSize int64) {
    switch accessPattern {
    case "checkpoint_write":
        // 检查点写入:高并发,大文件
        // 使用更多条带以提高吞吐
        if fileSize > 10<<30 { // > 10GB
            return min(l.maxStripeCount, 64), 4 << 20 // 64 条带,4MB
        } else if fileSize > 1<<30 { // > 1GB
            return 32, 2 << 20 // 32 条带,2MB
        }
        return 16, 1 << 20 // 16 条带,1MB

    case "checkpoint_read":
        // 检查点读取:顺序读为主
        if fileSize > 10<<30 {
            return 32, 4 << 20
        }
        return 16, 2 << 20

    case "model_serve":
        // 模型加载:追求最快加载
        return min(l.maxStripeCount, 128), 1 << 20

    default:
        return l.defaultStripeCount, l.defaultStripeSize
    }
}

func min(a, b int) int {
    if a < b {
        return a
    }
    return b
}

对象存储集成

// object_storage.go
package storage

import (
    "bytes"
    "context"
    "crypto/md5"
    "encoding/hex"
    "fmt"
    "io"
    "sync"
    "time"

    "github.com/aws/aws-sdk-go-v2/aws"
    "github.com/aws/aws-sdk-go-v2/service/s3"
    "github.com/aws/aws-sdk-go-v2/service/s3/types"
)

// ObjectStorage 对象存储接口
type ObjectStorage interface {
    // 上传对象
    Upload(ctx context.Context, bucket, key string, data []byte) error
    // 分片上传大对象
    MultipartUpload(ctx context.Context, bucket, key string, reader io.Reader, size int64) error
    // 下载对象
    Download(ctx context.Context, bucket, key string) ([]byte, error)
    // 并行下载大对象
    ParallelDownload(ctx context.Context, bucket, key string) ([]byte, error)
    // 列出对象
    List(ctx context.Context, bucket, prefix string) ([]ObjectInfo, error)
    // 删除对象
    Delete(ctx context.Context, bucket, key string) error
    // 复制对象
    Copy(ctx context.Context, srcBucket, srcKey, dstBucket, dstKey string) error
}

// ObjectInfo 对象信息
type ObjectInfo struct {
    Key          string
    Size         int64
    LastModified time.Time
    ETag         string
    StorageClass string
}

// S3Storage S3 对象存储实现
type S3Storage struct {
    client      *s3.Client
    partSize    int64
    concurrency int
}

func NewS3Storage(client *s3.Client) *S3Storage {
    return &S3Storage{
        client:      client,
        partSize:    100 << 20, // 100MB per part
        concurrency: 16,
    }
}

// MultipartUpload 分片上传大文件
func (s *S3Storage) MultipartUpload(
    ctx context.Context,
    bucket, key string,
    reader io.Reader,
    size int64,
) error {
    // 创建分片上传
    createResp, err := s.client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
        Bucket:       aws.String(bucket),
        Key:          aws.String(key),
        StorageClass: types.StorageClassStandardIa, // 检查点使用低频存储
    })
    if err != nil {
        return fmt.Errorf("create multipart upload: %w", err)
    }

    uploadID := *createResp.UploadId

    // 计算分片
    numParts := (size + s.partSize - 1) / s.partSize
    parts := make([]types.CompletedPart, numParts)

    var wg sync.WaitGroup
    errCh := make(chan error, numParts)
    sem := make(chan struct{}, s.concurrency)

    var mu sync.Mutex
    partNum := int32(1)

    for {
        // 读取一个分片
        buf := make([]byte, s.partSize)
        n, readErr := io.ReadFull(reader, buf)
        if n == 0 {
            break
        }

        currentPart := partNum
        partNum++
        data := buf[:n]

        wg.Add(1)
        go func(pn int32, d []byte) {
            defer wg.Done()
            sem <- struct{}{}
            defer func() { <-sem }()

            // 计算 MD5
            hash := md5.Sum(d)
            md5Str := hex.EncodeToString(hash[:])

            // 上传分片
            uploadResp, err := s.client.UploadPart(ctx, &s3.UploadPartInput{
                Bucket:     aws.String(bucket),
                Key:        aws.String(key),
                UploadId:   aws.String(uploadID),
                PartNumber: aws.Int32(pn),
                Body:       bytes.NewReader(d),
                ContentMD5: aws.String(md5Str),
            })
            if err != nil {
                errCh <- fmt.Errorf("upload part %d: %w", pn, err)
                return
            }

            mu.Lock()
            parts[pn-1] = types.CompletedPart{
                ETag:       uploadResp.ETag,
                PartNumber: aws.Int32(pn),
            }
            mu.Unlock()
        }(currentPart, data)

        if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
            break
        }
        if readErr != nil {
            // 取消上传
            s.abortMultipartUpload(ctx, bucket, key, uploadID)
            return fmt.Errorf("read data: %w", readErr)
        }
    }

    wg.Wait()
    close(errCh)

    // 检查错误
    for err := range errCh {
        if err != nil {
            s.abortMultipartUpload(ctx, bucket, key, uploadID)
            return err
        }
    }

    // 完成上传
    _, err = s.client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
        Bucket:   aws.String(bucket),
        Key:      aws.String(key),
        UploadId: aws.String(uploadID),
        MultipartUpload: &types.CompletedMultipartUpload{
            Parts: parts[:partNum-1],
        },
    })
    if err != nil {
        return fmt.Errorf("complete multipart upload: %w", err)
    }

    return nil
}

func (s *S3Storage) abortMultipartUpload(ctx context.Context, bucket, key, uploadID string) {
    s.client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
        Bucket:   aws.String(bucket),
        Key:      aws.String(key),
        UploadId: aws.String(uploadID),
    })
}

// ParallelDownload 并行下载大文件
func (s *S3Storage) ParallelDownload(ctx context.Context, bucket, key string) ([]byte, error) {
    // 获取对象大小
    headResp, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
        Bucket: aws.String(bucket),
        Key:    aws.String(key),
    })
    if err != nil {
        return nil, fmt.Errorf("head object: %w", err)
    }

    size := *headResp.ContentLength
    if size <= s.partSize {
        // 小文件直接下载
        return s.Download(ctx, bucket, key)
    }

    // 并行下载
    data := make([]byte, size)
    numParts := (size + s.partSize - 1) / s.partSize

    var wg sync.WaitGroup
    errCh := make(chan error, numParts)
    sem := make(chan struct{}, s.concurrency)

    for i := int64(0); i < numParts; i++ {
        start := i * s.partSize
        end := start + s.partSize - 1
        if end >= size {
            end = size - 1
        }

        wg.Add(1)
        go func(s3Start, s3End int64) {
            defer wg.Done()
            sem <- struct{}{}
            defer func() { <-sem }()

            rangeStr := fmt.Sprintf("bytes=%d-%d", s3Start, s3End)

            resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
                Bucket: aws.String(bucket),
                Key:    aws.String(key),
                Range:  aws.String(rangeStr),
            })
            if err != nil {
                errCh <- fmt.Errorf("get range %s: %w", rangeStr, err)
                return
            }
            defer resp.Body.Close()

            _, err = io.ReadFull(resp.Body, data[s3Start:s3End+1])
            if err != nil {
                errCh <- fmt.Errorf("read range %s: %w", rangeStr, err)
            }
        }(start, end)
    }

    wg.Wait()
    close(errCh)

    for err := range errCh {
        if err != nil {
            return nil, err
        }
    }

    return data, nil
}

// Download 下载对象
func (s *S3Storage) Download(ctx context.Context, bucket, key string) ([]byte, error) {
    resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
        Bucket: aws.String(bucket),
        Key:    aws.String(key),
    })
    if err != nil {
        return nil, fmt.Errorf("get object: %w", err)
    }
    defer resp.Body.Close()

    return io.ReadAll(resp.Body)
}

// Upload 上传对象
func (s *S3Storage) Upload(ctx context.Context, bucket, key string, data []byte) error {
    _, err := s.client.PutObject(ctx, &s3.PutObjectInput{
        Bucket: aws.String(bucket),
        Key:    aws.String(key),
        Body:   bytes.NewReader(data),
    })
    return err
}

// List 列出对象
func (s *S3Storage) List(ctx context.Context, bucket, prefix string) ([]ObjectInfo, error) {
    var objects []ObjectInfo

    paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
        Bucket: aws.String(bucket),
        Prefix: aws.String(prefix),
    })

    for paginator.HasMorePages() {
        page, err := paginator.NextPage(ctx)
        if err != nil {
            return nil, fmt.Errorf("list objects: %w", err)
        }

        for _, obj := range page.Contents {
            objects = append(objects, ObjectInfo{
                Key:          *obj.Key,
                Size:         *obj.Size,
                LastModified: *obj.LastModified,
                ETag:         *obj.ETag,
                StorageClass: string(obj.StorageClass),
            })
        }
    }

    return objects, nil
}

// Delete 删除对象
func (s *S3Storage) Delete(ctx context.Context, bucket, key string) error {
    _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
        Bucket: aws.String(bucket),
        Key:    aws.String(key),
    })
    return err
}

// Copy 复制对象
func (s *S3Storage) Copy(ctx context.Context, srcBucket, srcKey, dstBucket, dstKey string) error {
    _, err := s.client.CopyObject(ctx, &s3.CopyObjectInput{
        Bucket:     aws.String(dstBucket),
        Key:        aws.String(dstKey),
        CopySource: aws.String(fmt.Sprintf("%s/%s", srcBucket, srcKey)),
    })
    return err
}

检查点管理系统

检查点生命周期

┌─────────────────────────────────────────────────────────────────┐
│                    检查点生命周期管理                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  创建阶段:                                                       │
│  ┌────────┐    ┌────────┐    ┌────────┐    ┌────────┐          │
│  │ 触发    │───▶│ 序列化  │───▶│ 压缩    │───▶│ 上传    │          │
│  │ 保存    │    │ 状态    │    │ (可选)  │    │ 存储    │          │
│  └────────┘    └────────┘    └────────┘    └────────┘          │
│                                                                  │
│  存储阶段:                                                       │
│  ┌────────────────────────────────────────────────────────┐    │
│  │                                                         │    │
│  │   Hot Storage    Warm Storage    Cold Storage          │    │
│  │   (最新 3 个)     (近期 10 个)    (里程碑)              │    │
│  │   ┌───┐          ┌───┐          ┌───┐                 │    │
│  │   │NVMe│    ───▶  │PFS│    ───▶  │S3 │                 │    │
│  │   └───┘          └───┘          └───┘                 │    │
│  │                                                         │    │
│  └────────────────────────────────────────────────────────┘    │
│                                                                  │
│  恢复阶段:                                                       │
│  ┌────────┐    ┌────────┐    ┌────────┐    ┌────────┐          │
│  │ 定位    │───▶│ 下载    │───▶│ 解压    │───▶│ 反序列  │          │
│  │ 检查点  │    │ 数据    │    │ (可选)  │    │ 化      │          │
│  └────────┘    └────────┘    └────────┘    └────────┘          │
│                                                                  │
│  清理阶段:                                                       │
│  ┌────────┐    ┌────────┐    ┌────────┐                        │
│  │ 策略    │───▶│ 标记    │───▶│ 删除    │                        │
│  │ 评估    │    │ 过期    │    │ 清理    │                        │
│  └────────┘    └────────┘    └────────┘                        │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

检查点管理器实现

// checkpoint_manager.go
package checkpoint

import (
    "context"
    "encoding/json"
    "fmt"
    "path/filepath"
    "sort"
    "sync"
    "time"
)

// CheckpointManager 检查点管理器
type CheckpointManager struct {
    // 存储后端
    hotStorage  Storage // 本地/NVMe
    warmStorage Storage // 并行文件系统
    coldStorage Storage // 对象存储

    // 元数据存储
    metaStore MetadataStore

    // 配置
    config CheckpointConfig

    // 运行时状态
    mu             sync.RWMutex
    activeUploads  map[string]*UploadTask

    // 指标
    metrics *CheckpointMetrics
}

// CheckpointConfig 检查点配置
type CheckpointConfig struct {
    // 保存策略
    SaveInterval        int           // 每 N 步保存
    SaveIntervalMinutes int           // 或每 N 分钟保存

    // 保留策略
    HotRetention  int // Hot 存储保留数量
    WarmRetention int // Warm 存储保留数量
    ColdRetention int // Cold 存储保留数量(里程碑)

    // 存储路径
    HotPath  string
    WarmPath string
    ColdBucket string

    // 性能配置
    AsyncUpload     bool
    CompressionAlgo string // "none", "lz4", "zstd"
    ParallelIO      bool

    // 数据完整性
    VerifyChecksum  bool

    // 生命周期
    MilestoneSteps []int // 标记为里程碑的步数
}

// Checkpoint 检查点元数据
type Checkpoint struct {
    ID        string    `json:"id"`
    JobID     string    `json:"job_id"`
    Step      int64     `json:"step"`
    Epoch     int       `json:"epoch"`
    CreatedAt time.Time `json:"created_at"`

    // 存储位置
    StorageTier StorageTier `json:"storage_tier"`
    Path        string      `json:"path"`

    // 文件信息
    Files       []CheckpointFile `json:"files"`
    TotalSize   int64            `json:"total_size"`
    Checksum    string           `json:"checksum"`

    // 训练状态
    TrainingState TrainingState `json:"training_state"`

    // 标签
    IsMilestone bool              `json:"is_milestone"`
    Tags        map[string]string `json:"tags"`
}

// CheckpointFile 检查点文件
type CheckpointFile struct {
    Name     string `json:"name"`
    Size     int64  `json:"size"`
    Checksum string `json:"checksum"`
    Type     string `json:"type"` // "model", "optimizer", "scheduler", "rng"
}

// TrainingState 训练状态
type TrainingState struct {
    GlobalStep    int64             `json:"global_step"`
    Epoch         int               `json:"epoch"`
    LearningRate  float64           `json:"learning_rate"`
    Loss          float64           `json:"loss"`
    Metrics       map[string]float64 `json:"metrics"`
    RandomState   []byte            `json:"random_state"`
}

type StorageTier string

const (
    StorageTierHot  StorageTier = "hot"
    StorageTierWarm StorageTier = "warm"
    StorageTierCold StorageTier = "cold"
)

// NewCheckpointManager 创建检查点管理器
func NewCheckpointManager(
    hotStorage, warmStorage, coldStorage Storage,
    metaStore MetadataStore,
    config CheckpointConfig,
) *CheckpointManager {
    return &CheckpointManager{
        hotStorage:    hotStorage,
        warmStorage:   warmStorage,
        coldStorage:   coldStorage,
        metaStore:     metaStore,
        config:        config,
        activeUploads: make(map[string]*UploadTask),
        metrics:       NewCheckpointMetrics(),
    }
}

// Save 保存检查点
func (m *CheckpointManager) Save(
    ctx context.Context,
    jobID string,
    step int64,
    epoch int,
    state *TrainingState,
    data map[string][]byte, // 文件名 -> 数据
) (*Checkpoint, error) {

    startTime := time.Now()

    // 创建检查点元数据
    checkpoint := &Checkpoint{
        ID:            fmt.Sprintf("%s-step-%d", jobID, step),
        JobID:         jobID,
        Step:          step,
        Epoch:         epoch,
        CreatedAt:     time.Now(),
        StorageTier:   StorageTierHot,
        TrainingState: *state,
        IsMilestone:   m.isMilestone(step),
        Tags:          make(map[string]string),
    }

    // 计算检查点路径
    checkpoint.Path = m.checkpointPath(jobID, step)

    // 处理每个文件
    var files []CheckpointFile
    var totalSize int64

    for name, content := range data {
        // 可选压缩
        processedData := content
        if m.config.CompressionAlgo != "none" {
            var err error
            processedData, err = m.compress(content)
            if err != nil {
                return nil, fmt.Errorf("compress %s: %w", name, err)
            }
        }

        // 计算校验和
        checksum := m.calculateChecksum(processedData)

        // 保存到 Hot 存储
        filePath := filepath.Join(checkpoint.Path, name)
        if err := m.hotStorage.Write(ctx, filePath, processedData); err != nil {
            return nil, fmt.Errorf("write %s: %w", name, err)
        }

        files = append(files, CheckpointFile{
            Name:     name,
            Size:     int64(len(processedData)),
            Checksum: checksum,
            Type:     m.inferFileType(name),
        })
        totalSize += int64(len(processedData))
    }

    checkpoint.Files = files
    checkpoint.TotalSize = totalSize
    checkpoint.Checksum = m.calculateCheckpointChecksum(files)

    // 保存元数据
    if err := m.metaStore.SaveCheckpoint(ctx, checkpoint); err != nil {
        return nil, fmt.Errorf("save metadata: %w", err)
    }

    // 记录指标
    m.metrics.RecordSave(totalSize, time.Since(startTime))

    // 异步迁移到 Warm 存储
    if m.config.AsyncUpload {
        go m.migrateToWarm(context.Background(), checkpoint)
    }

    // 触发清理
    go m.cleanup(context.Background(), jobID)

    return checkpoint, nil
}

// Load 加载检查点
func (m *CheckpointManager) Load(
    ctx context.Context,
    checkpointID string,
) (map[string][]byte, *TrainingState, error) {

    startTime := time.Now()

    // 获取元数据
    checkpoint, err := m.metaStore.GetCheckpoint(ctx, checkpointID)
    if err != nil {
        return nil, nil, fmt.Errorf("get checkpoint metadata: %w", err)
    }

    // 选择存储后端
    storage := m.getStorage(checkpoint.StorageTier)

    // 加载所有文件
    data := make(map[string][]byte)

    for _, file := range checkpoint.Files {
        filePath := filepath.Join(checkpoint.Path, file.Name)

        content, err := storage.Read(ctx, filePath)
        if err != nil {
            return nil, nil, fmt.Errorf("read %s: %w", file.Name, err)
        }

        // 验证校验和
        if m.config.VerifyChecksum {
            if checksum := m.calculateChecksum(content); checksum != file.Checksum {
                return nil, nil, fmt.Errorf("checksum mismatch for %s", file.Name)
            }
        }

        // 解压
        if m.config.CompressionAlgo != "none" {
            content, err = m.decompress(content)
            if err != nil {
                return nil, nil, fmt.Errorf("decompress %s: %w", file.Name, err)
            }
        }

        data[file.Name] = content
    }

    // 记录指标
    m.metrics.RecordLoad(checkpoint.TotalSize, time.Since(startTime))

    return data, &checkpoint.TrainingState, nil
}

// LoadLatest 加载最新检查点
func (m *CheckpointManager) LoadLatest(
    ctx context.Context,
    jobID string,
) (map[string][]byte, *TrainingState, error) {

    // 获取最新检查点
    checkpoints, err := m.metaStore.ListCheckpoints(ctx, jobID)
    if err != nil {
        return nil, nil, fmt.Errorf("list checkpoints: %w", err)
    }

    if len(checkpoints) == 0 {
        return nil, nil, fmt.Errorf("no checkpoints found for job %s", jobID)
    }

    // 按步数排序,取最新
    sort.Slice(checkpoints, func(i, j int) bool {
        return checkpoints[i].Step > checkpoints[j].Step
    })

    return m.Load(ctx, checkpoints[0].ID)
}

// migrateToWarm 迁移到 Warm 存储
func (m *CheckpointManager) migrateToWarm(ctx context.Context, checkpoint *Checkpoint) {
    // 检查是否已在迁移
    m.mu.Lock()
    if _, exists := m.activeUploads[checkpoint.ID]; exists {
        m.mu.Unlock()
        return
    }
    task := &UploadTask{
        CheckpointID: checkpoint.ID,
        StartTime:    time.Now(),
        Status:       "in_progress",
    }
    m.activeUploads[checkpoint.ID] = task
    m.mu.Unlock()

    defer func() {
        m.mu.Lock()
        delete(m.activeUploads, checkpoint.ID)
        m.mu.Unlock()
    }()

    // 复制文件到 Warm 存储
    for _, file := range checkpoint.Files {
        srcPath := filepath.Join(checkpoint.Path, file.Name)
        dstPath := filepath.Join(m.config.WarmPath, checkpoint.JobID, checkpoint.ID, file.Name)

        data, err := m.hotStorage.Read(ctx, srcPath)
        if err != nil {
            m.metrics.RecordMigrationError("hot_to_warm")
            return
        }

        if err := m.warmStorage.Write(ctx, dstPath, data); err != nil {
            m.metrics.RecordMigrationError("hot_to_warm")
            return
        }
    }

    // 更新元数据
    checkpoint.StorageTier = StorageTierWarm
    checkpoint.Path = filepath.Join(m.config.WarmPath, checkpoint.JobID, checkpoint.ID)
    m.metaStore.SaveCheckpoint(ctx, checkpoint)

    // 如果是里程碑,还要迁移到 Cold 存储
    if checkpoint.IsMilestone {
        m.migrateToCold(ctx, checkpoint)
    }
}

// migrateToCold 迁移到 Cold 存储
func (m *CheckpointManager) migrateToCold(ctx context.Context, checkpoint *Checkpoint) {
    // 上传到对象存储
    for _, file := range checkpoint.Files {
        srcPath := filepath.Join(checkpoint.Path, file.Name)
        dstKey := fmt.Sprintf("%s/%s/%s", checkpoint.JobID, checkpoint.ID, file.Name)

        data, err := m.warmStorage.Read(ctx, srcPath)
        if err != nil {
            m.metrics.RecordMigrationError("warm_to_cold")
            return
        }

        if err := m.coldStorage.Write(ctx, dstKey, data); err != nil {
            m.metrics.RecordMigrationError("warm_to_cold")
            return
        }
    }

    // 更新元数据,记录 Cold 存储位置
    checkpoint.Tags["cold_path"] = fmt.Sprintf("s3://%s/%s/%s",
        m.config.ColdBucket, checkpoint.JobID, checkpoint.ID)
    m.metaStore.SaveCheckpoint(ctx, checkpoint)
}

// cleanup 清理过期检查点
func (m *CheckpointManager) cleanup(ctx context.Context, jobID string) {
    checkpoints, err := m.metaStore.ListCheckpoints(ctx, jobID)
    if err != nil {
        return
    }

    // 按存储层分组
    hotCheckpoints := filterByTier(checkpoints, StorageTierHot)
    warmCheckpoints := filterByTier(checkpoints, StorageTierWarm)

    // 按步数排序(降序)
    sort.Slice(hotCheckpoints, func(i, j int) bool {
        return hotCheckpoints[i].Step > hotCheckpoints[j].Step
    })
    sort.Slice(warmCheckpoints, func(i, j int) bool {
        return warmCheckpoints[i].Step > warmCheckpoints[j].Step
    })

    // 清理 Hot 存储中超出保留数量的检查点
    if len(hotCheckpoints) > m.config.HotRetention {
        for _, cp := range hotCheckpoints[m.config.HotRetention:] {
            m.deleteCheckpoint(ctx, cp, StorageTierHot)
        }
    }

    // 清理 Warm 存储中超出保留数量的非里程碑检查点
    nonMilestoneWarm := filterNonMilestone(warmCheckpoints)
    if len(nonMilestoneWarm) > m.config.WarmRetention {
        for _, cp := range nonMilestoneWarm[m.config.WarmRetention:] {
            m.deleteCheckpoint(ctx, cp, StorageTierWarm)
        }
    }
}

func (m *CheckpointManager) deleteCheckpoint(ctx context.Context, cp *Checkpoint, tier StorageTier) {
    storage := m.getStorage(tier)

    for _, file := range cp.Files {
        filePath := filepath.Join(cp.Path, file.Name)
        storage.Delete(ctx, filePath)
    }

    // 如果所有存储层都删除了,才删除元数据
    // 这里简化处理,实际需要跟踪每个层的删除状态
}

func (m *CheckpointManager) getStorage(tier StorageTier) Storage {
    switch tier {
    case StorageTierHot:
        return m.hotStorage
    case StorageTierWarm:
        return m.warmStorage
    case StorageTierCold:
        return m.coldStorage
    default:
        return m.hotStorage
    }
}

func (m *CheckpointManager) isMilestone(step int64) bool {
    for _, ms := range m.config.MilestoneSteps {
        if int64(ms) == step {
            return true
        }
    }
    // 或者每 N 步自动标记
    return step%10000 == 0
}

func (m *CheckpointManager) checkpointPath(jobID string, step int64) string {
    return filepath.Join(m.config.HotPath, jobID, fmt.Sprintf("step-%d", step))
}

func (m *CheckpointManager) compress(data []byte) ([]byte, error) {
    // 实现压缩逻辑
    return data, nil
}

func (m *CheckpointManager) decompress(data []byte) ([]byte, error) {
    // 实现解压逻辑
    return data, nil
}

func (m *CheckpointManager) calculateChecksum(data []byte) string {
    // 实现校验和计算
    return ""
}

func (m *CheckpointManager) calculateCheckpointChecksum(files []CheckpointFile) string {
    // 计算整个检查点的校验和
    return ""
}

func (m *CheckpointManager) inferFileType(name string) string {
    switch {
    case name == "model.pt" || name == "pytorch_model.bin":
        return "model"
    case name == "optimizer.pt":
        return "optimizer"
    case name == "scheduler.pt":
        return "scheduler"
    case name == "rng_state.pt":
        return "rng"
    default:
        return "other"
    }
}

type UploadTask struct {
    CheckpointID string
    StartTime    time.Time
    Status       string
}

func filterByTier(checkpoints []*Checkpoint, tier StorageTier) []*Checkpoint {
    var result []*Checkpoint
    for _, cp := range checkpoints {
        if cp.StorageTier == tier {
            result = append(result, cp)
        }
    }
    return result
}

func filterNonMilestone(checkpoints []*Checkpoint) []*Checkpoint {
    var result []*Checkpoint
    for _, cp := range checkpoints {
        if !cp.IsMilestone {
            result = append(result, cp)
        }
    }
    return result
}

分布式检查点

// distributed_checkpoint.go
package checkpoint

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// DistributedCheckpointManager 分布式检查点管理器
type DistributedCheckpointManager struct {
    // 本地检查点管理器
    localManager *CheckpointManager

    // 分布式协调
    coordinator Coordinator

    // 当前进程信息
    rank      int
    worldSize int

    // 配置
    config DistributedCheckpointConfig
}

// DistributedCheckpointConfig 分布式配置
type DistributedCheckpointConfig struct {
    // 分片策略
    ShardingStrategy string // "full", "sharded", "fsdp"

    // 聚合配置
    AggregateOnSave   bool // 保存时聚合到 rank 0
    AggregateOnLoad   bool // 加载时从聚合文件分发

    // 并行 IO
    ParallelRanks     int  // 并行保存的 rank 数

    // 超时
    BarrierTimeout    time.Duration
}

// DistributedCheckpoint 分布式检查点
type DistributedCheckpoint struct {
    *Checkpoint

    // 分片信息
    Shards []ShardInfo `json:"shards"`

    // 分布式状态
    WorldSize int `json:"world_size"`
}

// ShardInfo 分片信息
type ShardInfo struct {
    Rank       int      `json:"rank"`
    Files      []string `json:"files"`
    TotalSize  int64    `json:"total_size"`
    NodeID     string   `json:"node_id"`
}

// Coordinator 分布式协调接口
type Coordinator interface {
    Barrier(ctx context.Context, name string) error
    Broadcast(ctx context.Context, data []byte, root int) ([]byte, error)
    Gather(ctx context.Context, data []byte, root int) ([][]byte, error)
    AllGather(ctx context.Context, data []byte) ([][]byte, error)
}

// NewDistributedCheckpointManager 创建分布式检查点管理器
func NewDistributedCheckpointManager(
    localManager *CheckpointManager,
    coordinator Coordinator,
    rank, worldSize int,
    config DistributedCheckpointConfig,
) *DistributedCheckpointManager {
    return &DistributedCheckpointManager{
        localManager: localManager,
        coordinator:  coordinator,
        rank:         rank,
        worldSize:    worldSize,
        config:       config,
    }
}

// Save 分布式保存检查点
func (m *DistributedCheckpointManager) Save(
    ctx context.Context,
    jobID string,
    step int64,
    epoch int,
    state *TrainingState,
    localData map[string][]byte, // 本 rank 的数据
) (*DistributedCheckpoint, error) {

    // 阶段 1:准备(所有 rank 同步)
    if err := m.coordinator.Barrier(ctx, fmt.Sprintf("ckpt-prepare-%d", step)); err != nil {
        return nil, fmt.Errorf("barrier at prepare: %w", err)
    }

    var checkpoint *DistributedCheckpoint
    var saveErr error

    switch m.config.ShardingStrategy {
    case "full":
        // 每个 rank 保存完整检查点
        checkpoint, saveErr = m.saveFullReplicated(ctx, jobID, step, epoch, state, localData)

    case "sharded":
        // 每个 rank 保存自己的分片
        checkpoint, saveErr = m.saveSharded(ctx, jobID, step, epoch, state, localData)

    case "fsdp":
        // FSDP 风格:参数分片 + 优化器分片
        checkpoint, saveErr = m.saveFSDP(ctx, jobID, step, epoch, state, localData)
    }

    if saveErr != nil {
        return nil, saveErr
    }

    // 阶段 2:完成(所有 rank 同步)
    if err := m.coordinator.Barrier(ctx, fmt.Sprintf("ckpt-complete-%d", step)); err != nil {
        return nil, fmt.Errorf("barrier at complete: %w", err)
    }

    return checkpoint, nil
}

// saveSharded 分片保存
func (m *DistributedCheckpointManager) saveSharded(
    ctx context.Context,
    jobID string,
    step int64,
    epoch int,
    state *TrainingState,
    localData map[string][]byte,
) (*DistributedCheckpoint, error) {

    // 每个 rank 保存自己的分片
    // 使用 rank 作为文件名前缀
    shardedData := make(map[string][]byte)
    for name, data := range localData {
        shardedName := fmt.Sprintf("rank_%d_%s", m.rank, name)
        shardedData[shardedName] = data
    }

    // 调用本地管理器保存
    localCheckpoint, err := m.localManager.Save(ctx, jobID, step, epoch, state, shardedData)
    if err != nil {
        return nil, fmt.Errorf("save local shard: %w", err)
    }

    // 收集所有 rank 的分片信息
    shardInfo := ShardInfo{
        Rank:      m.rank,
        Files:     getFileNames(shardedData),
        TotalSize: calculateTotalSize(shardedData),
        NodeID:    getNodeID(),
    }

    shardInfoBytes, _ := json.Marshal(shardInfo)
    allShardInfoBytes, err := m.coordinator.AllGather(ctx, shardInfoBytes)
    if err != nil {
        return nil, fmt.Errorf("gather shard info: %w", err)
    }

    // 解析所有分片信息
    var allShards []ShardInfo
    for _, b := range allShardInfoBytes {
        var si ShardInfo
        json.Unmarshal(b, &si)
        allShards = append(allShards, si)
    }

    return &DistributedCheckpoint{
        Checkpoint: localCheckpoint,
        Shards:     allShards,
        WorldSize:  m.worldSize,
    }, nil
}

// saveFSDP FSDP 风格保存
func (m *DistributedCheckpointManager) saveFSDP(
    ctx context.Context,
    jobID string,
    step int64,
    epoch int,
    state *TrainingState,
    localData map[string][]byte,
) (*DistributedCheckpoint, error) {

    // FSDP 保存策略:
    // 1. 模型参数:每个 rank 保存自己的分片,或聚合后保存
    // 2. 优化器状态:每个 rank 保存自己的分片
    // 3. 其他状态:只在 rank 0 保存

    shardedData := make(map[string][]byte)

    for name, data := range localData {
        switch {
        case isModelParam(name):
            if m.config.AggregateOnSave {
                // 聚合模型参数到 rank 0
                aggregated, err := m.aggregateParams(ctx, data)
                if err != nil {
                    return nil, err
                }
                if m.rank == 0 {
                    shardedData[name] = aggregated
                }
            } else {
                // 分片保存
                shardedData[fmt.Sprintf("model_shard_%d.pt", m.rank)] = data
            }

        case isOptimizerState(name):
            // 优化器状态总是分片保存
            shardedData[fmt.Sprintf("optim_shard_%d.pt", m.rank)] = data

        default:
            // 其他状态只在 rank 0 保存
            if m.rank == 0 {
                shardedData[name] = data
            }
        }
    }

    return m.saveSharded(ctx, jobID, step, epoch, state, shardedData)
}

// saveFullReplicated 完全复制保存(每个 rank 保存完整)
func (m *DistributedCheckpointManager) saveFullReplicated(
    ctx context.Context,
    jobID string,
    step int64,
    epoch int,
    state *TrainingState,
    localData map[string][]byte,
) (*DistributedCheckpoint, error) {

    // 控制并发度:只有部分 rank 同时保存
    if m.rank < m.config.ParallelRanks {
        localCheckpoint, err := m.localManager.Save(ctx, jobID, step, epoch, state, localData)
        if err != nil {
            return nil, err
        }

        return &DistributedCheckpoint{
            Checkpoint: localCheckpoint,
            Shards: []ShardInfo{{
                Rank:      m.rank,
                TotalSize: localCheckpoint.TotalSize,
            }},
            WorldSize: m.worldSize,
        }, nil
    }

    // 其他 rank 等待
    return &DistributedCheckpoint{}, nil
}

// Load 分布式加载检查点
func (m *DistributedCheckpointManager) Load(
    ctx context.Context,
    checkpointID string,
) (map[string][]byte, *TrainingState, error) {

    // 阶段 1:准备
    if err := m.coordinator.Barrier(ctx, fmt.Sprintf("load-prepare-%s", checkpointID)); err != nil {
        return nil, nil, err
    }

    var data map[string][]byte
    var state *TrainingState
    var loadErr error

    switch m.config.ShardingStrategy {
    case "sharded", "fsdp":
        data, state, loadErr = m.loadSharded(ctx, checkpointID)
    case "full":
        data, state, loadErr = m.localManager.Load(ctx, checkpointID)
    }

    if loadErr != nil {
        return nil, nil, loadErr
    }

    // 阶段 2:完成
    if err := m.coordinator.Barrier(ctx, fmt.Sprintf("load-complete-%s", checkpointID)); err != nil {
        return nil, nil, err
    }

    return data, state, nil
}

// loadSharded 加载分片检查点
func (m *DistributedCheckpointManager) loadSharded(
    ctx context.Context,
    checkpointID string,
) (map[string][]byte, *TrainingState, error) {

    // 获取元数据
    metadata, err := m.localManager.metaStore.GetCheckpoint(ctx, checkpointID)
    if err != nil {
        return nil, nil, err
    }

    // 加载本 rank 的分片
    data := make(map[string][]byte)

    for _, file := range metadata.Files {
        // 检查是否是本 rank 的文件
        if isRankFile(file.Name, m.rank) || isSharedFile(file.Name) {
            content, err := m.localManager.hotStorage.Read(ctx,
                fmt.Sprintf("%s/%s", metadata.Path, file.Name))
            if err != nil {
                // 尝试从 warm 存储读取
                content, err = m.localManager.warmStorage.Read(ctx,
                    fmt.Sprintf("%s/%s", metadata.Path, file.Name))
                if err != nil {
                    return nil, nil, err
                }
            }

            // 去除 rank 前缀
            cleanName := removeRankPrefix(file.Name)
            data[cleanName] = content
        }
    }

    return data, &metadata.TrainingState, nil
}

// aggregateParams 聚合参数到 rank 0
func (m *DistributedCheckpointManager) aggregateParams(ctx context.Context, localParams []byte) ([]byte, error) {
    // 使用 Gather 收集到 rank 0
    allParams, err := m.coordinator.Gather(ctx, localParams, 0)
    if err != nil {
        return nil, err
    }

    if m.rank != 0 {
        return nil, nil
    }

    // 在 rank 0 上拼接所有参数
    var totalSize int
    for _, p := range allParams {
        totalSize += len(p)
    }

    result := make([]byte, 0, totalSize)
    for _, p := range allParams {
        result = append(result, p...)
    }

    return result, nil
}

// 辅助函数
func getFileNames(data map[string][]byte) []string {
    names := make([]string, 0, len(data))
    for name := range data {
        names = append(names, name)
    }
    return names
}

func calculateTotalSize(data map[string][]byte) int64 {
    var total int64
    for _, d := range data {
        total += int64(len(d))
    }
    return total
}

func getNodeID() string {
    // 返回节点标识
    return "node-1"
}

func isModelParam(name string) bool {
    return name == "model.pt" || name == "pytorch_model.bin"
}

func isOptimizerState(name string) bool {
    return name == "optimizer.pt"
}

func isRankFile(name string, rank int) bool {
    return fmt.Sprintf("rank_%d_", rank) == name[:len(fmt.Sprintf("rank_%d_", rank))]
}

func isSharedFile(name string) bool {
    // 共享文件不带 rank 前缀
    return len(name) > 0 && name[0] != 'r'
}

func removeRankPrefix(name string) string {
    // 去除 "rank_N_" 前缀
    for i := 0; i < len(name); i++ {
        if name[i] == '_' && i > 5 {
            return name[i+1:]
        }
    }
    return name
}

模型版本管理

模型注册中心

// model_registry.go
package registry

import (
    "context"
    "fmt"
    "sort"
    "time"
)

// ModelRegistry 模型注册中心
type ModelRegistry struct {
    store     MetadataStore
    storage   ModelStorage
    validator ModelValidator
}

// Model 模型定义
type Model struct {
    ID          string            `json:"id"`
    Name        string            `json:"name"`
    Description string            `json:"description"`
    Owner       string            `json:"owner"`
    Tags        []string          `json:"tags"`
    Labels      map[string]string `json:"labels"`
    CreatedAt   time.Time         `json:"created_at"`
    UpdatedAt   time.Time         `json:"updated_at"`
}

// ModelVersion 模型版本
type ModelVersion struct {
    ID          string            `json:"id"`
    ModelID     string            `json:"model_id"`
    Version     string            `json:"version"`       // 语义版本号
    Stage       ModelStage        `json:"stage"`         // 生命周期阶段

    // 来源
    Source      ModelSource       `json:"source"`

    // 存储
    Artifacts   []Artifact        `json:"artifacts"`

    // 元数据
    Framework   string            `json:"framework"`     // pytorch, tensorflow, etc.
    Architecture string           `json:"architecture"`  // transformer, cnn, etc.
    Parameters  int64             `json:"parameters"`    // 参数量

    // 训练信息
    TrainingJob string            `json:"training_job"`
    Checkpoint  string            `json:"checkpoint"`
    Metrics     map[string]float64 `json:"metrics"`

    // 签名(输入输出格式)
    Signature   ModelSignature    `json:"signature"`

    // 审计
    CreatedBy   string            `json:"created_by"`
    CreatedAt   time.Time         `json:"created_at"`
    ApprovedBy  string            `json:"approved_by"`
    ApprovedAt  *time.Time        `json:"approved_at"`
}

// ModelStage 模型阶段
type ModelStage string

const (
    StageNone       ModelStage = "none"
    StageDevelopment ModelStage = "development"
    StageStaging    ModelStage = "staging"
    StageProduction ModelStage = "production"
    StageArchived   ModelStage = "archived"
)

// ModelSource 模型来源
type ModelSource struct {
    Type        string `json:"type"` // "training", "import", "conversion"
    TrainingRun string `json:"training_run,omitempty"`
    ImportPath  string `json:"import_path,omitempty"`
    SourceModel string `json:"source_model,omitempty"` // 转换来源
}

// Artifact 模型产物
type Artifact struct {
    Name        string `json:"name"`
    Path        string `json:"path"`
    Size        int64  `json:"size"`
    Checksum    string `json:"checksum"`
    Format      string `json:"format"` // pytorch, safetensors, onnx, etc.
}

// ModelSignature 模型签名
type ModelSignature struct {
    Inputs  []TensorSpec `json:"inputs"`
    Outputs []TensorSpec `json:"outputs"`
}

// TensorSpec 张量规格
type TensorSpec struct {
    Name   string  `json:"name"`
    Dtype  string  `json:"dtype"`
    Shape  []int   `json:"shape"`  // -1 表示动态维度
}

// NewModelRegistry 创建模型注册中心
func NewModelRegistry(store MetadataStore, storage ModelStorage) *ModelRegistry {
    return &ModelRegistry{
        store:   store,
        storage: storage,
    }
}

// CreateModel 创建模型
func (r *ModelRegistry) CreateModel(ctx context.Context, model *Model) error {
    model.ID = generateID()
    model.CreatedAt = time.Now()
    model.UpdatedAt = time.Now()

    return r.store.CreateModel(ctx, model)
}

// RegisterVersion 注册新版本
func (r *ModelRegistry) RegisterVersion(
    ctx context.Context,
    modelID string,
    version *ModelVersion,
) error {
    // 验证模型存在
    model, err := r.store.GetModel(ctx, modelID)
    if err != nil {
        return fmt.Errorf("model not found: %w", err)
    }

    // 验证版本号唯一
    existing, _ := r.store.GetVersionByNumber(ctx, modelID, version.Version)
    if existing != nil {
        return fmt.Errorf("version %s already exists", version.Version)
    }

    // 设置默认值
    version.ID = generateID()
    version.ModelID = model.ID
    version.Stage = StageNone
    version.CreatedAt = time.Now()

    // 验证产物
    for i, artifact := range version.Artifacts {
        // 验证文件存在
        exists, err := r.storage.Exists(ctx, artifact.Path)
        if err != nil || !exists {
            return fmt.Errorf("artifact not found: %s", artifact.Path)
        }

        // 获取文件信息
        info, err := r.storage.Stat(ctx, artifact.Path)
        if err != nil {
            return err
        }

        version.Artifacts[i].Size = info.Size

        // 计算校验和
        if artifact.Checksum == "" {
            checksum, err := r.storage.Checksum(ctx, artifact.Path)
            if err != nil {
                return err
            }
            version.Artifacts[i].Checksum = checksum
        }
    }

    // 保存版本
    return r.store.CreateVersion(ctx, version)
}

// TransitionStage 转换阶段
func (r *ModelRegistry) TransitionStage(
    ctx context.Context,
    modelID string,
    version string,
    targetStage ModelStage,
    approver string,
) error {
    // 获取版本
    v, err := r.store.GetVersionByNumber(ctx, modelID, version)
    if err != nil {
        return err
    }

    // 验证转换规则
    if !isValidTransition(v.Stage, targetStage) {
        return fmt.Errorf("invalid transition from %s to %s", v.Stage, targetStage)
    }

    // 特殊处理:升级到 Production
    if targetStage == StageProduction {
        // 降级当前 Production 版本
        currentProd, _ := r.GetProductionVersion(ctx, modelID)
        if currentProd != nil {
            currentProd.Stage = StageArchived
            r.store.UpdateVersion(ctx, currentProd)
        }

        // 记录审批信息
        now := time.Now()
        v.ApprovedBy = approver
        v.ApprovedAt = &now
    }

    v.Stage = targetStage
    return r.store.UpdateVersion(ctx, v)
}

// GetProductionVersion 获取生产版本
func (r *ModelRegistry) GetProductionVersion(ctx context.Context, modelID string) (*ModelVersion, error) {
    versions, err := r.store.ListVersions(ctx, modelID)
    if err != nil {
        return nil, err
    }

    for _, v := range versions {
        if v.Stage == StageProduction {
            return v, nil
        }
    }

    return nil, fmt.Errorf("no production version found")
}

// GetLatestVersion 获取最新版本
func (r *ModelRegistry) GetLatestVersion(ctx context.Context, modelID string) (*ModelVersion, error) {
    versions, err := r.store.ListVersions(ctx, modelID)
    if err != nil {
        return nil, err
    }

    if len(versions) == 0 {
        return nil, fmt.Errorf("no versions found")
    }

    // 按创建时间排序
    sort.Slice(versions, func(i, j int) bool {
        return versions[i].CreatedAt.After(versions[j].CreatedAt)
    })

    return versions[0], nil
}

// CompareVersions 比较版本
func (r *ModelRegistry) CompareVersions(
    ctx context.Context,
    modelID string,
    version1, version2 string,
) (*VersionComparison, error) {
    v1, err := r.store.GetVersionByNumber(ctx, modelID, version1)
    if err != nil {
        return nil, err
    }

    v2, err := r.store.GetVersionByNumber(ctx, modelID, version2)
    if err != nil {
        return nil, err
    }

    comparison := &VersionComparison{
        Version1: v1,
        Version2: v2,
        MetricsDiff: make(map[string]float64),
    }

    // 比较指标
    for metric, val1 := range v1.Metrics {
        if val2, ok := v2.Metrics[metric]; ok {
            comparison.MetricsDiff[metric] = val2 - val1
        }
    }

    // 比较参数量
    comparison.ParametersDiff = v2.Parameters - v1.Parameters

    // 比较文件大小
    comparison.SizeDiff = r.calculateTotalSize(v2.Artifacts) - r.calculateTotalSize(v1.Artifacts)

    return comparison, nil
}

type VersionComparison struct {
    Version1       *ModelVersion
    Version2       *ModelVersion
    MetricsDiff    map[string]float64
    ParametersDiff int64
    SizeDiff       int64
}

func (r *ModelRegistry) calculateTotalSize(artifacts []Artifact) int64 {
    var total int64
    for _, a := range artifacts {
        total += a.Size
    }
    return total
}

func isValidTransition(from, to ModelStage) bool {
    validTransitions := map[ModelStage][]ModelStage{
        StageNone:        {StageDevelopment, StageArchived},
        StageDevelopment: {StageStaging, StageArchived},
        StageStaging:     {StageProduction, StageDevelopment, StageArchived},
        StageProduction:  {StageArchived},
        StageArchived:    {StageDevelopment}, // 可以复活到开发阶段
    }

    allowed, ok := validTransitions[from]
    if !ok {
        return false
    }

    for _, s := range allowed {
        if s == to {
            return true
        }
    }
    return false
}

func generateID() string {
    // 生成唯一 ID
    return fmt.Sprintf("%d", time.Now().UnixNano())
}

模型血缘追踪

// model_lineage.go
package registry

import (
    "context"
    "time"
)

// LineageTracker 血缘追踪器
type LineageTracker struct {
    store LineageStore
}

// LineageNode 血缘节点
type LineageNode struct {
    ID       string      `json:"id"`
    Type     NodeType    `json:"type"`
    Name     string      `json:"name"`
    Version  string      `json:"version,omitempty"`
    Metadata interface{} `json:"metadata,omitempty"`
}

// NodeType 节点类型
type NodeType string

const (
    NodeTypeDataset    NodeType = "dataset"
    NodeTypeModel      NodeType = "model"
    NodeTypeExperiment NodeType = "experiment"
    NodeTypeCheckpoint NodeType = "checkpoint"
    NodeTypeArtifact   NodeType = "artifact"
)

// LineageEdge 血缘边
type LineageEdge struct {
    ID        string    `json:"id"`
    Source    string    `json:"source"`    // 源节点 ID
    Target    string    `json:"target"`    // 目标节点 ID
    Relation  string    `json:"relation"`  // 关系类型
    Metadata  interface{} `json:"metadata,omitempty"`
    CreatedAt time.Time `json:"created_at"`
}

// RelationType 关系类型
const (
    RelationTrainedOn    = "trained_on"     // 模型在数据集上训练
    RelationDerivedFrom  = "derived_from"   // 派生自
    RelationProducedBy   = "produced_by"    // 由实验产生
    RelationConvertedTo  = "converted_to"   // 转换为
    RelationFinetunedFrom = "finetuned_from" // 微调自
)

// NewLineageTracker 创建血缘追踪器
func NewLineageTracker(store LineageStore) *LineageTracker {
    return &LineageTracker{store: store}
}

// RecordTraining 记录训练血缘
func (t *LineageTracker) RecordTraining(
    ctx context.Context,
    modelVersion *ModelVersion,
    datasets []string,
    experiment string,
    baseModel string, // 可选,微调场景
) error {
    // 创建模型节点
    modelNode := &LineageNode{
        ID:      fmt.Sprintf("model:%s:%s", modelVersion.ModelID, modelVersion.Version),
        Type:    NodeTypeModel,
        Name:    modelVersion.ModelID,
        Version: modelVersion.Version,
        Metadata: map[string]interface{}{
            "parameters": modelVersion.Parameters,
            "framework":  modelVersion.Framework,
        },
    }
    if err := t.store.CreateNode(ctx, modelNode); err != nil {
        return err
    }

    // 创建与数据集的关系
    for _, dataset := range datasets {
        edge := &LineageEdge{
            ID:        generateID(),
            Source:    fmt.Sprintf("dataset:%s", dataset),
            Target:    modelNode.ID,
            Relation:  RelationTrainedOn,
            CreatedAt: time.Now(),
        }
        if err := t.store.CreateEdge(ctx, edge); err != nil {
            return err
        }
    }

    // 创建与实验的关系
    if experiment != "" {
        edge := &LineageEdge{
            ID:        generateID(),
            Source:    fmt.Sprintf("experiment:%s", experiment),
            Target:    modelNode.ID,
            Relation:  RelationProducedBy,
            CreatedAt: time.Now(),
        }
        if err := t.store.CreateEdge(ctx, edge); err != nil {
            return err
        }
    }

    // 微调场景:创建与基础模型的关系
    if baseModel != "" {
        edge := &LineageEdge{
            ID:        generateID(),
            Source:    fmt.Sprintf("model:%s", baseModel),
            Target:    modelNode.ID,
            Relation:  RelationFinetunedFrom,
            CreatedAt: time.Now(),
        }
        if err := t.store.CreateEdge(ctx, edge); err != nil {
            return err
        }
    }

    return nil
}

// GetUpstream 获取上游血缘
func (t *LineageTracker) GetUpstream(
    ctx context.Context,
    nodeID string,
    depth int,
) (*LineageGraph, error) {
    return t.traverse(ctx, nodeID, depth, true)
}

// GetDownstream 获取下游血缘
func (t *LineageTracker) GetDownstream(
    ctx context.Context,
    nodeID string,
    depth int,
) (*LineageGraph, error) {
    return t.traverse(ctx, nodeID, depth, false)
}

// LineageGraph 血缘图
type LineageGraph struct {
    Nodes []*LineageNode `json:"nodes"`
    Edges []*LineageEdge `json:"edges"`
}

func (t *LineageTracker) traverse(
    ctx context.Context,
    startNode string,
    maxDepth int,
    upstream bool,
) (*LineageGraph, error) {
    visited := make(map[string]bool)
    graph := &LineageGraph{
        Nodes: []*LineageNode{},
        Edges: []*LineageEdge{},
    }

    var dfs func(nodeID string, depth int) error
    dfs = func(nodeID string, depth int) error {
        if depth > maxDepth || visited[nodeID] {
            return nil
        }
        visited[nodeID] = true

        // 获取节点
        node, err := t.store.GetNode(ctx, nodeID)
        if err != nil {
            return err
        }
        graph.Nodes = append(graph.Nodes, node)

        // 获取边
        var edges []*LineageEdge
        if upstream {
            edges, err = t.store.GetIncomingEdges(ctx, nodeID)
        } else {
            edges, err = t.store.GetOutgoingEdges(ctx, nodeID)
        }
        if err != nil {
            return err
        }

        for _, edge := range edges {
            graph.Edges = append(graph.Edges, edge)

            var nextNode string
            if upstream {
                nextNode = edge.Source
            } else {
                nextNode = edge.Target
            }

            if err := dfs(nextNode, depth+1); err != nil {
                return err
            }
        }

        return nil
    }

    if err := dfs(startNode, 0); err != nil {
        return nil, err
    }

    return graph, nil
}

// FindImpact 查找影响范围
func (t *LineageTracker) FindImpact(
    ctx context.Context,
    nodeID string,
) (*ImpactAnalysis, error) {
    downstream, err := t.GetDownstream(ctx, nodeID, 10)
    if err != nil {
        return nil, err
    }

    analysis := &ImpactAnalysis{
        AffectedModels:      []*LineageNode{},
        AffectedExperiments: []*LineageNode{},
        AffectedArtifacts:   []*LineageNode{},
    }

    for _, node := range downstream.Nodes {
        switch node.Type {
        case NodeTypeModel:
            analysis.AffectedModels = append(analysis.AffectedModels, node)
        case NodeTypeExperiment:
            analysis.AffectedExperiments = append(analysis.AffectedExperiments, node)
        case NodeTypeArtifact:
            analysis.AffectedArtifacts = append(analysis.AffectedArtifacts, node)
        }
    }

    return analysis, nil
}

type ImpactAnalysis struct {
    AffectedModels      []*LineageNode
    AffectedExperiments []*LineageNode
    AffectedArtifacts   []*LineageNode
}

模型格式转换

格式转换器

// model_converter.go
package converter

import (
    "context"
    "fmt"
    "os"
    "os/exec"
    "path/filepath"
)

// ModelConverter 模型格式转换器
type ModelConverter interface {
    Convert(ctx context.Context, input, output string, opts ConvertOptions) error
    SupportedFormats() []string
}

// ConvertOptions 转换选项
type ConvertOptions struct {
    SourceFormat string
    TargetFormat string
    Precision    string // fp32, fp16, int8
    OptLevel     int    // 优化级别
    DeviceType   string // cpu, cuda
}

// PyTorchConverter PyTorch 转换器
type PyTorchConverter struct {
    pythonPath string
}

func NewPyTorchConverter(pythonPath string) *PyTorchConverter {
    return &PyTorchConverter{pythonPath: pythonPath}
}

// Convert 执行转换
func (c *PyTorchConverter) Convert(ctx context.Context, input, output string, opts ConvertOptions) error {
    switch opts.TargetFormat {
    case "safetensors":
        return c.toSafetensors(ctx, input, output)
    case "onnx":
        return c.toONNX(ctx, input, output, opts)
    case "torchscript":
        return c.toTorchScript(ctx, input, output)
    default:
        return fmt.Errorf("unsupported target format: %s", opts.TargetFormat)
    }
}

// toSafetensors 转换为 safetensors 格式
func (c *PyTorchConverter) toSafetensors(ctx context.Context, input, output string) error {
    script := `
import torch
from safetensors.torch import save_file
import sys

input_path = sys.argv[1]
output_path = sys.argv[2]

# 加载 PyTorch 模型
state_dict = torch.load(input_path, map_location='cpu')

# 如果是完整模型,提取 state_dict
if hasattr(state_dict, 'state_dict'):
    state_dict = state_dict.state_dict()
elif 'state_dict' in state_dict:
    state_dict = state_dict['state_dict']
elif 'model' in state_dict:
    state_dict = state_dict['model']

# 转换为 safetensors
save_file(state_dict, output_path)
print(f"Converted to {output_path}")
`
    return c.runPythonScript(ctx, script, input, output)
}

// toONNX 转换为 ONNX 格式
func (c *PyTorchConverter) toONNX(ctx context.Context, input, output string, opts ConvertOptions) error {
    script := fmt.Sprintf(`
import torch
import sys

input_path = sys.argv[1]
output_path = sys.argv[2]

# 加载模型
model = torch.load(input_path, map_location='cpu')
model.eval()

# 创建示例输入(需要根据模型调整)
# 这里假设是语言模型
batch_size = 1
seq_length = 128
dummy_input = torch.randint(0, 1000, (batch_size, seq_length))

# 导出 ONNX
torch.onnx.export(
    model,
    dummy_input,
    output_path,
    input_names=['input_ids'],
    output_names=['logits'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence'},
        'logits': {0: 'batch_size', 1: 'sequence'}
    },
    opset_version=14,
    do_constant_folding=True,
)
print(f"Exported to {output_path}")
`)
    return c.runPythonScript(ctx, script, input, output)
}

// toTorchScript 转换为 TorchScript
func (c *PyTorchConverter) toTorchScript(ctx context.Context, input, output string) error {
    script := `
import torch
import sys

input_path = sys.argv[1]
output_path = sys.argv[2]

# 加载模型
model = torch.load(input_path, map_location='cpu')
model.eval()

# 尝试 trace
try:
    # 创建示例输入
    dummy_input = torch.randint(0, 1000, (1, 128))
    traced = torch.jit.trace(model, dummy_input)
    traced.save(output_path)
except Exception as e:
    # 如果 trace 失败,尝试 script
    print(f"Trace failed: {e}, trying script...")
    scripted = torch.jit.script(model)
    scripted.save(output_path)

print(f"Converted to {output_path}")
`
    return c.runPythonScript(ctx, script, input, output)
}

func (c *PyTorchConverter) runPythonScript(ctx context.Context, script string, args ...string) error {
    // 创建临时脚本文件
    tmpFile, err := os.CreateTemp("", "convert_*.py")
    if err != nil {
        return err
    }
    defer os.Remove(tmpFile.Name())

    if _, err := tmpFile.WriteString(script); err != nil {
        return err
    }
    tmpFile.Close()

    // 执行脚本
    cmdArgs := append([]string{tmpFile.Name()}, args...)
    cmd := exec.CommandContext(ctx, c.pythonPath, cmdArgs...)
    cmd.Stdout = os.Stdout
    cmd.Stderr = os.Stderr

    return cmd.Run()
}

func (c *PyTorchConverter) SupportedFormats() []string {
    return []string{"pytorch", "safetensors", "onnx", "torchscript"}
}

// ConversionPipeline 转换流水线
type ConversionPipeline struct {
    converters map[string]ModelConverter
    storage    ModelStorage
}

// NewConversionPipeline 创建转换流水线
func NewConversionPipeline(storage ModelStorage) *ConversionPipeline {
    return &ConversionPipeline{
        converters: make(map[string]ModelConverter),
        storage:    storage,
    }
}

// RegisterConverter 注册转换器
func (p *ConversionPipeline) RegisterConverter(name string, converter ModelConverter) {
    p.converters[name] = converter
}

// ConvertModel 转换模型
func (p *ConversionPipeline) ConvertModel(
    ctx context.Context,
    modelID, version string,
    targetFormats []string,
) ([]Artifact, error) {

    // 获取源模型
    sourceArtifact, err := p.storage.GetArtifact(ctx, modelID, version)
    if err != nil {
        return nil, fmt.Errorf("get source artifact: %w", err)
    }

    // 下载到临时目录
    tmpDir, err := os.MkdirTemp("", "model_convert_*")
    if err != nil {
        return nil, err
    }
    defer os.RemoveAll(tmpDir)

    inputPath := filepath.Join(tmpDir, "input", sourceArtifact.Name)
    if err := p.storage.Download(ctx, sourceArtifact.Path, inputPath); err != nil {
        return nil, err
    }

    var artifacts []Artifact

    // 执行转换
    for _, format := range targetFormats {
        converter := p.findConverter(sourceArtifact.Format, format)
        if converter == nil {
            return nil, fmt.Errorf("no converter for %s -> %s", sourceArtifact.Format, format)
        }

        outputPath := filepath.Join(tmpDir, "output", fmt.Sprintf("model.%s", format))
        os.MkdirAll(filepath.Dir(outputPath), 0755)

        opts := ConvertOptions{
            SourceFormat: sourceArtifact.Format,
            TargetFormat: format,
        }

        if err := converter.Convert(ctx, inputPath, outputPath, opts); err != nil {
            return nil, fmt.Errorf("convert to %s: %w", format, err)
        }

        // 上传转换结果
        storagePath := fmt.Sprintf("%s/%s/model.%s", modelID, version, format)
        if err := p.storage.Upload(ctx, outputPath, storagePath); err != nil {
            return nil, err
        }

        // 获取文件信息
        info, _ := os.Stat(outputPath)

        artifacts = append(artifacts, Artifact{
            Name:   fmt.Sprintf("model.%s", format),
            Path:   storagePath,
            Size:   info.Size(),
            Format: format,
        })
    }

    return artifacts, nil
}

func (p *ConversionPipeline) findConverter(source, target string) ModelConverter {
    for _, converter := range p.converters {
        formats := converter.SupportedFormats()
        hasSource, hasTarget := false, false
        for _, f := range formats {
            if f == source {
                hasSource = true
            }
            if f == target {
                hasTarget = true
            }
        }
        if hasSource && hasTarget {
            return converter
        }
    }
    return nil
}

Kubernetes 集成

PVC 动态供给

# model-storage-class.yaml
apiVersion: storage.k8s.io/v1
kind: StorageClass
metadata:
  name: model-storage
provisioner: csi.juicefs.com
parameters:
  # JuiceFS 配置(适合模型存储)
  csi.storage.k8s.io/provisioner-secret-name: juicefs-secret
  csi.storage.k8s.io/provisioner-secret-namespace: kube-system
  csi.storage.k8s.io/node-publish-secret-name: juicefs-secret
  csi.storage.k8s.io/node-publish-secret-namespace: kube-system
  # 挂载选项
  mountOptions: "cache-size=10240,writeback"
reclaimPolicy: Retain
volumeBindingMode: Immediate
allowVolumeExpansion: true

---
# checkpoint-storage-class.yaml
apiVersion: storage.k8s.io/v1
kind: StorageClass
metadata:
  name: checkpoint-storage
provisioner: lustre.csi.hpe.com
parameters:
  # Lustre 配置(适合检查点)
  mgsSpec: "192.168.1.10@tcp:/lustre"
  stripeSize: "4M"
  stripeCount: "4"
reclaimPolicy: Delete
volumeBindingMode: WaitForFirstConsumer

---
# 模型存储 PVC 模板
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
  name: model-storage-{{ .JobName }}
spec:
  storageClassName: model-storage
  accessModes:
    - ReadWriteMany
  resources:
    requests:
      storage: 500Gi

---
# 检查点 PVC 模板
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
  name: checkpoint-storage-{{ .JobName }}
spec:
  storageClassName: checkpoint-storage
  accessModes:
    - ReadWriteMany
  resources:
    requests:
      storage: 1Ti

训练任务存储挂载

# training-job-with-storage.yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: llm-training
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: training:latest
            resources:
              limits:
                nvidia.com/gpu: 8
            volumeMounts:
            # 数据集(只读)
            - name: dataset
              mountPath: /data
              readOnly: true
            # 检查点(读写,高性能)
            - name: checkpoints
              mountPath: /checkpoints
            # 模型输出(读写)
            - name: models
              mountPath: /models
            # 本地缓存(NVMe)
            - name: local-cache
              mountPath: /cache
            env:
            - name: CHECKPOINT_DIR
              value: /checkpoints
            - name: MODEL_OUTPUT_DIR
              value: /models
          volumes:
          # 数据集 - 使用共享存储
          - name: dataset
            persistentVolumeClaim:
              claimName: dataset-imagenet
          # 检查点 - 使用高性能并行文件系统
          - name: checkpoints
            persistentVolumeClaim:
              claimName: checkpoint-storage-llm-training
          # 模型输出 - 使用对象存储
          - name: models
            persistentVolumeClaim:
              claimName: model-storage-llm-training
          # 本地缓存 - 使用本地 NVMe
          - name: local-cache
            emptyDir:
              medium: Memory  # 或使用 hostPath 挂载 NVMe
              sizeLimit: 100Gi
    Worker:
      replicas: 7
      template:
        # 类似配置...

模型服务挂载

# model-serving-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: model-server
spec:
  replicas: 3
  selector:
    matchLabels:
      app: model-server
  template:
    metadata:
      labels:
        app: model-server
    spec:
      # 使用 Init Container 预加载模型
      initContainers:
      - name: model-loader
        image: model-loader:latest
        command:
        - /bin/sh
        - -c
        - |
          # 从对象存储下载模型到本地缓存
          aws s3 cp s3://models/llm-v1/ /cache/model/ --recursive
          # 验证完整性
          md5sum -c /cache/model/checksums.md5
        volumeMounts:
        - name: model-cache
          mountPath: /cache
        env:
        - name: AWS_ACCESS_KEY_ID
          valueFrom:
            secretKeyRef:
              name: s3-credentials
              key: access-key
        - name: AWS_SECRET_ACCESS_KEY
          valueFrom:
            secretKeyRef:
              name: s3-credentials
              key: secret-key

      containers:
      - name: server
        image: model-server:latest
        resources:
          limits:
            nvidia.com/gpu: 1
        volumeMounts:
        - name: model-cache
          mountPath: /models
          readOnly: true
        env:
        - name: MODEL_PATH
          value: /models/model

      volumes:
      - name: model-cache
        emptyDir:
          sizeLimit: 50Gi

      # 亲和性:优先调度到有缓存的节点
      affinity:
        nodeAffinity:
          preferredDuringSchedulingIgnoredDuringExecution:
          - weight: 100
            preference:
              matchExpressions:
              - key: model-cache/llm-v1
                operator: Exists

监控与告警

Prometheus 指标

// storage_metrics.go
package metrics

import (
    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

var (
    // 检查点指标
    CheckpointSaveTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "checkpoint_save_total",
            Help: "Total number of checkpoint saves",
        },
        []string{"job_id", "status"},
    )

    CheckpointSaveLatency = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "checkpoint_save_latency_seconds",
            Help:    "Checkpoint save latency in seconds",
            Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600},
        },
        []string{"job_id", "storage_tier"},
    )

    CheckpointSizeBytes = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "checkpoint_size_bytes",
            Help:    "Checkpoint size in bytes",
            Buckets: prometheus.ExponentialBuckets(1<<20, 2, 15), // 1MB to 32TB
        },
        []string{"job_id"},
    )

    CheckpointLoadTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "checkpoint_load_total",
            Help: "Total number of checkpoint loads",
        },
        []string{"job_id", "status"},
    )

    CheckpointLoadLatency = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "checkpoint_load_latency_seconds",
            Help:    "Checkpoint load latency in seconds",
            Buckets: []float64{1, 5, 10, 30, 60, 120, 300},
        },
        []string{"job_id", "storage_tier"},
    )

    // 存储指标
    StorageUsageBytes = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "model_storage_usage_bytes",
            Help: "Model storage usage in bytes",
        },
        []string{"storage_tier", "job_id"},
    )

    StorageBandwidthBytes = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "model_storage_bandwidth_bytes_total",
            Help: "Model storage bandwidth in bytes",
        },
        []string{"storage_tier", "operation"}, // operation: read, write
    )

    StorageOperationErrors = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "model_storage_errors_total",
            Help: "Total storage operation errors",
        },
        []string{"storage_tier", "operation", "error_type"},
    )

    // 模型注册指标
    ModelVersionsTotal = promauto.NewGaugeVec(
        prometheus.GaugeOpts{
            Name: "model_versions_total",
            Help: "Total number of model versions",
        },
        []string{"model_id", "stage"},
    )

    ModelConversionLatency = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "model_conversion_latency_seconds",
            Help:    "Model format conversion latency",
            Buckets: []float64{10, 30, 60, 120, 300, 600, 1800},
        },
        []string{"source_format", "target_format"},
    )

    // 数据迁移指标
    StorageMigrationTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "storage_migration_total",
            Help: "Total storage migrations",
        },
        []string{"source_tier", "target_tier", "status"},
    )

    StorageMigrationBytes = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "storage_migration_bytes_total",
            Help: "Total bytes migrated between storage tiers",
        },
        []string{"source_tier", "target_tier"},
    )
)

// RecordCheckpointSave 记录检查点保存
func RecordCheckpointSave(jobID string, tier string, size int64, latency float64, success bool) {
    status := "success"
    if !success {
        status = "failure"
    }

    CheckpointSaveTotal.WithLabelValues(jobID, status).Inc()
    CheckpointSaveLatency.WithLabelValues(jobID, tier).Observe(latency)
    CheckpointSizeBytes.WithLabelValues(jobID).Observe(float64(size))
    StorageBandwidthBytes.WithLabelValues(tier, "write").Add(float64(size))
}

// RecordCheckpointLoad 记录检查点加载
func RecordCheckpointLoad(jobID string, tier string, size int64, latency float64, success bool) {
    status := "success"
    if !success {
        status = "failure"
    }

    CheckpointLoadTotal.WithLabelValues(jobID, status).Inc()
    CheckpointLoadLatency.WithLabelValues(jobID, tier).Observe(latency)
    StorageBandwidthBytes.WithLabelValues(tier, "read").Add(float64(size))
}

Grafana Dashboard

{
  "dashboard": {
    "title": "Model Storage Dashboard",
    "panels": [
      {
        "title": "Checkpoint Save Performance",
        "type": "graph",
        "targets": [
          {
            "expr": "histogram_quantile(0.99, rate(checkpoint_save_latency_seconds_bucket[5m]))",
            "legendFormat": "p99 - {{storage_tier}}"
          },
          {
            "expr": "histogram_quantile(0.50, rate(checkpoint_save_latency_seconds_bucket[5m]))",
            "legendFormat": "p50 - {{storage_tier}}"
          }
        ]
      },
      {
        "title": "Storage Bandwidth",
        "type": "graph",
        "targets": [
          {
            "expr": "rate(model_storage_bandwidth_bytes_total[5m])",
            "legendFormat": "{{storage_tier}} - {{operation}}"
          }
        ],
        "yaxes": [{"format": "bytes"}]
      },
      {
        "title": "Storage Usage by Tier",
        "type": "piechart",
        "targets": [
          {
            "expr": "sum(model_storage_usage_bytes) by (storage_tier)",
            "legendFormat": "{{storage_tier}}"
          }
        ]
      },
      {
        "title": "Checkpoint Operations",
        "type": "stat",
        "targets": [
          {
            "expr": "sum(rate(checkpoint_save_total{status=\"success\"}[1h]))",
            "legendFormat": "Saves/hour"
          },
          {
            "expr": "sum(rate(checkpoint_load_total{status=\"success\"}[1h]))",
            "legendFormat": "Loads/hour"
          }
        ]
      },
      {
        "title": "Storage Errors",
        "type": "table",
        "targets": [
          {
            "expr": "sum(increase(model_storage_errors_total[24h])) by (storage_tier, operation, error_type)",
            "format": "table"
          }
        ]
      },
      {
        "title": "Model Versions by Stage",
        "type": "bargauge",
        "targets": [
          {
            "expr": "sum(model_versions_total) by (stage)",
            "legendFormat": "{{stage}}"
          }
        ]
      }
    ]
  }
}

最佳实践

存储架构设计原则

┌─────────────────────────────────────────────────────────────────┐
│                    存储架构最佳实践                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  1. 分层存储策略                                                 │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • Hot 层:最新 2-3 个检查点,使用本地 NVMe 或 RDMA        │   │
│  │ • Warm 层:近期检查点,使用并行文件系统(Lustre/GPFS)    │   │
│  │ • Cold 层:里程碑和归档,使用对象存储(S3/OSS)           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  2. 检查点优化                                                   │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 异步写入:不阻塞训练进程                                │   │
│  │ • 增量检查点:只保存变化的部分                            │   │
│  │ • 压缩存储:使用 LZ4/ZSTD 减少存储和带宽                  │   │
│  │ • 分片保存:大模型并行写入多个文件                        │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  3. 数据完整性                                                   │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 写入时计算校验和                                        │   │
│  │ • 读取时验证完整性                                        │   │
│  │ • 三副本存储关键数据                                      │   │
│  │ • 定期数据巡检                                            │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  4. 性能优化                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 条带化配置:根据文件大小调整条带数和大小                │   │
│  │ • 并行 I/O:多线程/进程并发读写                           │   │
│  │ • 预取优化:训练时预加载下一批数据                        │   │
│  │ • 缓存策略:合理配置各级缓存大小                          │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  5. 容量规划                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 预估公式:存储 = 模型大小 × 2 × 保留数量 × 安全系数    │   │
│  │ • 监控阈值:80% 告警,90% 自动清理                        │   │
│  │ • 弹性扩展:支持在线扩容                                  │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

小结

本章详细介绍了 AI 训练平台的模型存储与管理:

  1. 模型文件特征:分析了检查点、模型文件的大小特征和访问模式
  2. 分布式存储架构:讲解了并行文件系统和对象存储的集成方案
  3. 检查点管理:实现了完整的检查点生命周期管理和分布式检查点
  4. 模型版本管理:构建了模型注册中心和血缘追踪系统
  5. 格式转换:实现了多格式转换流水线
  6. Kubernetes 集成:展示了存储类配置和训练任务挂载

下一章我们将探讨 实验管理,讲解如何系统化地管理机器学习实验、追踪超参数和对比分析结果。

Prev
训练任务调度
Next
实验管理