模型存储与管理
概述
大规模 AI 训练过程中会产生海量的模型文件,包括训练检查点、最终模型、中间产物等。如何高效、可靠地存储和管理这些模型文件,是 AI 平台的核心能力之一。本章深入讲解模型存储架构、版本管理、检查点优化等关键技术。
模型文件特征分析
模型文件类型
┌─────────────────────────────────────────────────────────────────┐
│ 模型文件类型 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 检查点文件 │ │ 最终模型 │ │ 中间产物 │ │
│ │ Checkpoint │ │ Final Model │ │ Artifacts │ │
│ ├─────────────┤ ├─────────────┤ ├─────────────┤ │
│ │ 权重参数 │ │ 权重参数 │ │ 优化器状态 │ │
│ │ 优化器状态 │ │ 模型配置 │ │ 梯度累积 │ │
│ │ 训练状态 │ │ Tokenizer │ │ 激活值缓存 │ │
│ │ 随机种子 │ │ 推理配置 │ │ 通信缓冲 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ 大小特征: │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 模型规模 │ 检查点大小 │ 存储需求/天 │ │
│ ├──────────────────────────────────────────────────────────┤ │
│ │ 7B (LLaMA) │ 14-28 GB │ 200-500 GB │ │
│ │ 13B │ 26-52 GB │ 400-800 GB │ │
│ │ 65B │ 130-260 GB │ 2-4 TB │ │
│ │ 175B (GPT-3) │ 350-700 GB │ 5-10 TB │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
访问模式分析
// model_access_pattern.go
package storage
import (
"time"
)
// AccessPattern 模型文件访问模式
type AccessPattern struct {
// 检查点写入模式
CheckpointWrite CheckpointWritePattern
// 检查点读取模式(恢复训练)
CheckpointRead CheckpointReadPattern
// 模型导出模式
ModelExport ModelExportPattern
// 模型加载模式(推理)
ModelLoad ModelLoadPattern
}
// CheckpointWritePattern 检查点写入特征
type CheckpointWritePattern struct {
// 写入频率:每 N 步保存一次
SaveInterval int
// 典型间隔:1000-10000 步
TypicalSteps int
// 写入大小:与模型规模成正比
WriteSize int64
// 并发写入:分布式训练时所有 rank 同时写
ConcurrentWriters int
// 写入带宽需求:高峰期需要高吞吐
RequiredBandwidth int64 // bytes/s
// 写入模式:顺序写为主
Sequential bool
}
// CheckpointReadPattern 检查点读取特征
type CheckpointReadPattern struct {
// 读取频率:训练恢复时
Frequency string // "rare" - 仅故障恢复或继续训练
// 读取大小:完整检查点
ReadSize int64
// 并发读取:所有 rank 同时读
ConcurrentReaders int
// 读取带宽需求
RequiredBandwidth int64
// 读取模式:顺序读
Sequential bool
}
// ModelExportPattern 模型导出特征
type ModelExportPattern struct {
// 导出格式
Formats []string // ["pytorch", "safetensors", "onnx", "tensorrt"]
// 导出频率:训练结束或里程碑
Frequency string // "milestone"
// 处理类型:可能需要格式转换
RequiresConversion bool
}
// ModelLoadPattern 模型加载特征(推理场景)
type ModelLoadPattern struct {
// 加载频率:服务启动、扩容
Frequency string // "on_demand"
// 加载速度要求:影响服务启动时间
LatencyRequirement time.Duration
// 并发加载:多实例同时加载
ConcurrentLoaders int
// 部分加载:分层加载支持
SupportsPartialLoad bool
}
// CalculateStorageRequirements 计算存储需求
func CalculateStorageRequirements(
modelParams int64, // 模型参数量
trainingDays int,
checkpointInterval int, // 小时
keepCheckpoints int, // 保留数量
) StorageRequirements {
// 每个参数占用字节(FP16 = 2, FP32 = 4, 混合精度约 6)
bytesPerParam := int64(6)
// 检查点大小 = 参数 × 每参数字节 × 2(参数+优化器状态)
checkpointSize := modelParams * bytesPerParam * 2
// 每天检查点数
checkpointsPerDay := 24 / checkpointInterval
// 总存储需求
totalCheckpointStorage := checkpointSize * int64(keepCheckpoints)
// 峰值写入带宽(假设检查点写入时间 < 10分钟)
peakWriteBandwidth := checkpointSize / (10 * 60) // bytes/s
return StorageRequirements{
CheckpointSize: checkpointSize,
TotalStorage: totalCheckpointStorage,
PeakWriteBandwidth: peakWriteBandwidth,
CheckpointsPerDay: checkpointsPerDay,
RecommendedReplicas: 3, // 三副本保证可靠性
}
}
type StorageRequirements struct {
CheckpointSize int64
TotalStorage int64
PeakWriteBandwidth int64
CheckpointsPerDay int
RecommendedReplicas int
}
分布式存储架构
存储系统选型
┌─────────────────────────────────────────────────────────────────┐
│ AI 训练存储架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 存储层次架构 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────┐ │
│ │ 本地 NVMe │ ← 最快,容量有限 │
│ │ Cache │ 用于:激活值缓存、临时文件 │
│ └──────┬──────┘ │
│ │ │
│ ┌──────▼──────┐ │
│ │ 共享 NVMe │ ← 高性能,跨节点共享 │
│ │ over Fabric │ 用于:检查点缓存、模型加载 │
│ └──────┬──────┘ │
│ │ │
│ ┌──────▼──────┐ │
│ │ 并行文件 │ ← 高吞吐,大容量 │
│ │ 系统 │ 用于:检查点持久化、数据集 │
│ │ (Lustre/ │ │
│ │ GPFS) │ │
│ └──────┬──────┘ │
│ │ │
│ ┌──────▼──────┐ │
│ │ 对象存储 │ ← 最大容量,成本最低 │
│ │ (S3/OSS) │ 用于:模型归档、长期保存 │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
并行文件系统集成
// parallel_fs.go
package storage
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"syscall"
)
// ParallelFileSystem 并行文件系统接口
type ParallelFileSystem interface {
// 并行写入
ParallelWrite(ctx context.Context, path string, data []byte, opts WriteOptions) error
// 并行读取
ParallelRead(ctx context.Context, path string, opts ReadOptions) ([]byte, error)
// 条带化配置
SetStriping(path string, stripeCount int, stripeSize int64) error
// 获取文件布局
GetLayout(path string) (*FileLayout, error)
}
// LustreFS Lustre 文件系统实现
type LustreFS struct {
mountPoint string
defaultStripeCount int
defaultStripeSize int64
maxStripeCount int
}
// WriteOptions 写入选项
type WriteOptions struct {
StripeCount int // 条带数量
StripeSize int64 // 条带大小
Parallel bool // 是否并行写
DirectIO bool // 是否使用 Direct I/O
Compression string // 压缩算法
}
// ReadOptions 读取选项
type ReadOptions struct {
Parallel bool
DirectIO bool
ReadAhead int64 // 预读大小
}
// FileLayout 文件布局信息
type FileLayout struct {
StripeCount int
StripeSize int64
StripeOffset int64
OSTs []int // Object Storage Targets
}
func NewLustreFS(mountPoint string) *LustreFS {
return &LustreFS{
mountPoint: mountPoint,
defaultStripeCount: 4,
defaultStripeSize: 1 << 20, // 1MB
maxStripeCount: 128,
}
}
// SetStriping 设置文件条带化
func (l *LustreFS) SetStriping(path string, stripeCount int, stripeSize int64) error {
fullPath := filepath.Join(l.mountPoint, path)
// 确保目录存在
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("create directory: %w", err)
}
// 使用 lfs setstripe 命令设置条带
// 实际实现中使用 Lustre API
return l.setStripeViaAPI(fullPath, stripeCount, stripeSize)
}
func (l *LustreFS) setStripeViaAPI(path string, stripeCount int, stripeSize int64) error {
// Lustre ioctl 调用设置条带
// LOV_USER_MAGIC_V3 结构
type lovUserMd struct {
Magic uint32
Pattern uint32
ObjectID uint64
ObjectSeq uint64
StripeSize uint32
StripeCount uint16
StripeOffset uint16
}
// 这里简化处理,实际需要调用 Lustre API
return nil
}
// ParallelWrite 并行写入大文件
func (l *LustreFS) ParallelWrite(ctx context.Context, path string, data []byte, opts WriteOptions) error {
fullPath := filepath.Join(l.mountPoint, path)
// 设置条带化
if opts.StripeCount > 0 {
if err := l.SetStriping(path, opts.StripeCount, opts.StripeSize); err != nil {
return err
}
}
// 打开文件
flags := os.O_CREATE | os.O_WRONLY | os.O_TRUNC
if opts.DirectIO {
flags |= syscall.O_DIRECT
}
file, err := os.OpenFile(fullPath, flags, 0644)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer file.Close()
if !opts.Parallel {
// 顺序写入
_, err = file.Write(data)
return err
}
// 并行写入
return l.parallelWriteChunks(ctx, file, data, opts)
}
func (l *LustreFS) parallelWriteChunks(ctx context.Context, file *os.File, data []byte, opts WriteOptions) error {
chunkSize := opts.StripeSize
if chunkSize == 0 {
chunkSize = l.defaultStripeSize
}
numChunks := (int64(len(data)) + chunkSize - 1) / chunkSize
var wg sync.WaitGroup
errCh := make(chan error, numChunks)
for i := int64(0); i < numChunks; i++ {
start := i * chunkSize
end := start + chunkSize
if end > int64(len(data)) {
end = int64(len(data))
}
wg.Add(1)
go func(offset int64, chunk []byte) {
defer wg.Done()
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
default:
}
_, err := file.WriteAt(chunk, offset)
if err != nil {
errCh <- fmt.Errorf("write at %d: %w", offset, err)
}
}(start, data[start:end])
}
wg.Wait()
close(errCh)
for err := range errCh {
if err != nil {
return err
}
}
return file.Sync()
}
// ParallelRead 并行读取大文件
func (l *LustreFS) ParallelRead(ctx context.Context, path string, opts ReadOptions) ([]byte, error) {
fullPath := filepath.Join(l.mountPoint, path)
// 获取文件信息
info, err := os.Stat(fullPath)
if err != nil {
return nil, fmt.Errorf("stat file: %w", err)
}
// 打开文件
flags := os.O_RDONLY
if opts.DirectIO {
flags |= syscall.O_DIRECT
}
file, err := os.OpenFile(fullPath, flags, 0)
if err != nil {
return nil, fmt.Errorf("open file: %w", err)
}
defer file.Close()
// 获取文件布局以优化读取
layout, _ := l.GetLayout(path)
data := make([]byte, info.Size())
if !opts.Parallel {
_, err = io.ReadFull(file, data)
return data, err
}
// 根据条带布局并行读取
return l.parallelReadChunks(ctx, file, data, layout, opts)
}
func (l *LustreFS) parallelReadChunks(
ctx context.Context,
file *os.File,
data []byte,
layout *FileLayout,
opts ReadOptions,
) ([]byte, error) {
chunkSize := int64(1 << 20) // 1MB default
if layout != nil && layout.StripeSize > 0 {
chunkSize = layout.StripeSize
}
numChunks := (int64(len(data)) + chunkSize - 1) / chunkSize
var wg sync.WaitGroup
errCh := make(chan error, numChunks)
// 控制并发度
sem := make(chan struct{}, 16)
for i := int64(0); i < numChunks; i++ {
start := i * chunkSize
end := start + chunkSize
if end > int64(len(data)) {
end = int64(len(data))
}
wg.Add(1)
go func(offset, length int64) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
default:
}
_, err := file.ReadAt(data[offset:offset+length], offset)
if err != nil && err != io.EOF {
errCh <- fmt.Errorf("read at %d: %w", offset, err)
}
}(start, end-start)
}
wg.Wait()
close(errCh)
for err := range errCh {
if err != nil {
return nil, err
}
}
return data, nil
}
// GetLayout 获取文件条带布局
func (l *LustreFS) GetLayout(path string) (*FileLayout, error) {
// 通过 Lustre API 获取文件布局
// 简化实现
return &FileLayout{
StripeCount: l.defaultStripeCount,
StripeSize: l.defaultStripeSize,
StripeOffset: 0,
OSTs: []int{0, 1, 2, 3},
}, nil
}
// OptimalStripeConfig 计算最优条带配置
func (l *LustreFS) OptimalStripeConfig(fileSize int64, accessPattern string) (stripeCount int, stripeSize int64) {
switch accessPattern {
case "checkpoint_write":
// 检查点写入:高并发,大文件
// 使用更多条带以提高吞吐
if fileSize > 10<<30 { // > 10GB
return min(l.maxStripeCount, 64), 4 << 20 // 64 条带,4MB
} else if fileSize > 1<<30 { // > 1GB
return 32, 2 << 20 // 32 条带,2MB
}
return 16, 1 << 20 // 16 条带,1MB
case "checkpoint_read":
// 检查点读取:顺序读为主
if fileSize > 10<<30 {
return 32, 4 << 20
}
return 16, 2 << 20
case "model_serve":
// 模型加载:追求最快加载
return min(l.maxStripeCount, 128), 1 << 20
default:
return l.defaultStripeCount, l.defaultStripeSize
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
对象存储集成
// object_storage.go
package storage
import (
"bytes"
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"io"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
// ObjectStorage 对象存储接口
type ObjectStorage interface {
// 上传对象
Upload(ctx context.Context, bucket, key string, data []byte) error
// 分片上传大对象
MultipartUpload(ctx context.Context, bucket, key string, reader io.Reader, size int64) error
// 下载对象
Download(ctx context.Context, bucket, key string) ([]byte, error)
// 并行下载大对象
ParallelDownload(ctx context.Context, bucket, key string) ([]byte, error)
// 列出对象
List(ctx context.Context, bucket, prefix string) ([]ObjectInfo, error)
// 删除对象
Delete(ctx context.Context, bucket, key string) error
// 复制对象
Copy(ctx context.Context, srcBucket, srcKey, dstBucket, dstKey string) error
}
// ObjectInfo 对象信息
type ObjectInfo struct {
Key string
Size int64
LastModified time.Time
ETag string
StorageClass string
}
// S3Storage S3 对象存储实现
type S3Storage struct {
client *s3.Client
partSize int64
concurrency int
}
func NewS3Storage(client *s3.Client) *S3Storage {
return &S3Storage{
client: client,
partSize: 100 << 20, // 100MB per part
concurrency: 16,
}
}
// MultipartUpload 分片上传大文件
func (s *S3Storage) MultipartUpload(
ctx context.Context,
bucket, key string,
reader io.Reader,
size int64,
) error {
// 创建分片上传
createResp, err := s.client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
StorageClass: types.StorageClassStandardIa, // 检查点使用低频存储
})
if err != nil {
return fmt.Errorf("create multipart upload: %w", err)
}
uploadID := *createResp.UploadId
// 计算分片
numParts := (size + s.partSize - 1) / s.partSize
parts := make([]types.CompletedPart, numParts)
var wg sync.WaitGroup
errCh := make(chan error, numParts)
sem := make(chan struct{}, s.concurrency)
var mu sync.Mutex
partNum := int32(1)
for {
// 读取一个分片
buf := make([]byte, s.partSize)
n, readErr := io.ReadFull(reader, buf)
if n == 0 {
break
}
currentPart := partNum
partNum++
data := buf[:n]
wg.Add(1)
go func(pn int32, d []byte) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
// 计算 MD5
hash := md5.Sum(d)
md5Str := hex.EncodeToString(hash[:])
// 上传分片
uploadResp, err := s.client.UploadPart(ctx, &s3.UploadPartInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: aws.String(uploadID),
PartNumber: aws.Int32(pn),
Body: bytes.NewReader(d),
ContentMD5: aws.String(md5Str),
})
if err != nil {
errCh <- fmt.Errorf("upload part %d: %w", pn, err)
return
}
mu.Lock()
parts[pn-1] = types.CompletedPart{
ETag: uploadResp.ETag,
PartNumber: aws.Int32(pn),
}
mu.Unlock()
}(currentPart, data)
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
break
}
if readErr != nil {
// 取消上传
s.abortMultipartUpload(ctx, bucket, key, uploadID)
return fmt.Errorf("read data: %w", readErr)
}
}
wg.Wait()
close(errCh)
// 检查错误
for err := range errCh {
if err != nil {
s.abortMultipartUpload(ctx, bucket, key, uploadID)
return err
}
}
// 完成上传
_, err = s.client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: aws.String(uploadID),
MultipartUpload: &types.CompletedMultipartUpload{
Parts: parts[:partNum-1],
},
})
if err != nil {
return fmt.Errorf("complete multipart upload: %w", err)
}
return nil
}
func (s *S3Storage) abortMultipartUpload(ctx context.Context, bucket, key, uploadID string) {
s.client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: aws.String(uploadID),
})
}
// ParallelDownload 并行下载大文件
func (s *S3Storage) ParallelDownload(ctx context.Context, bucket, key string) ([]byte, error) {
// 获取对象大小
headResp, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return nil, fmt.Errorf("head object: %w", err)
}
size := *headResp.ContentLength
if size <= s.partSize {
// 小文件直接下载
return s.Download(ctx, bucket, key)
}
// 并行下载
data := make([]byte, size)
numParts := (size + s.partSize - 1) / s.partSize
var wg sync.WaitGroup
errCh := make(chan error, numParts)
sem := make(chan struct{}, s.concurrency)
for i := int64(0); i < numParts; i++ {
start := i * s.partSize
end := start + s.partSize - 1
if end >= size {
end = size - 1
}
wg.Add(1)
go func(s3Start, s3End int64) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
rangeStr := fmt.Sprintf("bytes=%d-%d", s3Start, s3End)
resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Range: aws.String(rangeStr),
})
if err != nil {
errCh <- fmt.Errorf("get range %s: %w", rangeStr, err)
return
}
defer resp.Body.Close()
_, err = io.ReadFull(resp.Body, data[s3Start:s3End+1])
if err != nil {
errCh <- fmt.Errorf("read range %s: %w", rangeStr, err)
}
}(start, end)
}
wg.Wait()
close(errCh)
for err := range errCh {
if err != nil {
return nil, err
}
}
return data, nil
}
// Download 下载对象
func (s *S3Storage) Download(ctx context.Context, bucket, key string) ([]byte, error) {
resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return nil, fmt.Errorf("get object: %w", err)
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
// Upload 上传对象
func (s *S3Storage) Upload(ctx context.Context, bucket, key string, data []byte) error {
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: bytes.NewReader(data),
})
return err
}
// List 列出对象
func (s *S3Storage) List(ctx context.Context, bucket, prefix string) ([]ObjectInfo, error) {
var objects []ObjectInfo
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Prefix: aws.String(prefix),
})
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("list objects: %w", err)
}
for _, obj := range page.Contents {
objects = append(objects, ObjectInfo{
Key: *obj.Key,
Size: *obj.Size,
LastModified: *obj.LastModified,
ETag: *obj.ETag,
StorageClass: string(obj.StorageClass),
})
}
}
return objects, nil
}
// Delete 删除对象
func (s *S3Storage) Delete(ctx context.Context, bucket, key string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
return err
}
// Copy 复制对象
func (s *S3Storage) Copy(ctx context.Context, srcBucket, srcKey, dstBucket, dstKey string) error {
_, err := s.client.CopyObject(ctx, &s3.CopyObjectInput{
Bucket: aws.String(dstBucket),
Key: aws.String(dstKey),
CopySource: aws.String(fmt.Sprintf("%s/%s", srcBucket, srcKey)),
})
return err
}
检查点管理系统
检查点生命周期
┌─────────────────────────────────────────────────────────────────┐
│ 检查点生命周期管理 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 创建阶段: │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ 触发 │───▶│ 序列化 │───▶│ 压缩 │───▶│ 上传 │ │
│ │ 保存 │ │ 状态 │ │ (可选) │ │ 存储 │ │
│ └────────┘ └────────┘ └────────┘ └────────┘ │
│ │
│ 存储阶段: │
│ ┌────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Hot Storage Warm Storage Cold Storage │ │
│ │ (最新 3 个) (近期 10 个) (里程碑) │ │
│ │ ┌───┐ ┌───┐ ┌───┐ │ │
│ │ │NVMe│ ───▶ │PFS│ ───▶ │S3 │ │ │
│ │ └───┘ └───┘ └───┘ │ │
│ │ │ │
│ └────────────────────────────────────────────────────────┘ │
│ │
│ 恢复阶段: │
│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ 定位 │───▶│ 下载 │───▶│ 解压 │───▶│ 反序列 │ │
│ │ 检查点 │ │ 数据 │ │ (可选) │ │ 化 │ │
│ └────────┘ └────────┘ └────────┘ └────────┘ │
│ │
│ 清理阶段: │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ 策略 │───▶│ 标记 │───▶│ 删除 │ │
│ │ 评估 │ │ 过期 │ │ 清理 │ │
│ └────────┘ └────────┘ └────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
检查点管理器实现
// checkpoint_manager.go
package checkpoint
import (
"context"
"encoding/json"
"fmt"
"path/filepath"
"sort"
"sync"
"time"
)
// CheckpointManager 检查点管理器
type CheckpointManager struct {
// 存储后端
hotStorage Storage // 本地/NVMe
warmStorage Storage // 并行文件系统
coldStorage Storage // 对象存储
// 元数据存储
metaStore MetadataStore
// 配置
config CheckpointConfig
// 运行时状态
mu sync.RWMutex
activeUploads map[string]*UploadTask
// 指标
metrics *CheckpointMetrics
}
// CheckpointConfig 检查点配置
type CheckpointConfig struct {
// 保存策略
SaveInterval int // 每 N 步保存
SaveIntervalMinutes int // 或每 N 分钟保存
// 保留策略
HotRetention int // Hot 存储保留数量
WarmRetention int // Warm 存储保留数量
ColdRetention int // Cold 存储保留数量(里程碑)
// 存储路径
HotPath string
WarmPath string
ColdBucket string
// 性能配置
AsyncUpload bool
CompressionAlgo string // "none", "lz4", "zstd"
ParallelIO bool
// 数据完整性
VerifyChecksum bool
// 生命周期
MilestoneSteps []int // 标记为里程碑的步数
}
// Checkpoint 检查点元数据
type Checkpoint struct {
ID string `json:"id"`
JobID string `json:"job_id"`
Step int64 `json:"step"`
Epoch int `json:"epoch"`
CreatedAt time.Time `json:"created_at"`
// 存储位置
StorageTier StorageTier `json:"storage_tier"`
Path string `json:"path"`
// 文件信息
Files []CheckpointFile `json:"files"`
TotalSize int64 `json:"total_size"`
Checksum string `json:"checksum"`
// 训练状态
TrainingState TrainingState `json:"training_state"`
// 标签
IsMilestone bool `json:"is_milestone"`
Tags map[string]string `json:"tags"`
}
// CheckpointFile 检查点文件
type CheckpointFile struct {
Name string `json:"name"`
Size int64 `json:"size"`
Checksum string `json:"checksum"`
Type string `json:"type"` // "model", "optimizer", "scheduler", "rng"
}
// TrainingState 训练状态
type TrainingState struct {
GlobalStep int64 `json:"global_step"`
Epoch int `json:"epoch"`
LearningRate float64 `json:"learning_rate"`
Loss float64 `json:"loss"`
Metrics map[string]float64 `json:"metrics"`
RandomState []byte `json:"random_state"`
}
type StorageTier string
const (
StorageTierHot StorageTier = "hot"
StorageTierWarm StorageTier = "warm"
StorageTierCold StorageTier = "cold"
)
// NewCheckpointManager 创建检查点管理器
func NewCheckpointManager(
hotStorage, warmStorage, coldStorage Storage,
metaStore MetadataStore,
config CheckpointConfig,
) *CheckpointManager {
return &CheckpointManager{
hotStorage: hotStorage,
warmStorage: warmStorage,
coldStorage: coldStorage,
metaStore: metaStore,
config: config,
activeUploads: make(map[string]*UploadTask),
metrics: NewCheckpointMetrics(),
}
}
// Save 保存检查点
func (m *CheckpointManager) Save(
ctx context.Context,
jobID string,
step int64,
epoch int,
state *TrainingState,
data map[string][]byte, // 文件名 -> 数据
) (*Checkpoint, error) {
startTime := time.Now()
// 创建检查点元数据
checkpoint := &Checkpoint{
ID: fmt.Sprintf("%s-step-%d", jobID, step),
JobID: jobID,
Step: step,
Epoch: epoch,
CreatedAt: time.Now(),
StorageTier: StorageTierHot,
TrainingState: *state,
IsMilestone: m.isMilestone(step),
Tags: make(map[string]string),
}
// 计算检查点路径
checkpoint.Path = m.checkpointPath(jobID, step)
// 处理每个文件
var files []CheckpointFile
var totalSize int64
for name, content := range data {
// 可选压缩
processedData := content
if m.config.CompressionAlgo != "none" {
var err error
processedData, err = m.compress(content)
if err != nil {
return nil, fmt.Errorf("compress %s: %w", name, err)
}
}
// 计算校验和
checksum := m.calculateChecksum(processedData)
// 保存到 Hot 存储
filePath := filepath.Join(checkpoint.Path, name)
if err := m.hotStorage.Write(ctx, filePath, processedData); err != nil {
return nil, fmt.Errorf("write %s: %w", name, err)
}
files = append(files, CheckpointFile{
Name: name,
Size: int64(len(processedData)),
Checksum: checksum,
Type: m.inferFileType(name),
})
totalSize += int64(len(processedData))
}
checkpoint.Files = files
checkpoint.TotalSize = totalSize
checkpoint.Checksum = m.calculateCheckpointChecksum(files)
// 保存元数据
if err := m.metaStore.SaveCheckpoint(ctx, checkpoint); err != nil {
return nil, fmt.Errorf("save metadata: %w", err)
}
// 记录指标
m.metrics.RecordSave(totalSize, time.Since(startTime))
// 异步迁移到 Warm 存储
if m.config.AsyncUpload {
go m.migrateToWarm(context.Background(), checkpoint)
}
// 触发清理
go m.cleanup(context.Background(), jobID)
return checkpoint, nil
}
// Load 加载检查点
func (m *CheckpointManager) Load(
ctx context.Context,
checkpointID string,
) (map[string][]byte, *TrainingState, error) {
startTime := time.Now()
// 获取元数据
checkpoint, err := m.metaStore.GetCheckpoint(ctx, checkpointID)
if err != nil {
return nil, nil, fmt.Errorf("get checkpoint metadata: %w", err)
}
// 选择存储后端
storage := m.getStorage(checkpoint.StorageTier)
// 加载所有文件
data := make(map[string][]byte)
for _, file := range checkpoint.Files {
filePath := filepath.Join(checkpoint.Path, file.Name)
content, err := storage.Read(ctx, filePath)
if err != nil {
return nil, nil, fmt.Errorf("read %s: %w", file.Name, err)
}
// 验证校验和
if m.config.VerifyChecksum {
if checksum := m.calculateChecksum(content); checksum != file.Checksum {
return nil, nil, fmt.Errorf("checksum mismatch for %s", file.Name)
}
}
// 解压
if m.config.CompressionAlgo != "none" {
content, err = m.decompress(content)
if err != nil {
return nil, nil, fmt.Errorf("decompress %s: %w", file.Name, err)
}
}
data[file.Name] = content
}
// 记录指标
m.metrics.RecordLoad(checkpoint.TotalSize, time.Since(startTime))
return data, &checkpoint.TrainingState, nil
}
// LoadLatest 加载最新检查点
func (m *CheckpointManager) LoadLatest(
ctx context.Context,
jobID string,
) (map[string][]byte, *TrainingState, error) {
// 获取最新检查点
checkpoints, err := m.metaStore.ListCheckpoints(ctx, jobID)
if err != nil {
return nil, nil, fmt.Errorf("list checkpoints: %w", err)
}
if len(checkpoints) == 0 {
return nil, nil, fmt.Errorf("no checkpoints found for job %s", jobID)
}
// 按步数排序,取最新
sort.Slice(checkpoints, func(i, j int) bool {
return checkpoints[i].Step > checkpoints[j].Step
})
return m.Load(ctx, checkpoints[0].ID)
}
// migrateToWarm 迁移到 Warm 存储
func (m *CheckpointManager) migrateToWarm(ctx context.Context, checkpoint *Checkpoint) {
// 检查是否已在迁移
m.mu.Lock()
if _, exists := m.activeUploads[checkpoint.ID]; exists {
m.mu.Unlock()
return
}
task := &UploadTask{
CheckpointID: checkpoint.ID,
StartTime: time.Now(),
Status: "in_progress",
}
m.activeUploads[checkpoint.ID] = task
m.mu.Unlock()
defer func() {
m.mu.Lock()
delete(m.activeUploads, checkpoint.ID)
m.mu.Unlock()
}()
// 复制文件到 Warm 存储
for _, file := range checkpoint.Files {
srcPath := filepath.Join(checkpoint.Path, file.Name)
dstPath := filepath.Join(m.config.WarmPath, checkpoint.JobID, checkpoint.ID, file.Name)
data, err := m.hotStorage.Read(ctx, srcPath)
if err != nil {
m.metrics.RecordMigrationError("hot_to_warm")
return
}
if err := m.warmStorage.Write(ctx, dstPath, data); err != nil {
m.metrics.RecordMigrationError("hot_to_warm")
return
}
}
// 更新元数据
checkpoint.StorageTier = StorageTierWarm
checkpoint.Path = filepath.Join(m.config.WarmPath, checkpoint.JobID, checkpoint.ID)
m.metaStore.SaveCheckpoint(ctx, checkpoint)
// 如果是里程碑,还要迁移到 Cold 存储
if checkpoint.IsMilestone {
m.migrateToCold(ctx, checkpoint)
}
}
// migrateToCold 迁移到 Cold 存储
func (m *CheckpointManager) migrateToCold(ctx context.Context, checkpoint *Checkpoint) {
// 上传到对象存储
for _, file := range checkpoint.Files {
srcPath := filepath.Join(checkpoint.Path, file.Name)
dstKey := fmt.Sprintf("%s/%s/%s", checkpoint.JobID, checkpoint.ID, file.Name)
data, err := m.warmStorage.Read(ctx, srcPath)
if err != nil {
m.metrics.RecordMigrationError("warm_to_cold")
return
}
if err := m.coldStorage.Write(ctx, dstKey, data); err != nil {
m.metrics.RecordMigrationError("warm_to_cold")
return
}
}
// 更新元数据,记录 Cold 存储位置
checkpoint.Tags["cold_path"] = fmt.Sprintf("s3://%s/%s/%s",
m.config.ColdBucket, checkpoint.JobID, checkpoint.ID)
m.metaStore.SaveCheckpoint(ctx, checkpoint)
}
// cleanup 清理过期检查点
func (m *CheckpointManager) cleanup(ctx context.Context, jobID string) {
checkpoints, err := m.metaStore.ListCheckpoints(ctx, jobID)
if err != nil {
return
}
// 按存储层分组
hotCheckpoints := filterByTier(checkpoints, StorageTierHot)
warmCheckpoints := filterByTier(checkpoints, StorageTierWarm)
// 按步数排序(降序)
sort.Slice(hotCheckpoints, func(i, j int) bool {
return hotCheckpoints[i].Step > hotCheckpoints[j].Step
})
sort.Slice(warmCheckpoints, func(i, j int) bool {
return warmCheckpoints[i].Step > warmCheckpoints[j].Step
})
// 清理 Hot 存储中超出保留数量的检查点
if len(hotCheckpoints) > m.config.HotRetention {
for _, cp := range hotCheckpoints[m.config.HotRetention:] {
m.deleteCheckpoint(ctx, cp, StorageTierHot)
}
}
// 清理 Warm 存储中超出保留数量的非里程碑检查点
nonMilestoneWarm := filterNonMilestone(warmCheckpoints)
if len(nonMilestoneWarm) > m.config.WarmRetention {
for _, cp := range nonMilestoneWarm[m.config.WarmRetention:] {
m.deleteCheckpoint(ctx, cp, StorageTierWarm)
}
}
}
func (m *CheckpointManager) deleteCheckpoint(ctx context.Context, cp *Checkpoint, tier StorageTier) {
storage := m.getStorage(tier)
for _, file := range cp.Files {
filePath := filepath.Join(cp.Path, file.Name)
storage.Delete(ctx, filePath)
}
// 如果所有存储层都删除了,才删除元数据
// 这里简化处理,实际需要跟踪每个层的删除状态
}
func (m *CheckpointManager) getStorage(tier StorageTier) Storage {
switch tier {
case StorageTierHot:
return m.hotStorage
case StorageTierWarm:
return m.warmStorage
case StorageTierCold:
return m.coldStorage
default:
return m.hotStorage
}
}
func (m *CheckpointManager) isMilestone(step int64) bool {
for _, ms := range m.config.MilestoneSteps {
if int64(ms) == step {
return true
}
}
// 或者每 N 步自动标记
return step%10000 == 0
}
func (m *CheckpointManager) checkpointPath(jobID string, step int64) string {
return filepath.Join(m.config.HotPath, jobID, fmt.Sprintf("step-%d", step))
}
func (m *CheckpointManager) compress(data []byte) ([]byte, error) {
// 实现压缩逻辑
return data, nil
}
func (m *CheckpointManager) decompress(data []byte) ([]byte, error) {
// 实现解压逻辑
return data, nil
}
func (m *CheckpointManager) calculateChecksum(data []byte) string {
// 实现校验和计算
return ""
}
func (m *CheckpointManager) calculateCheckpointChecksum(files []CheckpointFile) string {
// 计算整个检查点的校验和
return ""
}
func (m *CheckpointManager) inferFileType(name string) string {
switch {
case name == "model.pt" || name == "pytorch_model.bin":
return "model"
case name == "optimizer.pt":
return "optimizer"
case name == "scheduler.pt":
return "scheduler"
case name == "rng_state.pt":
return "rng"
default:
return "other"
}
}
type UploadTask struct {
CheckpointID string
StartTime time.Time
Status string
}
func filterByTier(checkpoints []*Checkpoint, tier StorageTier) []*Checkpoint {
var result []*Checkpoint
for _, cp := range checkpoints {
if cp.StorageTier == tier {
result = append(result, cp)
}
}
return result
}
func filterNonMilestone(checkpoints []*Checkpoint) []*Checkpoint {
var result []*Checkpoint
for _, cp := range checkpoints {
if !cp.IsMilestone {
result = append(result, cp)
}
}
return result
}
分布式检查点
// distributed_checkpoint.go
package checkpoint
import (
"context"
"fmt"
"sync"
"time"
)
// DistributedCheckpointManager 分布式检查点管理器
type DistributedCheckpointManager struct {
// 本地检查点管理器
localManager *CheckpointManager
// 分布式协调
coordinator Coordinator
// 当前进程信息
rank int
worldSize int
// 配置
config DistributedCheckpointConfig
}
// DistributedCheckpointConfig 分布式配置
type DistributedCheckpointConfig struct {
// 分片策略
ShardingStrategy string // "full", "sharded", "fsdp"
// 聚合配置
AggregateOnSave bool // 保存时聚合到 rank 0
AggregateOnLoad bool // 加载时从聚合文件分发
// 并行 IO
ParallelRanks int // 并行保存的 rank 数
// 超时
BarrierTimeout time.Duration
}
// DistributedCheckpoint 分布式检查点
type DistributedCheckpoint struct {
*Checkpoint
// 分片信息
Shards []ShardInfo `json:"shards"`
// 分布式状态
WorldSize int `json:"world_size"`
}
// ShardInfo 分片信息
type ShardInfo struct {
Rank int `json:"rank"`
Files []string `json:"files"`
TotalSize int64 `json:"total_size"`
NodeID string `json:"node_id"`
}
// Coordinator 分布式协调接口
type Coordinator interface {
Barrier(ctx context.Context, name string) error
Broadcast(ctx context.Context, data []byte, root int) ([]byte, error)
Gather(ctx context.Context, data []byte, root int) ([][]byte, error)
AllGather(ctx context.Context, data []byte) ([][]byte, error)
}
// NewDistributedCheckpointManager 创建分布式检查点管理器
func NewDistributedCheckpointManager(
localManager *CheckpointManager,
coordinator Coordinator,
rank, worldSize int,
config DistributedCheckpointConfig,
) *DistributedCheckpointManager {
return &DistributedCheckpointManager{
localManager: localManager,
coordinator: coordinator,
rank: rank,
worldSize: worldSize,
config: config,
}
}
// Save 分布式保存检查点
func (m *DistributedCheckpointManager) Save(
ctx context.Context,
jobID string,
step int64,
epoch int,
state *TrainingState,
localData map[string][]byte, // 本 rank 的数据
) (*DistributedCheckpoint, error) {
// 阶段 1:准备(所有 rank 同步)
if err := m.coordinator.Barrier(ctx, fmt.Sprintf("ckpt-prepare-%d", step)); err != nil {
return nil, fmt.Errorf("barrier at prepare: %w", err)
}
var checkpoint *DistributedCheckpoint
var saveErr error
switch m.config.ShardingStrategy {
case "full":
// 每个 rank 保存完整检查点
checkpoint, saveErr = m.saveFullReplicated(ctx, jobID, step, epoch, state, localData)
case "sharded":
// 每个 rank 保存自己的分片
checkpoint, saveErr = m.saveSharded(ctx, jobID, step, epoch, state, localData)
case "fsdp":
// FSDP 风格:参数分片 + 优化器分片
checkpoint, saveErr = m.saveFSDP(ctx, jobID, step, epoch, state, localData)
}
if saveErr != nil {
return nil, saveErr
}
// 阶段 2:完成(所有 rank 同步)
if err := m.coordinator.Barrier(ctx, fmt.Sprintf("ckpt-complete-%d", step)); err != nil {
return nil, fmt.Errorf("barrier at complete: %w", err)
}
return checkpoint, nil
}
// saveSharded 分片保存
func (m *DistributedCheckpointManager) saveSharded(
ctx context.Context,
jobID string,
step int64,
epoch int,
state *TrainingState,
localData map[string][]byte,
) (*DistributedCheckpoint, error) {
// 每个 rank 保存自己的分片
// 使用 rank 作为文件名前缀
shardedData := make(map[string][]byte)
for name, data := range localData {
shardedName := fmt.Sprintf("rank_%d_%s", m.rank, name)
shardedData[shardedName] = data
}
// 调用本地管理器保存
localCheckpoint, err := m.localManager.Save(ctx, jobID, step, epoch, state, shardedData)
if err != nil {
return nil, fmt.Errorf("save local shard: %w", err)
}
// 收集所有 rank 的分片信息
shardInfo := ShardInfo{
Rank: m.rank,
Files: getFileNames(shardedData),
TotalSize: calculateTotalSize(shardedData),
NodeID: getNodeID(),
}
shardInfoBytes, _ := json.Marshal(shardInfo)
allShardInfoBytes, err := m.coordinator.AllGather(ctx, shardInfoBytes)
if err != nil {
return nil, fmt.Errorf("gather shard info: %w", err)
}
// 解析所有分片信息
var allShards []ShardInfo
for _, b := range allShardInfoBytes {
var si ShardInfo
json.Unmarshal(b, &si)
allShards = append(allShards, si)
}
return &DistributedCheckpoint{
Checkpoint: localCheckpoint,
Shards: allShards,
WorldSize: m.worldSize,
}, nil
}
// saveFSDP FSDP 风格保存
func (m *DistributedCheckpointManager) saveFSDP(
ctx context.Context,
jobID string,
step int64,
epoch int,
state *TrainingState,
localData map[string][]byte,
) (*DistributedCheckpoint, error) {
// FSDP 保存策略:
// 1. 模型参数:每个 rank 保存自己的分片,或聚合后保存
// 2. 优化器状态:每个 rank 保存自己的分片
// 3. 其他状态:只在 rank 0 保存
shardedData := make(map[string][]byte)
for name, data := range localData {
switch {
case isModelParam(name):
if m.config.AggregateOnSave {
// 聚合模型参数到 rank 0
aggregated, err := m.aggregateParams(ctx, data)
if err != nil {
return nil, err
}
if m.rank == 0 {
shardedData[name] = aggregated
}
} else {
// 分片保存
shardedData[fmt.Sprintf("model_shard_%d.pt", m.rank)] = data
}
case isOptimizerState(name):
// 优化器状态总是分片保存
shardedData[fmt.Sprintf("optim_shard_%d.pt", m.rank)] = data
default:
// 其他状态只在 rank 0 保存
if m.rank == 0 {
shardedData[name] = data
}
}
}
return m.saveSharded(ctx, jobID, step, epoch, state, shardedData)
}
// saveFullReplicated 完全复制保存(每个 rank 保存完整)
func (m *DistributedCheckpointManager) saveFullReplicated(
ctx context.Context,
jobID string,
step int64,
epoch int,
state *TrainingState,
localData map[string][]byte,
) (*DistributedCheckpoint, error) {
// 控制并发度:只有部分 rank 同时保存
if m.rank < m.config.ParallelRanks {
localCheckpoint, err := m.localManager.Save(ctx, jobID, step, epoch, state, localData)
if err != nil {
return nil, err
}
return &DistributedCheckpoint{
Checkpoint: localCheckpoint,
Shards: []ShardInfo{{
Rank: m.rank,
TotalSize: localCheckpoint.TotalSize,
}},
WorldSize: m.worldSize,
}, nil
}
// 其他 rank 等待
return &DistributedCheckpoint{}, nil
}
// Load 分布式加载检查点
func (m *DistributedCheckpointManager) Load(
ctx context.Context,
checkpointID string,
) (map[string][]byte, *TrainingState, error) {
// 阶段 1:准备
if err := m.coordinator.Barrier(ctx, fmt.Sprintf("load-prepare-%s", checkpointID)); err != nil {
return nil, nil, err
}
var data map[string][]byte
var state *TrainingState
var loadErr error
switch m.config.ShardingStrategy {
case "sharded", "fsdp":
data, state, loadErr = m.loadSharded(ctx, checkpointID)
case "full":
data, state, loadErr = m.localManager.Load(ctx, checkpointID)
}
if loadErr != nil {
return nil, nil, loadErr
}
// 阶段 2:完成
if err := m.coordinator.Barrier(ctx, fmt.Sprintf("load-complete-%s", checkpointID)); err != nil {
return nil, nil, err
}
return data, state, nil
}
// loadSharded 加载分片检查点
func (m *DistributedCheckpointManager) loadSharded(
ctx context.Context,
checkpointID string,
) (map[string][]byte, *TrainingState, error) {
// 获取元数据
metadata, err := m.localManager.metaStore.GetCheckpoint(ctx, checkpointID)
if err != nil {
return nil, nil, err
}
// 加载本 rank 的分片
data := make(map[string][]byte)
for _, file := range metadata.Files {
// 检查是否是本 rank 的文件
if isRankFile(file.Name, m.rank) || isSharedFile(file.Name) {
content, err := m.localManager.hotStorage.Read(ctx,
fmt.Sprintf("%s/%s", metadata.Path, file.Name))
if err != nil {
// 尝试从 warm 存储读取
content, err = m.localManager.warmStorage.Read(ctx,
fmt.Sprintf("%s/%s", metadata.Path, file.Name))
if err != nil {
return nil, nil, err
}
}
// 去除 rank 前缀
cleanName := removeRankPrefix(file.Name)
data[cleanName] = content
}
}
return data, &metadata.TrainingState, nil
}
// aggregateParams 聚合参数到 rank 0
func (m *DistributedCheckpointManager) aggregateParams(ctx context.Context, localParams []byte) ([]byte, error) {
// 使用 Gather 收集到 rank 0
allParams, err := m.coordinator.Gather(ctx, localParams, 0)
if err != nil {
return nil, err
}
if m.rank != 0 {
return nil, nil
}
// 在 rank 0 上拼接所有参数
var totalSize int
for _, p := range allParams {
totalSize += len(p)
}
result := make([]byte, 0, totalSize)
for _, p := range allParams {
result = append(result, p...)
}
return result, nil
}
// 辅助函数
func getFileNames(data map[string][]byte) []string {
names := make([]string, 0, len(data))
for name := range data {
names = append(names, name)
}
return names
}
func calculateTotalSize(data map[string][]byte) int64 {
var total int64
for _, d := range data {
total += int64(len(d))
}
return total
}
func getNodeID() string {
// 返回节点标识
return "node-1"
}
func isModelParam(name string) bool {
return name == "model.pt" || name == "pytorch_model.bin"
}
func isOptimizerState(name string) bool {
return name == "optimizer.pt"
}
func isRankFile(name string, rank int) bool {
return fmt.Sprintf("rank_%d_", rank) == name[:len(fmt.Sprintf("rank_%d_", rank))]
}
func isSharedFile(name string) bool {
// 共享文件不带 rank 前缀
return len(name) > 0 && name[0] != 'r'
}
func removeRankPrefix(name string) string {
// 去除 "rank_N_" 前缀
for i := 0; i < len(name); i++ {
if name[i] == '_' && i > 5 {
return name[i+1:]
}
}
return name
}
模型版本管理
模型注册中心
// model_registry.go
package registry
import (
"context"
"fmt"
"sort"
"time"
)
// ModelRegistry 模型注册中心
type ModelRegistry struct {
store MetadataStore
storage ModelStorage
validator ModelValidator
}
// Model 模型定义
type Model struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Owner string `json:"owner"`
Tags []string `json:"tags"`
Labels map[string]string `json:"labels"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ModelVersion 模型版本
type ModelVersion struct {
ID string `json:"id"`
ModelID string `json:"model_id"`
Version string `json:"version"` // 语义版本号
Stage ModelStage `json:"stage"` // 生命周期阶段
// 来源
Source ModelSource `json:"source"`
// 存储
Artifacts []Artifact `json:"artifacts"`
// 元数据
Framework string `json:"framework"` // pytorch, tensorflow, etc.
Architecture string `json:"architecture"` // transformer, cnn, etc.
Parameters int64 `json:"parameters"` // 参数量
// 训练信息
TrainingJob string `json:"training_job"`
Checkpoint string `json:"checkpoint"`
Metrics map[string]float64 `json:"metrics"`
// 签名(输入输出格式)
Signature ModelSignature `json:"signature"`
// 审计
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
ApprovedBy string `json:"approved_by"`
ApprovedAt *time.Time `json:"approved_at"`
}
// ModelStage 模型阶段
type ModelStage string
const (
StageNone ModelStage = "none"
StageDevelopment ModelStage = "development"
StageStaging ModelStage = "staging"
StageProduction ModelStage = "production"
StageArchived ModelStage = "archived"
)
// ModelSource 模型来源
type ModelSource struct {
Type string `json:"type"` // "training", "import", "conversion"
TrainingRun string `json:"training_run,omitempty"`
ImportPath string `json:"import_path,omitempty"`
SourceModel string `json:"source_model,omitempty"` // 转换来源
}
// Artifact 模型产物
type Artifact struct {
Name string `json:"name"`
Path string `json:"path"`
Size int64 `json:"size"`
Checksum string `json:"checksum"`
Format string `json:"format"` // pytorch, safetensors, onnx, etc.
}
// ModelSignature 模型签名
type ModelSignature struct {
Inputs []TensorSpec `json:"inputs"`
Outputs []TensorSpec `json:"outputs"`
}
// TensorSpec 张量规格
type TensorSpec struct {
Name string `json:"name"`
Dtype string `json:"dtype"`
Shape []int `json:"shape"` // -1 表示动态维度
}
// NewModelRegistry 创建模型注册中心
func NewModelRegistry(store MetadataStore, storage ModelStorage) *ModelRegistry {
return &ModelRegistry{
store: store,
storage: storage,
}
}
// CreateModel 创建模型
func (r *ModelRegistry) CreateModel(ctx context.Context, model *Model) error {
model.ID = generateID()
model.CreatedAt = time.Now()
model.UpdatedAt = time.Now()
return r.store.CreateModel(ctx, model)
}
// RegisterVersion 注册新版本
func (r *ModelRegistry) RegisterVersion(
ctx context.Context,
modelID string,
version *ModelVersion,
) error {
// 验证模型存在
model, err := r.store.GetModel(ctx, modelID)
if err != nil {
return fmt.Errorf("model not found: %w", err)
}
// 验证版本号唯一
existing, _ := r.store.GetVersionByNumber(ctx, modelID, version.Version)
if existing != nil {
return fmt.Errorf("version %s already exists", version.Version)
}
// 设置默认值
version.ID = generateID()
version.ModelID = model.ID
version.Stage = StageNone
version.CreatedAt = time.Now()
// 验证产物
for i, artifact := range version.Artifacts {
// 验证文件存在
exists, err := r.storage.Exists(ctx, artifact.Path)
if err != nil || !exists {
return fmt.Errorf("artifact not found: %s", artifact.Path)
}
// 获取文件信息
info, err := r.storage.Stat(ctx, artifact.Path)
if err != nil {
return err
}
version.Artifacts[i].Size = info.Size
// 计算校验和
if artifact.Checksum == "" {
checksum, err := r.storage.Checksum(ctx, artifact.Path)
if err != nil {
return err
}
version.Artifacts[i].Checksum = checksum
}
}
// 保存版本
return r.store.CreateVersion(ctx, version)
}
// TransitionStage 转换阶段
func (r *ModelRegistry) TransitionStage(
ctx context.Context,
modelID string,
version string,
targetStage ModelStage,
approver string,
) error {
// 获取版本
v, err := r.store.GetVersionByNumber(ctx, modelID, version)
if err != nil {
return err
}
// 验证转换规则
if !isValidTransition(v.Stage, targetStage) {
return fmt.Errorf("invalid transition from %s to %s", v.Stage, targetStage)
}
// 特殊处理:升级到 Production
if targetStage == StageProduction {
// 降级当前 Production 版本
currentProd, _ := r.GetProductionVersion(ctx, modelID)
if currentProd != nil {
currentProd.Stage = StageArchived
r.store.UpdateVersion(ctx, currentProd)
}
// 记录审批信息
now := time.Now()
v.ApprovedBy = approver
v.ApprovedAt = &now
}
v.Stage = targetStage
return r.store.UpdateVersion(ctx, v)
}
// GetProductionVersion 获取生产版本
func (r *ModelRegistry) GetProductionVersion(ctx context.Context, modelID string) (*ModelVersion, error) {
versions, err := r.store.ListVersions(ctx, modelID)
if err != nil {
return nil, err
}
for _, v := range versions {
if v.Stage == StageProduction {
return v, nil
}
}
return nil, fmt.Errorf("no production version found")
}
// GetLatestVersion 获取最新版本
func (r *ModelRegistry) GetLatestVersion(ctx context.Context, modelID string) (*ModelVersion, error) {
versions, err := r.store.ListVersions(ctx, modelID)
if err != nil {
return nil, err
}
if len(versions) == 0 {
return nil, fmt.Errorf("no versions found")
}
// 按创建时间排序
sort.Slice(versions, func(i, j int) bool {
return versions[i].CreatedAt.After(versions[j].CreatedAt)
})
return versions[0], nil
}
// CompareVersions 比较版本
func (r *ModelRegistry) CompareVersions(
ctx context.Context,
modelID string,
version1, version2 string,
) (*VersionComparison, error) {
v1, err := r.store.GetVersionByNumber(ctx, modelID, version1)
if err != nil {
return nil, err
}
v2, err := r.store.GetVersionByNumber(ctx, modelID, version2)
if err != nil {
return nil, err
}
comparison := &VersionComparison{
Version1: v1,
Version2: v2,
MetricsDiff: make(map[string]float64),
}
// 比较指标
for metric, val1 := range v1.Metrics {
if val2, ok := v2.Metrics[metric]; ok {
comparison.MetricsDiff[metric] = val2 - val1
}
}
// 比较参数量
comparison.ParametersDiff = v2.Parameters - v1.Parameters
// 比较文件大小
comparison.SizeDiff = r.calculateTotalSize(v2.Artifacts) - r.calculateTotalSize(v1.Artifacts)
return comparison, nil
}
type VersionComparison struct {
Version1 *ModelVersion
Version2 *ModelVersion
MetricsDiff map[string]float64
ParametersDiff int64
SizeDiff int64
}
func (r *ModelRegistry) calculateTotalSize(artifacts []Artifact) int64 {
var total int64
for _, a := range artifacts {
total += a.Size
}
return total
}
func isValidTransition(from, to ModelStage) bool {
validTransitions := map[ModelStage][]ModelStage{
StageNone: {StageDevelopment, StageArchived},
StageDevelopment: {StageStaging, StageArchived},
StageStaging: {StageProduction, StageDevelopment, StageArchived},
StageProduction: {StageArchived},
StageArchived: {StageDevelopment}, // 可以复活到开发阶段
}
allowed, ok := validTransitions[from]
if !ok {
return false
}
for _, s := range allowed {
if s == to {
return true
}
}
return false
}
func generateID() string {
// 生成唯一 ID
return fmt.Sprintf("%d", time.Now().UnixNano())
}
模型血缘追踪
// model_lineage.go
package registry
import (
"context"
"time"
)
// LineageTracker 血缘追踪器
type LineageTracker struct {
store LineageStore
}
// LineageNode 血缘节点
type LineageNode struct {
ID string `json:"id"`
Type NodeType `json:"type"`
Name string `json:"name"`
Version string `json:"version,omitempty"`
Metadata interface{} `json:"metadata,omitempty"`
}
// NodeType 节点类型
type NodeType string
const (
NodeTypeDataset NodeType = "dataset"
NodeTypeModel NodeType = "model"
NodeTypeExperiment NodeType = "experiment"
NodeTypeCheckpoint NodeType = "checkpoint"
NodeTypeArtifact NodeType = "artifact"
)
// LineageEdge 血缘边
type LineageEdge struct {
ID string `json:"id"`
Source string `json:"source"` // 源节点 ID
Target string `json:"target"` // 目标节点 ID
Relation string `json:"relation"` // 关系类型
Metadata interface{} `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// RelationType 关系类型
const (
RelationTrainedOn = "trained_on" // 模型在数据集上训练
RelationDerivedFrom = "derived_from" // 派生自
RelationProducedBy = "produced_by" // 由实验产生
RelationConvertedTo = "converted_to" // 转换为
RelationFinetunedFrom = "finetuned_from" // 微调自
)
// NewLineageTracker 创建血缘追踪器
func NewLineageTracker(store LineageStore) *LineageTracker {
return &LineageTracker{store: store}
}
// RecordTraining 记录训练血缘
func (t *LineageTracker) RecordTraining(
ctx context.Context,
modelVersion *ModelVersion,
datasets []string,
experiment string,
baseModel string, // 可选,微调场景
) error {
// 创建模型节点
modelNode := &LineageNode{
ID: fmt.Sprintf("model:%s:%s", modelVersion.ModelID, modelVersion.Version),
Type: NodeTypeModel,
Name: modelVersion.ModelID,
Version: modelVersion.Version,
Metadata: map[string]interface{}{
"parameters": modelVersion.Parameters,
"framework": modelVersion.Framework,
},
}
if err := t.store.CreateNode(ctx, modelNode); err != nil {
return err
}
// 创建与数据集的关系
for _, dataset := range datasets {
edge := &LineageEdge{
ID: generateID(),
Source: fmt.Sprintf("dataset:%s", dataset),
Target: modelNode.ID,
Relation: RelationTrainedOn,
CreatedAt: time.Now(),
}
if err := t.store.CreateEdge(ctx, edge); err != nil {
return err
}
}
// 创建与实验的关系
if experiment != "" {
edge := &LineageEdge{
ID: generateID(),
Source: fmt.Sprintf("experiment:%s", experiment),
Target: modelNode.ID,
Relation: RelationProducedBy,
CreatedAt: time.Now(),
}
if err := t.store.CreateEdge(ctx, edge); err != nil {
return err
}
}
// 微调场景:创建与基础模型的关系
if baseModel != "" {
edge := &LineageEdge{
ID: generateID(),
Source: fmt.Sprintf("model:%s", baseModel),
Target: modelNode.ID,
Relation: RelationFinetunedFrom,
CreatedAt: time.Now(),
}
if err := t.store.CreateEdge(ctx, edge); err != nil {
return err
}
}
return nil
}
// GetUpstream 获取上游血缘
func (t *LineageTracker) GetUpstream(
ctx context.Context,
nodeID string,
depth int,
) (*LineageGraph, error) {
return t.traverse(ctx, nodeID, depth, true)
}
// GetDownstream 获取下游血缘
func (t *LineageTracker) GetDownstream(
ctx context.Context,
nodeID string,
depth int,
) (*LineageGraph, error) {
return t.traverse(ctx, nodeID, depth, false)
}
// LineageGraph 血缘图
type LineageGraph struct {
Nodes []*LineageNode `json:"nodes"`
Edges []*LineageEdge `json:"edges"`
}
func (t *LineageTracker) traverse(
ctx context.Context,
startNode string,
maxDepth int,
upstream bool,
) (*LineageGraph, error) {
visited := make(map[string]bool)
graph := &LineageGraph{
Nodes: []*LineageNode{},
Edges: []*LineageEdge{},
}
var dfs func(nodeID string, depth int) error
dfs = func(nodeID string, depth int) error {
if depth > maxDepth || visited[nodeID] {
return nil
}
visited[nodeID] = true
// 获取节点
node, err := t.store.GetNode(ctx, nodeID)
if err != nil {
return err
}
graph.Nodes = append(graph.Nodes, node)
// 获取边
var edges []*LineageEdge
if upstream {
edges, err = t.store.GetIncomingEdges(ctx, nodeID)
} else {
edges, err = t.store.GetOutgoingEdges(ctx, nodeID)
}
if err != nil {
return err
}
for _, edge := range edges {
graph.Edges = append(graph.Edges, edge)
var nextNode string
if upstream {
nextNode = edge.Source
} else {
nextNode = edge.Target
}
if err := dfs(nextNode, depth+1); err != nil {
return err
}
}
return nil
}
if err := dfs(startNode, 0); err != nil {
return nil, err
}
return graph, nil
}
// FindImpact 查找影响范围
func (t *LineageTracker) FindImpact(
ctx context.Context,
nodeID string,
) (*ImpactAnalysis, error) {
downstream, err := t.GetDownstream(ctx, nodeID, 10)
if err != nil {
return nil, err
}
analysis := &ImpactAnalysis{
AffectedModels: []*LineageNode{},
AffectedExperiments: []*LineageNode{},
AffectedArtifacts: []*LineageNode{},
}
for _, node := range downstream.Nodes {
switch node.Type {
case NodeTypeModel:
analysis.AffectedModels = append(analysis.AffectedModels, node)
case NodeTypeExperiment:
analysis.AffectedExperiments = append(analysis.AffectedExperiments, node)
case NodeTypeArtifact:
analysis.AffectedArtifacts = append(analysis.AffectedArtifacts, node)
}
}
return analysis, nil
}
type ImpactAnalysis struct {
AffectedModels []*LineageNode
AffectedExperiments []*LineageNode
AffectedArtifacts []*LineageNode
}
模型格式转换
格式转换器
// model_converter.go
package converter
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
)
// ModelConverter 模型格式转换器
type ModelConverter interface {
Convert(ctx context.Context, input, output string, opts ConvertOptions) error
SupportedFormats() []string
}
// ConvertOptions 转换选项
type ConvertOptions struct {
SourceFormat string
TargetFormat string
Precision string // fp32, fp16, int8
OptLevel int // 优化级别
DeviceType string // cpu, cuda
}
// PyTorchConverter PyTorch 转换器
type PyTorchConverter struct {
pythonPath string
}
func NewPyTorchConverter(pythonPath string) *PyTorchConverter {
return &PyTorchConverter{pythonPath: pythonPath}
}
// Convert 执行转换
func (c *PyTorchConverter) Convert(ctx context.Context, input, output string, opts ConvertOptions) error {
switch opts.TargetFormat {
case "safetensors":
return c.toSafetensors(ctx, input, output)
case "onnx":
return c.toONNX(ctx, input, output, opts)
case "torchscript":
return c.toTorchScript(ctx, input, output)
default:
return fmt.Errorf("unsupported target format: %s", opts.TargetFormat)
}
}
// toSafetensors 转换为 safetensors 格式
func (c *PyTorchConverter) toSafetensors(ctx context.Context, input, output string) error {
script := `
import torch
from safetensors.torch import save_file
import sys
input_path = sys.argv[1]
output_path = sys.argv[2]
# 加载 PyTorch 模型
state_dict = torch.load(input_path, map_location='cpu')
# 如果是完整模型,提取 state_dict
if hasattr(state_dict, 'state_dict'):
state_dict = state_dict.state_dict()
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
# 转换为 safetensors
save_file(state_dict, output_path)
print(f"Converted to {output_path}")
`
return c.runPythonScript(ctx, script, input, output)
}
// toONNX 转换为 ONNX 格式
func (c *PyTorchConverter) toONNX(ctx context.Context, input, output string, opts ConvertOptions) error {
script := fmt.Sprintf(`
import torch
import sys
input_path = sys.argv[1]
output_path = sys.argv[2]
# 加载模型
model = torch.load(input_path, map_location='cpu')
model.eval()
# 创建示例输入(需要根据模型调整)
# 这里假设是语言模型
batch_size = 1
seq_length = 128
dummy_input = torch.randint(0, 1000, (batch_size, seq_length))
# 导出 ONNX
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=['input_ids'],
output_names=['logits'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence'},
'logits': {0: 'batch_size', 1: 'sequence'}
},
opset_version=14,
do_constant_folding=True,
)
print(f"Exported to {output_path}")
`)
return c.runPythonScript(ctx, script, input, output)
}
// toTorchScript 转换为 TorchScript
func (c *PyTorchConverter) toTorchScript(ctx context.Context, input, output string) error {
script := `
import torch
import sys
input_path = sys.argv[1]
output_path = sys.argv[2]
# 加载模型
model = torch.load(input_path, map_location='cpu')
model.eval()
# 尝试 trace
try:
# 创建示例输入
dummy_input = torch.randint(0, 1000, (1, 128))
traced = torch.jit.trace(model, dummy_input)
traced.save(output_path)
except Exception as e:
# 如果 trace 失败,尝试 script
print(f"Trace failed: {e}, trying script...")
scripted = torch.jit.script(model)
scripted.save(output_path)
print(f"Converted to {output_path}")
`
return c.runPythonScript(ctx, script, input, output)
}
func (c *PyTorchConverter) runPythonScript(ctx context.Context, script string, args ...string) error {
// 创建临时脚本文件
tmpFile, err := os.CreateTemp("", "convert_*.py")
if err != nil {
return err
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.WriteString(script); err != nil {
return err
}
tmpFile.Close()
// 执行脚本
cmdArgs := append([]string{tmpFile.Name()}, args...)
cmd := exec.CommandContext(ctx, c.pythonPath, cmdArgs...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (c *PyTorchConverter) SupportedFormats() []string {
return []string{"pytorch", "safetensors", "onnx", "torchscript"}
}
// ConversionPipeline 转换流水线
type ConversionPipeline struct {
converters map[string]ModelConverter
storage ModelStorage
}
// NewConversionPipeline 创建转换流水线
func NewConversionPipeline(storage ModelStorage) *ConversionPipeline {
return &ConversionPipeline{
converters: make(map[string]ModelConverter),
storage: storage,
}
}
// RegisterConverter 注册转换器
func (p *ConversionPipeline) RegisterConverter(name string, converter ModelConverter) {
p.converters[name] = converter
}
// ConvertModel 转换模型
func (p *ConversionPipeline) ConvertModel(
ctx context.Context,
modelID, version string,
targetFormats []string,
) ([]Artifact, error) {
// 获取源模型
sourceArtifact, err := p.storage.GetArtifact(ctx, modelID, version)
if err != nil {
return nil, fmt.Errorf("get source artifact: %w", err)
}
// 下载到临时目录
tmpDir, err := os.MkdirTemp("", "model_convert_*")
if err != nil {
return nil, err
}
defer os.RemoveAll(tmpDir)
inputPath := filepath.Join(tmpDir, "input", sourceArtifact.Name)
if err := p.storage.Download(ctx, sourceArtifact.Path, inputPath); err != nil {
return nil, err
}
var artifacts []Artifact
// 执行转换
for _, format := range targetFormats {
converter := p.findConverter(sourceArtifact.Format, format)
if converter == nil {
return nil, fmt.Errorf("no converter for %s -> %s", sourceArtifact.Format, format)
}
outputPath := filepath.Join(tmpDir, "output", fmt.Sprintf("model.%s", format))
os.MkdirAll(filepath.Dir(outputPath), 0755)
opts := ConvertOptions{
SourceFormat: sourceArtifact.Format,
TargetFormat: format,
}
if err := converter.Convert(ctx, inputPath, outputPath, opts); err != nil {
return nil, fmt.Errorf("convert to %s: %w", format, err)
}
// 上传转换结果
storagePath := fmt.Sprintf("%s/%s/model.%s", modelID, version, format)
if err := p.storage.Upload(ctx, outputPath, storagePath); err != nil {
return nil, err
}
// 获取文件信息
info, _ := os.Stat(outputPath)
artifacts = append(artifacts, Artifact{
Name: fmt.Sprintf("model.%s", format),
Path: storagePath,
Size: info.Size(),
Format: format,
})
}
return artifacts, nil
}
func (p *ConversionPipeline) findConverter(source, target string) ModelConverter {
for _, converter := range p.converters {
formats := converter.SupportedFormats()
hasSource, hasTarget := false, false
for _, f := range formats {
if f == source {
hasSource = true
}
if f == target {
hasTarget = true
}
}
if hasSource && hasTarget {
return converter
}
}
return nil
}
Kubernetes 集成
PVC 动态供给
# model-storage-class.yaml
apiVersion: storage.k8s.io/v1
kind: StorageClass
metadata:
name: model-storage
provisioner: csi.juicefs.com
parameters:
# JuiceFS 配置(适合模型存储)
csi.storage.k8s.io/provisioner-secret-name: juicefs-secret
csi.storage.k8s.io/provisioner-secret-namespace: kube-system
csi.storage.k8s.io/node-publish-secret-name: juicefs-secret
csi.storage.k8s.io/node-publish-secret-namespace: kube-system
# 挂载选项
mountOptions: "cache-size=10240,writeback"
reclaimPolicy: Retain
volumeBindingMode: Immediate
allowVolumeExpansion: true
---
# checkpoint-storage-class.yaml
apiVersion: storage.k8s.io/v1
kind: StorageClass
metadata:
name: checkpoint-storage
provisioner: lustre.csi.hpe.com
parameters:
# Lustre 配置(适合检查点)
mgsSpec: "192.168.1.10@tcp:/lustre"
stripeSize: "4M"
stripeCount: "4"
reclaimPolicy: Delete
volumeBindingMode: WaitForFirstConsumer
---
# 模型存储 PVC 模板
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: model-storage-{{ .JobName }}
spec:
storageClassName: model-storage
accessModes:
- ReadWriteMany
resources:
requests:
storage: 500Gi
---
# 检查点 PVC 模板
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: checkpoint-storage-{{ .JobName }}
spec:
storageClassName: checkpoint-storage
accessModes:
- ReadWriteMany
resources:
requests:
storage: 1Ti
训练任务存储挂载
# training-job-with-storage.yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
name: llm-training
spec:
pytorchReplicaSpecs:
Master:
replicas: 1
template:
spec:
containers:
- name: pytorch
image: training:latest
resources:
limits:
nvidia.com/gpu: 8
volumeMounts:
# 数据集(只读)
- name: dataset
mountPath: /data
readOnly: true
# 检查点(读写,高性能)
- name: checkpoints
mountPath: /checkpoints
# 模型输出(读写)
- name: models
mountPath: /models
# 本地缓存(NVMe)
- name: local-cache
mountPath: /cache
env:
- name: CHECKPOINT_DIR
value: /checkpoints
- name: MODEL_OUTPUT_DIR
value: /models
volumes:
# 数据集 - 使用共享存储
- name: dataset
persistentVolumeClaim:
claimName: dataset-imagenet
# 检查点 - 使用高性能并行文件系统
- name: checkpoints
persistentVolumeClaim:
claimName: checkpoint-storage-llm-training
# 模型输出 - 使用对象存储
- name: models
persistentVolumeClaim:
claimName: model-storage-llm-training
# 本地缓存 - 使用本地 NVMe
- name: local-cache
emptyDir:
medium: Memory # 或使用 hostPath 挂载 NVMe
sizeLimit: 100Gi
Worker:
replicas: 7
template:
# 类似配置...
模型服务挂载
# model-serving-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: model-server
spec:
replicas: 3
selector:
matchLabels:
app: model-server
template:
metadata:
labels:
app: model-server
spec:
# 使用 Init Container 预加载模型
initContainers:
- name: model-loader
image: model-loader:latest
command:
- /bin/sh
- -c
- |
# 从对象存储下载模型到本地缓存
aws s3 cp s3://models/llm-v1/ /cache/model/ --recursive
# 验证完整性
md5sum -c /cache/model/checksums.md5
volumeMounts:
- name: model-cache
mountPath: /cache
env:
- name: AWS_ACCESS_KEY_ID
valueFrom:
secretKeyRef:
name: s3-credentials
key: access-key
- name: AWS_SECRET_ACCESS_KEY
valueFrom:
secretKeyRef:
name: s3-credentials
key: secret-key
containers:
- name: server
image: model-server:latest
resources:
limits:
nvidia.com/gpu: 1
volumeMounts:
- name: model-cache
mountPath: /models
readOnly: true
env:
- name: MODEL_PATH
value: /models/model
volumes:
- name: model-cache
emptyDir:
sizeLimit: 50Gi
# 亲和性:优先调度到有缓存的节点
affinity:
nodeAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
- weight: 100
preference:
matchExpressions:
- key: model-cache/llm-v1
operator: Exists
监控与告警
Prometheus 指标
// storage_metrics.go
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// 检查点指标
CheckpointSaveTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "checkpoint_save_total",
Help: "Total number of checkpoint saves",
},
[]string{"job_id", "status"},
)
CheckpointSaveLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "checkpoint_save_latency_seconds",
Help: "Checkpoint save latency in seconds",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600},
},
[]string{"job_id", "storage_tier"},
)
CheckpointSizeBytes = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "checkpoint_size_bytes",
Help: "Checkpoint size in bytes",
Buckets: prometheus.ExponentialBuckets(1<<20, 2, 15), // 1MB to 32TB
},
[]string{"job_id"},
)
CheckpointLoadTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "checkpoint_load_total",
Help: "Total number of checkpoint loads",
},
[]string{"job_id", "status"},
)
CheckpointLoadLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "checkpoint_load_latency_seconds",
Help: "Checkpoint load latency in seconds",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300},
},
[]string{"job_id", "storage_tier"},
)
// 存储指标
StorageUsageBytes = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "model_storage_usage_bytes",
Help: "Model storage usage in bytes",
},
[]string{"storage_tier", "job_id"},
)
StorageBandwidthBytes = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "model_storage_bandwidth_bytes_total",
Help: "Model storage bandwidth in bytes",
},
[]string{"storage_tier", "operation"}, // operation: read, write
)
StorageOperationErrors = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "model_storage_errors_total",
Help: "Total storage operation errors",
},
[]string{"storage_tier", "operation", "error_type"},
)
// 模型注册指标
ModelVersionsTotal = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "model_versions_total",
Help: "Total number of model versions",
},
[]string{"model_id", "stage"},
)
ModelConversionLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "model_conversion_latency_seconds",
Help: "Model format conversion latency",
Buckets: []float64{10, 30, 60, 120, 300, 600, 1800},
},
[]string{"source_format", "target_format"},
)
// 数据迁移指标
StorageMigrationTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "storage_migration_total",
Help: "Total storage migrations",
},
[]string{"source_tier", "target_tier", "status"},
)
StorageMigrationBytes = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "storage_migration_bytes_total",
Help: "Total bytes migrated between storage tiers",
},
[]string{"source_tier", "target_tier"},
)
)
// RecordCheckpointSave 记录检查点保存
func RecordCheckpointSave(jobID string, tier string, size int64, latency float64, success bool) {
status := "success"
if !success {
status = "failure"
}
CheckpointSaveTotal.WithLabelValues(jobID, status).Inc()
CheckpointSaveLatency.WithLabelValues(jobID, tier).Observe(latency)
CheckpointSizeBytes.WithLabelValues(jobID).Observe(float64(size))
StorageBandwidthBytes.WithLabelValues(tier, "write").Add(float64(size))
}
// RecordCheckpointLoad 记录检查点加载
func RecordCheckpointLoad(jobID string, tier string, size int64, latency float64, success bool) {
status := "success"
if !success {
status = "failure"
}
CheckpointLoadTotal.WithLabelValues(jobID, status).Inc()
CheckpointLoadLatency.WithLabelValues(jobID, tier).Observe(latency)
StorageBandwidthBytes.WithLabelValues(tier, "read").Add(float64(size))
}
Grafana Dashboard
{
"dashboard": {
"title": "Model Storage Dashboard",
"panels": [
{
"title": "Checkpoint Save Performance",
"type": "graph",
"targets": [
{
"expr": "histogram_quantile(0.99, rate(checkpoint_save_latency_seconds_bucket[5m]))",
"legendFormat": "p99 - {{storage_tier}}"
},
{
"expr": "histogram_quantile(0.50, rate(checkpoint_save_latency_seconds_bucket[5m]))",
"legendFormat": "p50 - {{storage_tier}}"
}
]
},
{
"title": "Storage Bandwidth",
"type": "graph",
"targets": [
{
"expr": "rate(model_storage_bandwidth_bytes_total[5m])",
"legendFormat": "{{storage_tier}} - {{operation}}"
}
],
"yaxes": [{"format": "bytes"}]
},
{
"title": "Storage Usage by Tier",
"type": "piechart",
"targets": [
{
"expr": "sum(model_storage_usage_bytes) by (storage_tier)",
"legendFormat": "{{storage_tier}}"
}
]
},
{
"title": "Checkpoint Operations",
"type": "stat",
"targets": [
{
"expr": "sum(rate(checkpoint_save_total{status=\"success\"}[1h]))",
"legendFormat": "Saves/hour"
},
{
"expr": "sum(rate(checkpoint_load_total{status=\"success\"}[1h]))",
"legendFormat": "Loads/hour"
}
]
},
{
"title": "Storage Errors",
"type": "table",
"targets": [
{
"expr": "sum(increase(model_storage_errors_total[24h])) by (storage_tier, operation, error_type)",
"format": "table"
}
]
},
{
"title": "Model Versions by Stage",
"type": "bargauge",
"targets": [
{
"expr": "sum(model_versions_total) by (stage)",
"legendFormat": "{{stage}}"
}
]
}
]
}
}
最佳实践
存储架构设计原则
┌─────────────────────────────────────────────────────────────────┐
│ 存储架构最佳实践 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 分层存储策略 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • Hot 层:最新 2-3 个检查点,使用本地 NVMe 或 RDMA │ │
│ │ • Warm 层:近期检查点,使用并行文件系统(Lustre/GPFS) │ │
│ │ • Cold 层:里程碑和归档,使用对象存储(S3/OSS) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 2. 检查点优化 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 异步写入:不阻塞训练进程 │ │
│ │ • 增量检查点:只保存变化的部分 │ │
│ │ • 压缩存储:使用 LZ4/ZSTD 减少存储和带宽 │ │
│ │ • 分片保存:大模型并行写入多个文件 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 3. 数据完整性 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 写入时计算校验和 │ │
│ │ • 读取时验证完整性 │ │
│ │ • 三副本存储关键数据 │ │
│ │ • 定期数据巡检 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 4. 性能优化 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 条带化配置:根据文件大小调整条带数和大小 │ │
│ │ • 并行 I/O:多线程/进程并发读写 │ │
│ │ • 预取优化:训练时预加载下一批数据 │ │
│ │ • 缓存策略:合理配置各级缓存大小 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 5. 容量规划 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 预估公式:存储 = 模型大小 × 2 × 保留数量 × 安全系数 │ │
│ │ • 监控阈值:80% 告警,90% 自动清理 │ │
│ │ • 弹性扩展:支持在线扩容 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
小结
本章详细介绍了 AI 训练平台的模型存储与管理:
- 模型文件特征:分析了检查点、模型文件的大小特征和访问模式
- 分布式存储架构:讲解了并行文件系统和对象存储的集成方案
- 检查点管理:实现了完整的检查点生命周期管理和分布式检查点
- 模型版本管理:构建了模型注册中心和血缘追踪系统
- 格式转换:实现了多格式转换流水线
- Kubernetes 集成:展示了存储类配置和训练任务挂载
下一章我们将探讨 实验管理,讲解如何系统化地管理机器学习实验、追踪超参数和对比分析结果。