HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

Speculative Decoding 推测解码

概述

Speculative Decoding(推测解码)是一种通过小模型预测、大模型验证来加速自回归生成的技术。本章深入讲解推测解码的原理、实现策略和工程优化。

原理与动机

自回归解码的瓶颈

┌─────────────────────────────────────────────────────────────────────────────┐
│                    Autoregressive Decoding Bottleneck                        │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  传统自回归解码:                                                             │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                                                                     │   │
│  │   Token 1 ──▶ LLM ──▶ Token 2 ──▶ LLM ──▶ Token 3 ──▶ ...          │   │
│  │      │         │         │         │         │                      │   │
│  │      └────────┬┘         └────────┬┘         └────────┬             │   │
│  │           延迟1             延迟2             延迟3                  │   │
│  │                                                                     │   │
│  │   总延迟 = N × 单次推理延迟                                          │   │
│  │   GPU利用率: 低 (Memory Bound)                                       │   │
│  │                                                                     │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
│                                                                             │
│  问题分析:                                                                  │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                                                                     │   │
│  │   • 每次只生成 1 个 token                                            │   │
│  │   • 串行依赖无法并行                                                  │   │
│  │   • GPU 计算单元利用率 < 10%                                         │   │
│  │   • 主要瓶颈:内存带宽(读取 KV Cache 和模型权重)                     │   │
│  │                                                                     │   │
│  │   理论计算量: 2 × params × tokens                                    │   │
│  │   实际瓶颈: 内存读取 params + KV cache                                │   │
│  │                                                                     │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Speculative Decoding 原理

┌─────────────────────────────────────────────────────────────────────────────┐
│                      Speculative Decoding Pipeline                           │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│   Phase 1: Draft (小模型推测)                                               │
│   ┌─────────────────────────────────────────────────────────────────────┐  │
│   │                                                                     │  │
│   │   Context ──▶ Draft Model ──▶ [t1, t2, t3, t4, t5] (K个候选token)   │  │
│   │                  (快速)                                              │  │
│   │                                                                     │  │
│   │   时间: ~K × draft_latency (可忽略)                                  │  │
│   │                                                                     │  │
│   └─────────────────────────────────────────────────────────────────────┘  │
│                                                                             │
│   Phase 2: Verify (大模型验证)                                              │
│   ┌─────────────────────────────────────────────────────────────────────┐  │
│   │                                                                     │  │
│   │   [Context, t1, t2, t3, t4, t5] ──▶ Target Model ──▶ Verify         │  │
│   │                                        (1次推理)                     │  │
│   │                                                                     │  │
│   │   一次并行验证 K 个 token                                            │  │
│   │   时间: ~1 × target_latency                                         │  │
│   │                                                                     │  │
│   └─────────────────────────────────────────────────────────────────────┘  │
│                                                                             │
│   Phase 3: Accept/Reject                                                    │
│   ┌─────────────────────────────────────────────────────────────────────┐  │
│   │                                                                     │  │
│   │   验证结果: t1 ✓, t2 ✓, t3 ✓, t4 ✗ (第4个开始错误)                  │  │
│   │                                                                     │  │
│   │   接受: t1, t2, t3 + 大模型在位置4的采样结果                         │  │
│   │   → 本轮生成 4 个 token                                              │  │
│   │                                                                     │  │
│   └─────────────────────────────────────────────────────────────────────┘  │
│                                                                             │
│   加速比分析:                                                               │
│   ┌─────────────────────────────────────────────────────────────────────┐  │
│   │                                                                     │  │
│   │   假设: K=5, 平均接受率 α=0.8                                       │  │
│   │                                                                     │  │
│   │   传统方法生成 5 个 token: 5 × target_latency                       │  │
│   │   推测解码生成 5 个 token: 1 × target_latency (期望接受4个)          │  │
│   │                                                                     │  │
│   │   加速比 ≈ K × α / (1 + draft_overhead) ≈ 3-5x                      │  │
│   │                                                                     │  │
│   └─────────────────────────────────────────────────────────────────────┘  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

核心算法实现

基础 Speculative Decoding

"""
Speculative Decoding 基础实现
"""
import torch
import torch.nn.functional as F
from typing import Tuple, List, Optional
from dataclasses import dataclass

@dataclass
class SpeculativeConfig:
    """推测解码配置"""
    num_speculative_tokens: int = 5  # K: 每轮推测的 token 数
    temperature: float = 1.0
    top_p: float = 0.9
    use_tree_attention: bool = False


class SpeculativeDecoder:
    """推测解码器"""

    def __init__(
        self,
        target_model,  # 大模型
        draft_model,   # 小模型
        config: SpeculativeConfig
    ):
        self.target_model = target_model
        self.draft_model = draft_model
        self.config = config
        self.K = config.num_speculative_tokens

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        max_length: int,
        eos_token_id: int = None
    ) -> torch.Tensor:
        """生成文本"""
        device = input_ids.device
        batch_size = input_ids.shape[0]

        # 初始化 KV Cache
        target_past = None
        draft_past = None

        generated = input_ids.clone()

        while generated.shape[1] < max_length:
            # Phase 1: Draft - 小模型生成 K 个候选 token
            draft_tokens, draft_probs, draft_past = self._draft(
                generated,
                draft_past
            )

            # Phase 2: Verify - 大模型并行验证
            target_probs, target_past = self._verify(
                generated,
                draft_tokens,
                target_past
            )

            # Phase 3: Accept/Reject - 决定接受多少 token
            accepted_tokens, num_accepted = self._accept_reject(
                draft_tokens,
                draft_probs,
                target_probs
            )

            # 更新生成序列
            generated = torch.cat([generated, accepted_tokens], dim=1)

            # 检查结束条件
            if eos_token_id is not None:
                if (generated[:, -1] == eos_token_id).all():
                    break

            # 更新 draft cache(需要截断到接受的位置)
            if num_accepted < self.K:
                draft_past = self._truncate_cache(draft_past, num_accepted)

        return generated

    def _draft(
        self,
        context: torch.Tensor,
        past_key_values: Optional[Tuple] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
        """使用 draft 模型生成候选 token"""
        device = context.device
        batch_size = context.shape[0]

        draft_tokens = []
        draft_probs = []

        # 生成 K 个 token
        for i in range(self.K):
            if i == 0:
                input_ids = context
            else:
                input_ids = draft_tokens[-1]

            # 前向传播
            outputs = self.draft_model(
                input_ids=input_ids,
                past_key_values=past_key_values,
                use_cache=True
            )

            logits = outputs.logits[:, -1, :]  # [B, V]
            past_key_values = outputs.past_key_values

            # 采样
            probs = self._apply_sampling(logits)
            next_token = torch.multinomial(probs, num_samples=1)

            draft_tokens.append(next_token)
            draft_probs.append(probs)

        # 堆叠结果
        draft_tokens = torch.cat(draft_tokens, dim=1)  # [B, K]
        draft_probs = torch.stack(draft_probs, dim=1)  # [B, K, V]

        return draft_tokens, draft_probs, past_key_values

    def _verify(
        self,
        context: torch.Tensor,
        draft_tokens: torch.Tensor,
        past_key_values: Optional[Tuple] = None
    ) -> Tuple[torch.Tensor, Tuple]:
        """使用 target 模型验证候选 token"""
        # 构建完整输入:context + draft_tokens
        if past_key_values is not None:
            # 只输入新 token
            verify_input = draft_tokens
        else:
            verify_input = torch.cat([context, draft_tokens], dim=1)

        # 一次前向传播验证所有 token
        outputs = self.target_model(
            input_ids=verify_input,
            past_key_values=past_key_values,
            use_cache=True
        )

        # 获取每个位置的概率
        logits = outputs.logits  # [B, K+1, V] 或 [B, seq+K, V]

        # 取最后 K+1 个位置的 logits
        if past_key_values is None:
            logits = logits[:, -self.K-1:, :]
        else:
            logits = logits  # [B, K, V]

        probs = self._apply_sampling(logits)

        return probs, outputs.past_key_values

    def _accept_reject(
        self,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
        target_probs: torch.Tensor
    ) -> Tuple[torch.Tensor, int]:
        """
        接受/拒绝算法

        使用 rejection sampling 保证分布一致性:
        - 如果 p_target(x) >= p_draft(x): 一定接受
        - 如果 p_target(x) < p_draft(x): 以概率 p_target(x)/p_draft(x) 接受
        """
        batch_size = draft_tokens.shape[0]
        device = draft_tokens.device

        accepted = []
        all_accepted = True

        for i in range(self.K):
            draft_token = draft_tokens[:, i]  # [B]

            # 获取 draft 和 target 在这个 token 上的概率
            draft_p = draft_probs[:, i].gather(1, draft_token.unsqueeze(1)).squeeze(1)  # [B]
            target_p = target_probs[:, i].gather(1, draft_token.unsqueeze(1)).squeeze(1)  # [B]

            # 计算接受概率
            accept_prob = torch.minimum(
                torch.ones_like(target_p),
                target_p / (draft_p + 1e-10)
            )

            # 随机决定是否接受
            random_vals = torch.rand(batch_size, device=device)
            accept_mask = random_vals < accept_prob

            if accept_mask.all():
                accepted.append(draft_token)
            else:
                all_accepted = False
                # 对于被拒绝的,使用修正后的分布采样
                # p_corrected = max(0, p_target - p_draft) / sum(...)
                residual_probs = F.relu(target_probs[:, i] - draft_probs[:, i])
                residual_probs = residual_probs / residual_probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)

                # 混合:接受的用 draft token,拒绝的用 residual 采样
                corrected_token = torch.where(
                    accept_mask,
                    draft_token,
                    torch.multinomial(residual_probs, num_samples=1).squeeze(1)
                )
                accepted.append(corrected_token)
                break

        # 如果所有都接受了,额外采样一个 token
        if all_accepted:
            extra_token = torch.multinomial(target_probs[:, -1], num_samples=1).squeeze(1)
            accepted.append(extra_token)

        accepted_tokens = torch.stack(accepted, dim=1)  # [B, num_accepted]
        num_accepted = len(accepted)

        return accepted_tokens, num_accepted

    def _apply_sampling(self, logits: torch.Tensor) -> torch.Tensor:
        """应用采样策略"""
        # Temperature
        if self.config.temperature != 1.0:
            logits = logits / self.config.temperature

        # Top-p (nucleus) sampling
        if self.config.top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # 移除概率累积超过 top_p 的 token
            sorted_indices_to_remove = cumulative_probs > self.config.top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=-1,
                index=sorted_indices,
                src=sorted_indices_to_remove
            )
            logits[indices_to_remove] = float('-inf')

        return F.softmax(logits, dim=-1)

    def _truncate_cache(
        self,
        past_key_values: Tuple,
        keep_length: int
    ) -> Tuple:
        """截断 KV cache"""
        if past_key_values is None:
            return None

        truncated = []
        for layer_past in past_key_values:
            # layer_past: (key, value) 各 [B, H, S, D]
            key, value = layer_past
            truncated.append((
                key[:, :, :-self.K + keep_length, :],
                value[:, :, :-self.K + keep_length, :]
            ))

        return tuple(truncated)

Tree Attention (Medusa)

"""
Tree Attention 实现 (Medusa 风格)
"""
import torch
import torch.nn as nn
from typing import List, Tuple

class TreeAttentionMask:
    """树形注意力掩码"""

    def __init__(self, num_heads: int, max_depth: int):
        self.num_heads = num_heads
        self.max_depth = max_depth

    def create_tree_mask(
        self,
        tree_structure: List[List[int]]
    ) -> torch.Tensor:
        """
        创建树形注意力掩码

        tree_structure: 树的结构,每层的分支数
        例如 [[1], [2, 2], [2, 2, 2, 2]] 表示:
            root
           /    \
          a      b
         / \    / \
        c   d  e   f
        """
        # 计算总节点数
        total_nodes = sum(sum(layer) for layer in tree_structure)

        # 创建掩码
        mask = torch.zeros(total_nodes, total_nodes, dtype=torch.bool)

        # 填充掩码 - 每个节点只能看到其祖先
        node_idx = 0
        parent_indices = [0]  # root 的父亲是自己

        for depth, layer in enumerate(tree_structure):
            new_parent_indices = []
            for branch_idx, num_children in enumerate(layer):
                parent = parent_indices[branch_idx] if branch_idx < len(parent_indices) else 0

                for child in range(num_children):
                    # 当前节点可以看到所有祖先
                    ancestor = parent
                    while ancestor is not None:
                        mask[node_idx, ancestor] = True
                        if ancestor == 0:
                            ancestor = None
                        else:
                            # 找父节点(简化处理)
                            ancestor = max(0, ancestor - 1)

                    mask[node_idx, node_idx] = True  # 自己也可见
                    new_parent_indices.append(node_idx)
                    node_idx += 1

            parent_indices = new_parent_indices

        return mask


class MedusaHead(nn.Module):
    """Medusa 预测头"""

    def __init__(
        self,
        hidden_size: int,
        vocab_size: int,
        num_heads: int = 4  # 预测未来 4 个位置
    ):
        super().__init__()
        self.num_heads = num_heads

        # 每个头预测一个未来位置
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.SiLU(),
                nn.Linear(hidden_size, vocab_size)
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
        """
        Args:
            hidden_states: [B, S, H]

        Returns:
            List of logits for each future position
        """
        return [head(hidden_states) for head in self.heads]


class MedusaDecoder:
    """Medusa 解码器"""

    def __init__(
        self,
        model,  # 带 Medusa heads 的模型
        num_heads: int = 4,
        top_k: int = 10  # 每个头选 top-k 候选
    ):
        self.model = model
        self.num_heads = num_heads
        self.top_k = top_k

        # 构建候选树
        self.tree_indices = self._build_tree_indices()

    def _build_tree_indices(self) -> torch.Tensor:
        """构建候选树索引"""
        # 简化:每个头选 top-k,形成 k^num_heads 个候选路径
        # 实际使用中会剪枝
        total_candidates = self.top_k ** self.num_heads
        return torch.arange(total_candidates)

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        max_length: int
    ) -> torch.Tensor:
        generated = input_ids.clone()

        while generated.shape[1] < max_length:
            # Step 1: 前向传播获取原始 logits 和 Medusa heads 预测
            outputs = self.model(generated, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1][:, -1:]  # [B, 1, H]

            # 原始模型的 next token
            base_logits = outputs.logits[:, -1, :]  # [B, V]

            # Medusa heads 预测
            medusa_logits = self.model.medusa_heads(hidden_states)  # List of [B, 1, V]

            # Step 2: 生成候选树
            candidates = self._generate_candidates(base_logits, medusa_logits)

            # Step 3: 并行验证所有候选
            accept_length = self._tree_verify(generated, candidates)

            # Step 4: 接受最长的正确前缀
            accepted = candidates[:, :accept_length]
            generated = torch.cat([generated, accepted], dim=1)

        return generated

    def _generate_candidates(
        self,
        base_logits: torch.Tensor,
        medusa_logits: List[torch.Tensor]
    ) -> torch.Tensor:
        """生成候选 token 树"""
        batch_size = base_logits.shape[0]
        device = base_logits.device

        # 获取每个位置的 top-k
        all_topk = []

        # 位置 0: 基础模型预测
        _, topk_0 = torch.topk(base_logits, self.top_k, dim=-1)
        all_topk.append(topk_0)

        # 位置 1-N: Medusa heads 预测
        for head_logits in medusa_logits:
            _, topk_i = torch.topk(head_logits.squeeze(1), self.top_k, dim=-1)
            all_topk.append(topk_i)

        # 构建候选路径(笛卡尔积,可优化)
        # 这里简化为只取最可能的路径
        candidates = torch.stack([t[:, 0] for t in all_topk], dim=1)

        return candidates

    def _tree_verify(
        self,
        context: torch.Tensor,
        candidates: torch.Tensor
    ) -> int:
        """验证候选树,返回可接受的长度"""
        # 并行验证所有候选
        verify_input = torch.cat([context, candidates], dim=1)
        outputs = self.model(verify_input)
        logits = outputs.logits

        # 检查每个位置
        accept_length = 0
        for i in range(candidates.shape[1]):
            pos = context.shape[1] + i - 1
            predicted = logits[:, pos, :].argmax(dim=-1)
            if (predicted == candidates[:, i]).all():
                accept_length = i + 1
            else:
                break

        return max(1, accept_length)

Self-Speculative Decoding

"""
Self-Speculative Decoding - 使用 Early Exit 作为 Draft
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple

class EarlyExitTransformer(nn.Module):
    """支持 Early Exit 的 Transformer"""

    def __init__(self, config, exit_layers: List[int] = None):
        super().__init__()
        self.config = config
        self.exit_layers = exit_layers or [config.num_hidden_layers // 3]

        # 正常的 transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(config)
            for _ in range(config.num_hidden_layers)
        ])

        # Early exit heads
        self.exit_heads = nn.ModuleDict({
            str(layer): nn.Linear(config.hidden_size, config.vocab_size)
            for layer in self.exit_layers
        })

        # 最终 head
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        exit_layer: Optional[int] = None,
        past_key_values: Optional[Tuple] = None
    ):
        hidden_states = self.embed_tokens(input_ids)

        new_past = []

        for i, layer in enumerate(self.layers):
            layer_past = past_key_values[i] if past_key_values else None
            hidden_states, new_layer_past = layer(
                hidden_states,
                past_key_values=layer_past
            )
            new_past.append(new_layer_past)

            # 检查是否 early exit
            if exit_layer is not None and i == exit_layer:
                logits = self.exit_heads[str(i)](hidden_states)
                return logits, tuple(new_past[:i+1])

        # 正常输出
        logits = self.lm_head(hidden_states)
        return logits, tuple(new_past)


class SelfSpeculativeDecoder:
    """Self-Speculative 解码器"""

    def __init__(
        self,
        model: EarlyExitTransformer,
        draft_exit_layer: int,
        num_speculative_tokens: int = 4
    ):
        self.model = model
        self.draft_exit_layer = draft_exit_layer
        self.K = num_speculative_tokens

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        max_length: int
    ) -> torch.Tensor:
        generated = input_ids.clone()
        past_key_values = None
        draft_past = None

        while generated.shape[1] < max_length:
            # Draft: 使用 early exit 快速生成
            draft_tokens, draft_probs, draft_past = self._draft_with_early_exit(
                generated,
                draft_past
            )

            # Verify: 使用完整模型验证
            target_probs, past_key_values = self._full_model_verify(
                generated,
                draft_tokens,
                past_key_values
            )

            # Accept/Reject
            accepted, num_accepted = self._speculative_accept(
                draft_tokens,
                draft_probs,
                target_probs
            )

            generated = torch.cat([generated, accepted], dim=1)

            # 同步 cache
            if num_accepted < self.K:
                draft_past = self._truncate_cache(draft_past, num_accepted)
                past_key_values = self._extend_cache(
                    past_key_values,
                    draft_past,
                    num_accepted
                )

        return generated

    def _draft_with_early_exit(
        self,
        context: torch.Tensor,
        past_key_values: Optional[Tuple]
    ) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
        """使用 early exit 生成 draft"""
        draft_tokens = []
        draft_probs = []

        for i in range(self.K):
            input_ids = context if i == 0 else draft_tokens[-1]

            logits, past_key_values = self.model(
                input_ids,
                exit_layer=self.draft_exit_layer,
                past_key_values=past_key_values
            )

            probs = F.softmax(logits[:, -1, :], dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            draft_tokens.append(next_token)
            draft_probs.append(probs)

        return (
            torch.cat(draft_tokens, dim=1),
            torch.stack(draft_probs, dim=1),
            past_key_values
        )

    def _full_model_verify(
        self,
        context: torch.Tensor,
        draft_tokens: torch.Tensor,
        past_key_values: Optional[Tuple]
    ) -> Tuple[torch.Tensor, Tuple]:
        """使用完整模型验证"""
        verify_input = torch.cat([context, draft_tokens], dim=1) if past_key_values is None else draft_tokens

        logits, new_past = self.model(
            verify_input,
            exit_layer=None,  # 不使用 early exit
            past_key_values=past_key_values
        )

        probs = F.softmax(logits[:, -self.K-1:, :], dim=-1)

        return probs, new_past

性能优化

批量推测解码

"""
批量推测解码优化
"""
import torch
from typing import List

class BatchedSpeculativeDecoder:
    """支持批量处理的推测解码器"""

    def __init__(
        self,
        target_model,
        draft_model,
        config: SpeculativeConfig
    ):
        self.target_model = target_model
        self.draft_model = draft_model
        self.config = config
        self.K = config.num_speculative_tokens

    @torch.no_grad()
    def generate_batch(
        self,
        input_ids: torch.Tensor,  # [B, S]
        max_lengths: List[int]    # 每个样本的最大长度
    ) -> List[torch.Tensor]:
        """批量生成"""
        batch_size = input_ids.shape[0]
        device = input_ids.device

        # 追踪每个样本的状态
        active_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
        generated = [input_ids[i:i+1].clone() for i in range(batch_size)]
        results = [None] * batch_size

        # Draft 和 target 的 past
        draft_pasts = [None] * batch_size
        target_pasts = [None] * batch_size

        while active_mask.any():
            # 收集活跃样本
            active_indices = active_mask.nonzero().squeeze(-1).tolist()
            if not active_indices:
                break

            # 批量 draft
            draft_inputs = torch.cat([
                generated[i] for i in active_indices
            ], dim=0)

            draft_tokens, draft_probs, new_draft_pasts = self._batched_draft(
                draft_inputs,
                [draft_pasts[i] for i in active_indices]
            )

            # 批量 verify
            target_probs, new_target_pasts = self._batched_verify(
                draft_inputs,
                draft_tokens,
                [target_pasts[i] for i in active_indices]
            )

            # 逐样本 accept/reject
            for j, i in enumerate(active_indices):
                accepted, num_accepted = self._accept_reject_single(
                    draft_tokens[j:j+1],
                    draft_probs[j:j+1],
                    target_probs[j:j+1]
                )

                generated[i] = torch.cat([generated[i], accepted], dim=1)

                # 更新 cache
                draft_pasts[i] = self._extract_past(new_draft_pasts, j, num_accepted)
                target_pasts[i] = self._extract_past(new_target_pasts, j, num_accepted)

                # 检查完成条件
                if generated[i].shape[1] >= max_lengths[i]:
                    active_mask[i] = False
                    results[i] = generated[i]

        # 处理剩余
        for i in range(batch_size):
            if results[i] is None:
                results[i] = generated[i]

        return results

    def _batched_draft(self, inputs, pasts):
        """批量 draft 生成"""
        # 合并 past_key_values
        # ...
        pass

    def _batched_verify(self, contexts, draft_tokens, pasts):
        """批量验证"""
        # ...
        pass


class AdaptiveSpeculativeDecoder:
    """自适应推测解码 - 动态调整 K"""

    def __init__(
        self,
        target_model,
        draft_model,
        min_k: int = 2,
        max_k: int = 8,
        target_accept_rate: float = 0.8
    ):
        self.target_model = target_model
        self.draft_model = draft_model
        self.min_k = min_k
        self.max_k = max_k
        self.target_accept_rate = target_accept_rate

        # 当前 K 值
        self.current_k = (min_k + max_k) // 2

        # 统计
        self.accept_rates = []
        self.window_size = 10

    def _update_k(self, accept_rate: float):
        """根据接受率调整 K"""
        self.accept_rates.append(accept_rate)

        if len(self.accept_rates) >= self.window_size:
            avg_rate = sum(self.accept_rates[-self.window_size:]) / self.window_size

            if avg_rate > self.target_accept_rate + 0.1:
                # 接受率高,可以增加 K
                self.current_k = min(self.current_k + 1, self.max_k)
            elif avg_rate < self.target_accept_rate - 0.1:
                # 接受率低,减少 K
                self.current_k = max(self.current_k - 1, self.min_k)

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_length: int) -> torch.Tensor:
        """自适应生成"""
        generated = input_ids.clone()

        while generated.shape[1] < max_length:
            # 使用当前 K
            K = self.current_k

            # Draft
            draft_tokens, draft_probs = self._draft(generated, K)

            # Verify
            target_probs = self._verify(generated, draft_tokens)

            # Accept/Reject
            accepted, num_accepted = self._accept_reject(
                draft_tokens,
                draft_probs,
                target_probs
            )

            generated = torch.cat([generated, accepted], dim=1)

            # 更新 K
            accept_rate = num_accepted / K
            self._update_k(accept_rate)

        return generated

小结

本章深入讲解了 Speculative Decoding:

  1. 核心原理:小模型推测 + 大模型验证
  2. 接受算法:Rejection Sampling 保证分布一致性
  3. 高级技术:Tree Attention (Medusa)、Self-Speculative
  4. 性能优化:批量处理、自适应 K

下一章我们将探讨 多模态推理,讲解图像、音频、视频等多模态内容的处理。