HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 完整学习路径

    • AI教程 - 从零到一的完整学习路径
    • 第00章:AI基础与发展史
    • 第01章:Python与AI开发环境
    • 第02章:数学基础-线性代数与微积分
    • 03-数据集详解-从获取到预处理
    • 04-从零训练第一个模型
    • 05-模型文件详解
    • 06-分布式训练-多GPU与多机
    • 07-模型调度与资源管理
    • 08-Transformer架构深度解析
    • 09-大语言模型原理与架构
    • 10-Token与Tokenization详解
    • 11-Prompt Engineering完全指南
    • 第12章:模型微调与LoRA技术
    • 第13章:RLHF与对齐技术
    • 第14章 AI编程助手原理与实现
    • 15-RAG系统设计与实现
    • 16-Agent智能体与工具调用
    • 17-多模态大模型
    • 第18章:AI前沿技术趋势
    • 第19章 AI热门话题与应用案例

08-Transformer架构深度解析

引言

Transformer是现代深度学习最重要的架构之一,彻底改变了自然语言处理领域。本章将从数学原理到代码实现,深度解析Transformer的每个组件。

1. Attention机制

1.1 Self-Attention数学原理

核心思想

Self-Attention允许序列中的每个位置关注序列中的所有位置,从而捕获长距离依赖关系。

给定输入序列 X = [x₁, x₂, ..., xₙ]

Self-Attention计算每个位置与其他所有位置的相关性,
然后基于这些相关性加权聚合信息。

数学定义

Attention(Q, K, V) = softmax(QK^T / √d_k) V

其中:
- Q (Query): 查询矩阵,形状 [seq_len, d_k]
- K (Key): 键矩阵,形状 [seq_len, d_k]
- V (Value): 值矩阵,形状 [seq_len, d_v]
- d_k: 键的维度(用于缩放)

计算步骤

import numpy as np

def self_attention_numpy(X, W_q, W_k, W_v):
    """
    NumPy实现Self-Attention

    Args:
        X: 输入序列 [seq_len, d_model]
        W_q: Query权重 [d_model, d_k]
        W_k: Key权重 [d_model, d_k]
        W_v: Value权重 [d_model, d_v]

    Returns:
        输出序列 [seq_len, d_v]
    """
    # 步骤1: 计算Q, K, V
    Q = np.dot(X, W_q)  # [seq_len, d_k]
    K = np.dot(X, W_k)  # [seq_len, d_k]
    V = np.dot(X, W_v)  # [seq_len, d_v]

    # 步骤2: 计算注意力分数
    # QK^T: [seq_len, seq_len]
    scores = np.dot(Q, K.T)

    # 步骤3: 缩放(避免softmax梯度消失)
    d_k = Q.shape[-1]
    scaled_scores = scores / np.sqrt(d_k)

    # 步骤4: Softmax归一化
    # 对每一行做softmax
    attention_weights = softmax(scaled_scores, axis=-1)

    # 步骤5: 加权聚合
    output = np.dot(attention_weights, V)  # [seq_len, d_v]

    return output, attention_weights

def softmax(x, axis=-1):
    """数值稳定的Softmax"""
    # 减去最大值提高数值稳定性
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

# 示例
seq_len = 4
d_model = 8
d_k = d_v = 4

# 输入序列
X = np.random.randn(seq_len, d_model)

# 权重矩阵
W_q = np.random.randn(d_model, d_k)
W_k = np.random.randn(d_model, d_k)
W_v = np.random.randn(d_model, d_v)

# 计算Self-Attention
output, weights = self_attention_numpy(X, W_q, W_k, W_v)

print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"\n注意力权重(每行求和应为1):\n{weights}")
print(f"每行求和: {weights.sum(axis=-1)}")

为什么要缩放?

# 展示不缩放的问题
d_k = 64
Q = np.random.randn(10, d_k)
K = np.random.randn(10, d_k)

# 不缩放的分数
scores_unscaled = np.dot(Q, K.T)
print(f"不缩放分数的方差: {scores_unscaled.var():.2f}")

# 缩放后的分数
scores_scaled = scores_unscaled / np.sqrt(d_k)
print(f"缩放后分数的方差: {scores_scaled.var():.2f}")

# 分析:
# 当d_k较大时,QK^T的方差约为d_k
# 导致softmax输入值很大,梯度接近0
# 缩放后方差接近1,避免梯度消失

1.2 Q、K、V矩阵计算

直观理解

"""
Q (Query): "我想要什么信息?"
K (Key):   "我有什么信息?"
V (Value): "我的实际内容是什么?"

类比:图书馆检索
- Query: 你的搜索关键词
- Key: 书籍的索引标签
- Value: 书籍的实际内容

匹配过程:
1. 用Query与所有Key比较(QK^T)
2. 找到最相关的书(softmax)
3. 读取这些书的内容(加权V)
"""

可视化示例

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention():
    """可视化Attention机制"""
    # 简单句子: "The cat sat on the mat"
    words = ["The", "cat", "sat", "on", "the", "mat"]
    seq_len = len(words)

    # 模拟词嵌入
    np.random.seed(42)
    X = np.random.randn(seq_len, 8)

    # 权重
    W_q = np.random.randn(8, 4)
    W_k = np.random.randn(8, 4)
    W_v = np.random.randn(8, 4)

    # 计算Attention
    Q = np.dot(X, W_q)
    K = np.dot(X, W_k)
    V = np.dot(X, W_v)

    scores = np.dot(Q, K.T) / np.sqrt(4)
    attention_weights = softmax(scores, axis=-1)

    # 可视化
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        attention_weights,
        xticklabels=words,
        yticklabels=words,
        annot=True,
        fmt='.2f',
        cmap='YlOrRd'
    )
    plt.title('Self-Attention Weights')
    plt.xlabel('Key (attending to)')
    plt.ylabel('Query (attending from)')
    plt.tight_layout()
    plt.savefig('attention_weights.png')

    # 解读:
    # attention_weights[i, j] 表示位置i对位置j的注意力
    # 例如:attention_weights[2, 1] 表示 "sat" 对 "cat" 的注意力

visualize_attention()

1.3 完整代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    """Self-Attention层"""

    def __init__(self, d_model, d_k, d_v):
        """
        Args:
            d_model: 输入维度
            d_k: Key/Query维度
            d_v: Value维度
        """
        super().__init__()
        self.d_k = d_k

        # 线性变换矩阵
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_v, bias=False)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len] (可选)

        Returns:
            output: [batch_size, seq_len, d_v]
            attention_weights: [batch_size, seq_len, seq_len]
        """
        # 1. 计算Q, K, V
        Q = self.W_q(x)  # [batch, seq_len, d_k]
        K = self.W_k(x)  # [batch, seq_len, d_k]
        V = self.W_v(x)  # [batch, seq_len, d_v]

        # 2. 计算注意力分数
        # Q @ K^T: [batch, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-2, -1))

        # 3. 缩放
        scores = scores / math.sqrt(self.d_k)

        # 4. 应用mask(可选)
        if mask is not None:
            # mask中0的位置会被设为-inf,softmax后为0
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 5. Softmax
        attention_weights = F.softmax(scores, dim=-1)

        # 6. 加权聚合
        output = torch.matmul(attention_weights, V)

        return output, attention_weights

# 使用示例
batch_size = 2
seq_len = 5
d_model = 8

x = torch.randn(batch_size, seq_len, d_model)

attention = SelfAttention(d_model, d_k=4, d_v=4)
output, weights = attention(x)

print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")

# 验证注意力权重归一化
print(f"注意力权重每行求和: {weights.sum(dim=-1)}")

Masked Attention

def create_causal_mask(seq_len):
    """
    创建因果mask(用于解码器)
    上三角为0,防止看到未来信息

    Returns:
        mask: [seq_len, seq_len]

    例如 seq_len=4:
    [[1, 0, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 1]]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# 使用causal mask
seq_len = 5
mask = create_causal_mask(seq_len)

print("因果Mask:")
print(mask)

# 应用到attention
x = torch.randn(1, seq_len, 8)
attention = SelfAttention(8, 4, 4)

# 扩展mask维度以匹配batch
mask = mask.unsqueeze(0)  # [1, seq_len, seq_len]

output, weights = attention(x, mask=mask)

print(f"\n带Mask的注意力权重:")
print(weights[0])上三角都是0,表示每个位置只能看到自己及之前的位置

Padding Mask

def create_padding_mask(seq_lengths, max_len):
    """
    创建padding mask

    Args:
        seq_lengths: [batch_size] 每个序列的实际长度
        max_len: 最大长度

    Returns:
        mask: [batch_size, max_len, max_len]
    """
    batch_size = len(seq_lengths)
    mask = torch.zeros(batch_size, max_len, max_len)

    for i, length in enumerate(seq_lengths):
        # 实际序列部分设为1
        mask[i, :length, :length] = 1

    return maskbatch中序列长度不同
seq_lengths = torch.tensor([3, 5, 4])
max_len = 5

mask = create_padding_mask(seq_lengths, max_len)

print("Padding Mask (batch=3, max_len=5):")
for i in range(3):
    print(f"\n序列{i} (长度={seq_lengths[i]}):")
    print(mask[i])

2. Multi-Head Attention

2.1 多头注意力机制

核心思想

"""
Single-head attention只能捕获一种关系模式
Multi-head attention可以并行学习多种关系

类比:
- 单头:从一个角度看问题
- 多头:从多个角度同时看问题

例如,对于句子 "The cat sat on the mat":
- Head 1: 关注语法关系(主谓宾)
- Head 2: 关注语义关系(动作-对象)
- Head 3: 关注位置关系(空间位置)
- ...
"""

数学定义

MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) W^O

其中:
headᵢ = Attention(QW^Q_i, KW^K_i, VW^V_i)

参数:
- h: 头的数量
- W^Q_i, W^K_i, W^V_i: 第i个头的投影矩阵
- W^O: 输出投影矩阵

2.2 为什么需要多头

import torch
import torch.nn as nn

def analyze_multihead_benefits():
    """分析多头注意力的优势"""

    # 句子: "The cat sat on the mat"
    words = ["The", "cat", "sat", "on", "the", "mat"]

    # 模拟3个注意力头学到的不同模式

    # Head 1: 语法关系(动词关注主语和宾语)
    attention_head1 = np.array([
        [0.8, 0.1, 0.0, 0.0, 0.0, 0.1],  # The -> The, cat, mat
        [0.1, 0.8, 0.1, 0.0, 0.0, 0.0],  # cat -> cat, sat
        [0.0, 0.4, 0.2, 0.1, 0.0, 0.3],  # sat -> cat, mat (主谓宾)
        [0.0, 0.0, 0.3, 0.4, 0.1, 0.2],  # on -> sat, on, mat
        [0.0, 0.0, 0.0, 0.2, 0.6, 0.2],  # the -> the, mat
        [0.1, 0.1, 0.2, 0.2, 0.1, 0.3],  # mat -> sat, on
    ])

    # Head 2: 位置关系(相邻词)
    attention_head2 = np.array([
        [0.5, 0.5, 0.0, 0.0, 0.0, 0.0],  # The -> The, cat
        [0.3, 0.4, 0.3, 0.0, 0.0, 0.0],  # cat -> The, cat, sat
        [0.0, 0.3, 0.4, 0.3, 0.0, 0.0],  # sat -> cat, sat, on
        [0.0, 0.0, 0.3, 0.4, 0.3, 0.0],  # on -> sat, on, the
        [0.0, 0.0, 0.0, 0.3, 0.4, 0.3],  # the -> on, the, mat
        [0.0, 0.0, 0.0, 0.0, 0.4, 0.6],  # mat -> the, mat
    ])

    # Head 3: 语义关系(名词和修饰词)
    attention_head3 = np.array([
        [0.5, 0.5, 0.0, 0.0, 0.0, 0.0],  # The -> cat (修饰)
        [0.2, 0.8, 0.0, 0.0, 0.0, 0.0],  # cat -> The, cat
        [0.0, 0.0, 1.0, 0.0, 0.0, 0.0],  # sat -> sat
        [0.0, 0.0, 0.0, 1.0, 0.0, 0.0],  # on -> on
        [0.0, 0.0, 0.0, 0.0, 0.4, 0.6],  # the -> mat (修饰)
        [0.0, 0.0, 0.0, 0.0, 0.3, 0.7],  # mat -> the, mat
    ])

    # 可视化三个头
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    for idx, (head, title) in enumerate([
        (attention_head1, 'Head 1: Syntactic'),
        (attention_head2, 'Head 2: Positional'),
        (attention_head3, 'Head 3: Semantic')
    ]):
        sns.heatmap(head, annot=True, fmt='.1f', cmap='YlOrRd',
                   xticklabels=words, yticklabels=words, ax=axes[idx])
        axes[idx].set_title(title)

    plt.tight_layout()
    plt.savefig('multihead_attention.png')

    print("不同的头关注不同的语言学特征:")
    print("Head 1: 捕获语法依赖关系")
    print("Head 2: 捕获局部上下文")
    print("Head 3: 捕获语义修饰关系")

analyze_multihead_benefits()

2.3 PyTorch实现

class MultiHeadAttention(nn.Module):
    """多头注意力"""

    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Args:
            d_model: 模型维度(必须能被num_heads整除)
            num_heads: 头的数量
            dropout: Dropout比例
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        # Q, K, V的线性变换(所有头共享参数,然后拆分)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # 输出投影
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x):
        """
        将最后一维拆分为(num_heads, d_k)

        Args:
            x: [batch, seq_len, d_model]

        Returns:
            [batch, num_heads, seq_len, d_k]
        """
        batch_size, seq_len, d_model = x.size()

        # 重塑: [batch, seq_len, num_heads, d_k]
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)

        # 转置: [batch, num_heads, seq_len, d_k]
        return x.transpose(1, 2)

    def combine_heads(self, x):
        """
        合并多个头

        Args:
            x: [batch, num_heads, seq_len, d_k]

        Returns:
            [batch, seq_len, d_model]
        """
        batch_size, num_heads, seq_len, d_k = x.size()

        # 转置: [batch, seq_len, num_heads, d_k]
        x = x.transpose(1, 2).contiguous()

        # 合并: [batch, seq_len, d_model]
        return x.view(batch_size, seq_len, self.d_model)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch, seq_len_q, d_model]
            key: [batch, seq_len_k, d_model]
            value: [batch, seq_len_v, d_model]
            mask: [batch, 1, seq_len_q, seq_len_k] 或 [batch, seq_len_q, seq_len_k]

        Returns:
            output: [batch, seq_len_q, d_model]
            attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
        """
        batch_size = query.size(0)

        # 1. 线性变换
        Q = self.W_q(query)  # [batch, seq_len_q, d_model]
        K = self.W_k(key)    # [batch, seq_len_k, d_model]
        V = self.W_v(value)  # [batch, seq_len_v, d_model]

        # 2. 拆分多头
        Q = self.split_heads(Q)  # [batch, num_heads, seq_len_q, d_k]
        K = self.split_heads(K)  # [batch, num_heads, seq_len_k, d_k]
        V = self.split_heads(V)  # [batch, num_heads, seq_len_v, d_k]

        # 3. 计算注意力分数
        # [batch, num_heads, seq_len_q, seq_len_k]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 4. 应用mask
        if mask is not None:
            # 扩展mask到多头: [batch, 1, seq_len_q, seq_len_k]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 5. Softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 6. 加权聚合
        # [batch, num_heads, seq_len_q, d_k]
        context = torch.matmul(attention_weights, V)

        # 7. 合并多头
        # [batch, seq_len_q, d_model]
        context = self.combine_heads(context)

        # 8. 输出投影
        output = self.W_o(context)

        return output, attention_weights

# 使用示例
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8

x = torch.randn(batch_size, seq_len, d_model)

mha = MultiHeadAttention(d_model, num_heads)

# Self-Attention
output, weights = mha(x, x, x)

print(f"输入: {x.shape}")
print(f"输出: {output.shape}")
print(f"注意力权重: {weights.shape}")

# Cross-Attention(如Encoder-Decoder)
encoder_output = torch.randn(batch_size, 20, d_model)
decoder_input = torch.randn(batch_size, 10, d_model)

output, weights = mha(
    query=decoder_input,
    key=encoder_output,
    value=encoder_output
)

print(f"\nCross-Attention:")
print(f"Decoder输入: {decoder_input.shape}")
print(f"Encoder输出: {encoder_output.shape}")
print(f"输出: {output.shape}")
print(f"注意力权重: {weights.shape}")

3. Transformer完整架构

3.1 Encoder-Decoder结构

"""
Transformer架构:

Encoder:
  输入 -> Embedding -> Position Encoding
    -> [Multi-Head Attention -> Add&Norm
        -> Feed Forward -> Add&Norm] × N
    -> Encoder输出

Decoder:
  输出 -> Embedding -> Position Encoding
    -> [Masked Multi-Head Attention -> Add&Norm
        -> Cross-Attention -> Add&Norm
        -> Feed Forward -> Add&Norm] × N
    -> Linear -> Softmax
    -> 输出概率
"""

3.2 Position Encoding

为什么需要位置编码?

"""
Attention机制是permutation-invariant(排列不变的)
即:Attention([A, B, C]) = Attention([C, A, B])

但语言是有顺序的:
"狗咬了人" ≠ "人咬了狗"

因此需要注入位置信息
"""

正弦位置编码

class PositionalEncoding(nn.Module):
    """正弦位置编码"""

    def __init__(self, d_model, max_len=5000, dropout=0.1):
        """
        Args:
            d_model: 模型维度
            max_len: 最大序列长度
            dropout: Dropout比例
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # 创建位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)

        # 位置索引 [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # 频率项
        # div_term[i] = 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # 应用正弦和余弦
        # PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 增加batch维度 [1, max_len, d_model]
        pe = pe.unsqueeze(0)

        # 注册为buffer(不是参数,但会保存到state_dict)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]

        Returns:
            [batch, seq_len, d_model]
        """
        # 添加位置编码
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# 可视化位置编码
def visualize_positional_encoding():
    """可视化位置编码"""
    d_model = 512
    max_len = 100

    pe_layer = PositionalEncoding(d_model, max_len)
    pe = pe_layer.pe.squeeze(0).numpy()  # [max_len, d_model]

    plt.figure(figsize=(12, 6))

    # 绘制前64维
    plt.imshow(pe[:, :64].T, cmap='RdBu', aspect='auto')
    plt.xlabel('Position')
    plt.ylabel('Dimension')
    plt.title('Positional Encoding')
    plt.colorbar()
    plt.tight_layout()
    plt.savefig('positional_encoding.png')

    # 绘制几个位置的编码曲线
    plt.figure(figsize=(12, 6))
    positions = [0, 10, 20, 40, 80]
    for pos in positions:
        plt.plot(pe[pos, :128], label=f'Position {pos}')

    plt.xlabel('Dimension')
    plt.ylabel('Value')
    plt.title('Positional Encoding Values at Different Positions')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('pe_curves.png')

visualize_positional_encoding()

可学习位置编码

class LearnedPositionalEncoding(nn.Module):
    """可学习的位置编码(如BERT)"""

    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # 位置嵌入(作为参数学习)
        self.position_embeddings = nn.Embedding(max_len, d_model)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.size()

        # 创建位置索引 [seq_len]
        positions = torch.arange(seq_len, device=x.device)

        # 获取位置编码 [seq_len, d_model]
        position_encodings = self.position_embeddings(positions)

        # 广播到batch: [batch, seq_len, d_model]
        x = x + position_encodings.unsqueeze(0)

        return self.dropout(x)

3.3 Feed Forward Network

class FeedForward(nn.Module):
    """Position-wise Feed-Forward Network"""

    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Args:
            d_model: 模型维度
            d_ff: 前馈网络隐藏层维度(通常是d_model的4倍)
            dropout: Dropout比例
        """
        super().__init__()

        # FFN(x) = max(0, xW1 + b1)W2 + b2
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]

        Returns:
            [batch, seq_len, d_model]
        """
        # 第一层 + ReLU
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层
        x = self.linear2(x)

        return x

# 使用示例
batch_size = 2
seq_len = 10
d_model = 512
d_ff = 2048

x = torch.randn(batch_size, seq_len, d_model)

ffn = FeedForward(d_model, d_ff)
output = ffn(x)

print(f"输入: {x.shape}")
print(f"输出: {output.shape}")

3.4 Layer Normalization

class LayerNorm(nn.Module):
    """Layer Normalization"""

    def __init__(self, d_model, eps=1e-6):
        """
        Args:
            d_model: 特征维度
            eps: 数值稳定性参数
        """
        super().__init__()

        # 可学习参数
        self.gamma = nn.Parameter(torch.ones(d_model))   # 缩放
        self.beta = nn.Parameter(torch.zeros(d_model))  # 平移

        self.eps = eps

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]

        Returns:
            [batch, seq_len, d_model]
        """
        # 计算均值和方差(在最后一维)
        mean = x.mean(dim=-1, keepdim=True)  # [batch, seq_len, 1]
        std = x.std(dim=-1, keepdim=True)    # [batch, seq_len, 1]

        # 归一化
        x_norm = (x - mean) / (std + self.eps)

        # 缩放和平移
        return self.gamma * x_norm + self.beta

# 对比LayerNorm vs BatchNorm
def compare_normalization():
    """比较不同归一化方法"""
    batch_size = 2
    seq_len = 3
    d_model = 4

    x = torch.tensor([
        [[1.0, 2.0, 3.0, 4.0],
         [5.0, 6.0, 7.0, 8.0],
         [9.0, 10.0, 11.0, 12.0]],

        [[13.0, 14.0, 15.0, 16.0],
         [17.0, 18.0, 19.0, 20.0],
         [21.0, 22.0, 23.0, 24.0]]
    ])

    print("原始数据:")
    print(x)

    # LayerNorm: 对每个样本的每个位置归一化
    ln = nn.LayerNorm(d_model)
    x_ln = ln(x)
    print("\nLayerNorm (在特征维度归一化):")
    print(x_ln)
    print("每个位置的均值:", x_ln.mean(dim=-1))
    print("每个位置的方差:", x_ln.var(dim=-1))

    # BatchNorm: 对整个batch的每个特征归一化
    bn = nn.BatchNorm1d(d_model)
    x_bn = bn(x.transpose(1, 2)).transpose(1, 2)
    print("\nBatchNorm (在batch维度归一化):")
    print(x_bn)

compare_normalization()

3.5 残差连接

class SublayerConnection(nn.Module):
    """
    残差连接 + LayerNorm

    实现两种顺序:
    1. Post-LN: x + Sublayer(Norm(x))  [原始Transformer]
    2. Pre-LN: x + Norm(Sublayer(x))   [更稳定,现代实现]
    """

    def __init__(self, d_model, dropout=0.1, pre_norm=True):
        super().__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.pre_norm = pre_norm

    def forward(self, x, sublayer):
        """
        Args:
            x: 输入
            sublayer: 子层函数(如MultiHeadAttention或FeedForward)

        Returns:
            输出
        """
        if self.pre_norm:
            # Pre-LN: Norm -> Sublayer -> Dropout -> Add
            return x + self.dropout(sublayer(self.norm(x)))
        else:
            # Post-LN: Sublayer -> Dropout -> Add -> Norm
            return self.norm(x + self.dropout(sublayer(x)))

# 使用示例
d_model = 512
x = torch.randn(2, 10, d_model)

# 创建子层
sublayer = FeedForward(d_model, 2048)

# Pre-LN残差连接
residual = SublayerConnection(d_model, pre_norm=True)
output = residual(x, sublayer)

print(f"输入: {x.shape}")
print(f"输出: {output.shape}")

4. 从零实现Transformer

4.1 Encoder层

class EncoderLayer(nn.Module):
    """Transformer Encoder层"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: 模型维度
            num_heads: 注意力头数
            d_ff: 前馈网络隐藏层维度
            dropout: Dropout比例
        """
        super().__init__()

        # Multi-Head Self-Attention
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)

        # Feed-Forward Network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)

        # 两个残差连接
        self.sublayer1 = SublayerConnection(d_model, dropout)
        self.sublayer2 = SublayerConnection(d_model, dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch, seq_len, d_model]
            mask: [batch, seq_len, seq_len]

        Returns:
            [batch, seq_len, d_model]
        """
        # Self-Attention子层
        x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, mask)[0])

        # Feed-Forward子层
        x = self.sublayer2(x, self.feed_forward)

        return x

class TransformerEncoder(nn.Module):
    """Transformer Encoder"""

    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            num_layers: Encoder层数
            d_model: 模型维度
            num_heads: 注意力头数
            d_ff: 前馈网络维度
            dropout: Dropout比例
        """
        super().__init__()

        # N个Encoder层
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # 最后的LayerNorm
        self.norm = LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch, seq_len, d_model]
            mask: [batch, seq_len, seq_len]

        Returns:
            [batch, seq_len, d_model]
        """
        # 通过所有Encoder层
        for layer in self.layers:
            x = layer(x, mask)

        # 最后的LayerNorm
        return self.norm(x)

# 使用示例
encoder = TransformerEncoder(
    num_layers=6,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    dropout=0.1
)

x = torch.randn(2, 10, 512)
output = encoder(x)

print(f"输入: {x.shape}")
print(f"输出: {output.shape}")

4.2 Decoder层

class DecoderLayer(nn.Module):
    """Transformer Decoder层"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # Masked Multi-Head Self-Attention
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)

        # Cross-Attention (Encoder-Decoder Attention)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)

        # Feed-Forward Network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)

        # 三个残差连接
        self.sublayer1 = SublayerConnection(d_model, dropout)
        self.sublayer2 = SublayerConnection(d_model, dropout)
        self.sublayer3 = SublayerConnection(d_model, dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Decoder输入 [batch, tgt_len, d_model]
            encoder_output: Encoder输出 [batch, src_len, d_model]
            src_mask: Encoder mask [batch, src_len, src_len]
            tgt_mask: Decoder mask (causal) [batch, tgt_len, tgt_len]

        Returns:
            [batch, tgt_len, d_model]
        """
        # Masked Self-Attention子层
        x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, tgt_mask)[0])

        # Cross-Attention子层(Query来自Decoder,Key/Value来自Encoder)
        x = self.sublayer2(
            x,
            lambda x: self.cross_attn(x, encoder_output, encoder_output, src_mask)[0]
        )

        # Feed-Forward子层
        x = self.sublayer3(x, self.feed_forward)

        return x

class TransformerDecoder(nn.Module):
    """Transformer Decoder"""

    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # N个Decoder层
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # 最后的LayerNorm
        self.norm = LayerNorm(d_model)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: [batch, tgt_len, d_model]
            encoder_output: [batch, src_len, d_model]
            src_mask: [batch, src_len, src_len]
            tgt_mask: [batch, tgt_len, tgt_len]

        Returns:
            [batch, tgt_len, d_model]
        """
        # 通过所有Decoder层
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        # 最后的LayerNorm
        return self.norm(x)

# 使用示例
decoder = TransformerDecoder(
    num_layers=6,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    dropout=0.1
)

# Decoder输入和Encoder输出
decoder_input = torch.randn(2, 8, 512)
encoder_output = torch.randn(2, 10, 512)

# 创建causal mask
tgt_len = decoder_input.size(1)
tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)

output = decoder(decoder_input, encoder_output, tgt_mask=tgt_mask)

print(f"Decoder输入: {decoder_input.shape}")
print(f"Encoder输出: {encoder_output.shape}")
print(f"Decoder输出: {output.shape}")

4.3 完整Transformer模型

class Transformer(nn.Module):
    """完整的Transformer模型(用于序列到序列任务)"""

    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model=512,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        d_ff=2048,
        max_seq_len=5000,
        dropout=0.1
    ):
        """
        Args:
            src_vocab_size: 源语言词表大小
            tgt_vocab_size: 目标语言词表大小
            d_model: 模型维度
            num_heads: 注意力头数
            num_encoder_layers: Encoder层数
            num_decoder_layers: Decoder层数
            d_ff: 前馈网络维度
            max_seq_len: 最大序列长度
            dropout: Dropout比例
        """
        super().__init__()

        self.d_model = d_model

        # 嵌入层
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)

        # Encoder和Decoder
        self.encoder = TransformerEncoder(
            num_encoder_layers, d_model, num_heads, d_ff, dropout
        )

        self.decoder = TransformerDecoder(
            num_decoder_layers, d_model, num_heads, d_ff, dropout
        )

        # 输出投影
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)

        # 初始化参数
        self._init_parameters()

    def _init_parameters(self):
        """参数初始化"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src, src_mask=None):
        """
        Encoder

        Args:
            src: [batch, src_len]
            src_mask: [batch, src_len, src_len]

        Returns:
            [batch, src_len, d_model]
        """
        # 嵌入 + 位置编码
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)

        # Encoder
        return self.encoder(x, src_mask)

    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """
        Decoder

        Args:
            tgt: [batch, tgt_len]
            encoder_output: [batch, src_len, d_model]
            src_mask: [batch, src_len, src_len]
            tgt_mask: [batch, tgt_len, tgt_len]

        Returns:
            [batch, tgt_len, d_model]
        """
        # 嵌入 + 位置编码
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)

        # Decoder
        return self.decoder(x, encoder_output, src_mask, tgt_mask)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        前向传播

        Args:
            src: [batch, src_len]
            tgt: [batch, tgt_len]
            src_mask: [batch, src_len, src_len]
            tgt_mask: [batch, tgt_len, tgt_len]

        Returns:
            logits: [batch, tgt_len, tgt_vocab_size]
        """
        # Encode
        encoder_output = self.encode(src, src_mask)

        # Decode
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)

        # 输出投影
        logits = self.output_projection(decoder_output)

        return logits

    def generate(self, src, max_len=50, start_symbol=1, end_symbol=2):
        """
        自回归生成

        Args:
            src: [batch, src_len]
            max_len: 最大生成长度
            start_symbol: 起始符号ID
            end_symbol: 结束符号ID

        Returns:
            generated: [batch, gen_len]
        """
        batch_size = src.size(0)
        device = src.device

        # Encode一次
        encoder_output = self.encode(src)

        # 初始化Decoder输入(起始符号)
        decoder_input = torch.full((batch_size, 1), start_symbol, dtype=torch.long, device=device)

        # 自回归生成
        for _ in range(max_len - 1):
            # 创建causal mask
            tgt_len = decoder_input.size(1)
            tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=device))
            tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)

            # Decode
            decoder_output = self.decode(decoder_input, encoder_output, tgt_mask=tgt_mask)

            # 预测下一个token(只看最后一个位置)
            logits = self.output_projection(decoder_output[:, -1:, :])
            next_token = logits.argmax(dim=-1)

            # 拼接到decoder_input
            decoder_input = torch.cat([decoder_input, next_token], dim=1)

            # 检查是否所有序列都生成了结束符号
            if (next_token == end_symbol).all():
                break

        return decoder_input

# 使用示例
model = Transformer(
    src_vocab_size=10000,
    tgt_vocab_size=10000,
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    d_ff=2048,
    max_seq_len=5000,
    dropout=0.1
)

# 训练
src = torch.randint(0, 10000, (2, 20))  # batch_size=2, src_len=20
tgt = torch.randint(0, 10000, (2, 15))  # tgt_len=15

# 创建causal mask
tgt_len = tgt.size(1)
tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)

logits = model(src, tgt, tgt_mask=tgt_mask)

print(f"源序列: {src.shape}")
print(f"目标序列: {tgt.shape}")
print(f"输出logits: {logits.shape}")

# 推理
generated = model.generate(src, max_len=20)
print(f"生成序列: {generated.shape}")

4.4 完整训练代码(机器翻译)

# train_transformer.py - 完整的Transformer训练脚本

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import math

class TranslationDataset(Dataset):
    """翻译数据集"""

    def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab,
                 src_tokenizer, tgt_tokenizer, max_len=100):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src = self.src_sentences[idx]
        tgt = self.tgt_sentences[idx]

        # Tokenize
        src_tokens = ['<sos>'] + self.src_tokenizer(src)[:self.max_len-2] + ['<eos>']
        tgt_tokens = ['<sos>'] + self.tgt_tokenizer(tgt)[:self.max_len-2] + ['<eos>']

        # 转换为ID
        src_ids = [self.src_vocab[token] for token in src_tokens]
        tgt_ids = [self.tgt_vocab[token] for token in tgt_tokens]

        return torch.tensor(src_ids), torch.tensor(tgt_ids)

def collate_fn(batch):
    """批处理函数(添加padding)"""
    src_batch, tgt_batch = zip(*batch)

    # Padding
    src_batch = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_batch = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=0)

    return src_batch, tgt_batch

def create_masks(src, tgt, pad_idx=0):
    """创建masks"""
    batch_size = src.size(0)
    src_len = src.size(1)
    tgt_len = tgt.size(1)

    # Source mask(padding位置)
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, src_len]

    # Target mask(causal + padding)
    tgt_padding_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, tgt_len]

    # Causal mask
    tgt_causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)
    tgt_causal_mask = tgt_causal_mask.to(tgt.device)

    # 组合
    tgt_mask = tgt_padding_mask & tgt_causal_mask

    return src_mask, tgt_mask

def train_epoch(model, dataloader, optimizer, criterion, device):
    """训练一个epoch"""
    model.train()
    total_loss = 0

    for batch_idx, (src, tgt) in enumerate(dataloader):
        src = src.to(device)
        tgt = tgt.to(device)

        # Decoder输入(去掉最后一个token)
        tgt_input = tgt[:, :-1]

        # 标签(去掉第一个token)
        tgt_output = tgt[:, 1:]

        # 创建masks
        src_mask, tgt_mask = create_masks(src, tgt_input)

        # Forward
        logits = model(src, tgt_input, src_mask, tgt_mask)

        # 计算loss
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_output.reshape(-1)
        )

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')

    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    """评估"""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src = src.to(device)
            tgt = tgt.to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            src_mask, tgt_mask = create_masks(src, tgt_input)

            logits = model(src, tgt_input, src_mask, tgt_mask)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt_output.reshape(-1)
            )

            total_loss += loss.item()

    return total_loss / len(dataloader)

def main():
    # 超参数
    d_model = 512
    num_heads = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    d_ff = 2048
    dropout = 0.1
    batch_size = 32
    num_epochs = 10
    lr = 0.0001

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 准备数据(示例:英语->法语)
    # 实际使用时从文件或数据集加载
    src_sentences = [
        "Hello, how are you?",
        "I am fine, thank you.",
        # ... 更多句子
    ]

    tgt_sentences = [
        "Bonjour, comment allez-vous?",
        "Je vais bien, merci.",
        # ...
    ]

    # Tokenizer
    src_tokenizer = get_tokenizer('basic_english')
    tgt_tokenizer = get_tokenizer('basic_english')  # 实际应使用法语tokenizer

    # 构建词表
    def yield_tokens(sentences, tokenizer):
        for sentence in sentences:
            yield tokenizer(sentence)

    src_vocab = build_vocab_from_iterator(
        yield_tokens(src_sentences, src_tokenizer),
        specials=['<unk>', '<pad>', '<sos>', '<eos>'],
        special_first=True
    )

    tgt_vocab = build_vocab_from_iterator(
        yield_tokens(tgt_sentences, tgt_tokenizer),
        specials=['<unk>', '<pad>', '<sos>', '<eos>'],
        special_first=True
    )

    src_vocab.set_default_index(src_vocab['<unk>'])
    tgt_vocab.set_default_index(tgt_vocab['<unk>'])

    # 创建数据集
    dataset = TranslationDataset(
        src_sentences, tgt_sentences,
        src_vocab, tgt_vocab,
        src_tokenizer, tgt_tokenizer
    )

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    # 创建模型
    model = Transformer(
        src_vocab_size=len(src_vocab),
        tgt_vocab_size=len(tgt_vocab),
        d_model=d_model,
        num_heads=num_heads,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        d_ff=d_ff,
        dropout=dropout
    ).to(device)

    # 优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)

    # Label Smoothing
    criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

    # 学习率调度器(Transformer原文的warmup策略)
    def lr_lambda(step):
        if step == 0:
            step = 1
        return (d_model ** -0.5) * min(step ** -0.5, step * (4000 ** -1.5))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # 训练循环
    best_loss = float('inf')

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')

        train_loss = train_epoch(model, dataloader, optimizer, criterion, device)
        val_loss = evaluate(model, dataloader, criterion, device)

        scheduler.step()

        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')

        # 保存最佳模型
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_transformer.pt')
            print('Model saved!')

if __name__ == '__main__':
    main()

5. Transformer变体

5.1 BERT(Encoder-only)

class BERT(nn.Module):
    """BERT模型(只使用Encoder)"""

    def __init__(
        self,
        vocab_size,
        d_model=768,
        num_heads=12,
        num_layers=12,
        d_ff=3072,
        max_seq_len=512,
        dropout=0.1
    ):
        super().__init__()

        self.d_model = d_model

        # Token嵌入
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # 位置嵌入(可学习)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        # Segment嵌入(区分句子A和句子B)
        self.segment_embedding = nn.Embedding(2, d_model)

        self.dropout = nn.Dropout(dropout)

        # Transformer Encoder
        self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)

        # MLM头(Masked Language Modeling)
        self.mlm_head = nn.Linear(d_model, vocab_size)

        # NSP头(Next Sentence Prediction)
        self.nsp_head = nn.Linear(d_model, 2)

    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        """
        Args:
            input_ids: [batch, seq_len]
            segment_ids: [batch, seq_len] (句子A为0,句子B为1)
            attention_mask: [batch, seq_len]

        Returns:
            encoder_output: [batch, seq_len, d_model]
        """
        batch_size, seq_len = input_ids.size()

        # Token嵌入
        token_emb = self.token_embedding(input_ids)

        # 位置嵌入
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        position_emb = self.position_embedding(positions)

        # Segment嵌入
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids)
        segment_emb = self.segment_embedding(segment_ids)

        # 组合嵌入
        embeddings = token_emb + position_emb + segment_emb
        embeddings = self.dropout(embeddings)

        # Attention mask
        if attention_mask is not None:
            # 扩展维度 [batch, 1, 1, seq_len]
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Encoder
        encoder_output = self.encoder(embeddings, attention_mask)

        return encoder_output

    def get_mlm_logits(self, encoder_output):
        """获取MLM预测logits"""
        return self.mlm_head(encoder_output)

    def get_nsp_logits(self, encoder_output):
        """获取NSP预测logits(使用[CLS] token)"""
        cls_output = encoder_output[:, 0, :]  # [batch, d_model]
        return self.nsp_head(cls_output)

# 使用示例
bert = BERT(vocab_size=30000, d_model=768, num_heads=12, num_layers=12)

# 输入
input_ids = torch.randint(0, 30000, (2, 128))
segment_ids = torch.cat([
    torch.zeros(2, 64, dtype=torch.long),
    torch.ones(2, 64, dtype=torch.long)
], dim=1)

# Forward
encoder_output = bert(input_ids, segment_ids)
mlm_logits = bert.get_mlm_logits(encoder_output)
nsp_logits = bert.get_nsp_logits(encoder_output)

print(f"Encoder输出: {encoder_output.shape}")
print(f"MLM logits: {mlm_logits.shape}")
print(f"NSP logits: {nsp_logits.shape}")

5.2 GPT(Decoder-only)

class GPT(nn.Module):
    """GPT模型(只使用Decoder,无Cross-Attention)"""

    def __init__(
        self,
        vocab_size,
        d_model=768,
        num_heads=12,
        num_layers=12,
        d_ff=3072,
        max_seq_len=1024,
        dropout=0.1
    ):
        super().__init__()

        self.d_model = d_model

        # Token嵌入
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # 位置嵌入
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        self.dropout = nn.Dropout(dropout)

        # Transformer Decoder层(移除Cross-Attention)
        self.layers = nn.ModuleList([
            GPTBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.ln_f = LayerNorm(d_model)

        # 语言模型头
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # 权重共享(嵌入和输出投影)
        self.lm_head.weight = self.token_embedding.weight

    def forward(self, input_ids):
        """
        Args:
            input_ids: [batch, seq_len]

        Returns:
            logits: [batch, seq_len, vocab_size]
        """
        batch_size, seq_len = input_ids.size()

        # Token嵌入
        token_emb = self.token_embedding(input_ids)

        # 位置嵌入
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        position_emb = self.position_embedding(positions)

        # 组合
        x = self.dropout(token_emb + position_emb)

        # Causal mask
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)

        # 通过所有层
        for layer in self.layers:
            x = layer(x, causal_mask)

        # 最后的LayerNorm
        x = self.ln_f(x)

        # 语言模型头
        logits = self.lm_head(x)

        return logits

    def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
        """
        自回归生成

        Args:
            input_ids: [batch, seq_len]
            max_new_tokens: 生成的最大token数
            temperature: 温度参数
            top_k: Top-K采样

        Returns:
            generated: [batch, seq_len + max_new_tokens]
        """
        for _ in range(max_new_tokens):
            # 截断到最大长度
            input_ids_cond = input_ids[:, -self.position_embedding.num_embeddings:]

            # Forward
            logits = self(input_ids_cond)

            # 只看最后一个位置
            logits = logits[:, -1, :] / temperature

            # Top-K采样
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Softmax
            probs = F.softmax(logits, dim=-1)

            # 采样
            next_token = torch.multinomial(probs, num_samples=1)

            # 拼接
            input_ids = torch.cat([input_ids, next_token], dim=1)

        return input_ids

class GPTBlock(nn.Module):
    """GPT块(Masked Self-Attention + FFN)"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        self.ln1 = LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ln2 = LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)

    def forward(self, x, mask):
        # Pre-LN: Norm -> Attention -> Add
        x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x), mask)[0]

        # Pre-LN: Norm -> FFN -> Add
        x = x + self.ffn(self.ln2(x))

        return x

# 使用示例
gpt = GPT(vocab_size=50257, d_model=768, num_heads=12, num_layers=12)

# 输入
input_ids = torch.randint(0, 50257, (2, 64))

# Forward
logits = gpt(input_ids)

print(f"输入: {input_ids.shape}")
print(f"输出logits: {logits.shape}")

# 生成
generated = gpt.generate(input_ids, max_new_tokens=50, temperature=0.8, top_k=50)
print(f"生成序列: {generated.shape}")

5.3 T5(Encoder-Decoder)

T5使用完整的Encoder-Decoder架构,与原始Transformer类似,但使用相对位置编码。

class RelativePositionBias(nn.Module):
    """T5相对位置偏置"""

    def __init__(self, num_heads, max_distance=128):
        super().__init__()
        self.num_heads = num_heads
        self.max_distance = max_distance

        # 相对位置嵌入
        self.relative_attention_bias = nn.Embedding(
            2 * max_distance + 1,
            num_heads
        )

    def forward(self, seq_len):
        """
        计算相对位置偏置

        Args:
            seq_len: 序列长度

        Returns:
            bias: [1, num_heads, seq_len, seq_len]
        """
        # 计算相对位置
        positions = torch.arange(seq_len, device=self.relative_attention_bias.weight.device)
        relative_positions = positions[None, :] - positions[:, None]  # [seq_len, seq_len]

        # 裁剪到[-max_distance, max_distance]
        relative_positions = torch.clamp(
            relative_positions,
            -self.max_distance,
            self.max_distance
        )

        # 偏移使索引从0开始
        relative_positions += self.max_distance

        # 获取偏置
        bias = self.relative_attention_bias(relative_positions)  # [seq_len, seq_len, num_heads]

        # 转置为 [1, num_heads, seq_len, seq_len]
        bias = bias.permute(2, 0, 1).unsqueeze(0)

        return bias

# T5使用的是原始Transformer架构 + 相对位置编码
# 这里省略完整实现,关键区别在于:
# 1. 使用相对位置偏置而非绝对位置编码
# 2. LayerNorm放在残差连接之前(Pre-LN)
# 3. 使用RMSNorm而非LayerNorm

6. 效率优化

6.1 Flash Attention

"""
Flash Attention原理:

标准Attention:
1. 计算QK^T: O(N^2 × d)
2. Softmax: O(N^2)
3. 乘以V: O(N^2 × d)
显存: O(N^2) (存储attention matrix)

Flash Attention:
1. 分块计算,不存储完整attention matrix
2. 使用kernel fusion减少内存访问
3. 显存: O(N) (只存储输出)
4. 速度: 2-4x加速
"""

# 使用Flash Attention (需要安装flash-attn库)
try:
    from flash_attn import flash_attn_func

    class FlashMultiHeadAttention(nn.Module):
        """使用Flash Attention的多头注意力"""

        def __init__(self, d_model, num_heads, dropout=0.1):
            super().__init__()
            assert d_model % num_heads == 0

            self.d_model = d_model
            self.num_heads = num_heads
            self.d_k = d_model // num_heads

            self.qkv_proj = nn.Linear(d_model, 3 * d_model)
            self.out_proj = nn.Linear(d_model, d_model)
            self.dropout = dropout

        def forward(self, x, mask=None):
            batch_size, seq_len, _ = x.size()

            # QKV投影
            qkv = self.qkv_proj(x)
            qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)

            # Flash Attention
            output = flash_attn_func(
                qkv,
                dropout_p=self.dropout if self.training else 0.0,
                causal=True if mask is not None else False
            )

            # 重塑
            output = output.reshape(batch_size, seq_len, self.d_model)

            # 输出投影
            return self.out_proj(output)

except ImportError:
    print("Flash Attention not installed")

6.2 Multi-Query Attention

class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention (MQA)

    与Multi-Head Attention的区别:
    - MHA: 每个头有独立的K, V
    - MQA: 所有头共享K, V,只有Q是多头的

    优势:
    - 减少KV cache大小(重要!)
    - 推理速度更快
    - 参数量更少
    """

    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Q是多头的
        self.W_q = nn.Linear(d_model, d_model)

        # K, V是单头的
        self.W_k = nn.Linear(d_model, self.d_k)
        self.W_v = nn.Linear(d_model, self.d_k)

        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # Q: 多头
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
        Q = Q.transpose(1, 2)  # [batch, num_heads, seq_len, d_k]

        # K, V: 单头
        K = self.W_k(x).unsqueeze(1)  # [batch, 1, seq_len, d_k]
        V = self.W_v(x).unsqueeze(1)  # [batch, 1, seq_len, d_k]

        # 广播K, V到所有头
        K = K.expand(batch_size, self.num_heads, seq_len, self.d_k)
        V = V.expand(batch_size, self.num_heads, seq_len, self.d_k)

        # 标准Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context = torch.matmul(attention_weights, V)

        # 合并头
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, self.d_model)

        return self.W_o(context)

# KV cache大小比较
def compare_kv_cache_size():
    """比较MHA和MQA的KV cache大小"""
    batch_size = 1
    seq_len = 2048
    d_model = 4096
    num_heads = 32
    d_k = d_model // num_heads  # 128

    # MHA: 每个头都有K, V
    mha_kv_size = 2 * batch_size * num_heads * seq_len * d_k
    # MQA: 只有一份K, V
    mqa_kv_size = 2 * batch_size * 1 * seq_len * d_k

    print(f"MHA KV cache: {mha_kv_size / 1024**2:.2f} MB")
    print(f"MQA KV cache: {mqa_kv_size / 1024**2:.2f} MB")
    print(f"节省: {(1 - mqa_kv_size/mha_kv_size)*100:.1f}%")

compare_kv_cache_size()

6.3 Group Query Attention

class GroupQueryAttention(nn.Module):
    """
    Group Query Attention (GQA)

    MHA和MQA的折中方案:
    - 将num_heads个头分成num_groups组
    - 每组共享K, V

    例如:8个头分成2组
    - Head 0-3: 共享KV_0
    - Head 4-7: 共享KV_1
    """

    def __init__(self, d_model, num_heads, num_groups, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        assert num_heads % num_groups == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.heads_per_group = num_heads // num_groups
        self.d_k = d_model // num_heads

        # Q: 所有头
        self.W_q = nn.Linear(d_model, d_model)

        # K, V: 每组一个
        self.W_k = nn.Linear(d_model, num_groups * self.d_k)
        self.W_v = nn.Linear(d_model, num_groups * self.d_k)

        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # Q: [batch, num_heads, seq_len, d_k]
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
        Q = Q.transpose(1, 2)

        # K, V: [batch, num_groups, seq_len, d_k]
        K = self.W_k(x).view(batch_size, seq_len, self.num_groups, self.d_k)
        K = K.transpose(1, 2)

        V = self.W_v(x).view(batch_size, seq_len, self.num_groups, self.d_k)
        V = V.transpose(1, 2)

        # 扩展K, V到所有头
        # [batch, num_groups, seq_len, d_k] -> [batch, num_heads, seq_len, d_k]
        K = K.repeat_interleave(self.heads_per_group, dim=1)
        V = V.repeat_interleave(self.heads_per_group, dim=1)

        # 标准Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context = torch.matmul(attention_weights, V)

        # 合并头
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, self.d_model)

        return self.W_o(context)

# 使用示例
gqa = GroupQueryAttention(d_model=512, num_heads=8, num_groups=2)

x = torch.randn(2, 10, 512)
output = gqa(x)

print(f"输入: {x.shape}")
print(f"输出: {output.shape}")

# KV cache大小:
# MHA: 2 × batch × num_heads × seq_len × d_k
# GQA: 2 × batch × num_groups × seq_len × d_k (num_groups < num_heads)
# MQA: 2 × batch × 1 × seq_len × d_k
Prev
07-模型调度与资源管理
Next
09-大语言模型原理与架构