HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 完整学习路径

    • AI教程 - 从零到一的完整学习路径
    • 第00章:AI基础与发展史
    • 第01章:Python与AI开发环境
    • 第02章:数学基础-线性代数与微积分
    • 03-数据集详解-从获取到预处理
    • 04-从零训练第一个模型
    • 05-模型文件详解
    • 06-分布式训练-多GPU与多机
    • 07-模型调度与资源管理
    • 08-Transformer架构深度解析
    • 09-大语言模型原理与架构
    • 10-Token与Tokenization详解
    • 11-Prompt Engineering完全指南
    • 第12章:模型微调与LoRA技术
    • 第13章:RLHF与对齐技术
    • 第14章 AI编程助手原理与实现
    • 15-RAG系统设计与实现
    • 16-Agent智能体与工具调用
    • 17-多模态大模型
    • 第18章:AI前沿技术趋势
    • 第19章 AI热门话题与应用案例

05-模型文件详解

1. 模型文件是什么

深度学习模型文件保存了训练好的模型的各种信息,使得我们可以在不重新训练的情况下加载和使用模型。

1.1 模型文件包含的内容

1. 权重参数(Weights/Parameters)

  • 神经网络各层的权重矩阵
  • 偏置项(bias)
  • 例如:卷积层的卷积核、全连接层的权重矩阵

2. 模型架构(Architecture)

  • 网络的结构定义
  • 层的配置和连接方式
  • 有些格式包含,有些不包含

3. 优化器状态(Optimizer State)

  • 优化器的当前状态
  • 例如Adam的动量信息
  • 用于断点续训

4. 训练配置(Training Configuration)

  • 当前的epoch
  • 学习率
  • 其他训练相关的超参数

1.2 为什么需要保存模型

# 训练一个模型可能需要几天甚至几周
for epoch in range(1000):
    train(...)  # 非常耗时

# 如果不保存模型:
# 1. 程序崩溃 -> 从头开始训练
# 2. 想要部署 -> 需要重新训练
# 3. 想要分享 -> 需要让别人也训练

# 保存模型后:
torch.save(model.state_dict(), 'model.pth')
# 1. 可以随时加载继续训练
# 2. 可以直接用于推理
# 3. 可以分享给他人使用

1.3 常见模型文件格式

格式框架特点
.pth / .ptPyTorch二进制格式,使用pickle序列化
.ckptPyTorch Lightningcheckpoint格式,包含更多信息
.safetensors通用更安全的格式,避免pickle的安全问题
.h5TensorFlow/KerasHDF5格式
SavedModelTensorFlow目录格式,包含完整模型
.pbTensorFlowProtocol Buffer格式
.onnxONNX跨框架格式
.tfliteTensorFlow Lite移动端部署格式

2. PyTorch模型文件

2.1 .pt/.pth格式

PyTorch使用Python的pickle序列化模型数据。

保存和加载权重

import torch
import torch.nn as nn

# 定义一个简单模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建模型
model = SimpleNet()

# ========== 保存模型权重 ==========
# 方式1:只保存权重(推荐)
torch.save(model.state_dict(), 'model_weights.pth')

# 方式2:保存整个模型(不推荐,兼容性差)
torch.save(model, 'model_complete.pth')

# ========== 加载模型权重 ==========
# 加载方式1
model = SimpleNet()  # 需要先定义模型结构
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 设置为评估模式

# 加载方式2
model = torch.load('model_complete.pth')
model.eval()

为什么推荐保存state_dict?

# state_dict只保存参数,不保存模型结构
# 优点:
# 1. 文件更小
# 2. 更灵活,可以修改模型代码
# 3. 兼容性更好
# 4. 避免pickle的安全问题

# 完整模型包含Python代码
# 缺点:
# 1. 文件更大
# 2. 依赖原始代码
# 3. Python版本变化可能导致无法加载
# 4. 安全风险(pickle可以执行任意代码)

2.2 state_dict的内容

import torch
import torch.nn as nn

# 创建模型
model = SimpleNet()

# 查看state_dict
state_dict = model.state_dict()
print("State Dict Keys:")
for key in state_dict.keys():
    print(f"  {key}: {state_dict[key].shape}")

# 输出:
# State Dict Keys:
#   fc1.weight: torch.Size([20, 10])
#   fc1.bias: torch.Size([20])
#   fc2.weight: torch.Size([5, 20])
#   fc2.bias: torch.Size([5])

# 查看具体的权重
print("\nfc1.weight:")
print(state_dict['fc1.weight'])

print("\nfc1.bias:")
print(state_dict['fc1.bias'])

# state_dict是一个OrderedDict
print(f"\nType: {type(state_dict)}")
# Type: <class 'collections.OrderedDict'>

更复杂的模型

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128 * 8 * 8, 10)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = CNN()
state_dict = model.state_dict()

print("CNN State Dict:")
for key, value in state_dict.items():
    print(f"{key:30s} {str(value.shape):20s} {value.dtype}")

# 输出:
# conv1.weight                   torch.Size([64, 3, 3, 3])    torch.float32
# conv1.bias                     torch.Size([64])             torch.float32
# bn1.weight                     torch.Size([64])             torch.float32
# bn1.bias                       torch.Size([64])             torch.float32
# bn1.running_mean               torch.Size([64])             torch.float32
# bn1.running_var                torch.Size([64])             torch.float32
# bn1.num_batches_tracked        torch.Size([])               torch.int64
# conv2.weight                   torch.Size([128, 64, 3, 3])  torch.float32
# conv2.bias                     torch.Size([128])            torch.float32
# bn2.weight                     torch.Size([128])            torch.float32
# bn2.bias                       torch.Size([128])            torch.float32
# bn2.running_mean               torch.Size([128])            torch.float32
# bn2.running_var                torch.Size([128])            torch.float32
# bn2.num_batches_tracked        torch.Size([])               torch.int64
# fc.weight                      torch.Size([10, 8192])       torch.float32
# fc.bias                        torch.Size([10])             torch.float32

2.3 完整保存和加载代码

保存完整的训练状态

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# 训练一些epoch
epoch = 50
train_loss = 0.123
val_loss = 0.456
best_acc = 95.67

# ========== 保存完整checkpoint ==========
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_loss': train_loss,
    'val_loss': val_loss,
    'best_acc': best_acc,
}

torch.save(checkpoint, 'checkpoint.pth')
print("Checkpoint saved!")

# ========== 加载完整checkpoint ==========
checkpoint = torch.load('checkpoint.pth')

model = CNN()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
val_loss = checkpoint['val_loss']
best_acc = checkpoint['best_acc']

print(f"Checkpoint loaded! Epoch: {epoch}, Best Acc: {best_acc}%")

# 继续训练
model.train()
# ... 继续训练代码 ...

保存多个版本的模型

def save_checkpoint(model, optimizer, epoch, val_acc, filename):
    """保存checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
    }
    torch.save(checkpoint, filename)
    print(f"保存checkpoint: {filename}")

# 在训练循环中
best_acc = 0.0
for epoch in range(num_epochs):
    # 训练...
    val_acc = validate(...)

    # 保存每个epoch的checkpoint
    save_checkpoint(model, optimizer, epoch, val_acc,
                   f'checkpoint_epoch_{epoch}.pth')

    # 保存最佳模型
    if val_acc > best_acc:
        best_acc = val_acc
        save_checkpoint(model, optimizer, epoch, val_acc,
                       'best_model.pth')

    # 保存最新模型(用于断点续训)
    save_checkpoint(model, optimizer, epoch, val_acc,
                   'latest_checkpoint.pth')

2.4 断点续训实现

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

def train_with_resume(model, train_loader, val_loader, device,
                     num_epochs=100, resume_from=None):
    """支持断点续训的训练函数"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    start_epoch = 0
    best_acc = 0.0

    # ========== 如果有checkpoint,加载它 ==========
    if resume_from and os.path.exists(resume_from):
        print(f"从checkpoint恢复: {resume_from}")
        checkpoint = torch.load(resume_from)

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_acc = checkpoint['best_acc']

        print(f"从epoch {start_epoch} 继续训练, 最佳准确率: {best_acc:.2f}%")

    model = model.to(device)

    # ========== 训练循环 ==========
    for epoch in range(start_epoch, num_epochs):
        # 训练
        model.train()
        train_loss = 0.0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # 验证
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()

        val_acc = 100. * val_correct / val_total

        scheduler.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Loss: {train_loss:.4f}, Val Acc: {val_acc:.2f}%')

        # ========== 保存checkpoint ==========
        # 保存最新checkpoint(每个epoch)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_acc': val_acc,
            'best_acc': best_acc,
        }, 'latest_checkpoint.pth')

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
            }, 'best_model.pth')
            print(f'保存最佳模型 (Val Acc: {val_acc:.2f}%)')

    print(f'训练完成! 最佳准确率: {best_acc:.2f}%')

# 使用示例
# 首次训练
# train_with_resume(model, train_loader, val_loader, device, num_epochs=100)

# 如果训练中断,从checkpoint恢复
# train_with_resume(model, train_loader, val_loader, device,
#                  num_epochs=100, resume_from='latest_checkpoint.pth')

2.5 跨设备保存和加载

# ========== 在GPU上训练,保存模型 ==========
device = torch.device('cuda')
model = CNN().to(device)
# ... 训练 ...
torch.save(model.state_dict(), 'model.pth')

# ========== 在CPU上加载 ==========
device = torch.device('cpu')
model = CNN()
# map_location指定加载到哪个设备
model.load_state_dict(torch.load('model.pth', map_location=device))

# ========== 在不同GPU上加载 ==========
# 保存在GPU 0
device = torch.device('cuda:0')
torch.save(model.state_dict(), 'model.pth')

# 加载到GPU 1
device = torch.device('cuda:1')
model.load_state_dict(torch.load('model.pth', map_location='cuda:1'))

# ========== 保存整个checkpoint ==========
# 保存时
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')

# 加载时
checkpoint = torch.load('checkpoint.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 将优化器状态移到正确的设备
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

3. TensorFlow模型文件

3.1 SavedModel格式

SavedModel是TensorFlow推荐的保存格式,是一个目录结构。

import tensorflow as tf
from tensorflow import keras

# 创建一个简单模型
model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 假设已经训练了模型
# model.fit(x_train, y_train, epochs=5)

# ========== 保存为SavedModel格式 ==========
model.save('saved_model/my_model')

# 目录结构:
# saved_model/my_model/
# ├── assets/
# ├── variables/
# │   ├── variables.data-00000-of-00001
# │   └── variables.index
# └── saved_model.pb

# ========== 加载SavedModel ==========
loaded_model = keras.models.load_model('saved_model/my_model')

# 使用加载的模型
predictions = loaded_model.predict(x_test)

SavedModel目录结构说明

saved_model/my_model/
├── assets/                          # 额外资源文件
├── variables/                       # 模型变量(权重)
│   ├── variables.data-00000-of-00001  # 变量数据
│   └── variables.index                # 变量索引
└── saved_model.pb                   # 模型结构和元信息

查看SavedModel信息

# 使用saved_model_cli查看模型信息
saved_model_cli show --dir saved_model/my_model --all

# 输出:
# MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
#
# signature_def['__saved_model_init_op']:
#   ...
#
# signature_def['serving_default']:
#   The given SavedModel SignatureDef contains the following input(s):
#     inputs['dense_input'] tensor_info:
#         dtype: DT_FLOAT
#         shape: (-1, 784)
#         name: serving_default_dense_input:0
#   The given SavedModel SignatureDef contains the following output(s):
#     outputs['dense_1'] tensor_info:
#         dtype: DT_FLOAT
#         shape: (-1, 10)
#         name: StatefulPartitionedCall:0

3.2 .h5格式

HDF5格式是Keras的传统保存格式。

import tensorflow as tf
from tensorflow import keras

# 创建模型
model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# ========== 保存为HDF5格式 ==========
# 保存整个模型(结构 + 权重 + 优化器状态)
model.save('my_model.h5')

# 只保存权重
model.save_weights('model_weights.h5')

# ========== 加载HDF5格式 ==========
# 加载整个模型
loaded_model = keras.models.load_model('my_model.h5')

# 只加载权重(需要先定义模型结构)
model = keras.Sequential([...])  # 定义相同的结构
model.load_weights('model_weights.h5')

查看HDF5文件内容

import h5py

def print_h5_structure(filename, group_path='/'):
    """打印HDF5文件结构"""
    with h5py.File(filename, 'r') as f:
        def print_attrs(name, obj):
            print(name)
            if isinstance(obj, h5py.Dataset):
                print(f"  Shape: {obj.shape}")
                print(f"  Dtype: {obj.dtype}")

        if group_path == '/':
            f.visititems(print_attrs)
        else:
            f[group_path].visititems(print_attrs)

# 查看模型文件结构
print_h5_structure('my_model.h5')

# 输出示例:
# model_weights/dense/dense/bias:0
#   Shape: (128,)
#   Dtype: float32
# model_weights/dense/dense/kernel:0
#   Shape: (784, 128)
#   Dtype: float32
# model_weights/dense_1/dense_1/bias:0
#   Shape: (10,)
#   Dtype: float32
# model_weights/dense_1/dense_1/kernel:0
#   Shape: (128, 10)
#   Dtype: float32

3.3 Checkpoint文件

TensorFlow checkpoint用于保存训练过程中的模型状态。

import tensorflow as tf
from tensorflow import keras

# 创建模型
model = keras.Sequential([...])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# ========== 训练时保存checkpoint ==========
# 回调函数,每个epoch保存一次
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath='training_checkpoints/cp-{epoch:04d}.ckpt',
    save_weights_only=True,
    save_freq='epoch',
    verbose=1
)

# 训练
model.fit(x_train, y_train,
          epochs=10,
          callbacks=[checkpoint_callback])

# 目录结构:
# training_checkpoints/
# ├── cp-0001.ckpt.data-00000-of-00001
# ├── cp-0001.ckpt.index
# ├── cp-0002.ckpt.data-00000-of-00001
# ├── cp-0002.ckpt.index
# └── checkpoint

# ========== 加载checkpoint ==========
# 加载最新的checkpoint
latest = tf.train.latest_checkpoint('training_checkpoints')
print(f"最新checkpoint: {latest}")

model.load_weights(latest)

# 或者加载特定的checkpoint
model.load_weights('training_checkpoints/cp-0005.ckpt')

自定义checkpoint保存

import tensorflow as tf

# 创建checkpoint对象
checkpoint = tf.train.Checkpoint(
    optimizer=optimizer,
    model=model,
    epoch=tf.Variable(0)
)

# 创建checkpoint管理器
manager = tf.train.CheckpointManager(
    checkpoint,
    directory='./checkpoints',
    max_to_keep=5  # 只保留最近5个checkpoint
)

# 训练循环
for epoch in range(num_epochs):
    # 训练...

    # 保存checkpoint
    checkpoint.epoch.assign_add(1)
    save_path = manager.save()
    print(f'保存checkpoint: {save_path}')

# 恢复最新的checkpoint
checkpoint.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print(f"恢复自: {manager.latest_checkpoint}")
else:
    print("从头开始训练")

4. 模型文件结构

4.1 打开.pth文件看内部结构

import torch
import pickle

# ========== 方法1:使用torch.load ==========
state_dict = torch.load('model.pth')
print("State Dict Keys:")
for key, value in state_dict.items():
    print(f"{key}: {value.shape}, {value.dtype}")

# ========== 方法2:使用pickle查看原始内容 ==========
with open('model.pth', 'rb') as f:
    # 不完全加载,只查看结构
    unpickler = pickle.Unpickler(f)
    data = unpickler.load()
    print(f"\nType: {type(data)}")
    print(f"Keys: {data.keys() if isinstance(data, dict) else 'Not a dict'}")

# ========== 方法3:详细分析 ==========
def analyze_model_file(filepath):
    """详细分析模型文件"""
    import os

    # 文件大小
    file_size = os.path.getsize(filepath) / (1024 * 1024)  # MB
    print(f"文件大小: {file_size:.2f} MB")

    # 加载内容
    data = torch.load(filepath)

    if isinstance(data, dict):
        print("\n=== 字典内容 ===")
        for key, value in data.items():
            if isinstance(value, torch.Tensor):
                print(f"{key}:")
                print(f"  形状: {value.shape}")
                print(f"  数据类型: {value.dtype}")
                print(f"  设备: {value.device}")
                print(f"  内存大小: {value.element_size() * value.numel() / (1024*1024):.2f} MB")
                print(f"  最小值: {value.min().item():.6f}")
                print(f"  最大值: {value.max().item():.6f}")
                print(f"  均值: {value.float().mean().item():.6f}")
                print(f"  标准差: {value.float().std().item():.6f}")
            elif isinstance(value, dict):
                print(f"{key}: (nested dict with {len(value)} items)")
            else:
                print(f"{key}: {value}")
            print()

    elif isinstance(data, torch.nn.Module):
        print("\n=== 完整模型 ===")
        print(data)
        print(f"\n总参数量: {sum(p.numel() for p in data.parameters()):,}")
    else:
        print(f"未知格式: {type(data)}")

# 使用
analyze_model_file('model.pth')

输出示例

文件大小: 45.32 MB

=== 字典内容 ===
conv1.weight:
  形状: torch.Size([64, 3, 7, 7])
  数据类型: torch.float32
  设备: cpu
  内存大小: 0.04 MB
  最小值: -0.156789
  最大值: 0.142345
  均值: -0.000123
  标准差: 0.054321

conv1.bias:
  形状: torch.Size([64])
  数据类型: torch.float32
  设备: cpu
  内存大小: 0.00 MB
  最小值: -0.089765
  最大值: 0.078901
  均值: 0.001234
  标准差: 0.023456

...

4.2 张量形状和数据类型

import torch
import torch.nn as nn

# 创建一个CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc = nn.Linear(128 * 8 * 8, 10)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = SimpleCNN()
torch.save(model.state_dict(), 'cnn_model.pth')

# ========== 分析张量形状 ==========
state_dict = torch.load('cnn_model.pth')

print("层级结构和张量形状分析:\n")
print(f"{'参数名称':<40} {'形状':<25} {'数据类型':<15} {'参数量':>15}")
print("=" * 100)

total_params = 0
for name, tensor in state_dict.items():
    num_params = tensor.numel()
    total_params += num_params

    # 解析层的类型
    layer_type = "Unknown"
    if 'conv' in name and 'weight' in name:
        layer_type = "Conv2d"
        # Conv2d权重形状: [out_channels, in_channels, kernel_h, kernel_w]
        shape_info = f"out={tensor.shape[0]}, in={tensor.shape[1]}, k={tensor.shape[2]}x{tensor.shape[3]}"
    elif 'conv' in name and 'bias' in name:
        layer_type = "Conv2d Bias"
        shape_info = f"channels={tensor.shape[0]}"
    elif 'bn' in name and 'weight' in name:
        layer_type = "BatchNorm"
        shape_info = f"features={tensor.shape[0]}"
    elif 'fc' in name and 'weight' in name:
        layer_type = "Linear"
        # Linear权重形状: [out_features, in_features]
        shape_info = f"out={tensor.shape[0]}, in={tensor.shape[1]}"
    elif 'fc' in name and 'bias' in name:
        layer_type = "Linear Bias"
        shape_info = f"features={tensor.shape[0]}"
    else:
        shape_info = str(tensor.shape)

    print(f"{name:<40} {str(tensor.shape):<25} {str(tensor.dtype):<15} {num_params:>15,}")
    print(f"  └─ {layer_type}: {shape_info}")

print("=" * 100)
print(f"{'总参数量:':<80} {total_params:>15,}")
print(f"{'模型大小 (float32):':<80} {total_params * 4 / (1024*1024):>12.2f} MB")

4.3 层级结构和命名

import torch
import torch.nn as nn

# 创建一个有嵌套结构的模型
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return torch.relu(out)

class ComplexNet(nn.Module):
    def __init__(self):
        super(ComplexNet, self).__init__()
        # 第一层
        self.conv_input = nn.Conv2d(3, 64, 7, padding=3)

        # 残差块
        self.layer1 = nn.Sequential(
            ResidualBlock(64, 64),
            ResidualBlock(64, 64)
        )

        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128)
        )

        # 全局池化和分类器
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_input(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = ComplexNet()

# ========== 查看模型的层级结构 ==========
print("=== 模型层级结构 ===\n")
for name, module in model.named_modules():
    if name == '':
        print("ComplexNet (root)")
    else:
        indent = '  ' * name.count('.')
        print(f"{indent}{name}: {module.__class__.__name__}")

# 输出:
# ComplexNet (root)
# conv_input: Conv2d
# layer1: Sequential
#   layer1.0: ResidualBlock
#     layer1.0.conv1: Conv2d
#     layer1.0.bn1: BatchNorm2d
#     layer1.0.conv2: Conv2d
#     layer1.0.bn2: BatchNorm2d
#   layer1.1: ResidualBlock
#     layer1.1.conv1: Conv2d
#     layer1.1.bn1: BatchNorm2d
#     layer1.1.conv2: Conv2d
#     layer1.1.bn2: BatchNorm2d
# ...

# ========== 查看参数的命名 ==========
print("\n=== 参数命名规则 ===\n")
for name, param in model.named_parameters():
    print(f"{name:<50} {str(param.shape):<25} {param.numel():>10,}")

# 输出:
# conv_input.weight                                  torch.Size([64, 3, 7, 7])         9,408
# conv_input.bias                                    torch.Size([64])                     64
# layer1.0.conv1.weight                             torch.Size([64, 64, 3, 3])       36,864
# layer1.0.conv1.bias                               torch.Size([64])                     64
# layer1.0.bn1.weight                               torch.Size([64])                     64
# layer1.0.bn1.bias                                 torch.Size([64])                     64
# ...

理解命名规则

# PyTorch的参数命名规则:
# 1. 使用点号(.)分隔层级
# 2. 模块名.子模块名.参数名# layer1.0.conv1.weight
#   ↓     ↓  ↓      ↓
#   模块  索引 子层  参数

# 对于Sequential容器,使用数字索引
# 对于自定义容器,使用属性名

# 可以通过名称访问参数
state_dict = model.state_dict()
conv1_weight = state_dict['layer1.0.conv1.weight']
print(f"Shape: {conv1_weight.shape}")

5. 模型转换

5.1 PyTorch → ONNX

ONNX (Open Neural Network Exchange) 是一个开放的模型格式,支持跨框架使用。

import torch
import torch.nn as nn

# 定义模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 10, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

model = SimpleNet()
model.eval()

# ========== 导出为ONNX ==========
# 创建虚拟输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出
torch.onnx.export(
    model,                          # 模型
    dummy_input,                    # 虚拟输入
    "model.onnx",                   # 输出文件名
    export_params=True,             # 导出参数
    opset_version=11,               # ONNX版本
    do_constant_folding=True,       # 常量折叠优化
    input_names=['input'],          # 输入名称
    output_names=['output'],        # 输出名称
    dynamic_axes={                  # 动态维度
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print("模型已导出为ONNX格式")

# ========== 验证ONNX模型 ==========
import onnx

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX模型验证通过")

# 打印模型信息
print("\nONNX模型信息:")
print(f"  IR版本: {onnx_model.ir_version}")
print(f"  生产者: {onnx_model.producer_name}")
print(f"  图的输入数: {len(onnx_model.graph.input)}")
print(f"  图的输出数: {len(onnx_model.graph.output)}")
print(f"  节点数: {len(onnx_model.graph.node)}")

# ========== 使用ONNX Runtime推理 ==========
import onnxruntime as ort
import numpy as np

# 创建推理会话
ort_session = ort.InferenceSession("model.onnx")

# 准备输入
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 推理
outputs = ort_session.run(
    None,
    {"input": input_data}
)

print(f"\n输出形状: {outputs[0].shape}")

# ========== 比较PyTorch和ONNX的输出 ==========
with torch.no_grad():
    pytorch_output = model(torch.from_numpy(input_data))

print("\n输出比较:")
print(f"PyTorch输出: {pytorch_output.numpy().flatten()[:5]}")
print(f"ONNX输出:    {outputs[0].flatten()[:5]}")
print(f"最大差异: {np.max(np.abs(pytorch_output.numpy() - outputs[0]))}")

5.2 ONNX → TensorRT

TensorRT是NVIDIA的高性能推理引擎。

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np

def build_engine_from_onnx(onnx_file_path, engine_file_path, fp16_mode=False):
    """
    从ONNX文件构建TensorRT引擎

    需要安装:
    pip install tensorrt
    pip install pycuda
    """
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

    # 创建builder
    builder = trt.Builder(TRT_LOGGER)

    # 创建网络
    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(network_flags)

    # 创建ONNX解析器
    parser = trt.OnnxParser(network, TRT_LOGGER)

    # 解析ONNX文件
    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print('ERROR: Failed to parse the ONNX file.')
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    # 配置builder
    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 30  # 1GB

    # 启用FP16
    if fp16_mode and builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        print("启用FP16模式")

    # 构建引擎
    print("构建TensorRT引擎...")
    engine = builder.build_engine(network, config)

    if engine is None:
        print("构建引擎失败")
        return None

    # 保存引擎
    with open(engine_file_path, 'wb') as f:
        f.write(engine.serialize())

    print(f"TensorRT引擎已保存到: {engine_file_path}")
    return engine

# 使用
# engine = build_engine_from_onnx('model.onnx', 'model.trt', fp16_mode=True)

使用TensorRT推理

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np

class TRTInference:
    def __init__(self, engine_path):
        """加载TensorRT引擎"""
        self.logger = trt.Logger(trt.Logger.WARNING)
        self.runtime = trt.Runtime(self.logger)

        # 加载引擎
        with open(engine_path, 'rb') as f:
            self.engine = self.runtime.deserialize_cuda_engine(f.read())

        self.context = self.engine.create_execution_context()

        # 分配内存
        self.inputs = []
        self.outputs = []
        self.bindings = []

        for binding in self.engine:
            size = trt.volume(self.engine.get_binding_shape(binding))
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))

            # 分配设备内存
            device_mem = cuda.mem_alloc(size * np.dtype(dtype).itemsize)
            self.bindings.append(int(device_mem))

            if self.engine.binding_is_input(binding):
                self.inputs.append({'name': binding, 'mem': device_mem, 'size': size, 'dtype': dtype})
            else:
                self.outputs.append({'name': binding, 'mem': device_mem, 'size': size, 'dtype': dtype})

    def infer(self, input_data):
        """执行推理"""
        # 复制输入数据到设备
        cuda.memcpy_htod(self.inputs[0]['mem'], input_data)

        # 执行推理
        self.context.execute_v2(bindings=self.bindings)

        # 复制输出数据到主机
        output = np.empty(self.outputs[0]['size'], dtype=self.outputs[0]['dtype'])
        cuda.memcpy_dtoh(output, self.outputs[0]['mem'])

        return output

# 使用
# trt_infer = TRTInference('model.trt')
# input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# output = trt_infer.infer(input_data)

5.3 量化和剪枝

量化(Quantization)

将float32精度降低到int8,减小模型大小和提高推理速度。

import torch
import torch.quantization

# ========== 动态量化 ==========
def dynamic_quantization(model):
    """
    动态量化(推理时量化,适用于LSTM/RNN)
    """
    quantized_model = torch.quantization.quantize_dynamic(
        model,
        {torch.nn.Linear, torch.nn.LSTM},  # 要量化的层类型
        dtype=torch.qint8
    )
    return quantized_model

# ========== 静态量化 ==========
def static_quantization(model, calibration_loader):
    """
    静态量化(需要校准数据)
    """
    # 1. 融合模块(Conv + BN + ReLU)
    model_fused = torch.quantization.fuse_modules(
        model,
        [['conv', 'bn', 'relu']]
    )

    # 2. 设置量化配置
    model_fused.qconfig = torch.quantization.get_default_qconfig('fbgemm')

    # 3. 准备模型
    torch.quantization.prepare(model_fused, inplace=True)

    # 4. 校准(使用代表性数据)
    model_fused.eval()
    with torch.no_grad():
        for data, _ in calibration_loader:
            model_fused(data)

    # 5. 转换为量化模型
    torch.quantization.convert(model_fused, inplace=True)

    return model_fused

# ========== 量化感知训练 ==========
def quantization_aware_training(model, train_loader, num_epochs=10):
    """
    量化感知训练(在训练时模拟量化)
    """
    # 1. 设置量化配置
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

    # 2. 准备模型
    torch.quantization.prepare_qat(model, inplace=True)

    # 3. 训练
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    # 4. 转换为量化模型
    model.eval()
    torch.quantization.convert(model, inplace=True)

    return model

# ========== 比较模型大小和速度 ==========
import time

def compare_models(original_model, quantized_model, test_data):
    """比较原始模型和量化模型"""
    # 模型大小
    def get_model_size(model):
        torch.save(model.state_dict(), "temp.pth")
        size = os.path.getsize("temp.pth") / (1024 * 1024)  # MB
        os.remove("temp.pth")
        return size

    original_size = get_model_size(original_model)
    quantized_size = get_model_size(quantized_model)

    print(f"原始模型大小: {original_size:.2f} MB")
    print(f"量化模型大小: {quantized_size:.2f} MB")
    print(f"压缩比: {original_size / quantized_size:.2f}x")

    # 推理速度
    original_model.eval()
    quantized_model.eval()

    # 原始模型
    start = time.time()
    with torch.no_grad():
        for _ in range(100):
            _ = original_model(test_data)
    original_time = time.time() - start

    # 量化模型
    start = time.time()
    with torch.no_grad():
        for _ in range(100):
            _ = quantized_model(test_data)
    quantized_time = time.time() - start

    print(f"\n原始模型推理时间: {original_time:.3f}s")
    print(f"量化模型推理时间: {quantized_time:.3f}s")
    print(f"加速比: {original_time / quantized_time:.2f}x")

剪枝(Pruning)

移除不重要的权重,减小模型大小。

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

def prune_model(model, amount=0.3):
    """
    剪枝模型

    Args:
        model: 要剪枝的模型
        amount: 剪枝比例(0-1)
    """
    # 对每个卷积层和全连接层进行剪枝
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            # L1非结构化剪枝
            prune.l1_unstructured(module, name='weight', amount=amount)

    print(f"剪枝完成,移除了 {amount*100}% 的权重")

    return model

def structured_pruning(model, amount=0.3):
    """
    结构化剪枝(移除整个通道或神经元)
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # 按L2范数剪枝卷积核
            prune.ln_structured(
                module,
                name='weight',
                amount=amount,
                n=2,
                dim=0  # 输出通道维度
            )

    print(f"结构化剪枝完成")
    return model

def global_pruning(model, amount=0.3):
    """
    全局剪枝(在所有层中选择最不重要的权重)
    """
    # 收集所有要剪枝的参数
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))

    # 全局剪枝
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )

    print(f"全局剪枝完成")
    return model

def remove_pruning(model):
    """
    永久移除剪枝掩码,使剪枝生效
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            try:
                prune.remove(module, 'weight')
            except:
                pass

    return model

# ========== 使用示例 ==========
# 原始模型
model = SimpleCNN()
print(f"原始模型参数量: {sum(p.numel() for p in model.parameters()):,}")

# 剪枝
pruned_model = prune_model(model, amount=0.5)

# 查看剪枝效果
for name, module in pruned_model.named_modules():
    if isinstance(module, nn.Conv2d):
        print(f"\n{name}:")
        print(f"  权重形状: {module.weight.shape}")
        print(f"  零元素比例: {(module.weight == 0).sum().item() / module.weight.numel():.2%}")

# 永久移除剪枝掩码
pruned_model = remove_pruning(pruned_model)

# 保存剪枝后的模型
torch.save(pruned_model.state_dict(), 'pruned_model.pth')

5.4 模型压缩技术

# 知识蒸馏(Knowledge Distillation)
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        """
        Args:
            student_logits: 学生模型输出
            teacher_logits: 教师模型输出
            labels: 真实标签
        """
        # 软目标损失(蒸馏损失)
        soft_loss = self.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)

        # 硬目标损失(分类损失)
        hard_loss = self.ce(student_logits, labels)

        # 组合损失
        loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return loss

def train_with_distillation(teacher_model, student_model, train_loader,
                            device, num_epochs=10):
    """使用知识蒸馏训练学生模型"""
    teacher_model.eval()  # 教师模型不训练
    student_model.train()

    criterion = DistillationLoss(temperature=3.0, alpha=0.7)
    optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        total_loss = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            # 教师模型推理
            with torch.no_grad():
                teacher_logits = teacher_model(data)

            # 学生模型训练
            optimizer.zero_grad()
            student_logits = student_model(data)
            loss = criterion(student_logits, teacher_logits, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

    return student_model

6. 模型部署格式

6.1 ONNX Runtime

import onnxruntime as ort
import numpy as np

class ONNXModel:
    def __init__(self, onnx_path):
        """加载ONNX模型"""
        # 创建推理会话
        self.session = ort.InferenceSession(onnx_path)

        # 获取输入输出信息
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

        print(f"输入名称: {self.input_name}")
        print(f"输出名称: {self.output_name}")

    def predict(self, input_data):
        """推理"""
        # 确保输入是numpy数组
        if not isinstance(input_data, np.ndarray):
            input_data = input_data.numpy()

        # 推理
        outputs = self.session.run(
            [self.output_name],
            {self.input_name: input_data}
        )

        return outputs[0]

# 使用
model = ONNXModel('model.onnx')
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = model.predict(input_data)
print(f"输出形状: {output.shape}")

6.2 TorchScript

import torch

# ========== 方法1: Tracing ==========
def export_torchscript_trace(model, example_input):
    """
    使用Tracing导出TorchScript
    适用于不包含控制流的模型
    """
    model.eval()

    # Trace模型
    traced_model = torch.jit.trace(model, example_input)

    # 保存
    traced_model.save('model_traced.pt')

    print("TorchScript (Trace) 导出完成")

    return traced_model

# ========== 方法2: Scripting ==========
def export_torchscript_script(model):
    """
    使用Scripting导出TorchScript
    支持控制流(if, loop等)
    """
    model.eval()

    # Script模型
    scripted_model = torch.jit.script(model)

    # 保存
    scripted_model.save('model_scripted.pt')

    print("TorchScript (Script) 导出完成")

    return scripted_model

# ========== 加载TorchScript模型 ==========
def load_torchscript(path):
    """加载TorchScript模型"""
    model = torch.jit.load(path)
    model.eval()
    return model

# ========== 使用示例 ==========
# 原始模型
model = SimpleCNN()
example_input = torch.randn(1, 3, 224, 224)

# 导出
traced_model = export_torchscript_trace(model, example_input)

# 加载和使用
loaded_model = load_torchscript('model_traced.pt')
output = loaded_model(example_input)
print(f"输出形状: {output.shape}")

# ========== 优化TorchScript模型 ==========
# 冻结模型(移除训练相关的操作)
frozen_model = torch.jit.freeze(traced_model)
frozen_model.save('model_frozen.pt')

# 进一步优化
optimized_model = torch.jit.optimize_for_inference(frozen_model)
optimized_model.save('model_optimized.pt')

6.3 TensorFlow Lite

import tensorflow as tf

# ========== 转换为TFLite ==========
def convert_to_tflite(saved_model_dir, output_file):
    """
    将SavedModel转换为TFLite格式
    """
    # 创建转换器
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)

    # 优化选项
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # 转换
    tflite_model = converter.convert()

    # 保存
    with open(output_file, 'wb') as f:
        f.write(tflite_model)

    print(f"TFLite模型已保存到: {output_file}")

# ========== 量化转换 ==========
def convert_to_tflite_quantized(saved_model_dir, output_file, representative_dataset):
    """
    转换为量化的TFLite模型
    """
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)

    # 全整数量化
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8

    tflite_model = converter.convert()

    with open(output_file, 'wb') as f:
        f.write(tflite_model)

    print(f"量化TFLite模型已保存到: {output_file}")

# ========== 使用TFLite模型推理 ==========
def tflite_inference(model_path, input_data):
    """使用TFLite模型推理"""
    # 加载模型
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # 获取输入输出张量
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    print("输入信息:")
    print(f"  形状: {input_details[0]['shape']}")
    print(f"  类型: {input_details[0]['dtype']}")

    # 设置输入
    interpreter.set_tensor(input_details[0]['index'], input_data)

    # 推理
    interpreter.invoke()

    # 获取输出
    output_data = interpreter.get_tensor(output_details[0]['index'])

    return output_data

# 使用
# convert_to_tflite('saved_model/my_model', 'model.tflite')
# output = tflite_inference('model.tflite', input_data)
Prev
04-从零训练第一个模型
Next
06-分布式训练-多GPU与多机