08-Transformer架构深度解析
引言
Transformer是现代深度学习最重要的架构之一,彻底改变了自然语言处理领域。本章将从数学原理到代码实现,深度解析Transformer的每个组件。
1. Attention机制
1.1 Self-Attention数学原理
核心思想
Self-Attention允许序列中的每个位置关注序列中的所有位置,从而捕获长距离依赖关系。
给定输入序列 X = [x₁, x₂, ..., xₙ]
Self-Attention计算每个位置与其他所有位置的相关性,
然后基于这些相关性加权聚合信息。
数学定义
Attention(Q, K, V) = softmax(QK^T / √d_k) V
其中:
- Q (Query): 查询矩阵,形状 [seq_len, d_k]
- K (Key): 键矩阵,形状 [seq_len, d_k]
- V (Value): 值矩阵,形状 [seq_len, d_v]
- d_k: 键的维度(用于缩放)
计算步骤
import numpy as np
def self_attention_numpy(X, W_q, W_k, W_v):
"""
NumPy实现Self-Attention
Args:
X: 输入序列 [seq_len, d_model]
W_q: Query权重 [d_model, d_k]
W_k: Key权重 [d_model, d_k]
W_v: Value权重 [d_model, d_v]
Returns:
输出序列 [seq_len, d_v]
"""
# 步骤1: 计算Q, K, V
Q = np.dot(X, W_q) # [seq_len, d_k]
K = np.dot(X, W_k) # [seq_len, d_k]
V = np.dot(X, W_v) # [seq_len, d_v]
# 步骤2: 计算注意力分数
# QK^T: [seq_len, seq_len]
scores = np.dot(Q, K.T)
# 步骤3: 缩放(避免softmax梯度消失)
d_k = Q.shape[-1]
scaled_scores = scores / np.sqrt(d_k)
# 步骤4: Softmax归一化
# 对每一行做softmax
attention_weights = softmax(scaled_scores, axis=-1)
# 步骤5: 加权聚合
output = np.dot(attention_weights, V) # [seq_len, d_v]
return output, attention_weights
def softmax(x, axis=-1):
"""数值稳定的Softmax"""
# 减去最大值提高数值稳定性
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
# 示例
seq_len = 4
d_model = 8
d_k = d_v = 4
# 输入序列
X = np.random.randn(seq_len, d_model)
# 权重矩阵
W_q = np.random.randn(d_model, d_k)
W_k = np.random.randn(d_model, d_k)
W_v = np.random.randn(d_model, d_v)
# 计算Self-Attention
output, weights = self_attention_numpy(X, W_q, W_k, W_v)
print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"\n注意力权重(每行求和应为1):\n{weights}")
print(f"每行求和: {weights.sum(axis=-1)}")
为什么要缩放?
# 展示不缩放的问题
d_k = 64
Q = np.random.randn(10, d_k)
K = np.random.randn(10, d_k)
# 不缩放的分数
scores_unscaled = np.dot(Q, K.T)
print(f"不缩放分数的方差: {scores_unscaled.var():.2f}")
# 缩放后的分数
scores_scaled = scores_unscaled / np.sqrt(d_k)
print(f"缩放后分数的方差: {scores_scaled.var():.2f}")
# 分析:
# 当d_k较大时,QK^T的方差约为d_k
# 导致softmax输入值很大,梯度接近0
# 缩放后方差接近1,避免梯度消失
1.2 Q、K、V矩阵计算
直观理解
"""
Q (Query): "我想要什么信息?"
K (Key): "我有什么信息?"
V (Value): "我的实际内容是什么?"
类比:图书馆检索
- Query: 你的搜索关键词
- Key: 书籍的索引标签
- Value: 书籍的实际内容
匹配过程:
1. 用Query与所有Key比较(QK^T)
2. 找到最相关的书(softmax)
3. 读取这些书的内容(加权V)
"""
可视化示例
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention():
"""可视化Attention机制"""
# 简单句子: "The cat sat on the mat"
words = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(words)
# 模拟词嵌入
np.random.seed(42)
X = np.random.randn(seq_len, 8)
# 权重
W_q = np.random.randn(8, 4)
W_k = np.random.randn(8, 4)
W_v = np.random.randn(8, 4)
# 计算Attention
Q = np.dot(X, W_q)
K = np.dot(X, W_k)
V = np.dot(X, W_v)
scores = np.dot(Q, K.T) / np.sqrt(4)
attention_weights = softmax(scores, axis=-1)
# 可视化
plt.figure(figsize=(8, 6))
sns.heatmap(
attention_weights,
xticklabels=words,
yticklabels=words,
annot=True,
fmt='.2f',
cmap='YlOrRd'
)
plt.title('Self-Attention Weights')
plt.xlabel('Key (attending to)')
plt.ylabel('Query (attending from)')
plt.tight_layout()
plt.savefig('attention_weights.png')
# 解读:
# attention_weights[i, j] 表示位置i对位置j的注意力
# 例如:attention_weights[2, 1] 表示 "sat" 对 "cat" 的注意力
visualize_attention()
1.3 完整代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
"""Self-Attention层"""
def __init__(self, d_model, d_k, d_v):
"""
Args:
d_model: 输入维度
d_k: Key/Query维度
d_v: Value维度
"""
super().__init__()
self.d_k = d_k
# 线性变换矩阵
self.W_q = nn.Linear(d_model, d_k, bias=False)
self.W_k = nn.Linear(d_model, d_k, bias=False)
self.W_v = nn.Linear(d_model, d_v, bias=False)
def forward(self, x, mask=None):
"""
Args:
x: [batch_size, seq_len, d_model]
mask: [batch_size, seq_len, seq_len] (可选)
Returns:
output: [batch_size, seq_len, d_v]
attention_weights: [batch_size, seq_len, seq_len]
"""
# 1. 计算Q, K, V
Q = self.W_q(x) # [batch, seq_len, d_k]
K = self.W_k(x) # [batch, seq_len, d_k]
V = self.W_v(x) # [batch, seq_len, d_v]
# 2. 计算注意力分数
# Q @ K^T: [batch, seq_len, seq_len]
scores = torch.matmul(Q, K.transpose(-2, -1))
# 3. 缩放
scores = scores / math.sqrt(self.d_k)
# 4. 应用mask(可选)
if mask is not None:
# mask中0的位置会被设为-inf,softmax后为0
scores = scores.masked_fill(mask == 0, float('-inf'))
# 5. Softmax
attention_weights = F.softmax(scores, dim=-1)
# 6. 加权聚合
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 使用示例
batch_size = 2
seq_len = 5
d_model = 8
x = torch.randn(batch_size, seq_len, d_model)
attention = SelfAttention(d_model, d_k=4, d_v=4)
output, weights = attention(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
# 验证注意力权重归一化
print(f"注意力权重每行求和: {weights.sum(dim=-1)}")
Masked Attention
def create_causal_mask(seq_len):
"""
创建因果mask(用于解码器)
上三角为0,防止看到未来信息
Returns:
mask: [seq_len, seq_len]
例如 seq_len=4:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# 使用causal mask
seq_len = 5
mask = create_causal_mask(seq_len)
print("因果Mask:")
print(mask)
# 应用到attention
x = torch.randn(1, seq_len, 8)
attention = SelfAttention(8, 4, 4)
# 扩展mask维度以匹配batch
mask = mask.unsqueeze(0) # [1, seq_len, seq_len]
output, weights = attention(x, mask=mask)
print(f"\n带Mask的注意力权重:")
print(weights[0])上三角都是0,表示每个位置只能看到自己及之前的位置
Padding Mask
def create_padding_mask(seq_lengths, max_len):
"""
创建padding mask
Args:
seq_lengths: [batch_size] 每个序列的实际长度
max_len: 最大长度
Returns:
mask: [batch_size, max_len, max_len]
"""
batch_size = len(seq_lengths)
mask = torch.zeros(batch_size, max_len, max_len)
for i, length in enumerate(seq_lengths):
# 实际序列部分设为1
mask[i, :length, :length] = 1
return maskbatch中序列长度不同
seq_lengths = torch.tensor([3, 5, 4])
max_len = 5
mask = create_padding_mask(seq_lengths, max_len)
print("Padding Mask (batch=3, max_len=5):")
for i in range(3):
print(f"\n序列{i} (长度={seq_lengths[i]}):")
print(mask[i])
2. Multi-Head Attention
2.1 多头注意力机制
核心思想
"""
Single-head attention只能捕获一种关系模式
Multi-head attention可以并行学习多种关系
类比:
- 单头:从一个角度看问题
- 多头:从多个角度同时看问题
例如,对于句子 "The cat sat on the mat":
- Head 1: 关注语法关系(主谓宾)
- Head 2: 关注语义关系(动作-对象)
- Head 3: 关注位置关系(空间位置)
- ...
"""
数学定义
MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) W^O
其中:
headᵢ = Attention(QW^Q_i, KW^K_i, VW^V_i)
参数:
- h: 头的数量
- W^Q_i, W^K_i, W^V_i: 第i个头的投影矩阵
- W^O: 输出投影矩阵
2.2 为什么需要多头
import torch
import torch.nn as nn
def analyze_multihead_benefits():
"""分析多头注意力的优势"""
# 句子: "The cat sat on the mat"
words = ["The", "cat", "sat", "on", "the", "mat"]
# 模拟3个注意力头学到的不同模式
# Head 1: 语法关系(动词关注主语和宾语)
attention_head1 = np.array([
[0.8, 0.1, 0.0, 0.0, 0.0, 0.1], # The -> The, cat, mat
[0.1, 0.8, 0.1, 0.0, 0.0, 0.0], # cat -> cat, sat
[0.0, 0.4, 0.2, 0.1, 0.0, 0.3], # sat -> cat, mat (主谓宾)
[0.0, 0.0, 0.3, 0.4, 0.1, 0.2], # on -> sat, on, mat
[0.0, 0.0, 0.0, 0.2, 0.6, 0.2], # the -> the, mat
[0.1, 0.1, 0.2, 0.2, 0.1, 0.3], # mat -> sat, on
])
# Head 2: 位置关系(相邻词)
attention_head2 = np.array([
[0.5, 0.5, 0.0, 0.0, 0.0, 0.0], # The -> The, cat
[0.3, 0.4, 0.3, 0.0, 0.0, 0.0], # cat -> The, cat, sat
[0.0, 0.3, 0.4, 0.3, 0.0, 0.0], # sat -> cat, sat, on
[0.0, 0.0, 0.3, 0.4, 0.3, 0.0], # on -> sat, on, the
[0.0, 0.0, 0.0, 0.3, 0.4, 0.3], # the -> on, the, mat
[0.0, 0.0, 0.0, 0.0, 0.4, 0.6], # mat -> the, mat
])
# Head 3: 语义关系(名词和修饰词)
attention_head3 = np.array([
[0.5, 0.5, 0.0, 0.0, 0.0, 0.0], # The -> cat (修饰)
[0.2, 0.8, 0.0, 0.0, 0.0, 0.0], # cat -> The, cat
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # sat -> sat
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0], # on -> on
[0.0, 0.0, 0.0, 0.0, 0.4, 0.6], # the -> mat (修饰)
[0.0, 0.0, 0.0, 0.0, 0.3, 0.7], # mat -> the, mat
])
# 可视化三个头
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, (head, title) in enumerate([
(attention_head1, 'Head 1: Syntactic'),
(attention_head2, 'Head 2: Positional'),
(attention_head3, 'Head 3: Semantic')
]):
sns.heatmap(head, annot=True, fmt='.1f', cmap='YlOrRd',
xticklabels=words, yticklabels=words, ax=axes[idx])
axes[idx].set_title(title)
plt.tight_layout()
plt.savefig('multihead_attention.png')
print("不同的头关注不同的语言学特征:")
print("Head 1: 捕获语法依赖关系")
print("Head 2: 捕获局部上下文")
print("Head 3: 捕获语义修饰关系")
analyze_multihead_benefits()
2.3 PyTorch实现
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, d_model, num_heads, dropout=0.1):
"""
Args:
d_model: 模型维度(必须能被num_heads整除)
num_heads: 头的数量
dropout: Dropout比例
"""
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# Q, K, V的线性变换(所有头共享参数,然后拆分)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 输出投影
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x):
"""
将最后一维拆分为(num_heads, d_k)
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, num_heads, seq_len, d_k]
"""
batch_size, seq_len, d_model = x.size()
# 重塑: [batch, seq_len, num_heads, d_k]
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
# 转置: [batch, num_heads, seq_len, d_k]
return x.transpose(1, 2)
def combine_heads(self, x):
"""
合并多个头
Args:
x: [batch, num_heads, seq_len, d_k]
Returns:
[batch, seq_len, d_model]
"""
batch_size, num_heads, seq_len, d_k = x.size()
# 转置: [batch, seq_len, num_heads, d_k]
x = x.transpose(1, 2).contiguous()
# 合并: [batch, seq_len, d_model]
return x.view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
"""
Args:
query: [batch, seq_len_q, d_model]
key: [batch, seq_len_k, d_model]
value: [batch, seq_len_v, d_model]
mask: [batch, 1, seq_len_q, seq_len_k] 或 [batch, seq_len_q, seq_len_k]
Returns:
output: [batch, seq_len_q, d_model]
attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
"""
batch_size = query.size(0)
# 1. 线性变换
Q = self.W_q(query) # [batch, seq_len_q, d_model]
K = self.W_k(key) # [batch, seq_len_k, d_model]
V = self.W_v(value) # [batch, seq_len_v, d_model]
# 2. 拆分多头
Q = self.split_heads(Q) # [batch, num_heads, seq_len_q, d_k]
K = self.split_heads(K) # [batch, num_heads, seq_len_k, d_k]
V = self.split_heads(V) # [batch, num_heads, seq_len_v, d_k]
# 3. 计算注意力分数
# [batch, num_heads, seq_len_q, seq_len_k]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 4. 应用mask
if mask is not None:
# 扩展mask到多头: [batch, 1, seq_len_q, seq_len_k]
if mask.dim() == 3:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf'))
# 5. Softmax
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 6. 加权聚合
# [batch, num_heads, seq_len_q, d_k]
context = torch.matmul(attention_weights, V)
# 7. 合并多头
# [batch, seq_len_q, d_model]
context = self.combine_heads(context)
# 8. 输出投影
output = self.W_o(context)
return output, attention_weights
# 使用示例
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)
# Self-Attention
output, weights = mha(x, x, x)
print(f"输入: {x.shape}")
print(f"输出: {output.shape}")
print(f"注意力权重: {weights.shape}")
# Cross-Attention(如Encoder-Decoder)
encoder_output = torch.randn(batch_size, 20, d_model)
decoder_input = torch.randn(batch_size, 10, d_model)
output, weights = mha(
query=decoder_input,
key=encoder_output,
value=encoder_output
)
print(f"\nCross-Attention:")
print(f"Decoder输入: {decoder_input.shape}")
print(f"Encoder输出: {encoder_output.shape}")
print(f"输出: {output.shape}")
print(f"注意力权重: {weights.shape}")
3. Transformer完整架构
3.1 Encoder-Decoder结构
"""
Transformer架构:
Encoder:
输入 -> Embedding -> Position Encoding
-> [Multi-Head Attention -> Add&Norm
-> Feed Forward -> Add&Norm] × N
-> Encoder输出
Decoder:
输出 -> Embedding -> Position Encoding
-> [Masked Multi-Head Attention -> Add&Norm
-> Cross-Attention -> Add&Norm
-> Feed Forward -> Add&Norm] × N
-> Linear -> Softmax
-> 输出概率
"""
3.2 Position Encoding
为什么需要位置编码?
"""
Attention机制是permutation-invariant(排列不变的)
即:Attention([A, B, C]) = Attention([C, A, B])
但语言是有顺序的:
"狗咬了人" ≠ "人咬了狗"
因此需要注入位置信息
"""
正弦位置编码
class PositionalEncoding(nn.Module):
"""正弦位置编码"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
"""
Args:
d_model: 模型维度
max_len: 最大序列长度
dropout: Dropout比例
"""
super().__init__()
self.dropout = nn.Dropout(dropout)
# 创建位置编码矩阵 [max_len, d_model]
pe = torch.zeros(max_len, d_model)
# 位置索引 [max_len, 1]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 频率项
# div_term[i] = 10000^(2i/d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# 应用正弦和余弦
# PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
# PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 增加batch维度 [1, max_len, d_model]
pe = pe.unsqueeze(0)
# 注册为buffer(不是参数,但会保存到state_dict)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
# 添加位置编码
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# 可视化位置编码
def visualize_positional_encoding():
"""可视化位置编码"""
d_model = 512
max_len = 100
pe_layer = PositionalEncoding(d_model, max_len)
pe = pe_layer.pe.squeeze(0).numpy() # [max_len, d_model]
plt.figure(figsize=(12, 6))
# 绘制前64维
plt.imshow(pe[:, :64].T, cmap='RdBu', aspect='auto')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding')
plt.colorbar()
plt.tight_layout()
plt.savefig('positional_encoding.png')
# 绘制几个位置的编码曲线
plt.figure(figsize=(12, 6))
positions = [0, 10, 20, 40, 80]
for pos in positions:
plt.plot(pe[pos, :128], label=f'Position {pos}')
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.title('Positional Encoding Values at Different Positions')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('pe_curves.png')
visualize_positional_encoding()
可学习位置编码
class LearnedPositionalEncoding(nn.Module):
"""可学习的位置编码(如BERT)"""
def __init__(self, d_model, max_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# 位置嵌入(作为参数学习)
self.position_embeddings = nn.Embedding(max_len, d_model)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.size()
# 创建位置索引 [seq_len]
positions = torch.arange(seq_len, device=x.device)
# 获取位置编码 [seq_len, d_model]
position_encodings = self.position_embeddings(positions)
# 广播到batch: [batch, seq_len, d_model]
x = x + position_encodings.unsqueeze(0)
return self.dropout(x)
3.3 Feed Forward Network
class FeedForward(nn.Module):
"""Position-wise Feed-Forward Network"""
def __init__(self, d_model, d_ff, dropout=0.1):
"""
Args:
d_model: 模型维度
d_ff: 前馈网络隐藏层维度(通常是d_model的4倍)
dropout: Dropout比例
"""
super().__init__()
# FFN(x) = max(0, xW1 + b1)W2 + b2
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
# 第一层 + ReLU
x = self.linear1(x)
x = F.relu(x)
x = self.dropout(x)
# 第二层
x = self.linear2(x)
return x
# 使用示例
batch_size = 2
seq_len = 10
d_model = 512
d_ff = 2048
x = torch.randn(batch_size, seq_len, d_model)
ffn = FeedForward(d_model, d_ff)
output = ffn(x)
print(f"输入: {x.shape}")
print(f"输出: {output.shape}")
3.4 Layer Normalization
class LayerNorm(nn.Module):
"""Layer Normalization"""
def __init__(self, d_model, eps=1e-6):
"""
Args:
d_model: 特征维度
eps: 数值稳定性参数
"""
super().__init__()
# 可学习参数
self.gamma = nn.Parameter(torch.ones(d_model)) # 缩放
self.beta = nn.Parameter(torch.zeros(d_model)) # 平移
self.eps = eps
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
# 计算均值和方差(在最后一维)
mean = x.mean(dim=-1, keepdim=True) # [batch, seq_len, 1]
std = x.std(dim=-1, keepdim=True) # [batch, seq_len, 1]
# 归一化
x_norm = (x - mean) / (std + self.eps)
# 缩放和平移
return self.gamma * x_norm + self.beta
# 对比LayerNorm vs BatchNorm
def compare_normalization():
"""比较不同归一化方法"""
batch_size = 2
seq_len = 3
d_model = 4
x = torch.tensor([
[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]],
[[13.0, 14.0, 15.0, 16.0],
[17.0, 18.0, 19.0, 20.0],
[21.0, 22.0, 23.0, 24.0]]
])
print("原始数据:")
print(x)
# LayerNorm: 对每个样本的每个位置归一化
ln = nn.LayerNorm(d_model)
x_ln = ln(x)
print("\nLayerNorm (在特征维度归一化):")
print(x_ln)
print("每个位置的均值:", x_ln.mean(dim=-1))
print("每个位置的方差:", x_ln.var(dim=-1))
# BatchNorm: 对整个batch的每个特征归一化
bn = nn.BatchNorm1d(d_model)
x_bn = bn(x.transpose(1, 2)).transpose(1, 2)
print("\nBatchNorm (在batch维度归一化):")
print(x_bn)
compare_normalization()
3.5 残差连接
class SublayerConnection(nn.Module):
"""
残差连接 + LayerNorm
实现两种顺序:
1. Post-LN: x + Sublayer(Norm(x)) [原始Transformer]
2. Pre-LN: x + Norm(Sublayer(x)) [更稳定,现代实现]
"""
def __init__(self, d_model, dropout=0.1, pre_norm=True):
super().__init__()
self.norm = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.pre_norm = pre_norm
def forward(self, x, sublayer):
"""
Args:
x: 输入
sublayer: 子层函数(如MultiHeadAttention或FeedForward)
Returns:
输出
"""
if self.pre_norm:
# Pre-LN: Norm -> Sublayer -> Dropout -> Add
return x + self.dropout(sublayer(self.norm(x)))
else:
# Post-LN: Sublayer -> Dropout -> Add -> Norm
return self.norm(x + self.dropout(sublayer(x)))
# 使用示例
d_model = 512
x = torch.randn(2, 10, d_model)
# 创建子层
sublayer = FeedForward(d_model, 2048)
# Pre-LN残差连接
residual = SublayerConnection(d_model, pre_norm=True)
output = residual(x, sublayer)
print(f"输入: {x.shape}")
print(f"输出: {output.shape}")
4. 从零实现Transformer
4.1 Encoder层
class EncoderLayer(nn.Module):
"""Transformer Encoder层"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
"""
Args:
d_model: 模型维度
num_heads: 注意力头数
d_ff: 前馈网络隐藏层维度
dropout: Dropout比例
"""
super().__init__()
# Multi-Head Self-Attention
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
# Feed-Forward Network
self.feed_forward = FeedForward(d_model, d_ff, dropout)
# 两个残差连接
self.sublayer1 = SublayerConnection(d_model, dropout)
self.sublayer2 = SublayerConnection(d_model, dropout)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, d_model]
mask: [batch, seq_len, seq_len]
Returns:
[batch, seq_len, d_model]
"""
# Self-Attention子层
x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, mask)[0])
# Feed-Forward子层
x = self.sublayer2(x, self.feed_forward)
return x
class TransformerEncoder(nn.Module):
"""Transformer Encoder"""
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
"""
Args:
num_layers: Encoder层数
d_model: 模型维度
num_heads: 注意力头数
d_ff: 前馈网络维度
dropout: Dropout比例
"""
super().__init__()
# N个Encoder层
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# 最后的LayerNorm
self.norm = LayerNorm(d_model)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, d_model]
mask: [batch, seq_len, seq_len]
Returns:
[batch, seq_len, d_model]
"""
# 通过所有Encoder层
for layer in self.layers:
x = layer(x, mask)
# 最后的LayerNorm
return self.norm(x)
# 使用示例
encoder = TransformerEncoder(
num_layers=6,
d_model=512,
num_heads=8,
d_ff=2048,
dropout=0.1
)
x = torch.randn(2, 10, 512)
output = encoder(x)
print(f"输入: {x.shape}")
print(f"输出: {output.shape}")
4.2 Decoder层
class DecoderLayer(nn.Module):
"""Transformer Decoder层"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Masked Multi-Head Self-Attention
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
# Cross-Attention (Encoder-Decoder Attention)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
# Feed-Forward Network
self.feed_forward = FeedForward(d_model, d_ff, dropout)
# 三个残差连接
self.sublayer1 = SublayerConnection(d_model, dropout)
self.sublayer2 = SublayerConnection(d_model, dropout)
self.sublayer3 = SublayerConnection(d_model, dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: Decoder输入 [batch, tgt_len, d_model]
encoder_output: Encoder输出 [batch, src_len, d_model]
src_mask: Encoder mask [batch, src_len, src_len]
tgt_mask: Decoder mask (causal) [batch, tgt_len, tgt_len]
Returns:
[batch, tgt_len, d_model]
"""
# Masked Self-Attention子层
x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, tgt_mask)[0])
# Cross-Attention子层(Query来自Decoder,Key/Value来自Encoder)
x = self.sublayer2(
x,
lambda x: self.cross_attn(x, encoder_output, encoder_output, src_mask)[0]
)
# Feed-Forward子层
x = self.sublayer3(x, self.feed_forward)
return x
class TransformerDecoder(nn.Module):
"""Transformer Decoder"""
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# N个Decoder层
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# 最后的LayerNorm
self.norm = LayerNorm(d_model)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: [batch, tgt_len, d_model]
encoder_output: [batch, src_len, d_model]
src_mask: [batch, src_len, src_len]
tgt_mask: [batch, tgt_len, tgt_len]
Returns:
[batch, tgt_len, d_model]
"""
# 通过所有Decoder层
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
# 最后的LayerNorm
return self.norm(x)
# 使用示例
decoder = TransformerDecoder(
num_layers=6,
d_model=512,
num_heads=8,
d_ff=2048,
dropout=0.1
)
# Decoder输入和Encoder输出
decoder_input = torch.randn(2, 8, 512)
encoder_output = torch.randn(2, 10, 512)
# 创建causal mask
tgt_len = decoder_input.size(1)
tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)
output = decoder(decoder_input, encoder_output, tgt_mask=tgt_mask)
print(f"Decoder输入: {decoder_input.shape}")
print(f"Encoder输出: {encoder_output.shape}")
print(f"Decoder输出: {output.shape}")
4.3 完整Transformer模型
class Transformer(nn.Module):
"""完整的Transformer模型(用于序列到序列任务)"""
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
max_seq_len=5000,
dropout=0.1
):
"""
Args:
src_vocab_size: 源语言词表大小
tgt_vocab_size: 目标语言词表大小
d_model: 模型维度
num_heads: 注意力头数
num_encoder_layers: Encoder层数
num_decoder_layers: Decoder层数
d_ff: 前馈网络维度
max_seq_len: 最大序列长度
dropout: Dropout比例
"""
super().__init__()
self.d_model = d_model
# 嵌入层
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Encoder和Decoder
self.encoder = TransformerEncoder(
num_encoder_layers, d_model, num_heads, d_ff, dropout
)
self.decoder = TransformerDecoder(
num_decoder_layers, d_model, num_heads, d_ff, dropout
)
# 输出投影
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
# 初始化参数
self._init_parameters()
def _init_parameters(self):
"""参数初始化"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src, src_mask=None):
"""
Encoder
Args:
src: [batch, src_len]
src_mask: [batch, src_len, src_len]
Returns:
[batch, src_len, d_model]
"""
# 嵌入 + 位置编码
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# Encoder
return self.encoder(x, src_mask)
def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
"""
Decoder
Args:
tgt: [batch, tgt_len]
encoder_output: [batch, src_len, d_model]
src_mask: [batch, src_len, src_len]
tgt_mask: [batch, tgt_len, tgt_len]
Returns:
[batch, tgt_len, d_model]
"""
# 嵌入 + 位置编码
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# Decoder
return self.decoder(x, encoder_output, src_mask, tgt_mask)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
"""
前向传播
Args:
src: [batch, src_len]
tgt: [batch, tgt_len]
src_mask: [batch, src_len, src_len]
tgt_mask: [batch, tgt_len, tgt_len]
Returns:
logits: [batch, tgt_len, tgt_vocab_size]
"""
# Encode
encoder_output = self.encode(src, src_mask)
# Decode
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
# 输出投影
logits = self.output_projection(decoder_output)
return logits
def generate(self, src, max_len=50, start_symbol=1, end_symbol=2):
"""
自回归生成
Args:
src: [batch, src_len]
max_len: 最大生成长度
start_symbol: 起始符号ID
end_symbol: 结束符号ID
Returns:
generated: [batch, gen_len]
"""
batch_size = src.size(0)
device = src.device
# Encode一次
encoder_output = self.encode(src)
# 初始化Decoder输入(起始符号)
decoder_input = torch.full((batch_size, 1), start_symbol, dtype=torch.long, device=device)
# 自回归生成
for _ in range(max_len - 1):
# 创建causal mask
tgt_len = decoder_input.size(1)
tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=device))
tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)
# Decode
decoder_output = self.decode(decoder_input, encoder_output, tgt_mask=tgt_mask)
# 预测下一个token(只看最后一个位置)
logits = self.output_projection(decoder_output[:, -1:, :])
next_token = logits.argmax(dim=-1)
# 拼接到decoder_input
decoder_input = torch.cat([decoder_input, next_token], dim=1)
# 检查是否所有序列都生成了结束符号
if (next_token == end_symbol).all():
break
return decoder_input
# 使用示例
model = Transformer(
src_vocab_size=10000,
tgt_vocab_size=10000,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
max_seq_len=5000,
dropout=0.1
)
# 训练
src = torch.randint(0, 10000, (2, 20)) # batch_size=2, src_len=20
tgt = torch.randint(0, 10000, (2, 15)) # tgt_len=15
# 创建causal mask
tgt_len = tgt.size(1)
tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)
logits = model(src, tgt, tgt_mask=tgt_mask)
print(f"源序列: {src.shape}")
print(f"目标序列: {tgt.shape}")
print(f"输出logits: {logits.shape}")
# 推理
generated = model.generate(src, max_len=20)
print(f"生成序列: {generated.shape}")
4.4 完整训练代码(机器翻译)
# train_transformer.py - 完整的Transformer训练脚本
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import math
class TranslationDataset(Dataset):
"""翻译数据集"""
def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab,
src_tokenizer, tgt_tokenizer, max_len=100):
self.src_sentences = src_sentences
self.tgt_sentences = tgt_sentences
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.max_len = max_len
def __len__(self):
return len(self.src_sentences)
def __getitem__(self, idx):
src = self.src_sentences[idx]
tgt = self.tgt_sentences[idx]
# Tokenize
src_tokens = ['<sos>'] + self.src_tokenizer(src)[:self.max_len-2] + ['<eos>']
tgt_tokens = ['<sos>'] + self.tgt_tokenizer(tgt)[:self.max_len-2] + ['<eos>']
# 转换为ID
src_ids = [self.src_vocab[token] for token in src_tokens]
tgt_ids = [self.tgt_vocab[token] for token in tgt_tokens]
return torch.tensor(src_ids), torch.tensor(tgt_ids)
def collate_fn(batch):
"""批处理函数(添加padding)"""
src_batch, tgt_batch = zip(*batch)
# Padding
src_batch = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
tgt_batch = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=0)
return src_batch, tgt_batch
def create_masks(src, tgt, pad_idx=0):
"""创建masks"""
batch_size = src.size(0)
src_len = src.size(1)
tgt_len = tgt.size(1)
# Source mask(padding位置)
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2) # [batch, 1, 1, src_len]
# Target mask(causal + padding)
tgt_padding_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2) # [batch, 1, 1, tgt_len]
# Causal mask
tgt_causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)
tgt_causal_mask = tgt_causal_mask.to(tgt.device)
# 组合
tgt_mask = tgt_padding_mask & tgt_causal_mask
return src_mask, tgt_mask
def train_epoch(model, dataloader, optimizer, criterion, device):
"""训练一个epoch"""
model.train()
total_loss = 0
for batch_idx, (src, tgt) in enumerate(dataloader):
src = src.to(device)
tgt = tgt.to(device)
# Decoder输入(去掉最后一个token)
tgt_input = tgt[:, :-1]
# 标签(去掉第一个token)
tgt_output = tgt[:, 1:]
# 创建masks
src_mask, tgt_mask = create_masks(src, tgt_input)
# Forward
logits = model(src, tgt_input, src_mask, tgt_mask)
# 计算loss
loss = criterion(
logits.reshape(-1, logits.size(-1)),
tgt_output.reshape(-1)
)
# Backward
optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
return total_loss / len(dataloader)
def evaluate(model, dataloader, criterion, device):
"""评估"""
model.eval()
total_loss = 0
with torch.no_grad():
for src, tgt in dataloader:
src = src.to(device)
tgt = tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
src_mask, tgt_mask = create_masks(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask)
loss = criterion(
logits.reshape(-1, logits.size(-1)),
tgt_output.reshape(-1)
)
total_loss += loss.item()
return total_loss / len(dataloader)
def main():
# 超参数
d_model = 512
num_heads = 8
num_encoder_layers = 6
num_decoder_layers = 6
d_ff = 2048
dropout = 0.1
batch_size = 32
num_epochs = 10
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 准备数据(示例:英语->法语)
# 实际使用时从文件或数据集加载
src_sentences = [
"Hello, how are you?",
"I am fine, thank you.",
# ... 更多句子
]
tgt_sentences = [
"Bonjour, comment allez-vous?",
"Je vais bien, merci.",
# ...
]
# Tokenizer
src_tokenizer = get_tokenizer('basic_english')
tgt_tokenizer = get_tokenizer('basic_english') # 实际应使用法语tokenizer
# 构建词表
def yield_tokens(sentences, tokenizer):
for sentence in sentences:
yield tokenizer(sentence)
src_vocab = build_vocab_from_iterator(
yield_tokens(src_sentences, src_tokenizer),
specials=['<unk>', '<pad>', '<sos>', '<eos>'],
special_first=True
)
tgt_vocab = build_vocab_from_iterator(
yield_tokens(tgt_sentences, tgt_tokenizer),
specials=['<unk>', '<pad>', '<sos>', '<eos>'],
special_first=True
)
src_vocab.set_default_index(src_vocab['<unk>'])
tgt_vocab.set_default_index(tgt_vocab['<unk>'])
# 创建数据集
dataset = TranslationDataset(
src_sentences, tgt_sentences,
src_vocab, tgt_vocab,
src_tokenizer, tgt_tokenizer
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
# 创建模型
model = Transformer(
src_vocab_size=len(src_vocab),
tgt_vocab_size=len(tgt_vocab),
d_model=d_model,
num_heads=num_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ff=d_ff,
dropout=dropout
).to(device)
# 优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
# Label Smoothing
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
# 学习率调度器(Transformer原文的warmup策略)
def lr_lambda(step):
if step == 0:
step = 1
return (d_model ** -0.5) * min(step ** -0.5, step * (4000 ** -1.5))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 训练循环
best_loss = float('inf')
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
train_loss = train_epoch(model, dataloader, optimizer, criterion, device)
val_loss = evaluate(model, dataloader, criterion, device)
scheduler.step()
print(f'Train Loss: {train_loss:.4f}')
print(f'Val Loss: {val_loss:.4f}')
# 保存最佳模型
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), 'best_transformer.pt')
print('Model saved!')
if __name__ == '__main__':
main()
5. Transformer变体
5.1 BERT(Encoder-only)
class BERT(nn.Module):
"""BERT模型(只使用Encoder)"""
def __init__(
self,
vocab_size,
d_model=768,
num_heads=12,
num_layers=12,
d_ff=3072,
max_seq_len=512,
dropout=0.1
):
super().__init__()
self.d_model = d_model
# Token嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
# 位置嵌入(可学习)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# Segment嵌入(区分句子A和句子B)
self.segment_embedding = nn.Embedding(2, d_model)
self.dropout = nn.Dropout(dropout)
# Transformer Encoder
self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
# MLM头(Masked Language Modeling)
self.mlm_head = nn.Linear(d_model, vocab_size)
# NSP头(Next Sentence Prediction)
self.nsp_head = nn.Linear(d_model, 2)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
"""
Args:
input_ids: [batch, seq_len]
segment_ids: [batch, seq_len] (句子A为0,句子B为1)
attention_mask: [batch, seq_len]
Returns:
encoder_output: [batch, seq_len, d_model]
"""
batch_size, seq_len = input_ids.size()
# Token嵌入
token_emb = self.token_embedding(input_ids)
# 位置嵌入
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_emb = self.position_embedding(positions)
# Segment嵌入
if segment_ids is None:
segment_ids = torch.zeros_like(input_ids)
segment_emb = self.segment_embedding(segment_ids)
# 组合嵌入
embeddings = token_emb + position_emb + segment_emb
embeddings = self.dropout(embeddings)
# Attention mask
if attention_mask is not None:
# 扩展维度 [batch, 1, 1, seq_len]
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Encoder
encoder_output = self.encoder(embeddings, attention_mask)
return encoder_output
def get_mlm_logits(self, encoder_output):
"""获取MLM预测logits"""
return self.mlm_head(encoder_output)
def get_nsp_logits(self, encoder_output):
"""获取NSP预测logits(使用[CLS] token)"""
cls_output = encoder_output[:, 0, :] # [batch, d_model]
return self.nsp_head(cls_output)
# 使用示例
bert = BERT(vocab_size=30000, d_model=768, num_heads=12, num_layers=12)
# 输入
input_ids = torch.randint(0, 30000, (2, 128))
segment_ids = torch.cat([
torch.zeros(2, 64, dtype=torch.long),
torch.ones(2, 64, dtype=torch.long)
], dim=1)
# Forward
encoder_output = bert(input_ids, segment_ids)
mlm_logits = bert.get_mlm_logits(encoder_output)
nsp_logits = bert.get_nsp_logits(encoder_output)
print(f"Encoder输出: {encoder_output.shape}")
print(f"MLM logits: {mlm_logits.shape}")
print(f"NSP logits: {nsp_logits.shape}")
5.2 GPT(Decoder-only)
class GPT(nn.Module):
"""GPT模型(只使用Decoder,无Cross-Attention)"""
def __init__(
self,
vocab_size,
d_model=768,
num_heads=12,
num_layers=12,
d_ff=3072,
max_seq_len=1024,
dropout=0.1
):
super().__init__()
self.d_model = d_model
# Token嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
# 位置嵌入
self.position_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
# Transformer Decoder层(移除Cross-Attention)
self.layers = nn.ModuleList([
GPTBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.ln_f = LayerNorm(d_model)
# 语言模型头
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# 权重共享(嵌入和输出投影)
self.lm_head.weight = self.token_embedding.weight
def forward(self, input_ids):
"""
Args:
input_ids: [batch, seq_len]
Returns:
logits: [batch, seq_len, vocab_size]
"""
batch_size, seq_len = input_ids.size()
# Token嵌入
token_emb = self.token_embedding(input_ids)
# 位置嵌入
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_emb = self.position_embedding(positions)
# 组合
x = self.dropout(token_emb + position_emb)
# Causal mask
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# 通过所有层
for layer in self.layers:
x = layer(x, causal_mask)
# 最后的LayerNorm
x = self.ln_f(x)
# 语言模型头
logits = self.lm_head(x)
return logits
def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
"""
自回归生成
Args:
input_ids: [batch, seq_len]
max_new_tokens: 生成的最大token数
temperature: 温度参数
top_k: Top-K采样
Returns:
generated: [batch, seq_len + max_new_tokens]
"""
for _ in range(max_new_tokens):
# 截断到最大长度
input_ids_cond = input_ids[:, -self.position_embedding.num_embeddings:]
# Forward
logits = self(input_ids_cond)
# 只看最后一个位置
logits = logits[:, -1, :] / temperature
# Top-K采样
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Softmax
probs = F.softmax(logits, dim=-1)
# 采样
next_token = torch.multinomial(probs, num_samples=1)
# 拼接
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
class GPTBlock(nn.Module):
"""GPT块(Masked Self-Attention + FFN)"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ln2 = LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
def forward(self, x, mask):
# Pre-LN: Norm -> Attention -> Add
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x), mask)[0]
# Pre-LN: Norm -> FFN -> Add
x = x + self.ffn(self.ln2(x))
return x
# 使用示例
gpt = GPT(vocab_size=50257, d_model=768, num_heads=12, num_layers=12)
# 输入
input_ids = torch.randint(0, 50257, (2, 64))
# Forward
logits = gpt(input_ids)
print(f"输入: {input_ids.shape}")
print(f"输出logits: {logits.shape}")
# 生成
generated = gpt.generate(input_ids, max_new_tokens=50, temperature=0.8, top_k=50)
print(f"生成序列: {generated.shape}")
5.3 T5(Encoder-Decoder)
T5使用完整的Encoder-Decoder架构,与原始Transformer类似,但使用相对位置编码。
class RelativePositionBias(nn.Module):
"""T5相对位置偏置"""
def __init__(self, num_heads, max_distance=128):
super().__init__()
self.num_heads = num_heads
self.max_distance = max_distance
# 相对位置嵌入
self.relative_attention_bias = nn.Embedding(
2 * max_distance + 1,
num_heads
)
def forward(self, seq_len):
"""
计算相对位置偏置
Args:
seq_len: 序列长度
Returns:
bias: [1, num_heads, seq_len, seq_len]
"""
# 计算相对位置
positions = torch.arange(seq_len, device=self.relative_attention_bias.weight.device)
relative_positions = positions[None, :] - positions[:, None] # [seq_len, seq_len]
# 裁剪到[-max_distance, max_distance]
relative_positions = torch.clamp(
relative_positions,
-self.max_distance,
self.max_distance
)
# 偏移使索引从0开始
relative_positions += self.max_distance
# 获取偏置
bias = self.relative_attention_bias(relative_positions) # [seq_len, seq_len, num_heads]
# 转置为 [1, num_heads, seq_len, seq_len]
bias = bias.permute(2, 0, 1).unsqueeze(0)
return bias
# T5使用的是原始Transformer架构 + 相对位置编码
# 这里省略完整实现,关键区别在于:
# 1. 使用相对位置偏置而非绝对位置编码
# 2. LayerNorm放在残差连接之前(Pre-LN)
# 3. 使用RMSNorm而非LayerNorm
6. 效率优化
6.1 Flash Attention
"""
Flash Attention原理:
标准Attention:
1. 计算QK^T: O(N^2 × d)
2. Softmax: O(N^2)
3. 乘以V: O(N^2 × d)
显存: O(N^2) (存储attention matrix)
Flash Attention:
1. 分块计算,不存储完整attention matrix
2. 使用kernel fusion减少内存访问
3. 显存: O(N) (只存储输出)
4. 速度: 2-4x加速
"""
# 使用Flash Attention (需要安装flash-attn库)
try:
from flash_attn import flash_attn_func
class FlashMultiHeadAttention(nn.Module):
"""使用Flash Attention的多头注意力"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = dropout
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.size()
# QKV投影
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
# Flash Attention
output = flash_attn_func(
qkv,
dropout_p=self.dropout if self.training else 0.0,
causal=True if mask is not None else False
)
# 重塑
output = output.reshape(batch_size, seq_len, self.d_model)
# 输出投影
return self.out_proj(output)
except ImportError:
print("Flash Attention not installed")
6.2 Multi-Query Attention
class MultiQueryAttention(nn.Module):
"""
Multi-Query Attention (MQA)
与Multi-Head Attention的区别:
- MHA: 每个头有独立的K, V
- MQA: 所有头共享K, V,只有Q是多头的
优势:
- 减少KV cache大小(重要!)
- 推理速度更快
- 参数量更少
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Q是多头的
self.W_q = nn.Linear(d_model, d_model)
# K, V是单头的
self.W_k = nn.Linear(d_model, self.d_k)
self.W_v = nn.Linear(d_model, self.d_k)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.size()
# Q: 多头
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, num_heads, seq_len, d_k]
# K, V: 单头
K = self.W_k(x).unsqueeze(1) # [batch, 1, seq_len, d_k]
V = self.W_v(x).unsqueeze(1) # [batch, 1, seq_len, d_k]
# 广播K, V到所有头
K = K.expand(batch_size, self.num_heads, seq_len, self.d_k)
V = V.expand(batch_size, self.num_heads, seq_len, self.d_k)
# 标准Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
# 合并头
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
return self.W_o(context)
# KV cache大小比较
def compare_kv_cache_size():
"""比较MHA和MQA的KV cache大小"""
batch_size = 1
seq_len = 2048
d_model = 4096
num_heads = 32
d_k = d_model // num_heads # 128
# MHA: 每个头都有K, V
mha_kv_size = 2 * batch_size * num_heads * seq_len * d_k
# MQA: 只有一份K, V
mqa_kv_size = 2 * batch_size * 1 * seq_len * d_k
print(f"MHA KV cache: {mha_kv_size / 1024**2:.2f} MB")
print(f"MQA KV cache: {mqa_kv_size / 1024**2:.2f} MB")
print(f"节省: {(1 - mqa_kv_size/mha_kv_size)*100:.1f}%")
compare_kv_cache_size()
6.3 Group Query Attention
class GroupQueryAttention(nn.Module):
"""
Group Query Attention (GQA)
MHA和MQA的折中方案:
- 将num_heads个头分成num_groups组
- 每组共享K, V
例如:8个头分成2组
- Head 0-3: 共享KV_0
- Head 4-7: 共享KV_1
"""
def __init__(self, d_model, num_heads, num_groups, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
assert num_heads % num_groups == 0
self.d_model = d_model
self.num_heads = num_heads
self.num_groups = num_groups
self.heads_per_group = num_heads // num_groups
self.d_k = d_model // num_heads
# Q: 所有头
self.W_q = nn.Linear(d_model, d_model)
# K, V: 每组一个
self.W_k = nn.Linear(d_model, num_groups * self.d_k)
self.W_v = nn.Linear(d_model, num_groups * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.size()
# Q: [batch, num_heads, seq_len, d_k]
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
Q = Q.transpose(1, 2)
# K, V: [batch, num_groups, seq_len, d_k]
K = self.W_k(x).view(batch_size, seq_len, self.num_groups, self.d_k)
K = K.transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_groups, self.d_k)
V = V.transpose(1, 2)
# 扩展K, V到所有头
# [batch, num_groups, seq_len, d_k] -> [batch, num_heads, seq_len, d_k]
K = K.repeat_interleave(self.heads_per_group, dim=1)
V = V.repeat_interleave(self.heads_per_group, dim=1)
# 标准Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
# 合并头
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
return self.W_o(context)
# 使用示例
gqa = GroupQueryAttention(d_model=512, num_heads=8, num_groups=2)
x = torch.randn(2, 10, 512)
output = gqa(x)
print(f"输入: {x.shape}")
print(f"输出: {output.shape}")
# KV cache大小:
# MHA: 2 × batch × num_heads × seq_len × d_k
# GQA: 2 × batch × num_groups × seq_len × d_k (num_groups < num_heads)
# MQA: 2 × batch × 1 × seq_len × d_k