HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 完整学习路径

    • AI教程 - 从零到一的完整学习路径
    • 第00章:AI基础与发展史
    • 第01章:Python与AI开发环境
    • 第02章:数学基础-线性代数与微积分
    • 03-数据集详解-从获取到预处理
    • 04-从零训练第一个模型
    • 05-模型文件详解
    • 06-分布式训练-多GPU与多机
    • 07-模型调度与资源管理
    • 08-Transformer架构深度解析
    • 09-大语言模型原理与架构
    • 10-Token与Tokenization详解
    • 11-Prompt Engineering完全指南
    • 第12章:模型微调与LoRA技术
    • 第13章:RLHF与对齐技术
    • 第14章 AI编程助手原理与实现
    • 15-RAG系统设计与实现
    • 16-Agent智能体与工具调用
    • 17-多模态大模型
    • 第18章:AI前沿技术趋势
    • 第19章 AI热门话题与应用案例

04-从零训练第一个模型

1. 完整训练流程

训练一个深度学习模型的完整流程包括以下几个关键步骤:

数据准备 → 模型定义 → 损失函数 → 优化器 → 训练循环 → 验证 → 测试 → 保存模型

让我们通过一个简单的例子来理解整个流程:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# ============ 1. 数据准备 ============
# 假设已经有了训练、验证、测试数据集
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ 2. 模型定义 ============
class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = SimpleModel(input_size=784, hidden_size=128, num_classes=10)

# ============ 3. 损失函数 ============
criterion = nn.CrossEntropyLoss()

# ============ 4. 优化器 ============
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ============ 5. 训练循环 ============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 10
for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # ============ 6. 验证阶段 ============
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    train_loss = train_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    val_acc = 100 * correct / total

    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, '
          f'Val Loss: {val_loss:.4f}, '
          f'Val Acc: {val_acc:.2f}%')

# ============ 7. 测试阶段 ============
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_acc = 100 * test_correct / test_total
print(f'Test Accuracy: {test_acc:.2f}%')

# ============ 8. 保存模型 ============
torch.save(model.state_dict(), 'model.pth')

现在让我们详细讲解每个步骤。

1.1 数据准备

数据准备是训练的第一步,需要确保数据格式正确、已经过预处理、并且划分为训练集、验证集和测试集。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# 定义数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# 划分训练集和验证集
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# 测试集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

1.2 模型定义

模型定义需要继承nn.Module并实现__init__和forward方法。

import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        # 卷积层
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        # 池化层
        self.pool = nn.MaxPool2d(2, 2)

        # 全连接层
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

        # Dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # 第一个卷积块
        x = self.conv1(x)  # [B, 1, 28, 28] -> [B, 32, 28, 28]
        x = F.relu(x)
        x = self.pool(x)   # [B, 32, 28, 28] -> [B, 32, 14, 14]

        # 第二个卷积块
        x = self.conv2(x)  # [B, 32, 14, 14] -> [B, 64, 14, 14]
        x = F.relu(x)
        x = self.pool(x)   # [B, 64, 14, 14] -> [B, 64, 7, 7]

        # 展平
        x = x.view(-1, 64 * 7 * 7)  # [B, 64, 7, 7] -> [B, 3136]

        # 全连接层
        x = self.fc1(x)  # [B, 3136] -> [B, 128]
        x = F.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)  # [B, 128] -> [B, 10]

        return x

# 创建模型实例
model = CNN(num_classes=10)

# 查看模型结构
print(model)

# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")

1.3 损失函数选择

不同任务使用不同的损失函数:

分类任务

import torch.nn as nn

# 多分类(输出是logits,未经过softmax)
criterion = nn.CrossEntropyLoss()

# 二分类
criterion = nn.BCEWithLogitsLoss()  # 输出是logits
# 或
criterion = nn.BCELoss()  # 输出已经过sigmoid

# 多标签分类
criterion = nn.BCEWithLogitsLoss()  # 每个类别独立

回归任务

# 均方误差
criterion = nn.MSELoss()

# 平均绝对误差
criterion = nn.L1Loss()

# 平滑L1损失
criterion = nn.SmoothL1Loss()

其他任务

# 对比学习
criterion = nn.CosineEmbeddingLoss()

# 排序学习
criterion = nn.MarginRankingLoss()

# 负对数似然
criterion = nn.NLLLoss()

1.4 优化器配置

常用优化器

import torch.optim as optim

# SGD(随机梯度下降)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# Adam(自适应学习率)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-4)

# AdamW(Adam + 权重衰减)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# RMSprop
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)

# Adagrad
optimizer = optim.Adagrad(model.parameters(), lr=0.01)

学习率调度器

from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau

# 每N个epoch降低学习率
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

# 余弦退火
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0)

# 根据指标自动降低学习率
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

# 在训练循环中使用
for epoch in range(num_epochs):
    train(...)
    val_loss = validate(...)

    # StepLR或CosineAnnealingLR
    scheduler.step()

    # ReduceLROnPlateau
    scheduler.step(val_loss)

1.5 训练循环

完整的训练循环包括前向传播、计算损失、反向传播、更新参数。

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """训练一个epoch"""
    model.train()  # 设置为训练模式

    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        # 将数据移到设备
        data, target = data.to(device), target.to(device)

        # 清零梯度
        optimizer.zero_grad()

        # 前向传播
        output = model(data)

        # 计算损失
        loss = criterion(output, target)

        # 反向传播
        loss.backward()

        # 更新参数
        optimizer.step()

        # 统计
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        # 打印进度
        if (batch_idx + 1) % 100 == 0:
            print(f'Batch [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}, '
                  f'Acc: {100.*correct/total:.2f}%')

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

1.6 验证和测试

验证和测试时不需要计算梯度,使用torch.no_grad()节省内存。

def validate(model, val_loader, criterion, device):
    """验证模型"""
    model.eval()  # 设置为评估模式

    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # 不计算梯度
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            # 前向传播
            output = model(data)

            # 计算损失
            loss = criterion(output, target)

            # 统计
            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total

    return val_loss, val_acc

2. 实战案例1:MNIST手写数字识别

MNIST是深度学习的"Hello World",包含70,000张28x28的手写数字图像(0-9)。

2.1 数据加载

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子
torch.manual_seed(42)

# 数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差
])

# 下载和加载数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 划分训练集和验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"训练集: {len(train_dataset)} 样本")
print(f"验证集: {len(val_dataset)} 样本")
print(f"测试集: {len(test_dataset)} 样本")

# 可视化一些样本
def show_samples(loader, num_samples=10):
    data_iter = iter(loader)
    images, labels = next(data_iter)

    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    for i, ax in enumerate(axes.flat):
        if i >= num_samples:
            break
        img = images[i].squeeze().numpy()
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')

    plt.tight_layout()
    plt.savefig('mnist_samples.png')
    print("保存样本可视化: mnist_samples.png")

show_samples(train_loader)

2.2 CNN模型实现

class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()

        # 第一个卷积块
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(0.25)

        # 第二个卷积块
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout2d(0.25)

        # 全连接层
        self.fc1 = nn.Linear(128 * 7 * 7, 256)
        self.bn4 = nn.BatchNorm1d(256)
        self.dropout3 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        # 第一个卷积块
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        # 第二个卷积块
        x = self.conv3(x)
        x = self.bn3(x)
        x = torch.relu(x)
        x = self.pool2(x)
        x = self.dropout2(x)

        # 展平
        x = x.view(x.size(0), -1)

        # 全连接层
        x = self.fc1(x)
        x = self.bn4(x)
        x = torch.relu(x)
        x = self.dropout3(x)
        x = self.fc2(x)

        return x

# 创建模型
model = MNISTNet()
print(model)

# 统计参数
total_params = sum(p.numel() for p in model.parameters())
print(f"\n总参数量: {total_params:,}")

2.3 完整训练代码

import time
from collections import defaultdict

def train_mnist():
    """训练MNIST模型"""
    # 设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 模型
    model = MNISTNet().to(device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    # 训练历史
    history = defaultdict(list)

    # 最佳模型
    best_val_acc = 0.0

    # 训练
    num_epochs = 15
    for epoch in range(num_epochs):
        epoch_start = time.time()

        # ========== 训练阶段 ==========
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += target.size(0)
            train_correct += predicted.eq(target).sum().item()

            if (batch_idx + 1) % 200 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], '
                      f'Batch [{batch_idx+1}/{len(train_loader)}], '
                      f'Loss: {loss.item():.4f}')

        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # ========== 验证阶段 ==========
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)

                output = model(data)
                loss = criterion(output, target)

                val_loss += loss.item()
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # 学习率调度
        scheduler.step()

        # 记录历史
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        epoch_time = time.time() - epoch_start

        print(f'\nEpoch [{epoch+1}/{num_epochs}] - {epoch_time:.2f}s')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 60)

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, 'best_mnist_model.pth')
            print(f'保存最佳模型 (Val Acc: {val_acc:.2f}%)')

    # ========== 测试阶段 ==========
    print("\n开始测试...")
    checkpoint = torch.load('best_mnist_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)
            _, predicted = output.max(1)
            test_total += target.size(0)
            test_correct += predicted.eq(target).sum().item()

    test_acc = 100. * test_correct / test_total
    print(f'Test Accuracy: {test_acc:.2f}%')

    return history

# 运行训练
history = train_mnist()

2.4 可视化训练过程

def plot_training_history(history):
    """绘制训练历史"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # 损失曲线
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss Curve')
    axes[0].legend()
    axes[0].grid(True)

    # 准确率曲线
    axes[1].plot(history['train_acc'], label='Train Acc')
    axes[1].plot(history['val_acc'], label='Val Acc')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Accuracy Curve')
    axes[1].legend()
    axes[1].grid(True)

    # 学习率曲线
    axes[2].plot(history['lr'], label='Learning Rate')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule')
    axes[2].legend()
    axes[2].grid(True)
    axes[2].set_yscale('log')

    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150)
    print("保存训练历史: training_history.png")

plot_training_history(history)

2.5 预测可视化

def visualize_predictions(model, test_loader, device, num_samples=20):
    """可视化预测结果"""
    model.eval()

    # 获取一批数据
    data_iter = iter(test_loader)
    images, labels = next(data_iter)

    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)
        _, predicted = outputs.max(1)

    images = images.cpu()
    predicted = predicted.cpu()

    # 绘制
    fig, axes = plt.subplots(4, 5, figsize=(15, 12))
    for i, ax in enumerate(axes.flat):
        if i >= num_samples:
            break

        img = images[i].squeeze().numpy()
        true_label = labels[i].item()
        pred_label = predicted[i].item()

        ax.imshow(img, cmap='gray')

        # 正确预测用绿色,错误用红色
        color = 'green' if true_label == pred_label else 'red'
        ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
        ax.axis('off')

    plt.tight_layout()
    plt.savefig('predictions.png', dpi=150)
    print("保存预测可视化: predictions.png")

# 加载最佳模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MNISTNet().to(device)
checkpoint = torch.load('best_mnist_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

visualize_predictions(model, test_loader, device)

3. 实战案例2:情感分析模型

情感分析是NLP的经典任务,判断文本的情感倾向(正面/负面)。

3.1 文本数据处理

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import Counter
import re

# ========== 数据集类 ==========
class SentimentDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_length=100):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # 文本转索引
        tokens = self.tokenize(text)
        indices = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]

        # 截断或填充
        if len(indices) < self.max_length:
            indices = indices + [self.vocab['<PAD>']] * (self.max_length - len(indices))
        else:
            indices = indices[:self.max_length]

        return torch.tensor(indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)

    @staticmethod
    def tokenize(text):
        """简单分词"""
        text = text.lower()
        text = re.sub(r'[^a-z\s]', '', text)
        return text.split()

# ========== 构建词汇表 ==========
def build_vocab(texts, min_freq=2, max_vocab_size=10000):
    """构建词汇表"""
    # 统计词频
    counter = Counter()
    for text in texts:
        tokens = SentimentDataset.tokenize(text)
        counter.update(tokens)

    # 按频率排序
    vocab = {'<PAD>': 0, '<UNK>': 1}
    idx = 2
    for word, freq in counter.most_common(max_vocab_size):
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1

    print(f"词汇表大小: {len(vocab)}")
    return vocab

# ========== 准备数据 ==========
# 示例数据(实际使用时应从文件加载)
train_texts = [
    "this movie is great and amazing",
    "i love this film so much",
    "wonderful acting and great story",
    "best movie ever seen",
    "terrible movie and bad acting",
    "i hate this film",
    "worst movie ever",
    "boring and disappointing",
] * 1000  # 扩展数据

train_labels = [1, 1, 1, 1, 0, 0, 0, 0] * 1000  # 1: positive, 0: negative

# 划分数据
from sklearn.model_selection import train_test_split

X_train, X_temp, y_train, y_temp = train_test_split(
    train_texts, train_labels, test_size=0.3, random_state=42, stratify=train_labels
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"训练集: {len(X_train)}")
print(f"验证集: {len(X_val)}")
print(f"测试集: {len(X_test)}")

# 构建词汇表
vocab = build_vocab(X_train)

# 创建数据集
train_dataset = SentimentDataset(X_train, y_train, vocab, max_length=50)
val_dataset = SentimentDataset(X_val, y_val, vocab, max_length=50)
test_dataset = SentimentDataset(X_test, y_test, vocab, max_length=50)

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

3.2 RNN/LSTM模型

class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256,
                 num_layers=2, dropout=0.5, num_classes=2):
        super(SentimentLSTM, self).__init__()

        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        # LSTM层
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )

        # 全连接层
        self.fc = nn.Linear(hidden_dim * 2, num_classes)  # *2 因为是双向

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [batch_size, seq_length]

        # 词嵌入
        embedded = self.embedding(x)  # [batch_size, seq_length, embedding_dim]
        embedded = self.dropout(embedded)

        # LSTM
        lstm_out, (hidden, cell) = self.lstm(embedded)
        # lstm_out: [batch_size, seq_length, hidden_dim*2]
        # hidden: [num_layers*2, batch_size, hidden_dim]

        # 使用最后一个时间步的输出
        # 双向LSTM的最后一层前向和后向hidden拼接
        hidden_fwd = hidden[-2, :, :]  # 前向最后一层
        hidden_bwd = hidden[-1, :, :]  # 后向最后一层
        hidden_cat = torch.cat([hidden_fwd, hidden_bwd], dim=1)  # [batch_size, hidden_dim*2]

        hidden_cat = self.dropout(hidden_cat)

        # 全连接层
        output = self.fc(hidden_cat)  # [batch_size, num_classes]

        return output

# 创建模型
model = SentimentLSTM(
    vocab_size=len(vocab),
    embedding_dim=128,
    hidden_dim=256,
    num_layers=2,
    dropout=0.5,
    num_classes=2
)

print(model)
print(f"总参数量: {sum(p.numel() for p in model.parameters()):,}")

3.3 词嵌入(Word2Vec/GloVe)

如果有预训练的词向量,可以加载使用:

def load_pretrained_embeddings(vocab, embedding_file, embedding_dim=300):
    """
    加载预训练词向量(如GloVe)

    GloVe下载:
    https://nlp.stanford.edu/projects/glove/
    例如:glove.6B.300d.txt
    """
    embeddings = np.random.randn(len(vocab), embedding_dim) * 0.01

    with open(embedding_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split()
            word = parts[0]
            if word in vocab:
                vector = np.array([float(x) for x in parts[1:]])
                embeddings[vocab[word]] = vector

    return torch.FloatTensor(embeddings)

# 使用预训练词向量
# pretrained_embeddings = load_pretrained_embeddings(vocab, 'glove.6B.300d.txt', 300)
# model.embedding.weight.data.copy_(pretrained_embeddings)
# model.embedding.weight.requires_grad = True  # 可以选择是否继续训练

3.4 完整训练流程

from collections import defaultdict
import time

def train_sentiment_model():
    """训练情感分析模型"""
    # 设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 模型
    model = SentimentLSTM(
        vocab_size=len(vocab),
        embedding_dim=128,
        hidden_dim=256,
        num_layers=2,
        dropout=0.5,
        num_classes=2
    ).to(device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

    # 训练历史
    history = defaultdict(list)
    best_val_acc = 0.0

    # 训练
    num_epochs = 20
    for epoch in range(num_epochs):
        epoch_start = time.time()

        # ========== 训练阶段 ==========
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            # 梯度裁剪,防止梯度爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            optimizer.step()

            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += target.size(0)
            train_correct += predicted.eq(target).sum().item()

        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # ========== 验证阶段 ==========
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)

                output = model(data)
                loss = criterion(output, target)

                val_loss += loss.item()
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # 学习率调度
        scheduler.step(val_acc)

        # 记录历史
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        epoch_time = time.time() - epoch_start

        print(f'Epoch [{epoch+1}/{num_epochs}] - {epoch_time:.2f}s')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 60)

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'vocab': vocab,
                'val_acc': val_acc,
            }, 'best_sentiment_model.pth')
            print(f'保存最佳模型 (Val Acc: {val_acc:.2f}%)')

    # ========== 测试阶段 ==========
    print("\n开始测试...")
    checkpoint = torch.load('best_sentiment_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)
            _, predicted = output.max(1)
            test_total += target.size(0)
            test_correct += predicted.eq(target).sum().item()

    test_acc = 100. * test_correct / test_total
    print(f'Test Accuracy: {test_acc:.2f}%')

    return history, vocab

# 运行训练
history, vocab = train_sentiment_model()

3.5 预测新文本

def predict_sentiment(model, text, vocab, device, max_length=50):
    """预测文本情感"""
    model.eval()

    # 分词和编码
    tokens = SentimentDataset.tokenize(text)
    indices = [vocab.get(token, vocab['<UNK>']) for token in tokens]

    # 填充或截断
    if len(indices) < max_length:
        indices = indices + [vocab['<PAD>']] * (max_length - len(indices))
    else:
        indices = indices[:max_length]

    # 转换为张量
    input_tensor = torch.tensor([indices], dtype=torch.long).to(device)

    # 预测
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = output.argmax(dim=1).item()
        confidence = probabilities[0][predicted_class].item()

    sentiment = "Positive" if predicted_class == 1 else "Negative"
    return sentiment, confidence

# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load('best_sentiment_model.pth')
vocab = checkpoint['vocab']

model = SentimentLSTM(len(vocab)).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

# 测试
test_texts = [
    "this movie is absolutely amazing and wonderful",
    "i hate this terrible film",
    "great acting and fantastic story",
    "boring and disappointing experience"
]

for text in test_texts:
    sentiment, confidence = predict_sentiment(model, text, vocab, device)
    print(f"Text: {text}")
    print(f"Sentiment: {sentiment} (confidence: {confidence:.2%})\n")

4. 模型评估

4.1 评估指标

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

def evaluate_model(model, test_loader, device):
    """全面评估模型"""
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            output = model(data)
            _, predicted = output.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(target.numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # 计算指标
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")

    # 分类报告
    print("\n分类报告:")
    print(classification_report(all_labels, all_preds))

    # 混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    print("\n混淆矩阵:")
    print(cm)

    return accuracy, precision, recall, f1, cm

# 评估
accuracy, precision, recall, f1, cm = evaluate_model(model, test_loader, device)

4.2 混淆矩阵可视化

import matplotlib.pyplot as plt
import seaborn as sns

def plot_confusion_matrix(cm, class_names):
    """绘制混淆矩阵"""
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=150)
    print("保存混淆矩阵: confusion_matrix.png")

# MNIST示例
class_names = [str(i) for i in range(10)]
plot_confusion_matrix(cm, class_names)

4.3 ROC曲线和AUC

from sklearn.metrics import roc_curve, auc, roc_auc_score
import matplotlib.pyplot as plt

def plot_roc_curve(model, test_loader, device, num_classes=2):
    """绘制ROC曲线"""
    model.eval()

    all_probs = []
    all_labels = []

    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            output = model(data)
            probs = torch.softmax(output, dim=1)

            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(target.numpy())

    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)

    # 二分类
    if num_classes == 2:
        fpr, tpr, _ = roc_curve(all_labels, all_probs[:, 1])
        roc_auc = auc(fpr, tpr)

        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2,
                label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('roc_curve.png', dpi=150)
        print(f"AUC: {roc_auc:.4f}")
        print("保存ROC曲线: roc_curve.png")

    # 多分类
    else:
        from sklearn.preprocessing import label_binarize

        # 二值化标签
        y_bin = label_binarize(all_labels, classes=range(num_classes))

        plt.figure(figsize=(10, 8))
        for i in range(num_classes):
            fpr, tpr, _ = roc_curve(y_bin[:, i], all_probs[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f'Class {i} (AUC = {roc_auc:.2f})')

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Multi-class ROC Curve')
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('roc_curve_multiclass.png', dpi=150)
        print("保存多分类ROC曲线: roc_curve_multiclass.png")

# 二分类示例
plot_roc_curve(model, test_loader, device, num_classes=2)

5. 超参数调优

5.1 学习率调优

def find_learning_rate(model, train_loader, criterion, device,
                       start_lr=1e-7, end_lr=10, num_iterations=100):
    """学习率范围测试"""
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=start_lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=(end_lr / start_lr) ** (1 / num_iterations)
    )

    losses = []
    lrs = []

    iterator = iter(train_loader)
    for iteration in range(num_iterations):
        try:
            data, target = next(iterator)
        except StopIteration:
            iterator = iter(train_loader)
            data, target = next(iterator)

        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        lrs.append(optimizer.param_groups[0]['lr'])
        losses.append(loss.item())

        lr_scheduler.step()

    # 绘制
    plt.figure(figsize=(10, 6))
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('Learning Rate Finder')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('lr_finder.png', dpi=150)
    print("保存学习率测试: lr_finder.png")

    # 找到最佳学习率(损失下降最快的点)
    gradients = np.gradient(losses)
    best_lr_idx = np.argmin(gradients)
    best_lr = lrs[best_lr_idx]
    print(f"建议学习率: {best_lr:.6f}")

    return lrs, losses, best_lr

# 使用
# lrs, losses, best_lr = find_learning_rate(model, train_loader, criterion, device)

5.2 批量大小调优

def test_batch_sizes(model, dataset, batch_sizes=[16, 32, 64, 128, 256]):
    """测试不同批量大小"""
    import time

    results = {}

    for batch_size in batch_sizes:
        print(f"\n测试 batch_size = {batch_size}")

        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

        # 测试吞吐量
        start_time = time.time()
        num_batches = 100

        for i, (data, target) in enumerate(loader):
            if i >= num_batches:
                break
            # 模拟训练
            _ = data, target

        elapsed = time.time() - start_time
        throughput = num_batches * batch_size / elapsed

        results[batch_size] = {
            'time': elapsed,
            'throughput': throughput
        }

        print(f"时间: {elapsed:.2f}s, 吞吐量: {throughput:.2f} samples/s")

    # 可视化
    batch_sizes_list = list(results.keys())
    throughputs = [results[bs]['throughput'] for bs in batch_sizes_list]

    plt.figure(figsize=(10, 6))
    plt.bar(range(len(batch_sizes_list)), throughputs)
    plt.xlabel('Batch Size')
    plt.ylabel('Throughput (samples/s)')
    plt.title('Batch Size vs Throughput')
    plt.xticks(range(len(batch_sizes_list)), batch_sizes_list)
    plt.grid(True, axis='y')
    plt.tight_layout()
    plt.savefig('batch_size_comparison.png', dpi=150)
    print("\n保存批量大小比较: batch_size_comparison.png")

    return results

5.3 网格搜索

from itertools import product

def grid_search(param_grid, train_loader, val_loader, device):
    """网格搜索超参数"""
    best_params = None
    best_val_acc = 0.0
    results = []

    # 生成所有参数组合
    keys = param_grid.keys()
    values = param_grid.values()
    combinations = list(product(*values))

    print(f"总共 {len(combinations)} 组参数组合")

    for idx, params in enumerate(combinations):
        param_dict = dict(zip(keys, params))
        print(f"\n[{idx+1}/{len(combinations)}] 测试参数: {param_dict}")

        # 创建模型
        model = MNISTNet().to(device)

        # 使用当前参数
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(
            model.parameters(),
            lr=param_dict['lr'],
            weight_decay=param_dict['weight_decay']
        )

        # 训练几个epoch
        num_epochs = 5
        for epoch in range(num_epochs):
            # 训练
            model.train()
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

        # 验证
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

        val_acc = 100. * correct / total
        print(f"验证准确率: {val_acc:.2f}%")

        results.append({
            'params': param_dict,
            'val_acc': val_acc
        })

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_params = param_dict

    print(f"\n最佳参数: {best_params}")
    print(f"最佳验证准确率: {best_val_acc:.2f}%")

    return best_params, results

# 定义参数网格
param_grid = {
    'lr': [0.0001, 0.001, 0.01],
    'weight_decay': [0, 1e-5, 1e-4]
}

# 运行网格搜索
# best_params, results = grid_search(param_grid, train_loader, val_loader, device)

6. 可视化工具

6.1 TensorBoard使用

from torch.utils.tensorboard import SummaryWriter
import torch

def train_with_tensorboard(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    """使用TensorBoard记录训练过程"""
    # 创建SummaryWriter
    writer = SummaryWriter('runs/experiment_1')

    # 记录模型结构
    dataiter = iter(train_loader)
    images, labels = next(dataiter)
    writer.add_graph(model, images.to(device))

    global_step = 0

    for epoch in range(num_epochs):
        # ========== 训练阶段 ==========
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += target.size(0)
            train_correct += predicted.eq(target).sum().item()

            # 记录批次损失
            if batch_idx % 100 == 0:
                writer.add_scalar('Loss/train_batch', loss.item(), global_step)

            global_step += 1

        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # ========== 验证阶段 ==========
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)

                val_loss += loss.item()
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # 记录epoch指标
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/val', val_acc, epoch)
        writer.add_scalar('Learning_rate', optimizer.param_groups[0]['lr'], epoch)

        # 记录模型参数和梯度
        for name, param in model.named_parameters():
            writer.add_histogram(f'Parameters/{name}', param, epoch)
            if param.grad is not None:
                writer.add_histogram(f'Gradients/{name}', param.grad, epoch)

        # 记录一些预测图像
        if epoch % 5 == 0:
            dataiter = iter(val_loader)
            images, labels = next(dataiter)
            images = images.to(device)

            with torch.no_grad():
                outputs = model(images)
                _, predicted = outputs.max(1)

            # 选择一些图像
            img_grid = torchvision.utils.make_grid(images[:16].cpu())
            writer.add_image('Predictions', img_grid, epoch)

        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    writer.close()
    print("\n训练完成!运行以下命令查看TensorBoard:")
    print("tensorboard --logdir=runs")

# 使用
# train_with_tensorboard(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=15)

6.2 Weights & Biases (wandb)

import wandb

def train_with_wandb(model, train_loader, val_loader, criterion, optimizer, device, config):
    """使用wandb记录训练过程"""
    # 初始化wandb
    wandb.init(
        project="mnist-classification",
        config=config,
        name="experiment-1"
    )

    # 监控模型
    wandb.watch(model, log='all', log_freq=100)

    num_epochs = config['epochs']

    for epoch in range(num_epochs):
        # ========== 训练阶段 ==========
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += target.size(0)
            train_correct += predicted.eq(target).sum().item()

            # 记录批次指标
            if batch_idx % 100 == 0:
                wandb.log({
                    'batch_loss': loss.item(),
                    'batch': epoch * len(train_loader) + batch_idx
                })

        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # ========== 验证阶段 ==========
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)

                val_loss += loss.item()
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # 记录epoch指标
        wandb.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # 保存模型
    torch.save(model.state_dict(), 'model_final.pth')
    wandb.save('model_final.pth')

    wandb.finish()

# 配置
config = {
    'epochs': 15,
    'batch_size': 64,
    'learning_rate': 0.001,
    'architecture': 'CNN',
    'optimizer': 'Adam'
}

# 使用
# train_with_wandb(model, train_loader, val_loader, criterion, optimizer, device, config)
Prev
03-数据集详解-从获取到预处理
Next
05-模型文件详解