HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 完整学习路径

    • AI教程 - 从零到一的完整学习路径
    • 第00章:AI基础与发展史
    • 第01章:Python与AI开发环境
    • 第02章:数学基础-线性代数与微积分
    • 03-数据集详解-从获取到预处理
    • 04-从零训练第一个模型
    • 05-模型文件详解
    • 06-分布式训练-多GPU与多机
    • 07-模型调度与资源管理
    • 08-Transformer架构深度解析
    • 09-大语言模型原理与架构
    • 10-Token与Tokenization详解
    • 11-Prompt Engineering完全指南
    • 第12章:模型微调与LoRA技术
    • 第13章:RLHF与对齐技术
    • 第14章 AI编程助手原理与实现
    • 15-RAG系统设计与实现
    • 16-Agent智能体与工具调用
    • 17-多模态大模型
    • 第18章:AI前沿技术趋势
    • 第19章 AI热门话题与应用案例

09-大语言模型原理与架构

引言

大语言模型(Large Language Models, LLMs)是人工智能领域的重大突破,它们能够理解和生成人类语言,完成翻译、问答、代码生成等多种任务。本章将深入探讨LLM的发展历程、核心架构、预训练过程和推理机制。

1. LLM发展历程

1.1 GPT系列的演进

GPT-1 (2018年6月)

GPT-1(Generative Pre-trained Transformer)是OpenAI发布的第一个GPT模型,标志着预训练+微调范式的开端。

核心特点:

  • 参数量:117M(1.17亿)
  • 架构:12层Transformer Decoder
  • 预训练数据:BooksCorpus(约5GB文本)
  • 训练任务:语言模型(下一个词预测)
  • 创新点:通过无监督预训练学习通用语言表示,再在下游任务上微调

架构细节:

输入维度:768
注意力头数:12
前馈网络维度:3072
层数:12
词表大小:40,478
最大序列长度:512

训练过程:

  1. 预训练阶段:在BooksCorpus上进行语言建模
  2. 微调阶段:在具体任务(如文本分类、问答)上微调

GPT-2 (2019年2月)

GPT-2是GPT-1的放大版本,证明了"规模即一切"的假设。

核心特点:

  • 参数量:1.5B(15亿)- 最大版本
  • 架构:48层Transformer Decoder
  • 预训练数据:WebText(40GB文本,800万网页)
  • 词表大小:50,257(使用BPE)
  • 创新点:Zero-shot学习能力,无需微调即可完成多种任务

模型版本:

版本参数量层数隐藏维度注意力头数
Small117M1276812
Medium345M24102416
Large762M36128020
XL1.5B48160025

训练数据WebText:

  • 来源:Reddit上获得至少3个赞的外链
  • 过滤:去重、质量筛选
  • 规模:约40GB,800万文档

GPT-3 (2020年5月)

GPT-3是GPT系列的里程碑,展示了In-context Learning的强大能力。

核心特点:

  • 参数量:175B(1750亿)
  • 架构:96层Transformer Decoder
  • 预训练数据:Common Crawl、WebText2、Books1、Books2、Wikipedia(约570GB)
  • 上下文长度:2048 tokens
  • 创新点:Few-shot和Zero-shot能力显著提升,无需微调

架构配置(175B版本):

层数:96
隐藏维度:12288
注意力头数:96
头维度:128
前馈网络维度:49152
词表大小:50257

训练数据混合:

数据集权重Token数Epochs
Common Crawl (filtered)60%~410B0.44
WebText222%~19B2.9
Books18%~12B1.9
Books28%~55B0.43
Wikipedia3%~3B3.4

In-context Learning:

  • Zero-shot:仅提供任务描述
  • One-shot:提供1个示例
  • Few-shot:提供几个示例(通常5-10个)

示例:

# Few-shot示例
输入:
Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>

输出:
fromage

计算成本:

  • 训练时间:约3640 PetaFLOP/s-days
  • 成本估算:约460万美元(使用V100 GPU)
  • GPU数量:约1万块V100训练数月

GPT-3.5 和 ChatGPT (2022年11月)

ChatGPT基于GPT-3.5模型,引入了RLHF(Reinforcement Learning from Human Feedback)。

训练流程:

  1. 监督微调(SFT)

    • 雇佣标注员编写高质量对话
    • 在GPT-3基础上进行监督学习
    • 约13,000个对话样本
  2. 奖励模型训练(RM)

    • 标注员对多个回复进行排序
    • 训练奖励模型预测人类偏好
    • 约33,000个比较样本
  3. 近端策略优化(PPO)

    • 使用奖励模型优化策略
    • 平衡奖励最大化和KL散度
    • 迭代优化

RLHF公式:

目标函数:

maximize E[r(x, y)] - β * KL(π(y|x) || π_ref(y|x))

其中:
- r(x, y):奖励模型给出的分数
- π(y|x):当前策略(模型)
- π_ref(y|x):参考策略(SFT模型)
- β:KL惩罚系数
- KL:KL散度,防止模型偏离过远

InstructGPT vs ChatGPT:

  • InstructGPT:更适合单轮指令跟随
  • ChatGPT:优化多轮对话能力

GPT-4 (2023年3月)

GPT-4是OpenAI最强大的多模态模型。

核心特点:

  • 参数量:未公开(传闻1.76T,使用MoE)
  • 多模态:支持图像输入
  • 上下文长度:8K(标准版)、32K(扩展版)
  • 性能:在多项测试中接近人类水平

架构推测(未官方确认):

  • 可能使用Mixture of Experts(MoE)
  • 8个专家模型,每个约220B参数
  • 总参数1.76T,激活参数约280B

能力提升:

测试GPT-3.5GPT-4
Bar Exam10%90%
SAT Math70%89%
Leetcode (Easy)31%68%
Leetcode (Medium)21%40%
AP Biology62%85%

多模态能力:

# GPT-4 Vision示例(伪代码)
response = openai.ChatCompletion.create(
    model="gpt-4-vision-preview",
    messages=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "这张图片里有什么?"},
                {"type": "image_url", "image_url": "https://..."}
            ]
        }
    ]
)

1.2 BERT预训练范式

BERT(Bidirectional Encoder Representations from Transformers)由Google在2018年10月发布,采用Encoder架构。

核心创新

1. 双向编码

  • GPT使用单向(从左到右)Decoder
  • BERT使用双向Encoder,能同时看到上下文

2. 预训练任务

Masked Language Model (MLM):

原句:我 喜欢 吃 苹果
掩码:我 [MASK] 吃 苹果
目标:预测[MASK]位置的词是"喜欢"

实现细节:

  • 随机选择15%的tokens进行掩码
  • 其中80%替换为[MASK]
  • 10%替换为随机词
  • 10%保持不变

Next Sentence Prediction (NSP):

输入A:今天天气很好
输入B:我们去公园玩吧
标签:IsNext (正样本)

输入A:今天天气很好
输入B:量子力学是物理学分支
标签:NotNext (负样本)

模型架构

BERT-Base:

层数:12
隐藏维度:768
注意力头数:12
前馈网络维度:3072
参数量:110M
最大序列长度:512

BERT-Large:

层数:24
隐藏维度:1024
注意力头数:16
前馈网络维度:4096
参数量:340M
最大序列长度:512

输入表示

BERT的输入由三部分嵌入相加组成:

# BERT输入嵌入
def get_bert_input(token_ids, segment_ids, position_ids):
    token_embeddings = TokenEmbedding(token_ids)
    segment_embeddings = SegmentEmbedding(segment_ids)
    position_embeddings = PositionEmbedding(position_ids)

    return token_embeddings + segment_embeddings + position_embeddings

# 示例
输入:[CLS] 我 喜欢 AI [SEP] AI 很 有趣 [SEP]
Token IDs: [101, 2769, 1599, 2335, 102, 2335, 2523, 3300, 102]
Segment IDs: [0, 0, 0, 0, 0, 1, 1, 1, 1]
Position IDs: [0, 1, 2, 3, 4, 5, 6, 7, 8]

训练数据

  • BooksCorpus(800M words)
  • English Wikipedia(2500M words)
  • 总计:约33亿词

BERT的影响

BERT开创了"预训练+微调"的新范式,催生了大量变体:

  • RoBERTa:去除NSP,更大batch size和数据
  • ALBERT:参数共享,减少模型大小
  • ELECTRA:替换检测任务,训练效率更高
  • DeBERTa:解耦注意力,相对位置编码

1.3 开源模型生态

LLaMA系列(Meta)

LLaMA 1 (2023年2月):

Meta发布的开源基础模型,性能优异且高效。

模型规模:

模型参数量层数隐藏维度注意力头数学习率
LLaMA-7B7B324096323.0e-4
LLaMA-13B13B405120403.0e-4
LLaMA-33B33B606656521.5e-4
LLaMA-65B65B808192641.5e-4

训练数据(1.4T tokens):

数据集采样比例Epochs
CommonCrawl67%1.10
C415%1.06
Github4.5%0.64
Wikipedia4.5%2.45
Books4.5%2.23
ArXiv2.5%1.06
StackExchange2%1.03

架构创新:

  • Pre-normalization(GPT-3风格)
  • SwiGLU激活函数(替代ReLU)
  • Rotary Position Embedding(RoPE)
  • 去除绝对位置编码

LLaMA 2 (2023年7月):

增强版本,训练数据增加40%。

  • 训练数据:2T tokens
  • 上下文长度:4096 tokens
  • 模型:7B、13B、70B
  • 开源协议:商用友好

LLaMA 2-Chat: 经过RLHF优化的对话模型。

训练流程:

  1. 监督微调(SFT):27,540条高质量对话
  2. 奖励模型:100万条二元比较数据
  3. RLHF:5轮迭代优化

LLaMA 3 (2024年4月):

  • 训练数据:15T tokens
  • 上下文长度:8192 tokens
  • 词表大小:128,000(扩展)
  • 模型:8B、70B、405B

Mistral系列(Mistral AI)

Mistral 7B (2023年9月):

7B参数的高性能开源模型。

核心特性:

  • Grouped Query Attention(GQA)
  • Sliding Window Attention(SWA)
  • Rolling Buffer Cache

Sliding Window Attention:

传统注意力:每个token关注所有历史token
滑动窗口:每个token仅关注窗口内的token(如4096)

优势:
- 降低计算复杂度
- 支持更长序列
- 保持性能

Grouped Query Attention:

传统MHA:每个头都有独立的K、V
GQA:多个头共享K、V

例子(8头,2组):
Group 1: Q1, Q2, Q3, Q4 共享 K1, V1
Group 2: Q5, Q6, Q7, Q8 共享 K2, V2

优势:
- 减少KV Cache大小
- 加速推理
- 轻微性能损失

性能对比:

模型参数量MMLUHellaSwagArc-C
Mistral 7B7B62.5%83.3%59.4%
LLaMA 2 7B7B45.3%77.2%52.9%
LLaMA 2 13B13B54.8%79.2%58.0%

Mixtral 8x7B (2023年12月):

使用Mixture of Experts架构。

MoE架构:

总参数:46.7B
激活参数:12.9B(每次前向)

结构:
- 8个专家网络(每个7B)
- Top-2路由(每次激活2个专家)
- 路由器网络选择专家

前向过程:
输入 -> 路由器 -> 选择Top-2专家 -> 加权组合 -> 输出

路由器实现:

class MoELayer(nn.Module):
    def __init__(self, num_experts=8, d_model=4096):
        self.experts = nn.ModuleList([
            FeedForward(d_model) for _ in range(num_experts)
        ])
        self.router = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # 计算路由权重
        router_logits = self.router(x)  # [batch, seq, num_experts]
        router_probs = F.softmax(router_logits, dim=-1)

        # 选择Top-2专家
        top2_probs, top2_indices = torch.topk(router_probs, 2, dim=-1)

        # 归一化权重
        top2_probs = top2_probs / top2_probs.sum(dim=-1, keepdim=True)

        # 专家计算
        output = torch.zeros_like(x)
        for i in range(2):
            expert_idx = top2_indices[:, :, i]
            expert_prob = top2_probs[:, :, i:i+1]
            # 批量调用专家
            expert_output = self.experts[expert_idx](x)
            output += expert_prob * expert_output

        return output

Qwen系列(阿里巴巴)

Qwen 1.5 (2024年2月):

针对中文优化的开源模型。

模型规模:

  • 0.5B、1.8B、4B、7B、14B、72B

训练数据:

  • 总量:3T tokens
  • 中文比例:约40%
  • 代码比例:约10%

词表设计:

  • 词表大小:151,643
  • 中文token效率:比GPT-3.5高约30%

中文优化示例:

# GPT-3 Tokenizer
text = "我喜欢编程"
tokens = ["我", "喜", "欢", "编", "程"]  # 5个tokens

# Qwen Tokenizer
text = "我喜欢编程"
tokens = ["我", "喜欢", "编程"]  # 3个tokens(更高效)

Qwen-VL: 多模态版本,支持图像理解。

Qwen-Audio: 支持音频输入的模型。

其他重要开源模型

Falcon (TII):

  • 训练数据:1T-2T tokens
  • 高质量数据筛选
  • RefinedWeb数据集

MPT (MosaicML):

  • 商用友好许可
  • 高效训练框架
  • 支持长上下文(65K)

Baichuan (百川智能):

  • 中文优化
  • Baichuan-7B、13B
  • 开源商用

ChatGLM (智谱AI):

  • GLM架构(混合Encoder-Decoder)
  • ChatGLM-6B、ChatGLM2-6B、ChatGLM3-6B
  • 量化友好设计

2. LLM架构详解

2.1 Decoder-only架构

现代LLM(如GPT、LLaMA)主要采用Decoder-only架构,这是Transformer的简化版本。

为什么选择Decoder-only?

1. 自回归生成

  • 自然支持文本生成
  • 训练目标清晰(下一个token预测)

2. 统一架构

  • 同一模型支持理解和生成
  • 无需Encoder-Decoder的复杂交互

3. 扩展性好

  • 易于扩展到超大规模
  • 训练效率高

完整架构代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GPTConfig:
    """GPT模型配置"""
    def __init__(
        self,
        vocab_size=50257,      # 词表大小
        n_layer=12,            # 层数
        n_head=12,             # 注意力头数
        n_embd=768,            # 嵌入维度
        block_size=1024,       # 最大序列长度
        dropout=0.1,           # Dropout率
        bias=True,             # 是否使用bias
    ):
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.block_size = block_size
        self.dropout = dropout
        self.bias = bias

class CausalSelfAttention(nn.Module):
    """因果自注意力机制"""

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # Q、K、V的投影(合并为一个矩阵提高效率)
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)

        # 输出投影
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # Dropout
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        # 因果掩码(下三角矩阵)
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
            .view(1, 1, config.block_size, config.block_size)
        )

    def forward(self, x):
        B, T, C = x.size()  # batch_size, sequence_length, embedding_dim

        # 计算Q、K、V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        # 分离多头
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # 计算注意力分数
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        # 应用因果掩码
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))

        # Softmax
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)

        # 加权求和
        y = att @ v  # (B, nh, T, hs)

        # 合并多头
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # 输出投影
        y = self.resid_dropout(self.c_proj(y))

        return y

class MLP(nn.Module):
    """前馈网络"""

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    """Transformer块"""

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        # Pre-norm架构
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    """完整的GPT模型"""

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),  # token embedding
            wpe=nn.Embedding(config.block_size, config.n_embd),  # position embedding
            drop=nn.Dropout(config.dropout),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # 权重共享
        self.transformer.wte.weight = self.lm_head.weight

        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"序列长度{t}超过最大长度{self.config.block_size}"

        # 位置索引
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)  # (1, t)

        # 前向传播
        tok_emb = self.transformer.wte(idx)  # token embeddings (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos)  # position embeddings (1, t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)

        if targets is not None:
            # 训练模式:计算loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
        else:
            # 推理模式:只计算最后一个token的logits
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        生成文本

        Args:
            idx: (B, T) 输入token序列
            max_new_tokens: 生成的最大token数
            temperature: 温度参数,控制随机性
            top_k: Top-k采样
        """
        for _ in range(max_new_tokens):
            # 截断到最大长度
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            # 前向传播
            logits, _ = self(idx_cond)

            # 取最后一个token的logits
            logits = logits[:, -1, :] / temperature

            # Top-k采样
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Softmax转换为概率
            probs = F.softmax(logits, dim=-1)

            # 采样
            idx_next = torch.multinomial(probs, num_samples=1)

            # 拼接
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

# 使用示例
if __name__ == "__main__":
    # 创建GPT-2 Small配置
    config = GPTConfig(
        vocab_size=50257,
        n_layer=12,
        n_head=12,
        n_embd=768,
        block_size=1024,
    )

    model = GPT(config)

    # 统计参数量
    n_params = sum(p.numel() for p in model.parameters())
    print(f"参数量: {n_params/1e6:.2f}M")

    # 前向传播测试
    idx = torch.randint(0, config.vocab_size, (2, 128))  # batch=2, seq_len=128
    logits, loss = model(idx, targets=idx)
    print(f"Logits shape: {logits.shape}")

    # 生成测试
    generated = model.generate(idx[:, :10], max_new_tokens=50, temperature=0.8, top_k=50)
    print(f"Generated shape: {generated.shape}")

2.2 Causal Attention Mask

因果掩码是Decoder-only架构的核心,确保每个位置只能关注之前的位置。

掩码矩阵

# 创建因果掩码
def create_causal_mask(seq_len):
    """
    创建因果掩码

    对于序列长度为4的例子:
    [[1, 0, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 1]]

    1表示可以关注,0表示不能关注
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# 可视化
seq_len = 4
mask = create_causal_mask(seq_len)
print("Causal Mask:")
print(mask)

# 在注意力计算中应用
def apply_causal_mask(attention_scores, mask):
    """
    将掩码应用到注意力分数上

    Args:
        attention_scores: (B, H, T, T) 注意力分数
        mask: (T, T) 因果掩码
    """
    # 将0的位置设为负无穷,softmax后会变成0
    attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
    return attention_scores

注意力可视化

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens):
    """
    可视化注意力权重

    Args:
        attention_weights: (seq_len, seq_len) 注意力权重矩阵
        tokens: 序列的token列表
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Blues',
        annot=True,
        fmt='.2f'
    )
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.title('Causal Attention Weights')
    plt.tight_layout()
    plt.show()

# 示例
tokens = ["The", "cat", "sat", "on"]
seq_len = len(tokens)

# 模拟注意力权重(经过softmax后,被掩码的位置为0)
attention = torch.tril(torch.rand(seq_len, seq_len))
# 按行归一化
attention = attention / attention.sum(dim=-1, keepdim=True)

visualize_attention(attention.numpy(), tokens)

输出示例:

       The   cat   sat    on
The   1.00  0.00  0.00  0.00
cat   0.40  0.60  0.00  0.00
sat   0.20  0.35  0.45  0.00
on    0.15  0.25  0.30  0.30

2.3 词表和Tokenizer

BPE (Byte Pair Encoding)

GPT系列使用的分词算法。

算法原理:

  1. 初始化:将文本分割为字符
  2. 统计:统计相邻字符对的频率
  3. 合并:合并频率最高的字符对
  4. 迭代:重复步骤2-3,直到达到词表大小

完整实现:

import re
from collections import defaultdict, Counter

class BPETokenizer:
    """Byte Pair Encoding分词器"""

    def __init__(self, vocab_size=1000):
        self.vocab_size = vocab_size
        self.merges = {}  # 合并规则
        self.vocab = {}   # 词表

    def get_stats(self, words):
        """统计相邻字符对的频率"""
        pairs = defaultdict(int)
        for word, freq in words.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[symbols[i], symbols[i + 1]] += freq
        return pairs

    def merge_vocab(self, pair, words):
        """合并词表中的字符对"""
        new_words = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)

        for word in words:
            new_word = word.replace(bigram, replacement)
            new_words[new_word] = words[word]

        return new_words

    def train(self, texts):
        """训练BPE分词器"""
        # 1. 初始化:统计词频
        word_freqs = Counter()
        for text in texts:
            words = text.split()
            word_freqs.update(words)

        # 2. 将每个词拆分为字符,添加结束符
        vocab = {}
        for word, freq in word_freqs.items():
            vocab[' '.join(list(word)) + ' </w>'] = freq

        # 3. 初始词表:所有字符
        self.vocab = set()
        for word in vocab.keys():
            self.vocab.update(word.split())

        # 4. 迭代合并
        for i in range(self.vocab_size - len(self.vocab)):
            pairs = self.get_stats(vocab)
            if not pairs:
                break

            # 找频率最高的pair
            best_pair = max(pairs, key=pairs.get)

            # 合并
            vocab = self.merge_vocab(best_pair, vocab)

            # 记录合并规则
            self.merges[best_pair] = i
            self.vocab.add(''.join(best_pair))

            if (i + 1) % 100 == 0:
                print(f"合并 {i + 1}: {best_pair} -> {''.join(best_pair)}")

        # 5. 构建最终词表
        self.vocab = {token: i for i, token in enumerate(sorted(self.vocab))}

        print(f"训练完成!词表大小: {len(self.vocab)}")

    def tokenize(self, text):
        """对文本进行分词"""
        tokens = []
        words = text.split()

        for word in words:
            # 将词拆分为字符
            word_tokens = list(word) + ['</w>']

            # 应用合并规则
            while len(word_tokens) > 1:
                pairs = [(word_tokens[i], word_tokens[i + 1])
                        for i in range(len(word_tokens) - 1)]

                # 找出可以合并的pair(在merges中且优先级最高)
                mergeable_pairs = [p for p in pairs if p in self.merges]
                if not mergeable_pairs:
                    break

                # 选择优先级最高的(最早学到的)
                best_pair = min(mergeable_pairs, key=lambda x: self.merges[x])

                # 执行合并
                new_tokens = []
                i = 0
                while i < len(word_tokens):
                    if i < len(word_tokens) - 1 and \
                       (word_tokens[i], word_tokens[i + 1]) == best_pair:
                        new_tokens.append(word_tokens[i] + word_tokens[i + 1])
                        i += 2
                    else:
                        new_tokens.append(word_tokens[i])
                        i += 1

                word_tokens = new_tokens

            tokens.extend(word_tokens)

        return tokens

    def encode(self, text):
        """将文本编码为token IDs"""
        tokens = self.tokenize(text)
        return [self.vocab.get(token, self.vocab.get('<unk>', 0)) for token in tokens]

    def decode(self, token_ids):
        """将token IDs解码为文本"""
        id_to_token = {i: token for token, i in self.vocab.items()}
        tokens = [id_to_token.get(i, '<unk>') for i in token_ids]
        text = ''.join(tokens).replace('</w>', ' ').strip()
        return text

# 使用示例
if __name__ == "__main__":
    # 训练数据
    texts = [
        "low lower lowest",
        "new newer newest",
        "wide wider widest",
    ] * 100  # 重复以增加频率

    # 训练分词器
    tokenizer = BPETokenizer(vocab_size=100)
    tokenizer.train(texts)

    # 测试分词
    test_text = "lower newest"
    tokens = tokenizer.tokenize(test_text)
    print(f"\n文本: {test_text}")
    print(f"分词: {tokens}")

    # 测试编码解码
    encoded = tokenizer.encode(test_text)
    print(f"编码: {encoded}")
    decoded = tokenizer.decode(encoded)
    print(f"解码: {decoded}")

WordPiece

BERT使用的分词算法,与BPE类似但合并策略不同。

核心区别:

  • BPE:选择频率最高的pair合并
  • WordPiece:选择使语言模型似然最大的pair合并

公式:

score(pair) = freq(pair) / (freq(first) * freq(second))

实现:

class WordPieceTokenizer:
    """WordPiece分词器"""

    def __init__(self, vocab_size=1000):
        self.vocab_size = vocab_size
        self.vocab = {}

    def get_pair_scores(self, words):
        """计算字符对的分数(似然增益)"""
        pair_freqs = defaultdict(int)
        token_freqs = defaultdict(int)

        for word, freq in words.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pair_freqs[(symbols[i], symbols[i + 1])] += freq
            for symbol in symbols:
                token_freqs[symbol] += freq

        # 计算分数
        scores = {}
        for pair, freq in pair_freqs.items():
            first, second = pair
            scores[pair] = freq / (token_freqs[first] * token_freqs[second])

        return scores

    def train(self, texts):
        """训练WordPiece分词器"""
        # 初始化词表
        word_freqs = Counter()
        for text in texts:
            words = text.split()
            word_freqs.update(words)

        vocab = {}
        for word, freq in word_freqs.items():
            # 除了第一个字符,其他都加##前缀
            tokens = [word[0]] + [f'##{c}' for c in word[1:]]
            vocab[' '.join(tokens)] = freq

        # 初始词表
        self.vocab = set()
        for word in vocab.keys():
            self.vocab.update(word.split())

        # 迭代合并
        for i in range(self.vocab_size - len(self.vocab)):
            scores = self.get_pair_scores(vocab)
            if not scores:
                break

            # 选择分数最高的pair
            best_pair = max(scores, key=scores.get)

            # 合并
            vocab = self.merge_vocab(best_pair, vocab)
            self.vocab.add(''.join(best_pair))

            if (i + 1) % 100 == 0:
                print(f"合并 {i + 1}: {best_pair}, score={scores[best_pair]:.6f}")

        self.vocab = {token: i for i, token in enumerate(sorted(self.vocab))}
        print(f"训练完成!词表大小: {len(self.vocab)}")

    def merge_vocab(self, pair, words):
        """合并词表中的字符对"""
        new_words = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)

        for word in words:
            new_word = word.replace(bigram, replacement)
            new_words[new_word] = words[word]

        return new_words

    def tokenize(self, text):
        """对文本进行分词"""
        tokens = []
        for word in text.split():
            # 贪心最长匹配
            word_tokens = []
            start = 0

            while start < len(word):
                end = len(word)
                found = False

                while start < end:
                    substr = word[start:end]
                    if start > 0:
                        substr = f'##{substr}'

                    if substr in self.vocab:
                        word_tokens.append(substr)
                        start = end
                        found = True
                        break

                    end -= 1

                if not found:
                    # 未知字符
                    word_tokens.append('[UNK]')
                    start += 1

            tokens.extend(word_tokens)

        return tokens

SentencePiece

LLaMA等模型使用的分词器,直接在原始文本上训练,无需预分词。

特点:

  • 语言无关(不依赖空格分词)
  • 支持多种算法(BPE、Unigram)
  • 可逆(能完美还原原文,包括空格)

使用示例:

import sentencepiece as spm

# 训练SentencePiece模型
spm.SentencePieceTrainer.train(
    input='corpus.txt',
    model_prefix='tokenizer',
    vocab_size=32000,
    character_coverage=0.9995,
    model_type='bpe',  # 或'unigram'
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
)

# 加载模型
sp = spm.SentencePieceProcessor()
sp.load('tokenizer.model')

# 编码
text = "Hello, world! 你好世界!"
tokens = sp.encode_as_pieces(text)
ids = sp.encode_as_ids(text)

print(f"Text: {text}")
print(f"Tokens: {tokens}")
print(f"IDs: {ids}")

# 解码
decoded = sp.decode_ids(ids)
print(f"Decoded: {decoded}")
assert text == decoded  # 完美还原

2.4 Position Embedding

位置编码让模型能够理解token的位置信息。

绝对位置编码(Sinusoidal)

原始Transformer使用的方法。

公式:

PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))

其中:
- pos: 位置
- i: 维度索引
- d: 嵌入维度

实现:

def get_sinusoidal_encoding(seq_len, d_model):
    """
    生成正弦位置编码

    Args:
        seq_len: 序列长度
        d_model: 嵌入维度

    Returns:
        (seq_len, d_model) 的位置编码矩阵
    """
    position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
    )

    pe = torch.zeros(seq_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe

# 可视化
import matplotlib.pyplot as plt

pe = get_sinusoidal_encoding(100, 512)
plt.figure(figsize=(15, 5))
plt.imshow(pe, cmap='RdBu', aspect='auto')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.title('Sinusoidal Position Encoding')
plt.colorbar()
plt.show()

可学习位置编码

GPT系列使用的方法。

class LearnedPositionEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return self.pos_embedding(positions)

RoPE (Rotary Position Embedding)

LLaMA、GLM等模型使用的相对位置编码。

核心思想: 通过旋转矩阵编码相对位置信息。

优势:

  • 自然支持相对位置
  • 可外推到更长序列
  • 计算效率高

实现:

def precompute_freqs_cis(dim, max_seq_len, theta=10000.0):
    """
    预计算RoPE的旋转频率

    Args:
        dim: 头维度
        max_seq_len: 最大序列长度
        theta: 基础频率
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(max_seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # 复数形式
    return freqs_cis

def reshape_for_broadcast(freqs_cis, x):
    """调整形状以便广播"""
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq, xk, freqs_cis):
    """
    应用RoPE到Q和K

    Args:
        xq: Query (batch, seq_len, n_heads, head_dim)
        xk: Key (batch, seq_len, n_heads, head_dim)
        freqs_cis: 旋转频率
    """
    # 将实数转换为复数
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 应用旋转
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

# 使用示例
class RoPEAttention(nn.Module):
    def __init__(self, n_heads, head_dim, max_seq_len):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = head_dim

        # 预计算旋转频率
        self.register_buffer(
            "freqs_cis",
            precompute_freqs_cis(head_dim, max_seq_len)
        )

    def forward(self, q, k, v):
        # q, k, v: (batch, seq_len, n_heads, head_dim)
        seq_len = q.shape[1]

        # 应用RoPE
        q, k = apply_rotary_emb(q, k, self.freqs_cis[:seq_len])

        # 计算注意力
        # ... (标准注意力计算)

        return output

RoPE的数学原理:

对于位置m和n的两个向量,RoPE通过旋转使得它们的内积仅依赖于相对位置(m-n):

<RoPE(q_m), RoPE(k_n)> = <R_m * q_m, R_n * k_n>
                        = <R_(m-n) * q_m, k_n>

这样模型自然学会相对位置关系。

ALiBi (Attention with Linear Biases)

在注意力分数上添加线性偏置。

公式:

attention_score(q_i, k_j) = q_i · k_j - m * (i - j)

其中:
- m: 每个头的斜率(不同头使用不同斜率)
- (i - j): 相对距离

实现:

class ALiBiAttention(nn.Module):
    def __init__(self, n_heads, max_seq_len):
        super().__init__()
        self.n_heads = n_heads

        # 为每个头计算斜率
        slopes = torch.Tensor(self._get_slopes(n_heads))

        # 创建偏置矩阵
        alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(n_heads, -1, -1)
        alibi = alibi.view(n_heads, 1, max_seq_len)

        self.register_buffer('alibi', alibi)

    def _get_slopes(self, n_heads):
        """计算每个头的斜率"""
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(n_heads).is_integer():
            return get_slopes_power_of_2(n_heads)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
            return (
                get_slopes_power_of_2(closest_power_of_2) +
                self._get_slopes(2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
            )

    def forward(self, q, k, v):
        # q, k: (batch, n_heads, seq_len, head_dim)
        seq_len = q.size(2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))

        # 添加ALiBi偏置
        scores += self.alibi[:, :, :seq_len, :seq_len]

        # Softmax
        attn = F.softmax(scores, dim=-1)

        # 加权求和
        output = torch.matmul(attn, v)

        return output

# 可视化不同头的斜率
n_heads = 8
slopes = ALiBiAttention(n_heads, 512)._get_slopes(n_heads)
print(f"Slopes for {n_heads} heads: {slopes}")

优势:

  • 无需位置嵌入
  • 外推性好(可推广到更长序列)
  • 训练和推理效率高

3. 预训练过程

3.1 Next Token Prediction任务

LLM的核心训练目标是预测下一个token。

训练目标(交叉熵损失):

L = -∑∑ log P(x_t | x_1, ..., x_{t-1})

其中:
- x_t: 第t个token
- P(x_t | x_1, ..., x_{t-1}): 给定前文预测x_t的概率

实现:

def compute_loss(model, batch):
    """
    计算语言模型损失

    Args:
        model: GPT模型
        batch: (batch_size, seq_len+1) 的token序列
    """
    # 输入和目标
    inputs = batch[:, :-1]  # 前seq_len个token
    targets = batch[:, 1:]  # 后seq_len个token(向右移一位)

    # 前向传播
    logits, loss = model(inputs, targets=targets)

    return loss

# 训练循环
def train_step(model, optimizer, batch):
    model.train()
    optimizer.zero_grad()

    loss = compute_loss(model, batch)

    loss.backward()

    # 梯度裁剪
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()

    return loss.item()

示例:

输入序列: "The cat sat on the"
模型预测: cat  sat  on  the  mat
实际目标: cat  sat  on  the  mat

Loss = CrossEntropy([logits_1, ..., logits_5], [cat, sat, on, the, mat])

3.2 训练数据规模和来源

数据规模演进

模型训练数据量数据来源
GPT-15GBBooksCorpus
GPT-240GBWebText
GPT-3570GBCommon Crawl, Books, Wikipedia
LLaMA1.4T tokens多源混合
LLaMA 22T tokens公开数据
LLaMA 315T tokens高质量筛选数据

数据来源

1. Common Crawl

  • 网页爬虫数据
  • 规模巨大(PB级)
  • 需要大量清洗

清洗流程:

def clean_common_crawl(text):
    """清洗Common Crawl数据"""
    # 1. 去除HTML标签
    text = remove_html_tags(text)

    # 2. 语言过滤
    if not is_english(text):
        return None

    # 3. 质量过滤
    if len(text.split()) < 50:  # 太短
        return None

    if profanity_score(text) > 0.5:  # 低俗内容
        return None

    # 4. 去重
    if is_duplicate(text):
        return None

    # 5. 个人信息脱敏
    text = remove_pii(text)

    return text

2. 书籍数据

  • Books1, Books2
  • 高质量长文本
  • 版权问题(通常未公开)

3. Wikipedia

  • 高质量百科知识
  • 多语言覆盖
  • 定期更新

4. 代码数据

  • GitHub
  • StackOverflow
  • 提升代码能力

5. 学术论文

  • ArXiv
  • PubMed
  • 提升专业知识

数据混合策略

不同数据源按比例混合:

class DataMixer:
    """数据混合器"""

    def __init__(self, datasets, weights):
        """
        Args:
            datasets: 数据集列表
            weights: 采样权重
        """
        self.datasets = datasets
        self.weights = np.array(weights)
        self.weights = self.weights / self.weights.sum()

    def sample_batch(self, batch_size):
        """采样一个批次"""
        # 根据权重选择数据集
        dataset_idx = np.random.choice(len(self.datasets), p=self.weights)
        dataset = self.datasets[dataset_idx]

        # 从选中的数据集采样
        return dataset.sample(batch_size)

# 使用示例
mixer = DataMixer(
    datasets=[common_crawl, books, wikipedia, code, arxiv],
    weights=[0.60, 0.16, 0.03, 0.045, 0.025]
)

batch = mixer.sample_batch(batch_size=32)

3.3 计算量估算(FLOPs)

FLOPs计算公式

前向传播的FLOPs:

对于单个样本,序列长度为L,词表大小为V:

1. Token Embedding: 忽略(查表操作)

2. Self-Attention (每层):
   - QKV投影: 3 * (2 * L * d * d) = 6Ld²
   - 注意力计算: 2 * L² * d
   - 输出投影: 2 * L * d * d = 2Ld²
   小计: 8Ld² + 2L²d

3. FFN (每层):
   - 上投影: 2 * L * d * 4d = 8Ld²
   - 下投影: 2 * L * 4d * d = 8Ld²
   小计: 16Ld²

4. 每层总计: 24Ld² + 2L²d

5. N层: N * (24Ld² + 2L²d)

6. 输出层: 2 * L * d * V

总FLOPs ≈ N * (24Ld² + 2L²d) + 2LdV

简化(忽略L²项,因为d >> L):
FLOPs ≈ 2 * N * L * d² * (12 + V/Nd)

训练的总FLOPs:

训练FLOPs = 前向FLOPs * 3 * 训练tokens

其中乘以3是因为:
- 前向传播: 1x
- 反向传播: 2x (约等于)

实际计算:

def estimate_training_flops(
    n_params,      # 参数量
    n_tokens,      # 训练token数
    n_layers,      # 层数
    d_model,       # 隐藏维度
    seq_len,       # 序列长度
    vocab_size,    # 词表大小
):
    """
    估算训练FLOPs

    Returns:
        总FLOPs数
    """
    # 每个token的前向FLOPs
    flops_per_token = 2 * n_params + 2 * n_layers * seq_len * d_model

    # 训练FLOPs(包含反向传播)
    training_flops = 3 * flops_per_token * n_tokens

    return training_flops

# GPT-3 175B的计算量
n_params = 175e9
n_tokens = 300e9
n_layers = 96
d_model = 12288
seq_len = 2048
vocab_size = 50257

flops = estimate_training_flops(n_params, n_tokens, n_layers, d_model, seq_len, vocab_size)

print(f"GPT-3训练FLOPs: {flops:.2e}")
print(f"GPT-3训练FLOPs: {flops / 1e21:.2f} ZFLOPs")

# 输出: GPT-3训练FLOPs: 3.14e+23 (约314 ZFLOPs)

Chinchilla缩放定律

Chinchilla论文提出:给定计算预算C,最优的模型大小N和训练数据量D满足:

N_opt ∝ C^a
D_opt ∝ C^b

其中 a ≈ 0.5, b ≈ 0.5

即:模型参数量和训练数据量应以相同速度增长

启示:

  • GPT-3可能训练不足(175B参数,300B tokens)
  • Chinchilla:70B参数,1.4T tokens,性能更好
  • LLaMA 2:70B参数,2T tokens

3.4 训练时间和成本

硬件配置

常用GPU:

GPU内存FP16性能价格(云)
A100 80GB80GB312 TFLOPS$3/小时
H100 80GB80GB1000 TFLOPS$5/小时
V100 32GB32GB125 TFLOPS$2/小时

训练时间估算

def estimate_training_time(
    total_flops,        # 总计算量
    gpu_flops,          # 单GPU性能(FLOPS)
    n_gpus,             # GPU数量
    mfu=0.5,            # 模型FLOPs利用率
):
    """
    估算训练时间

    Args:
        total_flops: 总FLOPs
        gpu_flops: 单GPU的FP16 FLOPS
        n_gpus: GPU数量
        mfu: 模型FLOPs利用率(通常30%-50%)

    Returns:
        训练时间(小时)
    """
    effective_flops = gpu_flops * n_gpus * mfu
    training_seconds = total_flops / effective_flops
    training_hours = training_seconds / 3600

    return training_hours

# LLaMA 65B训练时间
total_flops = 6.3e23  # 约630 ZFLOPs
gpu_flops = 312e12    # A100的FP16性能
n_gpus = 2048         # 使用的GPU数量
mfu = 0.4             # 实际利用率

hours = estimate_training_time(total_flops, gpu_flops, n_gpus, mfu)
days = hours / 24

print(f"LLaMA 65B训练时间:")
print(f"  {hours:.0f} 小时")
print(f"  {days:.0f} 天")
print(f"使用 {n_gpus} 张A100 GPU")

# 输出: 约21天,使用2048张A100

成本估算

def estimate_training_cost(
    training_hours,
    n_gpus,
    gpu_cost_per_hour=3.0,
    power_cost_per_kwh=0.1,
    gpu_power_kw=0.4,
    overhead=0.3,
):
    """
    估算训练成本

    Args:
        training_hours: 训练时间(小时)
        n_gpus: GPU数量
        gpu_cost_per_hour: GPU租用成本($/小时)
        power_cost_per_kwh: 电费($/kWh)
        gpu_power_kw: 单GPU功耗(kW)
        overhead: 其他成本占比

    Returns:
        总成本(美元)
    """
    # GPU租用成本
    gpu_cost = training_hours * n_gpus * gpu_cost_per_hour

    # 电费
    power_cost = training_hours * n_gpus * gpu_power_kw * power_cost_per_kwh

    # 总成本(含overhead)
    total_cost = (gpu_cost + power_cost) * (1 + overhead)

    return {
        'gpu_cost': gpu_cost,
        'power_cost': power_cost,
        'total_cost': total_cost,
    }

# LLaMA 65B成本
costs = estimate_training_cost(
    training_hours=21 * 24,
    n_gpus=2048,
    gpu_cost_per_hour=3.0,
)

print(f"LLaMA 65B训练成本:")
print(f"  GPU租用: ${costs['gpu_cost']/1e6:.2f}M")
print(f"  电费: ${costs['power_cost']/1e6:.2f}M")
print(f"  总成本: ${costs['total_cost']/1e6:.2f}M")

# 输出: 约$4M-$5M

实际案例:

模型参数量GPUGPU数时间成本
GPT-3 175B175BV100~10000~3月~$5M
LLaMA 65B65BA1002048~21天~$4M
LLaMA 2 70B70BA100不详不详~$5M
Mistral 7B7BA100不详不详~$0.5M

4. 模型规模

4.1 参数量、层数、隐藏维度

模型缩放规律

参数量 ≈ 12 * n_layers * d_model²

其中:
- n_layers: 层数
- d_model: 隐藏维度

详细分解(每层):
- QKV投影: 3 * d_model * d_model = 3d²
- 注意力输出: d_model * d_model = d²
- FFN: 2 * d_model * (4 * d_model) = 8d²
小计: 12d²

实际计算:

def calculate_params(n_layers, d_model, vocab_size):
    """
    计算模型参数量

    Args:
        n_layers: 层数
        d_model: 隐藏维度
        vocab_size: 词表大小

    Returns:
        总参数量
    """
    # Transformer层参数
    transformer_params = n_layers * 12 * d_model ** 2

    # Embedding层(token + position)
    embedding_params = 2 * vocab_size * d_model

    # 输出层(通常与token embedding共享)
    # output_params = vocab_size * d_model

    # 总参数(不重复计算共享的embedding)
    total_params = transformer_params + embedding_params

    return total_params

# 验证GPT-2模型配置
configs = [
    ("GPT-2 Small", 12, 768, 50257),
    ("GPT-2 Medium", 24, 1024, 50257),
    ("GPT-2 Large", 36, 1280, 50257),
    ("GPT-2 XL", 48, 1600, 50257),
]

for name, n_layers, d_model, vocab_size in configs:
    params = calculate_params(n_layers, d_model, vocab_size)
    print(f"{name}: {params/1e6:.1f}M 参数")

# 输出:
# GPT-2 Small: 124.4M 参数
# GPT-2 Medium: 354.8M 参数
# GPT-2 Large: 774.0M 参数
# GPT-2 XL: 1542.0M 参数

4.2 7B、13B、70B模型对比

常见模型配置

model_configs = {
    "7B": {
        "n_layers": 32,
        "d_model": 4096,
        "n_heads": 32,
        "ffn_dim": 11008,  # SwiGLU: ~2.7 * d_model
    },
    "13B": {
        "n_layers": 40,
        "d_model": 5120,
        "n_heads": 40,
        "ffn_dim": 13824,
    },
    "70B": {
        "n_layers": 80,
        "d_model": 8192,
        "n_heads": 64,
        "ffn_dim": 28672,
    },
}

def print_model_stats(name, config):
    """打印模型统计信息"""
    n_layers = config["n_layers"]
    d_model = config["d_model"]
    n_heads = config["n_heads"]
    ffn_dim = config["ffn_dim"]

    # 计算参数量
    attn_params = n_layers * (4 * d_model ** 2)  # QKV + output
    ffn_params = n_layers * (2 * d_model * ffn_dim)  # up + down
    total_params = attn_params + ffn_params

    # 激活值内存(batch=1, seq=2048, dtype=fp16)
    seq_len = 2048
    activation_memory = (
        seq_len * d_model * 2 +  # token + pos embedding
        n_layers * seq_len * d_model * 2 * 4  # 每层4份激活值
    ) * 2 / 1024**3  # 转换为GB (fp16)

    print(f"\n{name}:")
    print(f"  层数: {n_layers}")
    print(f"  隐藏维度: {d_model}")
    print(f"  注意力头数: {n_heads}")
    print(f"  每头维度: {d_model // n_heads}")
    print(f"  FFN维度: {ffn_dim}")
    print(f"  参数量: {total_params/1e9:.1f}B")
    print(f"  激活内存: {activation_memory:.2f}GB")

for name, config in model_configs.items():
    print_model_stats(name, config)

输出:

7B:
  层数: 32
  隐藏维度: 4096
  注意力头数: 32
  每头维度: 128
  FFN维度: 11008
  参数量: 6.7B
  激活内存: 2.1GB

13B:
  层数: 40
  隐藏维度: 5120
  注意力头数: 40
  每头维度: 128
  FFN维度: 13824
  参数量: 13.0B
  激活内存: 3.3GB

70B:
  层数: 80
  隐藏维度: 8192
  注意力头数: 64
  每头维度: 128
  FFN维度: 28672
  参数量: 65.2B
  激活内存: 10.5GB

性能对比

模型参数量MMLUHellaSwagHumanEval推理速度
7B7B35.1%76.1%12.2%快
13B13B46.9%79.2%18.3%中
70B70B69.8%87.3%29.9%慢

选择建议:

  • 7B: 适合资源受限场景,快速响应
  • 13B: 性能和效率平衡
  • 70B: 追求最佳性能

4.3 显存占用计算公式

模型权重

权重内存 = 参数量 * 每参数字节数

数据类型:
- FP32: 4 bytes
- FP16/BF16: 2 bytes
- INT8: 1 byte
- INT4: 0.5 bytes

示例:

def calculate_weight_memory(n_params, dtype='fp16'):
    """计算权重占用的显存"""
    bytes_per_param = {
        'fp32': 4,
        'fp16': 2,
        'bf16': 2,
        'int8': 1,
        'int4': 0.5,
    }

    memory_bytes = n_params * bytes_per_param[dtype]
    memory_gb = memory_bytes / (1024 ** 3)

    return memory_gb

# LLaMA 7B
params_7b = 7e9
print(f"7B模型权重:")
print(f"  FP32: {calculate_weight_memory(params_7b, 'fp32'):.2f}GB")
print(f"  FP16: {calculate_weight_memory(params_7b, 'fp16'):.2f}GB")
print(f"  INT8: {calculate_weight_memory(params_7b, 'int8'):.2f}GB")
print(f"  INT4: {calculate_weight_memory(params_7b, 'int4'):.2f}GB")

# 输出:
# FP32: 26.01GB
# FP16: 13.01GB
# INT8: 6.50GB
# INT4: 3.25GB

训练显存

训练时需要存储:

  1. 模型权重 (W)
  2. 梯度 (G): 与权重相同大小
  3. 优化器状态 (O): Adam需要2份(momentum + variance)
  4. 激活值 (A): 用于反向传播

公式(Adam优化器,混合精度训练):

总显存 = 模型权重(FP16) + 梯度(FP16) + 优化器状态(FP32*2) + 激活值

具体:
- 模型权重: 2 * P bytes
- 梯度: 2 * P bytes
- 优化器状态: 8 * P bytes (FP32的momentum和variance)
- 激活值: 与batch size和序列长度相关

总计: 12 * P bytes(不含激活值)

实现:

def calculate_training_memory(
    n_params,
    batch_size,
    seq_len,
    n_layers,
    d_model,
    mixed_precision=True,
):
    """
    计算训练显存

    Args:
        n_params: 参数量
        batch_size: 批次大小
        seq_len: 序列长度
        n_layers: 层数
        d_model: 隐藏维度
        mixed_precision: 是否使用混合精度
    """
    # 模型权重 + 梯度 + 优化器状态
    if mixed_precision:
        # 权重FP16: 2P, 梯度FP16: 2P, 优化器FP32: 8P
        model_memory = 12 * n_params
    else:
        # 全FP32: 4P + 4P + 8P = 16P
        model_memory = 16 * n_params

    # 激活值(每层保存)
    # 包括:注意力中间结果、FFN中间结果等
    activation_per_layer = batch_size * seq_len * d_model * (
        2 +  # QK, softmax(QK)
        2 +  # attention output, residual
        4    # FFN (2个中间层 * 2)
    )

    if mixed_precision:
        activation_memory = activation_per_layer * n_layers * 2  # FP16
    else:
        activation_memory = activation_per_layer * n_layers * 4  # FP32

    # 总显存
    total_memory = model_memory + activation_memory
    total_memory_gb = total_memory / (1024 ** 3)

    return {
        'model_gb': model_memory / (1024 ** 3),
        'activation_gb': activation_memory / (1024 ** 3),
        'total_gb': total_memory_gb,
    }

# LLaMA 7B训练显存
memory = calculate_training_memory(
    n_params=7e9,
    batch_size=4,
    seq_len=2048,
    n_layers=32,
    d_model=4096,
    mixed_precision=True,
)

print(f"7B模型训练显存 (batch=4, seq=2048):")
print(f"  模型+梯度+优化器: {memory['model_gb']:.2f}GB")
print(f"  激活值: {memory['activation_gb']:.2f}GB")
print(f"  总计: {memory['total_gb']:.2f}GB")

# 输出示例: 总计约120GB

推理显存

推理时只需要:

  1. 模型权重
  2. KV Cache(存储已计算的K和V)
  3. 当前输入的激活值

公式:

推理显存 = 模型权重 + KV Cache + 临时激活值

KV Cache = 2 * n_layers * batch * seq_len * d_model * bytes_per_element

其中2是因为K和V各一份

实现:

def calculate_inference_memory(
    n_params,
    batch_size,
    max_seq_len,
    n_layers,
    d_model,
    dtype='fp16',
):
    """计算推理显存"""
    bytes_per_elem = {'fp32': 4, 'fp16': 2, 'bf16': 2, 'int8': 1}[dtype]

    # 模型权重
    model_memory = n_params * bytes_per_elem

    # KV Cache
    kv_cache_memory = (
        2 *  # K和V
        n_layers *
        batch_size *
        max_seq_len *
        d_model *
        bytes_per_elem
    )

    # 临时激活值(只需一层)
    activation_memory = batch_size * max_seq_len * d_model * bytes_per_elem * 4

    total_memory = model_memory + kv_cache_memory + activation_memory

    return {
        'model_gb': model_memory / (1024 ** 3),
        'kv_cache_gb': kv_cache_memory / (1024 ** 3),
        'activation_gb': activation_memory / (1024 ** 3),
        'total_gb': total_memory / (1024 ** 3),
    }

# LLaMA 7B推理显存
memory = calculate_inference_memory(
    n_params=7e9,
    batch_size=1,
    max_seq_len=2048,
    n_layers=32,
    d_model=4096,
    dtype='fp16',
)

print(f"7B模型推理显存 (batch=1, seq=2048, fp16):")
print(f"  模型权重: {memory['model_gb']:.2f}GB")
print(f"  KV Cache: {memory['kv_cache_gb']:.2f}GB")
print(f"  激活值: {memory['activation_gb']:.2f}GB")
print(f"  总计: {memory['total_gb']:.2f}GB")

# 输出: 约14-16GB,24GB显卡可运行

不同模型的推理显存需求:

模型FP16INT8INT4推荐显卡
7B14GB8GB5GBRTX 3090 (24GB)
13B26GB14GB8GBRTX 4090 (24GB) 需量化
70B140GB72GB38GBA100 80GB * 2

5. 推理过程详解

5.1 自回归生成

自回归生成是逐个token生成文本的过程。

流程:

  1. 输入prompt,编码为token IDs
  2. 前向传播,得到下一个token的概率分布
  3. 采样得到下一个token
  4. 将新token加入序列
  5. 重复2-4,直到生成结束符或达到最大长度

完整实现:

@torch.no_grad()
def autoregressive_generate(
    model,
    tokenizer,
    prompt,
    max_new_tokens=100,
    temperature=1.0,
    top_k=None,
    top_p=None,
    repetition_penalty=1.0,
):
    """
    自回归生成

    Args:
        model: 语言模型
        tokenizer: 分词器
        prompt: 输入文本
        max_new_tokens: 最大生成长度
        temperature: 温度参数
        top_k: Top-k采样
        top_p: Top-p (nucleus)采样
        repetition_penalty: 重复惩罚

    Returns:
        生成的文本
    """
    model.eval()
    device = next(model.parameters()).device

    # 编码prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # 用于repetition penalty
    generated_tokens = input_ids.clone()

    for _ in range(max_new_tokens):
        # 前向传播
        outputs = model(input_ids)
        logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]

        # 取最后一个token的logits
        next_token_logits = logits[:, -1, :]

        # 应用重复惩罚
        if repetition_penalty != 1.0:
            for token_id in set(generated_tokens[0].tolist()):
                next_token_logits[0, token_id] /= repetition_penalty

        # 应用温度
        next_token_logits = next_token_logits / temperature

        # Top-k过滤
        if top_k is not None:
            indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
            next_token_logits[indices_to_remove] = float('-inf')

        # Top-p (nucleus)过滤
        if top_p is not None:
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # 移除累积概率超过top_p的tokens
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            next_token_logits[indices_to_remove] = float('-inf')

        # 转换为概率
        probs = F.softmax(next_token_logits, dim=-1)

        # 采样
        next_token = torch.multinomial(probs, num_samples=1)

        # 添加到序列
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

        # 检查是否生成结束符
        if next_token.item() == tokenizer.eos_token_id:
            break

    # 解码
    generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

    return generated_text

# 使用示例
prompt = "The future of artificial intelligence is"
generated = autoregressive_generate(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_new_tokens=50,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.1,
)
print(generated)

5.2 KV Cache机制

KV Cache通过缓存已计算的K和V来加速推理。

问题: 自回归生成时,每次都要重新计算所有历史token的K和V,造成大量重复计算。

解决: 缓存已计算的K和V,每次只计算新token的KV。

实现:

class KVCache:
    """KV Cache管理器"""

    def __init__(self, n_layers, n_heads, head_dim, max_seq_len, dtype=torch.float16):
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.dtype = dtype

        # 为每层创建K和V的缓存
        self.k_cache = []
        self.v_cache = []

        for _ in range(n_layers):
            self.k_cache.append(
                torch.zeros(1, n_heads, max_seq_len, head_dim, dtype=dtype)
            )
            self.v_cache.append(
                torch.zeros(1, n_heads, max_seq_len, head_dim, dtype=dtype)
            )

        self.current_len = 0

    def update(self, layer_idx, k, v):
        """
        更新KV Cache

        Args:
            layer_idx: 层索引
            k: 新的K (batch, n_heads, seq_len, head_dim)
            v: 新的V
        """
        seq_len = k.size(2)

        # 写入cache
        self.k_cache[layer_idx][:, :, self.current_len:self.current_len + seq_len, :] = k
        self.v_cache[layer_idx][:, :, self.current_len:self.current_len + seq_len, :] = v

    def get(self, layer_idx):
        """
        获取当前的K和V

        Returns:
            k, v: 当前长度的K和V
        """
        k = self.k_cache[layer_idx][:, :, :self.current_len, :]
        v = self.v_cache[layer_idx][:, :, :self.current_len, :]
        return k, v

    def forward_step(self):
        """前进一步"""
        self.current_len += 1

    def reset(self):
        """重置cache"""
        self.current_len = 0

class AttentionWithKVCache(nn.Module):
    """带KV Cache的注意力层"""

    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.n_embd // config.n_heads

        self.q_proj = nn.Linear(config.n_embd, config.n_embd)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd)
        self.o_proj = nn.Linear(config.n_embd, config.n_embd)

    def forward(self, x, kv_cache=None, layer_idx=0):
        """
        前向传播

        Args:
            x: 输入 (batch, seq_len, n_embd)
            kv_cache: KV Cache对象
            layer_idx: 层索引
        """
        B, T, C = x.size()

        # 计算Q, K, V
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        if kv_cache is not None:
            # 使用KV Cache
            kv_cache.update(layer_idx, k, v)
            k, v = kv_cache.get(layer_idx)

        # 计算注意力
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = F.softmax(att, dim=-1)
        y = att @ v

        # 合并多头
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.o_proj(y)

        return y

# 使用KV Cache生成
@torch.no_grad()
def generate_with_kv_cache(model, input_ids, max_new_tokens=100):
    """使用KV Cache的生成"""
    device = input_ids.device

    # 创建KV Cache
    kv_cache = KVCache(
        n_layers=model.config.n_layer,
        n_heads=model.config.n_head,
        head_dim=model.config.n_embd // model.config.n_head,
        max_seq_len=model.config.block_size,
    )
    kv_cache.to(device)

    # 第一次前向:处理整个prompt
    logits = model(input_ids, kv_cache=kv_cache)
    kv_cache.current_len = input_ids.size(1)

    # 采样第一个新token
    next_token = torch.multinomial(F.softmax(logits[:, -1, :], dim=-1), num_samples=1)
    input_ids = torch.cat([input_ids, next_token], dim=1)

    # 自回归生成
    for _ in range(max_new_tokens - 1):
        # 只需前向最后一个token
        logits = model(next_token, kv_cache=kv_cache)
        kv_cache.forward_step()

        # 采样
        next_token = torch.multinomial(F.softmax(logits[:, -1, :], dim=-1), num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids

加速效果:

# 不使用KV Cache:每次前向O(n²)
# 生成100个token: 1 + 2 + 3 + ... + 100 = 5050次注意力计算

# 使用KV Cache:每次前向O(n)
# 生成100个token: 100次注意力计算

# 加速比:5050 / 100 = 50.5x

5.3 解码策略

Greedy Decoding

每次选择概率最高的token。

def greedy_search(logits):
    """贪心搜索"""
    return torch.argmax(logits, dim=-1)

优点:

  • 简单快速
  • 确定性输出

缺点:

  • 可能陷入重复
  • 缺乏多样性

Sampling

根据概率分布随机采样。

def sample(logits, temperature=1.0):
    """
    采样

    Args:
        logits: (vocab_size,) 未归一化的分数
        temperature: 温度参数
            - temperature < 1: 更确定(尖锐分布)
            - temperature > 1: 更随机(平滑分布)
    """
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

# 温度效果演示
def show_temperature_effect():
    logits = torch.tensor([2.0, 1.0, 0.5])

    for temp in [0.5, 1.0, 2.0]:
        probs = F.softmax(logits / temp, dim=-1)
        print(f"Temperature={temp}: {probs.tolist()}")

show_temperature_effect()

# 输出:
# Temperature=0.5: [0.632, 0.258, 0.110]  (更集中)
# Temperature=1.0: [0.506, 0.307, 0.186]  (原始)
# Temperature=2.0: [0.420, 0.331, 0.249]  (更平滑)

Beam Search

维护top-k个候选序列。

def beam_search(
    model,
    input_ids,
    beam_size=4,
    max_length=100,
    length_penalty=1.0,
):
    """
    束搜索

    Args:
        model: 语言模型
        input_ids: 初始输入 (1, seq_len)
        beam_size: 束大小
        max_length: 最大长度
        length_penalty: 长度惩罚(>1惩罚短序列)
    """
    device = input_ids.device
    vocab_size = model.config.vocab_size

    # 初始化:beam_size个候选
    sequences = input_ids.repeat(beam_size, 1)  # (beam_size, seq_len)
    scores = torch.zeros(beam_size, device=device)  # 累积log概率

    for _ in range(max_length):
        # 前向传播
        logits = model(sequences)[:, -1, :]  # (beam_size, vocab_size)
        log_probs = F.log_softmax(logits, dim=-1)

        # 计算候选分数
        # (beam_size, 1) + (beam_size, vocab_size) = (beam_size, vocab_size)
        candidate_scores = scores.unsqueeze(1) + log_probs

        # 展平并选择top-k
        candidate_scores = candidate_scores.view(-1)  # (beam_size * vocab_size,)
        topk_scores, topk_indices = torch.topk(candidate_scores, beam_size)

        # 恢复beam和token索引
        beam_indices = topk_indices // vocab_size
        token_indices = topk_indices % vocab_size

        # 更新序列
        sequences = torch.cat([
            sequences[beam_indices],
            token_indices.unsqueeze(1)
        ], dim=1)

        # 更新分数(应用长度惩罚)
        scores = topk_scores / (sequences.size(1) ** length_penalty)

    # 返回得分最高的序列
    best_idx = scores.argmax()
    return sequences[best_idx]

# 使用示例
output = beam_search(
    model=model,
    input_ids=input_ids,
    beam_size=4,
    max_length=50,
    length_penalty=1.2,
)

Top-k Sampling

只从概率最高的k个token中采样。

def top_k_sampling(logits, k=50, temperature=1.0):
    """
    Top-k采样

    Args:
        logits: (vocab_size,)
        k: 保留的top-k个token
        temperature: 温度
    """
    # 应用温度
    logits = logits / temperature

    # 找到top-k
    topk_logits, topk_indices = torch.topk(logits, k)

    # 过滤其他token(设为负无穷)
    filtered_logits = torch.full_like(logits, float('-inf'))
    filtered_logits[topk_indices] = topk_logits

    # 采样
    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

可视化:

import matplotlib.pyplot as plt

# 模拟一个概率分布
vocab_size = 100
logits = torch.randn(vocab_size) * 2
probs = F.softmax(logits, dim=-1)

# 排序
sorted_probs, sorted_indices = torch.sort(probs, descending=True)

# 绘制
plt.figure(figsize=(12, 4))

# 原始分布
plt.subplot(1, 2, 1)
plt.bar(range(vocab_size), sorted_probs)
plt.title('Original Distribution')
plt.xlabel('Token (sorted by probability)')
plt.ylabel('Probability')

# Top-k (k=10)
k = 10
topk_probs = torch.zeros_like(sorted_probs)
topk_probs[:k] = sorted_probs[:k]
topk_probs = topk_probs / topk_probs.sum()  # 重新归一化

plt.subplot(1, 2, 2)
plt.bar(range(vocab_size), topk_probs)
plt.title(f'Top-{k} Sampling')
plt.xlabel('Token (sorted by probability)')
plt.ylabel('Probability')

plt.tight_layout()
plt.show()

Top-p (Nucleus) Sampling

从累积概率达到p的最小token集合中采样。

def top_p_sampling(logits, p=0.9, temperature=1.0):
    """
    Top-p (nucleus)采样

    Args:
        logits: (vocab_size,)
        p: 累积概率阈值
        temperature: 温度
    """
    # 应用温度
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)

    # 降序排序
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)

    # 计算累积概率
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # 找到累积概率超过p的位置
    sorted_indices_to_remove = cumulative_probs > p

    # 保留第一个超过p的token(确保至少有一个)
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    # 过滤
    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
    filtered_logits = logits.clone()
    filtered_logits[indices_to_remove] = float('-inf')

    # 采样
    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

# 可视化Top-p
def visualize_top_p(p=0.9):
    # 模拟分布
    logits = torch.randn(100) * 2
    probs = F.softmax(logits, dim=-1)
    sorted_probs, _ = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # 找到p的位置
    nucleus_size = (cumulative_probs <= p).sum().item() + 1

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.bar(range(len(sorted_probs)), sorted_probs)
    plt.axvline(nucleus_size, color='red', linestyle='--', label=f'Nucleus size={nucleus_size}')
    plt.title(f'Top-p={p} Distribution')
    plt.xlabel('Token (sorted)')
    plt.ylabel('Probability')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(cumulative_probs)
    plt.axhline(p, color='red', linestyle='--', label=f'p={p}')
    plt.axvline(nucleus_size, color='red', linestyle='--')
    plt.title('Cumulative Probability')
    plt.xlabel('Token (sorted)')
    plt.ylabel('Cumulative Probability')
    plt.legend()

    plt.tight_layout()
    plt.show()

    print(f"Nucleus包含 {nucleus_size} 个token")

visualize_top_p(p=0.9)

Top-k vs Top-p:

  • Top-k: 固定数量,适合概率分布比较均匀的情况
  • Top-p: 动态数量,自动适应分布的尖锐程度

对比总结

策略优点缺点适用场景
Greedy快速、确定重复、单调需要确定性输出
Sampling多样性可能不连贯创意写作
Beam Search高质量慢、缺乏多样性翻译、摘要
Top-k平衡质量和多样性k难调通用
Top-p自适应稍慢通用(推荐)

5.4 完整推理代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

class LLMInference:
    """完整的LLM推理类"""

    def __init__(self, model_name_or_path, device='cuda'):
        """
        初始化

        Args:
            model_name_or_path: 模型名称或路径
            device: 设备
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.float16,
            device_map='auto',
        )
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt,
        max_new_tokens=100,
        temperature=1.0,
        top_k=None,
        top_p=None,
        repetition_penalty=1.0,
        do_sample=True,
        num_return_sequences=1,
    ):
        """
        生成文本

        Args:
            prompt: 输入提示
            max_new_tokens: 最大生成长度
            temperature: 温度
            top_k: Top-k采样
            top_p: Top-p采样
            repetition_penalty: 重复惩罚
            do_sample: 是否采样(否则贪心)
            num_return_sequences: 返回序列数

        Returns:
            生成的文本列表
        """
        # 编码输入
        inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
        input_ids = inputs['input_ids']

        # 生成
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=do_sample,
            num_return_sequences=num_return_sequences,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        # 解码
        generated_texts = [
            self.tokenizer.decode(output, skip_special_tokens=True)
            for output in outputs
        ]

        return generated_texts

    def chat(
        self,
        messages,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
    ):
        """
        对话生成

        Args:
            messages: 对话历史
                [
                    {"role": "user", "content": "你好"},
                    {"role": "assistant", "content": "你好!有什么可以帮你的吗?"},
                    {"role": "user", "content": "介绍一下Python"},
                ]
        """
        # 应用对话模板
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        # 生成
        outputs = self.generate(
            prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
        )

        return outputs[0]

    def batch_generate(self, prompts, **kwargs):
        """批量生成"""
        results = []
        for prompt in prompts:
            result = self.generate(prompt, **kwargs)
            results.extend(result)
        return results

# 使用示例
if __name__ == "__main__":
    # 初始化
    inferencer = LLMInference("meta-llama/Llama-2-7b-chat-hf")

    # 单次生成
    prompt = "The future of artificial intelligence is"
    outputs = inferencer.generate(
        prompt,
        max_new_tokens=100,
        temperature=0.8,
        top_p=0.9,
        num_return_sequences=3,
    )

    print("生成结果:")
    for i, text in enumerate(outputs):
        print(f"\n样本 {i+1}:")
        print(text)

    # 对话
    messages = [
        {"role": "user", "content": "什么是大语言模型?"},
    ]

    response = inferencer.chat(messages, temperature=0.7)
    print(f"\n助手: {response}")

    # 批量生成
    prompts = [
        "Once upon a time",
        "In the year 2050",
        "The secret to happiness is",
    ]

    results = inferencer.batch_generate(
        prompts,
        max_new_tokens=50,
        temperature=0.9,
    )

    print("\n批量生成结果:")
    for prompt, result in zip(prompts, results):
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {result}")
Prev
08-Transformer架构深度解析
Next
10-Token与Tokenization详解