Speculative Decoding 推测解码
概述
Speculative Decoding(推测解码)是一种通过小模型预测、大模型验证来加速自回归生成的技术。本章深入讲解推测解码的原理、实现策略和工程优化。
原理与动机
自回归解码的瓶颈
┌─────────────────────────────────────────────────────────────────────────────┐
│ Autoregressive Decoding Bottleneck │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 传统自回归解码: │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Token 1 ──▶ LLM ──▶ Token 2 ──▶ LLM ──▶ Token 3 ──▶ ... │ │
│ │ │ │ │ │ │ │ │
│ │ └────────┬┘ └────────┬┘ └────────┬ │ │
│ │ 延迟1 延迟2 延迟3 │ │
│ │ │ │
│ │ 总延迟 = N × 单次推理延迟 │ │
│ │ GPU利用率: 低 (Memory Bound) │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ 问题分析: │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ • 每次只生成 1 个 token │ │
│ │ • 串行依赖无法并行 │ │
│ │ • GPU 计算单元利用率 < 10% │ │
│ │ • 主要瓶颈:内存带宽(读取 KV Cache 和模型权重) │ │
│ │ │ │
│ │ 理论计算量: 2 × params × tokens │ │
│ │ 实际瓶颈: 内存读取 params + KV cache │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
Speculative Decoding 原理
┌─────────────────────────────────────────────────────────────────────────────┐
│ Speculative Decoding Pipeline │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Phase 1: Draft (小模型推测) │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Context ──▶ Draft Model ──▶ [t1, t2, t3, t4, t5] (K个候选token) │ │
│ │ (快速) │ │
│ │ │ │
│ │ 时间: ~K × draft_latency (可忽略) │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ Phase 2: Verify (大模型验证) │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ [Context, t1, t2, t3, t4, t5] ──▶ Target Model ──▶ Verify │ │
│ │ (1次推理) │ │
│ │ │ │
│ │ 一次并行验证 K 个 token │ │
│ │ 时间: ~1 × target_latency │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ Phase 3: Accept/Reject │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 验证结果: t1 ✓, t2 ✓, t3 ✓, t4 ✗ (第4个开始错误) │ │
│ │ │ │
│ │ 接受: t1, t2, t3 + 大模型在位置4的采样结果 │ │
│ │ → 本轮生成 4 个 token │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ 加速比分析: │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 假设: K=5, 平均接受率 α=0.8 │ │
│ │ │ │
│ │ 传统方法生成 5 个 token: 5 × target_latency │ │
│ │ 推测解码生成 5 个 token: 1 × target_latency (期望接受4个) │ │
│ │ │ │
│ │ 加速比 ≈ K × α / (1 + draft_overhead) ≈ 3-5x │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
核心算法实现
基础 Speculative Decoding
"""
Speculative Decoding 基础实现
"""
import torch
import torch.nn.functional as F
from typing import Tuple, List, Optional
from dataclasses import dataclass
@dataclass
class SpeculativeConfig:
"""推测解码配置"""
num_speculative_tokens: int = 5 # K: 每轮推测的 token 数
temperature: float = 1.0
top_p: float = 0.9
use_tree_attention: bool = False
class SpeculativeDecoder:
"""推测解码器"""
def __init__(
self,
target_model, # 大模型
draft_model, # 小模型
config: SpeculativeConfig
):
self.target_model = target_model
self.draft_model = draft_model
self.config = config
self.K = config.num_speculative_tokens
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_length: int,
eos_token_id: int = None
) -> torch.Tensor:
"""生成文本"""
device = input_ids.device
batch_size = input_ids.shape[0]
# 初始化 KV Cache
target_past = None
draft_past = None
generated = input_ids.clone()
while generated.shape[1] < max_length:
# Phase 1: Draft - 小模型生成 K 个候选 token
draft_tokens, draft_probs, draft_past = self._draft(
generated,
draft_past
)
# Phase 2: Verify - 大模型并行验证
target_probs, target_past = self._verify(
generated,
draft_tokens,
target_past
)
# Phase 3: Accept/Reject - 决定接受多少 token
accepted_tokens, num_accepted = self._accept_reject(
draft_tokens,
draft_probs,
target_probs
)
# 更新生成序列
generated = torch.cat([generated, accepted_tokens], dim=1)
# 检查结束条件
if eos_token_id is not None:
if (generated[:, -1] == eos_token_id).all():
break
# 更新 draft cache(需要截断到接受的位置)
if num_accepted < self.K:
draft_past = self._truncate_cache(draft_past, num_accepted)
return generated
def _draft(
self,
context: torch.Tensor,
past_key_values: Optional[Tuple] = None
) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
"""使用 draft 模型生成候选 token"""
device = context.device
batch_size = context.shape[0]
draft_tokens = []
draft_probs = []
# 生成 K 个 token
for i in range(self.K):
if i == 0:
input_ids = context
else:
input_ids = draft_tokens[-1]
# 前向传播
outputs = self.draft_model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True
)
logits = outputs.logits[:, -1, :] # [B, V]
past_key_values = outputs.past_key_values
# 采样
probs = self._apply_sampling(logits)
next_token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(next_token)
draft_probs.append(probs)
# 堆叠结果
draft_tokens = torch.cat(draft_tokens, dim=1) # [B, K]
draft_probs = torch.stack(draft_probs, dim=1) # [B, K, V]
return draft_tokens, draft_probs, past_key_values
def _verify(
self,
context: torch.Tensor,
draft_tokens: torch.Tensor,
past_key_values: Optional[Tuple] = None
) -> Tuple[torch.Tensor, Tuple]:
"""使用 target 模型验证候选 token"""
# 构建完整输入:context + draft_tokens
if past_key_values is not None:
# 只输入新 token
verify_input = draft_tokens
else:
verify_input = torch.cat([context, draft_tokens], dim=1)
# 一次前向传播验证所有 token
outputs = self.target_model(
input_ids=verify_input,
past_key_values=past_key_values,
use_cache=True
)
# 获取每个位置的概率
logits = outputs.logits # [B, K+1, V] 或 [B, seq+K, V]
# 取最后 K+1 个位置的 logits
if past_key_values is None:
logits = logits[:, -self.K-1:, :]
else:
logits = logits # [B, K, V]
probs = self._apply_sampling(logits)
return probs, outputs.past_key_values
def _accept_reject(
self,
draft_tokens: torch.Tensor,
draft_probs: torch.Tensor,
target_probs: torch.Tensor
) -> Tuple[torch.Tensor, int]:
"""
接受/拒绝算法
使用 rejection sampling 保证分布一致性:
- 如果 p_target(x) >= p_draft(x): 一定接受
- 如果 p_target(x) < p_draft(x): 以概率 p_target(x)/p_draft(x) 接受
"""
batch_size = draft_tokens.shape[0]
device = draft_tokens.device
accepted = []
all_accepted = True
for i in range(self.K):
draft_token = draft_tokens[:, i] # [B]
# 获取 draft 和 target 在这个 token 上的概率
draft_p = draft_probs[:, i].gather(1, draft_token.unsqueeze(1)).squeeze(1) # [B]
target_p = target_probs[:, i].gather(1, draft_token.unsqueeze(1)).squeeze(1) # [B]
# 计算接受概率
accept_prob = torch.minimum(
torch.ones_like(target_p),
target_p / (draft_p + 1e-10)
)
# 随机决定是否接受
random_vals = torch.rand(batch_size, device=device)
accept_mask = random_vals < accept_prob
if accept_mask.all():
accepted.append(draft_token)
else:
all_accepted = False
# 对于被拒绝的,使用修正后的分布采样
# p_corrected = max(0, p_target - p_draft) / sum(...)
residual_probs = F.relu(target_probs[:, i] - draft_probs[:, i])
residual_probs = residual_probs / residual_probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
# 混合:接受的用 draft token,拒绝的用 residual 采样
corrected_token = torch.where(
accept_mask,
draft_token,
torch.multinomial(residual_probs, num_samples=1).squeeze(1)
)
accepted.append(corrected_token)
break
# 如果所有都接受了,额外采样一个 token
if all_accepted:
extra_token = torch.multinomial(target_probs[:, -1], num_samples=1).squeeze(1)
accepted.append(extra_token)
accepted_tokens = torch.stack(accepted, dim=1) # [B, num_accepted]
num_accepted = len(accepted)
return accepted_tokens, num_accepted
def _apply_sampling(self, logits: torch.Tensor) -> torch.Tensor:
"""应用采样策略"""
# Temperature
if self.config.temperature != 1.0:
logits = logits / self.config.temperature
# Top-p (nucleus) sampling
if self.config.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除概率累积超过 top_p 的 token
sorted_indices_to_remove = cumulative_probs > self.config.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = float('-inf')
return F.softmax(logits, dim=-1)
def _truncate_cache(
self,
past_key_values: Tuple,
keep_length: int
) -> Tuple:
"""截断 KV cache"""
if past_key_values is None:
return None
truncated = []
for layer_past in past_key_values:
# layer_past: (key, value) 各 [B, H, S, D]
key, value = layer_past
truncated.append((
key[:, :, :-self.K + keep_length, :],
value[:, :, :-self.K + keep_length, :]
))
return tuple(truncated)
Tree Attention (Medusa)
"""
Tree Attention 实现 (Medusa 风格)
"""
import torch
import torch.nn as nn
from typing import List, Tuple
class TreeAttentionMask:
"""树形注意力掩码"""
def __init__(self, num_heads: int, max_depth: int):
self.num_heads = num_heads
self.max_depth = max_depth
def create_tree_mask(
self,
tree_structure: List[List[int]]
) -> torch.Tensor:
"""
创建树形注意力掩码
tree_structure: 树的结构,每层的分支数
例如 [[1], [2, 2], [2, 2, 2, 2]] 表示:
root
/ \
a b
/ \ / \
c d e f
"""
# 计算总节点数
total_nodes = sum(sum(layer) for layer in tree_structure)
# 创建掩码
mask = torch.zeros(total_nodes, total_nodes, dtype=torch.bool)
# 填充掩码 - 每个节点只能看到其祖先
node_idx = 0
parent_indices = [0] # root 的父亲是自己
for depth, layer in enumerate(tree_structure):
new_parent_indices = []
for branch_idx, num_children in enumerate(layer):
parent = parent_indices[branch_idx] if branch_idx < len(parent_indices) else 0
for child in range(num_children):
# 当前节点可以看到所有祖先
ancestor = parent
while ancestor is not None:
mask[node_idx, ancestor] = True
if ancestor == 0:
ancestor = None
else:
# 找父节点(简化处理)
ancestor = max(0, ancestor - 1)
mask[node_idx, node_idx] = True # 自己也可见
new_parent_indices.append(node_idx)
node_idx += 1
parent_indices = new_parent_indices
return mask
class MedusaHead(nn.Module):
"""Medusa 预测头"""
def __init__(
self,
hidden_size: int,
vocab_size: int,
num_heads: int = 4 # 预测未来 4 个位置
):
super().__init__()
self.num_heads = num_heads
# 每个头预测一个未来位置
self.heads = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, vocab_size)
)
for _ in range(num_heads)
])
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
"""
Args:
hidden_states: [B, S, H]
Returns:
List of logits for each future position
"""
return [head(hidden_states) for head in self.heads]
class MedusaDecoder:
"""Medusa 解码器"""
def __init__(
self,
model, # 带 Medusa heads 的模型
num_heads: int = 4,
top_k: int = 10 # 每个头选 top-k 候选
):
self.model = model
self.num_heads = num_heads
self.top_k = top_k
# 构建候选树
self.tree_indices = self._build_tree_indices()
def _build_tree_indices(self) -> torch.Tensor:
"""构建候选树索引"""
# 简化:每个头选 top-k,形成 k^num_heads 个候选路径
# 实际使用中会剪枝
total_candidates = self.top_k ** self.num_heads
return torch.arange(total_candidates)
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_length: int
) -> torch.Tensor:
generated = input_ids.clone()
while generated.shape[1] < max_length:
# Step 1: 前向传播获取原始 logits 和 Medusa heads 预测
outputs = self.model(generated, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1][:, -1:] # [B, 1, H]
# 原始模型的 next token
base_logits = outputs.logits[:, -1, :] # [B, V]
# Medusa heads 预测
medusa_logits = self.model.medusa_heads(hidden_states) # List of [B, 1, V]
# Step 2: 生成候选树
candidates = self._generate_candidates(base_logits, medusa_logits)
# Step 3: 并行验证所有候选
accept_length = self._tree_verify(generated, candidates)
# Step 4: 接受最长的正确前缀
accepted = candidates[:, :accept_length]
generated = torch.cat([generated, accepted], dim=1)
return generated
def _generate_candidates(
self,
base_logits: torch.Tensor,
medusa_logits: List[torch.Tensor]
) -> torch.Tensor:
"""生成候选 token 树"""
batch_size = base_logits.shape[0]
device = base_logits.device
# 获取每个位置的 top-k
all_topk = []
# 位置 0: 基础模型预测
_, topk_0 = torch.topk(base_logits, self.top_k, dim=-1)
all_topk.append(topk_0)
# 位置 1-N: Medusa heads 预测
for head_logits in medusa_logits:
_, topk_i = torch.topk(head_logits.squeeze(1), self.top_k, dim=-1)
all_topk.append(topk_i)
# 构建候选路径(笛卡尔积,可优化)
# 这里简化为只取最可能的路径
candidates = torch.stack([t[:, 0] for t in all_topk], dim=1)
return candidates
def _tree_verify(
self,
context: torch.Tensor,
candidates: torch.Tensor
) -> int:
"""验证候选树,返回可接受的长度"""
# 并行验证所有候选
verify_input = torch.cat([context, candidates], dim=1)
outputs = self.model(verify_input)
logits = outputs.logits
# 检查每个位置
accept_length = 0
for i in range(candidates.shape[1]):
pos = context.shape[1] + i - 1
predicted = logits[:, pos, :].argmax(dim=-1)
if (predicted == candidates[:, i]).all():
accept_length = i + 1
else:
break
return max(1, accept_length)
Self-Speculative Decoding
"""
Self-Speculative Decoding - 使用 Early Exit 作为 Draft
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
class EarlyExitTransformer(nn.Module):
"""支持 Early Exit 的 Transformer"""
def __init__(self, config, exit_layers: List[int] = None):
super().__init__()
self.config = config
self.exit_layers = exit_layers or [config.num_hidden_layers // 3]
# 正常的 transformer layers
self.layers = nn.ModuleList([
TransformerBlock(config)
for _ in range(config.num_hidden_layers)
])
# Early exit heads
self.exit_heads = nn.ModuleDict({
str(layer): nn.Linear(config.hidden_size, config.vocab_size)
for layer in self.exit_layers
})
# 最终 head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
exit_layer: Optional[int] = None,
past_key_values: Optional[Tuple] = None
):
hidden_states = self.embed_tokens(input_ids)
new_past = []
for i, layer in enumerate(self.layers):
layer_past = past_key_values[i] if past_key_values else None
hidden_states, new_layer_past = layer(
hidden_states,
past_key_values=layer_past
)
new_past.append(new_layer_past)
# 检查是否 early exit
if exit_layer is not None and i == exit_layer:
logits = self.exit_heads[str(i)](hidden_states)
return logits, tuple(new_past[:i+1])
# 正常输出
logits = self.lm_head(hidden_states)
return logits, tuple(new_past)
class SelfSpeculativeDecoder:
"""Self-Speculative 解码器"""
def __init__(
self,
model: EarlyExitTransformer,
draft_exit_layer: int,
num_speculative_tokens: int = 4
):
self.model = model
self.draft_exit_layer = draft_exit_layer
self.K = num_speculative_tokens
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_length: int
) -> torch.Tensor:
generated = input_ids.clone()
past_key_values = None
draft_past = None
while generated.shape[1] < max_length:
# Draft: 使用 early exit 快速生成
draft_tokens, draft_probs, draft_past = self._draft_with_early_exit(
generated,
draft_past
)
# Verify: 使用完整模型验证
target_probs, past_key_values = self._full_model_verify(
generated,
draft_tokens,
past_key_values
)
# Accept/Reject
accepted, num_accepted = self._speculative_accept(
draft_tokens,
draft_probs,
target_probs
)
generated = torch.cat([generated, accepted], dim=1)
# 同步 cache
if num_accepted < self.K:
draft_past = self._truncate_cache(draft_past, num_accepted)
past_key_values = self._extend_cache(
past_key_values,
draft_past,
num_accepted
)
return generated
def _draft_with_early_exit(
self,
context: torch.Tensor,
past_key_values: Optional[Tuple]
) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
"""使用 early exit 生成 draft"""
draft_tokens = []
draft_probs = []
for i in range(self.K):
input_ids = context if i == 0 else draft_tokens[-1]
logits, past_key_values = self.model(
input_ids,
exit_layer=self.draft_exit_layer,
past_key_values=past_key_values
)
probs = F.softmax(logits[:, -1, :], dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(next_token)
draft_probs.append(probs)
return (
torch.cat(draft_tokens, dim=1),
torch.stack(draft_probs, dim=1),
past_key_values
)
def _full_model_verify(
self,
context: torch.Tensor,
draft_tokens: torch.Tensor,
past_key_values: Optional[Tuple]
) -> Tuple[torch.Tensor, Tuple]:
"""使用完整模型验证"""
verify_input = torch.cat([context, draft_tokens], dim=1) if past_key_values is None else draft_tokens
logits, new_past = self.model(
verify_input,
exit_layer=None, # 不使用 early exit
past_key_values=past_key_values
)
probs = F.softmax(logits[:, -self.K-1:, :], dim=-1)
return probs, new_past
性能优化
批量推测解码
"""
批量推测解码优化
"""
import torch
from typing import List
class BatchedSpeculativeDecoder:
"""支持批量处理的推测解码器"""
def __init__(
self,
target_model,
draft_model,
config: SpeculativeConfig
):
self.target_model = target_model
self.draft_model = draft_model
self.config = config
self.K = config.num_speculative_tokens
@torch.no_grad()
def generate_batch(
self,
input_ids: torch.Tensor, # [B, S]
max_lengths: List[int] # 每个样本的最大长度
) -> List[torch.Tensor]:
"""批量生成"""
batch_size = input_ids.shape[0]
device = input_ids.device
# 追踪每个样本的状态
active_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
generated = [input_ids[i:i+1].clone() for i in range(batch_size)]
results = [None] * batch_size
# Draft 和 target 的 past
draft_pasts = [None] * batch_size
target_pasts = [None] * batch_size
while active_mask.any():
# 收集活跃样本
active_indices = active_mask.nonzero().squeeze(-1).tolist()
if not active_indices:
break
# 批量 draft
draft_inputs = torch.cat([
generated[i] for i in active_indices
], dim=0)
draft_tokens, draft_probs, new_draft_pasts = self._batched_draft(
draft_inputs,
[draft_pasts[i] for i in active_indices]
)
# 批量 verify
target_probs, new_target_pasts = self._batched_verify(
draft_inputs,
draft_tokens,
[target_pasts[i] for i in active_indices]
)
# 逐样本 accept/reject
for j, i in enumerate(active_indices):
accepted, num_accepted = self._accept_reject_single(
draft_tokens[j:j+1],
draft_probs[j:j+1],
target_probs[j:j+1]
)
generated[i] = torch.cat([generated[i], accepted], dim=1)
# 更新 cache
draft_pasts[i] = self._extract_past(new_draft_pasts, j, num_accepted)
target_pasts[i] = self._extract_past(new_target_pasts, j, num_accepted)
# 检查完成条件
if generated[i].shape[1] >= max_lengths[i]:
active_mask[i] = False
results[i] = generated[i]
# 处理剩余
for i in range(batch_size):
if results[i] is None:
results[i] = generated[i]
return results
def _batched_draft(self, inputs, pasts):
"""批量 draft 生成"""
# 合并 past_key_values
# ...
pass
def _batched_verify(self, contexts, draft_tokens, pasts):
"""批量验证"""
# ...
pass
class AdaptiveSpeculativeDecoder:
"""自适应推测解码 - 动态调整 K"""
def __init__(
self,
target_model,
draft_model,
min_k: int = 2,
max_k: int = 8,
target_accept_rate: float = 0.8
):
self.target_model = target_model
self.draft_model = draft_model
self.min_k = min_k
self.max_k = max_k
self.target_accept_rate = target_accept_rate
# 当前 K 值
self.current_k = (min_k + max_k) // 2
# 统计
self.accept_rates = []
self.window_size = 10
def _update_k(self, accept_rate: float):
"""根据接受率调整 K"""
self.accept_rates.append(accept_rate)
if len(self.accept_rates) >= self.window_size:
avg_rate = sum(self.accept_rates[-self.window_size:]) / self.window_size
if avg_rate > self.target_accept_rate + 0.1:
# 接受率高,可以增加 K
self.current_k = min(self.current_k + 1, self.max_k)
elif avg_rate < self.target_accept_rate - 0.1:
# 接受率低,减少 K
self.current_k = max(self.current_k - 1, self.min_k)
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, max_length: int) -> torch.Tensor:
"""自适应生成"""
generated = input_ids.clone()
while generated.shape[1] < max_length:
# 使用当前 K
K = self.current_k
# Draft
draft_tokens, draft_probs = self._draft(generated, K)
# Verify
target_probs = self._verify(generated, draft_tokens)
# Accept/Reject
accepted, num_accepted = self._accept_reject(
draft_tokens,
draft_probs,
target_probs
)
generated = torch.cat([generated, accepted], dim=1)
# 更新 K
accept_rate = num_accepted / K
self._update_k(accept_rate)
return generated
小结
本章深入讲解了 Speculative Decoding:
- 核心原理:小模型推测 + 大模型验证
- 接受算法:Rejection Sampling 保证分布一致性
- 高级技术:Tree Attention (Medusa)、Self-Speculative
- 性能优化:批量处理、自适应 K
下一章我们将探讨 多模态推理,讲解图像、音频、视频等多模态内容的处理。