多模态推理
概述
多模态大模型(Multimodal Large Language Models, MLLMs)能够理解和生成多种模态的数据,包括文本、图像、视频、音频等。本章深入探讨多模态推理的架构设计、优化技术和工程实践。
多模态模型架构
主流架构对比
┌─────────────────────────────────────────────────────────────────────────────┐
│ 多模态模型架构演进 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. 双塔架构 (CLIP) 2. 融合架构 (Flamingo) │
│ ┌─────────┐ ┌─────────┐ ┌─────────────────────────────┐ │
│ │ Image │ │ Text │ │ Language Model │ │
│ │ Encoder │ │ Encoder │ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │
│ └────┬────┘ └────┬────┘ │ │Cross│ │Self │ │Cross│ │ │
│ │ │ │ │Attn │ │Attn │ │Attn │ │ │
│ └─────┬─────┘ │ └──┬──┘ └─────┘ └──┬──┘ │ │
│ │ │ │ │ │ │
│ Contrastive │ ┌──┴───────────────┴──┐ │ │
│ Loss │ │ Vision Encoder │ │ │
│ │ └─────────────────────┘ │ │
│ └─────────────────────────────┘ │
│ │
│ 3. 统一序列架构 (LLaVA) 4. 原生多模态 (Gemini) │
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
│ │ Language Model │ │ Native Multimodal Model │ │
│ │ [IMG][IMG]...[TXT][TXT] │ │ │ │
│ └───────────┬─────────────┘ │ Image ──┐ │ │
│ │ │ Video ──┼──► Unified ──► Output │
│ ┌─────────┴─────────┐ │ Audio ──┤ Encoder │
│ │ Vision Encoder │ │ Text ───┘ │ │
│ │ + Projector │ └─────────────────────────────┘ │
│ └───────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
LLaVA 架构实现
"""
LLaVA 风格的多模态模型实现
"""
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
from typing import Optional, List, Tuple
import torch.nn.functional as F
class VisionProjector(nn.Module):
"""视觉特征投影器 - 将视觉特征映射到语言模型空间"""
def __init__(
self,
vision_hidden_size: int = 1024,
text_hidden_size: int = 4096,
projector_type: str = "mlp" # mlp, linear, qformer
):
super().__init__()
self.projector_type = projector_type
if projector_type == "linear":
# 简单线性投影
self.projector = nn.Linear(vision_hidden_size, text_hidden_size)
elif projector_type == "mlp":
# 两层 MLP (LLaVA-1.5 默认)
self.projector = nn.Sequential(
nn.Linear(vision_hidden_size, text_hidden_size),
nn.GELU(),
nn.Linear(text_hidden_size, text_hidden_size)
)
elif projector_type == "mlp_deep":
# 更深的 MLP
self.projector = nn.Sequential(
nn.Linear(vision_hidden_size, text_hidden_size),
nn.GELU(),
nn.Linear(text_hidden_size, text_hidden_size),
nn.GELU(),
nn.Linear(text_hidden_size, text_hidden_size)
)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
"""
Args:
vision_features: [batch, num_patches, vision_hidden_size]
Returns:
projected: [batch, num_patches, text_hidden_size]
"""
return self.projector(vision_features)
class SpatialPoolingProjector(nn.Module):
"""空间池化投影器 - 减少视觉 token 数量"""
def __init__(
self,
vision_hidden_size: int = 1024,
text_hidden_size: int = 4096,
pool_size: int = 2 # 2x2 池化
):
super().__init__()
self.pool_size = pool_size
# 池化后的特征维度
pooled_dim = vision_hidden_size * pool_size * pool_size
self.projector = nn.Sequential(
nn.Linear(pooled_dim, text_hidden_size),
nn.GELU(),
nn.Linear(text_hidden_size, text_hidden_size)
)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
"""
Args:
vision_features: [batch, num_patches, vision_hidden_size]
num_patches = H * W (e.g., 24 * 24 = 576)
Returns:
projected: [batch, reduced_patches, text_hidden_size]
"""
batch_size = vision_features.shape[0]
# 假设是方形 patch grid
num_patches = vision_features.shape[1]
h = w = int(num_patches ** 0.5)
# Reshape: [batch, H, W, hidden]
x = vision_features.view(batch_size, h, w, -1)
# 空间池化: [batch, H//pool, W//pool, hidden * pool^2]
new_h = h // self.pool_size
new_w = w // self.pool_size
x = x.view(batch_size, new_h, self.pool_size, new_w, self.pool_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(batch_size, new_h * new_w, -1)
return self.projector(x)
class MultimodalLLM(nn.Module):
"""多模态大语言模型"""
def __init__(
self,
vision_tower: str = "openai/clip-vit-large-patch14-336",
language_model: str = "meta-llama/Llama-2-7b-hf",
projector_type: str = "mlp",
freeze_vision: bool = True,
freeze_llm: bool = False
):
super().__init__()
# 加载视觉编码器
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
self.vision_hidden_size = self.vision_tower.config.hidden_size
# 加载语言模型
self.language_model = LlamaForCausalLM.from_pretrained(
language_model,
torch_dtype=torch.float16
)
self.text_hidden_size = self.language_model.config.hidden_size
# 投影器
self.projector = VisionProjector(
vision_hidden_size=self.vision_hidden_size,
text_hidden_size=self.text_hidden_size,
projector_type=projector_type
)
# 冻结参数
if freeze_vision:
for param in self.vision_tower.parameters():
param.requires_grad = False
if freeze_llm:
for param in self.language_model.parameters():
param.requires_grad = False
# 特殊 token
self.image_token_id = -200 # 占位符
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""
编码图像为视觉特征
Args:
images: [batch, 3, H, W] 预处理后的图像
Returns:
image_features: [batch, num_patches, text_hidden_size]
"""
# 视觉编码
with torch.no_grad() if not self.vision_tower.training else torch.enable_grad():
vision_outputs = self.vision_tower(images, output_hidden_states=True)
# 使用倒数第二层特征 (更丰富的语义)
image_features = vision_outputs.hidden_states[-2]
# 投影到语言空间
image_features = self.projector(image_features)
return image_features
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
images: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
前向传播
Args:
input_ids: [batch, seq_len] 文本 token ids
attention_mask: [batch, seq_len]
images: [batch, 3, H, W] 图像 (可选)
labels: [batch, seq_len] 训练标签
"""
# 获取文本 embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if images is not None:
# 编码图像
image_features = self.encode_images(images) # [batch, num_patches, hidden]
# 找到图像占位符位置并替换
batch_size = input_ids.shape[0]
for i in range(batch_size):
# 找到 image token 的位置
image_token_mask = input_ids[i] == self.image_token_id
num_image_tokens = image_token_mask.sum()
if num_image_tokens > 0:
# 替换为图像特征
image_token_indices = torch.where(image_token_mask)[0]
inputs_embeds[i, image_token_indices] = image_features[i, :num_image_tokens]
# 语言模型前向
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
return_dict=True
)
return outputs
class DynamicResolutionProcessor:
"""动态分辨率处理器 - 支持任意分辨率图像"""
def __init__(
self,
base_resolution: int = 336,
max_tiles: int = 6,
patch_size: int = 14
):
self.base_resolution = base_resolution
self.max_tiles = max_tiles
self.patch_size = patch_size
def get_optimal_grid(self, width: int, height: int) -> Tuple[int, int]:
"""计算最优切分网格"""
aspect_ratio = width / height
best_grid = (1, 1)
best_score = float('inf')
for h_tiles in range(1, self.max_tiles + 1):
for w_tiles in range(1, self.max_tiles + 1):
if h_tiles * w_tiles > self.max_tiles:
continue
grid_ratio = w_tiles / h_tiles
ratio_diff = abs(aspect_ratio - grid_ratio)
# 考虑分辨率利用率
target_h = h_tiles * self.base_resolution
target_w = w_tiles * self.base_resolution
scale_h = target_h / height
scale_w = target_w / width
scale = min(scale_h, scale_w)
actual_h = int(height * scale)
actual_w = int(width * scale)
utilization = (actual_h * actual_w) / (target_h * target_w)
# 综合评分
score = ratio_diff * 2 + (1 - utilization)
if score < best_score:
best_score = score
best_grid = (h_tiles, w_tiles)
return best_grid
def process_image(self, image: torch.Tensor) -> List[torch.Tensor]:
"""
处理图像为多个 tiles
Args:
image: [3, H, W] 原始图像
Returns:
tiles: List of [3, base_res, base_res] 图像块
"""
_, h, w = image.shape
h_tiles, w_tiles = self.get_optimal_grid(w, h)
# 计算目标尺寸
target_h = h_tiles * self.base_resolution
target_w = w_tiles * self.base_resolution
# 缩放图像
image = F.interpolate(
image.unsqueeze(0),
size=(target_h, target_w),
mode='bilinear',
align_corners=False
).squeeze(0)
# 切分为 tiles
tiles = []
for i in range(h_tiles):
for j in range(w_tiles):
tile = image[
:,
i * self.base_resolution:(i + 1) * self.base_resolution,
j * self.base_resolution:(j + 1) * self.base_resolution
]
tiles.append(tile)
# 添加全局缩略图
thumbnail = F.interpolate(
image.unsqueeze(0),
size=(self.base_resolution, self.base_resolution),
mode='bilinear',
align_corners=False
).squeeze(0)
tiles.insert(0, thumbnail)
return tiles
视觉编码优化
高效视觉编码器
"""
高效视觉编码优化技术
"""
import torch
import torch.nn as nn
from typing import Optional
import math
class EfficientViTEncoder(nn.Module):
"""高效视觉编码器"""
def __init__(
self,
image_size: int = 336,
patch_size: int = 14,
hidden_size: int = 1024,
num_layers: int = 24,
num_heads: int = 16,
use_flash_attention: bool = True
):
super().__init__()
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(
3, hidden_size,
kernel_size=patch_size,
stride=patch_size
)
# Position embedding
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, hidden_size)
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
# Transformer layers
self.layers = nn.ModuleList([
EfficientViTBlock(
hidden_size=hidden_size,
num_heads=num_heads,
use_flash_attention=use_flash_attention
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
def forward(
self,
images: torch.Tensor,
output_hidden_states: bool = False
) -> torch.Tensor:
batch_size = images.shape[0]
# Patch embedding: [B, 3, H, W] -> [B, hidden, H/P, W/P] -> [B, N, hidden]
x = self.patch_embed(images)
x = x.flatten(2).transpose(1, 2)
# 添加 CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Position embedding
x = x + self.pos_embed
# Transformer layers
hidden_states = []
for layer in self.layers:
x = layer(x)
if output_hidden_states:
hidden_states.append(x)
x = self.norm(x)
if output_hidden_states:
return x, hidden_states
return x
class EfficientViTBlock(nn.Module):
"""高效 ViT Block"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_flash_attention: bool = True
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = EfficientAttention(
hidden_size=hidden_size,
num_heads=num_heads,
use_flash_attention=use_flash_attention
)
self.norm2 = nn.LayerNorm(hidden_size)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, int(hidden_size * mlp_ratio)),
nn.GELU(),
nn.Linear(int(hidden_size * mlp_ratio), hidden_size)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class EfficientAttention(nn.Module):
"""高效注意力 (支持 Flash Attention)"""
def __init__(
self,
hidden_size: int,
num_heads: int,
use_flash_attention: bool = True
):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.use_flash_attention = use_flash_attention
self.qkv = nn.Linear(hidden_size, hidden_size * 3)
self.proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# QKV projection
qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
if self.use_flash_attention and hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
# 使用 PyTorch 2.0 的 SDPA (自动选择 Flash Attention)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
else:
# 标准注意力
scale = self.head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
attn_output = attn @ v
# 输出投影
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
return self.proj(attn_output)
class TokenPruningViT(nn.Module):
"""Token 剪枝 ViT - 动态减少视觉 token"""
def __init__(
self,
base_encoder: nn.Module,
prune_layers: list = [6, 12, 18], # 在哪些层进行剪枝
keep_ratios: list = [0.7, 0.5, 0.3] # 保留比例
):
super().__init__()
self.base_encoder = base_encoder
self.prune_layers = prune_layers
self.keep_ratios = keep_ratios
# 重要性评分网络
hidden_size = base_encoder.layers[0].attn.qkv.in_features
self.importance_predictors = nn.ModuleList([
nn.Linear(hidden_size, 1)
for _ in prune_layers
])
def forward(self, images: torch.Tensor) -> torch.Tensor:
batch_size = images.shape[0]
# Patch embedding
x = self.base_encoder.patch_embed(images)
x = x.flatten(2).transpose(1, 2)
# 添加 CLS token
cls_tokens = self.base_encoder.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.base_encoder.pos_embed
prune_idx = 0
for layer_idx, layer in enumerate(self.base_encoder.layers):
x = layer(x)
# 在指定层进行剪枝
if prune_idx < len(self.prune_layers) and layer_idx == self.prune_layers[prune_idx]:
# 计算重要性分数 (不包括 CLS token)
importance = self.importance_predictors[prune_idx](x[:, 1:]) # [B, N-1, 1]
importance = importance.squeeze(-1) # [B, N-1]
# 保留 top-k
num_patches = x.shape[1] - 1
keep_num = int(num_patches * self.keep_ratios[prune_idx])
# 选择最重要的 token
_, indices = importance.topk(keep_num, dim=1)
indices = indices.sort(dim=1).values # 保持空间顺序
# 收集保留的 token
batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, keep_num)
kept_tokens = x[:, 1:][batch_indices, indices] # [B, keep_num, hidden]
# 拼接 CLS token
x = torch.cat([x[:, :1], kept_tokens], dim=1)
prune_idx += 1
x = self.base_encoder.norm(x)
return x
视觉特征压缩
"""
视觉特征压缩技术
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class PerceiverResampler(nn.Module):
"""
Perceiver Resampler - 将可变数量的视觉 token 压缩为固定数量
参考: Flamingo 论文
"""
def __init__(
self,
hidden_size: int = 1024,
num_queries: int = 64, # 输出固定数量的 token
num_layers: int = 6,
num_heads: int = 16
):
super().__init__()
# 可学习的查询 token
self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02)
# Cross-attention layers
self.layers = nn.ModuleList([
PerceiverBlock(
hidden_size=hidden_size,
num_heads=num_heads
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
"""
Args:
vision_features: [batch, num_patches, hidden_size]
Returns:
compressed: [batch, num_queries, hidden_size]
"""
batch_size = vision_features.shape[0]
# 扩展查询
queries = self.queries.unsqueeze(0).expand(batch_size, -1, -1)
# Cross-attention
for layer in self.layers:
queries = layer(queries, vision_features)
return self.norm(queries)
class PerceiverBlock(nn.Module):
"""Perceiver Block with cross-attention"""
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.norm3 = nn.LayerNorm(hidden_size)
# Cross-attention
self.cross_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
# Self-attention
self.self_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
# FFN
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
def forward(
self,
queries: torch.Tensor,
context: torch.Tensor
) -> torch.Tensor:
# Cross-attention to context
queries = queries + self.cross_attn(
self.norm1(queries),
context,
context
)[0]
# Self-attention
queries = queries + self.self_attn(
self.norm2(queries),
self.norm2(queries),
self.norm2(queries)
)[0]
# FFN
queries = queries + self.ffn(self.norm3(queries))
return queries
class QFormerCompressor(nn.Module):
"""
Q-Former 风格的特征压缩
参考: BLIP-2, InstructBLIP
"""
def __init__(
self,
vision_hidden_size: int = 1024,
text_hidden_size: int = 768,
num_queries: int = 32,
num_layers: int = 6,
num_heads: int = 12
):
super().__init__()
self.num_queries = num_queries
# 可学习查询
self.query_tokens = nn.Parameter(
torch.zeros(1, num_queries, text_hidden_size)
)
# Cross-attention 层
self.cross_attention_layers = nn.ModuleList([
QFormerCrossAttentionLayer(
text_hidden_size=text_hidden_size,
vision_hidden_size=vision_hidden_size,
num_heads=num_heads
)
for _ in range(num_layers)
])
# Self-attention 层
self.self_attention_layers = nn.ModuleList([
QFormerSelfAttentionLayer(
hidden_size=text_hidden_size,
num_heads=num_heads
)
for _ in range(num_layers)
])
def forward(
self,
vision_features: torch.Tensor,
text_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
vision_features: [batch, num_patches, vision_hidden]
text_embeds: [batch, text_len, text_hidden] (可选,用于条件压缩)
"""
batch_size = vision_features.shape[0]
# 初始化查询
query_output = self.query_tokens.expand(batch_size, -1, -1)
for cross_layer, self_layer in zip(
self.cross_attention_layers,
self.self_attention_layers
):
# Cross-attention with vision
query_output = cross_layer(query_output, vision_features)
# Self-attention (可选择性地加入文本条件)
if text_embeds is not None:
combined = torch.cat([query_output, text_embeds], dim=1)
query_output = self_layer(combined)[:, :self.num_queries]
else:
query_output = self_layer(query_output)
return query_output
class QFormerCrossAttentionLayer(nn.Module):
def __init__(self, text_hidden_size: int, vision_hidden_size: int, num_heads: int):
super().__init__()
self.norm = nn.LayerNorm(text_hidden_size)
self.cross_attn = nn.MultiheadAttention(
text_hidden_size, num_heads,
kdim=vision_hidden_size, vdim=vision_hidden_size,
batch_first=True
)
def forward(self, query: torch.Tensor, vision: torch.Tensor) -> torch.Tensor:
return query + self.cross_attn(self.norm(query), vision, vision)[0]
class QFormerSelfAttentionLayer(nn.Module):
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
return x
多模态推理服务
高效推理引擎
"""
多模态推理服务实现
"""
import torch
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass
from PIL import Image
import asyncio
from concurrent.futures import ThreadPoolExecutor
import numpy as np
@dataclass
class MultimodalRequest:
"""多模态请求"""
request_id: str
text: str
images: List[Union[str, Image.Image]] # 图像路径或 PIL Image
max_tokens: int = 512
temperature: float = 0.7
stream: bool = False
@dataclass
class MultimodalResponse:
"""多模态响应"""
request_id: str
text: str
usage: Dict[str, int]
class MultimodalInferenceEngine:
"""多模态推理引擎"""
def __init__(
self,
model_path: str,
vision_processor_path: str,
device: str = "cuda",
max_batch_size: int = 8,
max_image_tiles: int = 6
):
self.device = device
self.max_batch_size = max_batch_size
self.max_image_tiles = max_image_tiles
# 加载模型
self.model = self._load_model(model_path)
self.vision_processor = self._load_vision_processor(vision_processor_path)
self.tokenizer = self._load_tokenizer(model_path)
# 图像预处理线程池
self.image_executor = ThreadPoolExecutor(max_workers=4)
# 请求队列
self.request_queue = asyncio.Queue()
# KV Cache 管理
self.kv_cache_manager = MultimodalKVCacheManager(
max_batch_size=max_batch_size,
max_seq_len=4096,
hidden_size=self.model.config.hidden_size,
num_layers=self.model.config.num_hidden_layers,
num_heads=self.model.config.num_attention_heads
)
def _load_model(self, model_path: str):
"""加载多模态模型"""
# 实际实现会加载具体的模型
pass
def _load_vision_processor(self, path: str):
"""加载视觉处理器"""
pass
def _load_tokenizer(self, path: str):
"""加载分词器"""
pass
async def process_images(
self,
images: List[Union[str, Image.Image]]
) -> torch.Tensor:
"""异步处理图像"""
loop = asyncio.get_event_loop()
# 并行加载和预处理图像
tasks = []
for img in images:
task = loop.run_in_executor(
self.image_executor,
self._preprocess_single_image,
img
)
tasks.append(task)
processed = await asyncio.gather(*tasks)
# 合并为 batch
return torch.stack(processed).to(self.device)
def _preprocess_single_image(
self,
image: Union[str, Image.Image]
) -> torch.Tensor:
"""预处理单张图像"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
# 使用视觉处理器
processed = self.vision_processor(
images=image,
return_tensors="pt"
)
return processed.pixel_values.squeeze(0)
@torch.inference_mode()
async def generate(
self,
request: MultimodalRequest
) -> MultimodalResponse:
"""生成响应"""
# 1. 预处理图像
if request.images:
image_features = await self.process_images(request.images)
image_embeds = self.model.encode_images(image_features)
else:
image_embeds = None
# 2. 编码文本
input_ids = self.tokenizer.encode(
request.text,
return_tensors="pt"
).to(self.device)
# 3. 准备输入
if image_embeds is not None:
inputs_embeds = self._merge_image_text(input_ids, image_embeds)
else:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
# 4. 分配 KV Cache
cache_id = self.kv_cache_manager.allocate(
seq_len=inputs_embeds.shape[1] + request.max_tokens
)
# 5. 生成
try:
generated_ids = await self._generate_tokens(
inputs_embeds=inputs_embeds,
cache_id=cache_id,
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=request.stream
)
finally:
# 释放 cache
self.kv_cache_manager.free(cache_id)
# 6. 解码
generated_text = self.tokenizer.decode(
generated_ids,
skip_special_tokens=True
)
return MultimodalResponse(
request_id=request.request_id,
text=generated_text,
usage={
"prompt_tokens": inputs_embeds.shape[1],
"completion_tokens": len(generated_ids),
"total_tokens": inputs_embeds.shape[1] + len(generated_ids)
}
)
def _merge_image_text(
self,
input_ids: torch.Tensor,
image_embeds: torch.Tensor
) -> torch.Tensor:
"""合并图像和文本 embedding"""
text_embeds = self.model.get_input_embeddings()(input_ids)
# 找到 <image> 占位符位置
image_token_id = self.tokenizer.convert_tokens_to_ids("<image>")
image_positions = (input_ids == image_token_id).nonzero(as_tuple=True)
# 构建完整的 embedding 序列
# [text before image] + [image embeds] + [text after image]
batch_size = input_ids.shape[0]
merged_embeds = []
for i in range(batch_size):
pos = image_positions[1][image_positions[0] == i]
if len(pos) > 0:
pos = pos[0].item()
before = text_embeds[i, :pos]
after = text_embeds[i, pos+1:]
merged = torch.cat([before, image_embeds[i], after], dim=0)
else:
merged = text_embeds[i]
merged_embeds.append(merged)
# Pad to same length
max_len = max(e.shape[0] for e in merged_embeds)
padded = torch.zeros(batch_size, max_len, text_embeds.shape[-1], device=self.device)
for i, emb in enumerate(merged_embeds):
padded[i, :emb.shape[0]] = emb
return padded
async def _generate_tokens(
self,
inputs_embeds: torch.Tensor,
cache_id: int,
max_tokens: int,
temperature: float,
stream: bool
) -> torch.Tensor:
"""Token 生成循环"""
generated = []
# Prefill
logits = self.model.forward_prefill(
inputs_embeds=inputs_embeds,
cache_id=cache_id
)
for _ in range(max_tokens):
# Sample
next_token = self._sample(logits[:, -1:], temperature)
generated.append(next_token)
# 检查是否结束
if next_token.item() == self.tokenizer.eos_token_id:
break
# Decode step
logits = self.model.forward_decode(
input_ids=next_token,
cache_id=cache_id
)
# 让出控制权给其他协程
await asyncio.sleep(0)
return torch.cat(generated, dim=1)
def _sample(
self,
logits: torch.Tensor,
temperature: float
) -> torch.Tensor:
"""采样下一个 token"""
if temperature == 0:
return logits.argmax(dim=-1)
probs = torch.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs.squeeze(1), num_samples=1)
class MultimodalKVCacheManager:
"""多模态 KV Cache 管理器"""
def __init__(
self,
max_batch_size: int,
max_seq_len: int,
hidden_size: int,
num_layers: int,
num_heads: int
):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
head_dim = hidden_size // num_heads
# 预分配 KV Cache 池
self.k_cache_pool = torch.zeros(
max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
dtype=torch.float16,
device="cuda"
)
self.v_cache_pool = torch.zeros(
max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
dtype=torch.float16,
device="cuda"
)
# 空闲 slot 追踪
self.free_slots = list(range(max_batch_size))
self.slot_lengths = {} # slot_id -> current_length
def allocate(self, seq_len: int) -> int:
"""分配一个 cache slot"""
if not self.free_slots:
raise RuntimeError("No free cache slots available")
slot_id = self.free_slots.pop(0)
self.slot_lengths[slot_id] = 0
return slot_id
def free(self, slot_id: int):
"""释放 cache slot"""
self.free_slots.append(slot_id)
del self.slot_lengths[slot_id]
def get_cache(self, slot_id: int) -> tuple:
"""获取指定 slot 的 KV cache"""
length = self.slot_lengths[slot_id]
return (
self.k_cache_pool[slot_id, :, :length],
self.v_cache_pool[slot_id, :, :length]
)
def update_cache(
self,
slot_id: int,
new_k: torch.Tensor,
new_v: torch.Tensor
):
"""更新 KV cache"""
length = self.slot_lengths[slot_id]
new_length = new_k.shape[2]
self.k_cache_pool[slot_id, :, length:length+new_length] = new_k
self.v_cache_pool[slot_id, :, length:length+new_length] = new_v
self.slot_lengths[slot_id] = length + new_length
视频理解
视频编码器
"""
视频理解模块
"""
import torch
import torch.nn as nn
from typing import List, Optional, Tuple
import torch.nn.functional as F
class VideoEncoder(nn.Module):
"""视频编码器"""
def __init__(
self,
image_encoder: nn.Module,
num_frames: int = 8,
temporal_aggregation: str = "avg", # avg, attention, perceiver
hidden_size: int = 1024
):
super().__init__()
self.image_encoder = image_encoder
self.num_frames = num_frames
self.temporal_aggregation = temporal_aggregation
if temporal_aggregation == "attention":
self.temporal_attention = TemporalAttention(hidden_size)
elif temporal_aggregation == "perceiver":
self.temporal_perceiver = TemporalPerceiver(
hidden_size=hidden_size,
num_queries=64
)
def forward(
self,
video: torch.Tensor, # [batch, num_frames, 3, H, W]
return_all_frames: bool = False
) -> torch.Tensor:
batch_size, num_frames = video.shape[:2]
# 编码每一帧
video_flat = video.view(-1, *video.shape[2:]) # [B*T, 3, H, W]
frame_features = self.image_encoder(video_flat) # [B*T, num_patches, hidden]
# Reshape
num_patches = frame_features.shape[1]
hidden_size = frame_features.shape[2]
frame_features = frame_features.view(
batch_size, num_frames, num_patches, hidden_size
)
# 时序聚合
if self.temporal_aggregation == "avg":
video_features = frame_features.mean(dim=1) # [B, patches, hidden]
elif self.temporal_aggregation == "attention":
# 使用时序注意力
video_features = self.temporal_attention(frame_features)
elif self.temporal_aggregation == "perceiver":
# 使用 Perceiver 压缩时序
video_features = self.temporal_perceiver(frame_features)
if return_all_frames:
return video_features, frame_features
return video_features
class TemporalAttention(nn.Module):
"""时序注意力"""
def __init__(self, hidden_size: int, num_heads: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
# 可学习的时间位置编码
self.temporal_pos_embed = nn.Parameter(
torch.zeros(1, 32, 1, hidden_size) # 最多 32 帧
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, num_frames, num_patches, hidden]
Returns:
output: [batch, num_patches, hidden]
"""
batch_size, num_frames, num_patches, hidden = x.shape
# 添加时间位置编码
x = x + self.temporal_pos_embed[:, :num_frames]
# 转置为 [batch, patches, frames, hidden]
x = x.permute(0, 2, 1, 3)
# 计算注意力
q = self.q_proj(x).view(batch_size, num_patches, num_frames, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, num_patches, num_frames, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, num_patches, num_frames, self.num_heads, self.head_dim)
# [B, patches, heads, frames, head_dim]
q = q.permute(0, 1, 3, 2, 4)
k = k.permute(0, 1, 3, 2, 4)
v = v.permute(0, 1, 3, 2, 4)
# Attention
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
out = torch.matmul(attn, v) # [B, patches, heads, frames, head_dim]
# 聚合时间维度 (使用第一帧的 attention 作为 query)
out = out.mean(dim=-2) # [B, patches, heads, head_dim]
out = out.view(batch_size, num_patches, hidden)
return self.out_proj(out)
class TemporalPerceiver(nn.Module):
"""时序 Perceiver - 将视频压缩为固定长度"""
def __init__(
self,
hidden_size: int,
num_queries: int = 64,
num_layers: int = 2
):
super().__init__()
self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=hidden_size,
nhead=8,
dim_feedforward=hidden_size * 4,
batch_first=True
)
for _ in range(num_layers)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, num_frames, num_patches, hidden]
Returns:
output: [batch, num_queries, hidden]
"""
batch_size, num_frames, num_patches, hidden = x.shape
# Flatten temporal and spatial dimensions
x = x.view(batch_size, num_frames * num_patches, hidden)
# Expand queries
queries = self.queries.unsqueeze(0).expand(batch_size, -1, -1)
# Cross-attention
for layer in self.layers:
queries = layer(queries, x)
return queries
class FrameSampler:
"""视频帧采样器"""
def __init__(
self,
num_frames: int = 8,
sampling_strategy: str = "uniform" # uniform, keyframe, dynamic
):
self.num_frames = num_frames
self.sampling_strategy = sampling_strategy
def sample(
self,
video: torch.Tensor, # [total_frames, 3, H, W]
fps: float = 30.0,
duration: float = None
) -> Tuple[torch.Tensor, List[int]]:
"""
采样视频帧
Returns:
sampled_frames: [num_frames, 3, H, W]
frame_indices: 采样的帧索引
"""
total_frames = video.shape[0]
if self.sampling_strategy == "uniform":
# 均匀采样
indices = torch.linspace(0, total_frames - 1, self.num_frames).long()
elif self.sampling_strategy == "keyframe":
# 基于关键帧的采样 (需要预计算关键帧)
indices = self._keyframe_sampling(video)
elif self.sampling_strategy == "dynamic":
# 动态采样 - 基于帧间差异
indices = self._dynamic_sampling(video)
return video[indices], indices.tolist()
def _keyframe_sampling(self, video: torch.Tensor) -> torch.Tensor:
"""基于关键帧的采样"""
# 计算帧间差异
diffs = (video[1:] - video[:-1]).abs().mean(dim=[1, 2, 3])
# 找到变化最大的帧
_, top_indices = diffs.topk(min(self.num_frames - 2, len(diffs)))
# 添加首尾帧
indices = torch.cat([
torch.tensor([0]),
top_indices.sort().values + 1,
torch.tensor([len(video) - 1])
])
# 如果不够,均匀补充
if len(indices) < self.num_frames:
all_indices = set(indices.tolist())
remaining = self.num_frames - len(indices)
step = len(video) // (remaining + 1)
for i in range(1, remaining + 1):
idx = i * step
if idx not in all_indices:
indices = torch.cat([indices, torch.tensor([idx])])
return indices[:self.num_frames].sort().values
def _dynamic_sampling(self, video: torch.Tensor) -> torch.Tensor:
"""动态采样 - 在变化大的区域采样更多帧"""
# 计算帧间差异
diffs = (video[1:] - video[:-1]).abs().mean(dim=[1, 2, 3])
diffs = torch.cat([torch.tensor([0.0]), diffs])
# 归一化为概率分布
probs = diffs / diffs.sum()
# 加入均匀先验
uniform = torch.ones_like(probs) / len(probs)
probs = 0.5 * probs + 0.5 * uniform
# 采样
indices = torch.multinomial(probs, self.num_frames, replacement=False)
return indices.sort().values
多模态批处理
异构批处理
"""
多模态异构批处理
处理不同数量图像的请求
"""
import torch
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import asyncio
@dataclass
class MultimodalBatchItem:
"""批处理项"""
request_id: str
text_tokens: torch.Tensor
image_features: Optional[torch.Tensor] # [num_images, num_patches, hidden]
num_images: int
max_tokens: int
class MultimodalBatcher:
"""多模态批处理器"""
def __init__(
self,
max_batch_size: int = 8,
max_total_images: int = 16, # 批次最大图像数
max_total_tokens: int = 8192, # 批次最大 token 数
timeout_ms: int = 50
):
self.max_batch_size = max_batch_size
self.max_total_images = max_total_images
self.max_total_tokens = max_total_tokens
self.timeout_ms = timeout_ms
self.pending_items: List[MultimodalBatchItem] = []
self.lock = asyncio.Lock()
async def add_request(self, item: MultimodalBatchItem) -> bool:
"""添加请求到批处理队列"""
async with self.lock:
# 检查是否可以加入当前批次
current_images = sum(i.num_images for i in self.pending_items)
current_tokens = sum(len(i.text_tokens) for i in self.pending_items)
can_add = (
len(self.pending_items) < self.max_batch_size and
current_images + item.num_images <= self.max_total_images and
current_tokens + len(item.text_tokens) <= self.max_total_tokens
)
if can_add:
self.pending_items.append(item)
return True
return False
async def get_batch(self) -> List[MultimodalBatchItem]:
"""获取当前批次"""
async with self.lock:
batch = self.pending_items
self.pending_items = []
return batch
def prepare_batch(
self,
items: List[MultimodalBatchItem]
) -> Dict[str, torch.Tensor]:
"""准备批次数据"""
if not items:
return {}
# 1. Pad 文本 token
max_text_len = max(len(item.text_tokens) for item in items)
text_tokens = torch.zeros(len(items), max_text_len, dtype=torch.long)
text_mask = torch.zeros(len(items), max_text_len, dtype=torch.bool)
for i, item in enumerate(items):
seq_len = len(item.text_tokens)
text_tokens[i, :seq_len] = item.text_tokens
text_mask[i, :seq_len] = True
# 2. 处理图像特征 (variable number of images)
# 使用 packed representation
all_image_features = []
image_offsets = [0]
for item in items:
if item.image_features is not None:
all_image_features.append(item.image_features)
image_offsets.append(image_offsets[-1] + item.num_images)
else:
image_offsets.append(image_offsets[-1])
if all_image_features:
packed_images = torch.cat(all_image_features, dim=0)
else:
packed_images = None
return {
"text_tokens": text_tokens,
"text_mask": text_mask,
"packed_images": packed_images,
"image_offsets": torch.tensor(image_offsets),
"request_ids": [item.request_id for item in items]
}
class MultimodalContinuousBatching:
"""多模态 Continuous Batching"""
def __init__(
self,
model: nn.Module,
max_batch_size: int = 8,
max_seq_len: int = 4096
):
self.model = model
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
# 运行中的请求
self.running_requests: Dict[str, RunningRequest] = {}
# KV Cache
self.kv_cache = MultimodalKVCache(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
model_config=model.config
)
async def step(self) -> List[Dict[str, Any]]:
"""执行一步推理"""
if not self.running_requests:
return []
# 收集当前批次的输入
batch_input_ids = []
batch_positions = []
batch_cache_ids = []
for req_id, req in self.running_requests.items():
batch_input_ids.append(req.current_token)
batch_positions.append(req.position)
batch_cache_ids.append(req.cache_id)
# 批量推理
input_ids = torch.stack(batch_input_ids)
positions = torch.tensor(batch_positions)
logits = self.model.forward_decode(
input_ids=input_ids,
positions=positions,
cache_ids=batch_cache_ids,
kv_cache=self.kv_cache
)
# 采样和更新
results = []
completed = []
for i, (req_id, req) in enumerate(self.running_requests.items()):
# 采样
next_token = self._sample(logits[i], req.temperature)
req.generated_tokens.append(next_token.item())
req.current_token = next_token
req.position += 1
# 检查完成条件
if self._should_stop(req, next_token):
results.append({
"request_id": req_id,
"tokens": req.generated_tokens,
"finished": True
})
completed.append(req_id)
else:
results.append({
"request_id": req_id,
"token": next_token.item(),
"finished": False
})
# 清理完成的请求
for req_id in completed:
self._free_request(req_id)
return results
def _sample(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
if temperature == 0:
return logits.argmax()
probs = torch.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs, 1)
def _should_stop(self, req, token: torch.Tensor) -> bool:
return (
token.item() == self.model.config.eos_token_id or
len(req.generated_tokens) >= req.max_tokens
)
def _free_request(self, req_id: str):
req = self.running_requests.pop(req_id)
self.kv_cache.free(req.cache_id)
@dataclass
class RunningRequest:
"""运行中的请求"""
request_id: str
current_token: torch.Tensor
position: int
cache_id: int
max_tokens: int
temperature: float
generated_tokens: List[int]
性能优化
图像缓存与复用
"""
图像特征缓存系统
"""
import torch
import hashlib
from typing import Dict, Optional, Tuple
from collections import OrderedDict
import threading
class ImageFeatureCache:
"""图像特征缓存"""
def __init__(
self,
max_cache_size: int = 1000, # 最大缓存图像数
cache_device: str = "cpu" # 缓存设备
):
self.max_cache_size = max_cache_size
self.cache_device = cache_device
# LRU 缓存
self.cache: OrderedDict[str, torch.Tensor] = OrderedDict()
self.lock = threading.Lock()
# 统计
self.hits = 0
self.misses = 0
def get_cache_key(self, image_bytes: bytes) -> str:
"""计算图像的缓存 key"""
return hashlib.md5(image_bytes).hexdigest()
def get(self, key: str) -> Optional[torch.Tensor]:
"""获取缓存的特征"""
with self.lock:
if key in self.cache:
# 移到末尾 (最近使用)
self.cache.move_to_end(key)
self.hits += 1
return self.cache[key]
self.misses += 1
return None
def put(self, key: str, features: torch.Tensor):
"""缓存特征"""
with self.lock:
if key in self.cache:
self.cache.move_to_end(key)
return
# 移到 cache device
features = features.to(self.cache_device)
# 检查容量
while len(self.cache) >= self.max_cache_size:
self.cache.popitem(last=False)
self.cache[key] = features
def get_stats(self) -> Dict[str, float]:
"""获取缓存统计"""
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0
return {
"hits": self.hits,
"misses": self.misses,
"hit_rate": hit_rate,
"cache_size": len(self.cache)
}
class PrefetchingImageLoader:
"""预取图像加载器"""
def __init__(
self,
vision_encoder: nn.Module,
cache: ImageFeatureCache,
num_workers: int = 4
):
self.vision_encoder = vision_encoder
self.cache = cache
# 预取队列和线程池
self.prefetch_queue = asyncio.Queue()
self.workers = []
for _ in range(num_workers):
worker = asyncio.create_task(self._prefetch_worker())
self.workers.append(worker)
async def _prefetch_worker(self):
"""预取工作线程"""
while True:
image_bytes, future = await self.prefetch_queue.get()
try:
key = self.cache.get_cache_key(image_bytes)
# 检查缓存
features = self.cache.get(key)
if features is None:
# 加载和编码
image = self._load_image(image_bytes)
with torch.no_grad():
features = self.vision_encoder(image)
self.cache.put(key, features)
future.set_result(features.cuda())
except Exception as e:
future.set_exception(e)
def _load_image(self, image_bytes: bytes) -> torch.Tensor:
"""加载图像"""
from PIL import Image
import io
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# 预处理...
return processed_image
async def load(self, image_bytes: bytes) -> torch.Tensor:
"""加载图像 (带缓存)"""
key = self.cache.get_cache_key(image_bytes)
# 先检查缓存
features = self.cache.get(key)
if features is not None:
return features.cuda()
# 加入预取队列
future = asyncio.Future()
await self.prefetch_queue.put((image_bytes, future))
return await future
async def prefetch(self, image_bytes_list: List[bytes]):
"""批量预取"""
futures = []
for image_bytes in image_bytes_list:
key = self.cache.get_cache_key(image_bytes)
if self.cache.get(key) is None:
future = asyncio.Future()
await self.prefetch_queue.put((image_bytes, future))
futures.append(future)
# 等待所有预取完成
if futures:
await asyncio.gather(*futures, return_exceptions=True)
多模态量化
Vision-Language 量化策略
"""
多模态模型量化
"""
import torch
import torch.nn as nn
from typing import Dict, List, Optional
class MultimodalQuantizer:
"""多模态模型量化器"""
def __init__(
self,
vision_bits: int = 8, # 视觉编码器量化位数
language_bits: int = 4, # 语言模型量化位数
projector_bits: int = 8 # 投影器量化位数
):
self.vision_bits = vision_bits
self.language_bits = language_bits
self.projector_bits = projector_bits
def quantize_model(self, model: nn.Module) -> nn.Module:
"""量化整个模型"""
# 1. 量化视觉编码器 (通常用 INT8)
model.vision_tower = self._quantize_vision(
model.vision_tower,
bits=self.vision_bits
)
# 2. 量化语言模型 (可以用 INT4)
model.language_model = self._quantize_language(
model.language_model,
bits=self.language_bits
)
# 3. 量化投影器
model.projector = self._quantize_projector(
model.projector,
bits=self.projector_bits
)
return model
def _quantize_vision(
self,
vision_model: nn.Module,
bits: int
) -> nn.Module:
"""量化视觉编码器"""
if bits == 8:
# 动态 INT8 量化
return torch.quantization.quantize_dynamic(
vision_model,
{nn.Linear},
dtype=torch.qint8
)
elif bits == 4:
# 使用 bitsandbytes 4-bit
from bitsandbytes.nn import Linear4bit
return self._replace_linear_with_4bit(vision_model)
return vision_model
def _quantize_language(
self,
language_model: nn.Module,
bits: int
) -> nn.Module:
"""量化语言模型"""
if bits == 4:
# GPTQ/AWQ 风格的 4-bit 量化
return self._apply_gptq_quantization(language_model)
elif bits == 8:
return torch.quantization.quantize_dynamic(
language_model,
{nn.Linear},
dtype=torch.qint8
)
return language_model
def _quantize_projector(
self,
projector: nn.Module,
bits: int
) -> nn.Module:
"""量化投影器 - 通常保持较高精度"""
if bits == 8:
return torch.quantization.quantize_dynamic(
projector,
{nn.Linear},
dtype=torch.qint8
)
return projector
def _apply_gptq_quantization(self, model: nn.Module) -> nn.Module:
"""应用 GPTQ 量化"""
# 需要校准数据
# 这里是简化的实现
from auto_gptq import AutoGPTQForCausalLM
# 实际使用需要更完整的实现
return model
class MixedPrecisionConfig:
"""混合精度配置"""
def __init__(self):
# 不同层的精度配置
self.layer_config = {
# 视觉编码器 - 使用 FP16
"vision_tower": torch.float16,
# 投影器 - 保持 FP16 以保证图像-文本对齐
"projector": torch.float16,
# 语言模型 Embedding - FP16
"embed_tokens": torch.float16,
# 语言模型注意力 - 可以用 INT8
"self_attn": torch.int8,
# 语言模型 FFN - 可以用 INT4
"mlp": "int4",
# 输出层 - FP16
"lm_head": torch.float16
}
def apply(self, model: nn.Module) -> nn.Module:
"""应用混合精度配置"""
for name, module in model.named_modules():
for pattern, dtype in self.layer_config.items():
if pattern in name:
if dtype == torch.float16:
module.half()
elif dtype == torch.int8:
# 应用 INT8 量化
pass
elif dtype == "int4":
# 应用 INT4 量化
pass
return model
部署配置示例
Kubernetes 部署
# multimodal-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: multimodal-llm
labels:
app: multimodal-llm
spec:
replicas: 2
selector:
matchLabels:
app: multimodal-llm
template:
metadata:
labels:
app: multimodal-llm
spec:
containers:
- name: multimodal-server
image: multimodal-llm:latest
ports:
- containerPort: 8000
resources:
limits:
nvidia.com/gpu: 1
memory: "32Gi"
cpu: "8"
requests:
nvidia.com/gpu: 1
memory: "24Gi"
cpu: "4"
env:
- name: MODEL_PATH
value: "/models/llava-1.5-7b"
- name: MAX_BATCH_SIZE
value: "8"
- name: MAX_IMAGE_TILES
value: "6"
- name: VISION_CACHE_SIZE
value: "1000"
volumeMounts:
- name: model-storage
mountPath: /models
- name: cache-storage
mountPath: /cache
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 120
periodSeconds: 30
readinessProbe:
httpGet:
path: /ready
port: 8000
initialDelaySeconds: 60
periodSeconds: 10
volumes:
- name: model-storage
persistentVolumeClaim:
claimName: model-pvc
- name: cache-storage
emptyDir:
medium: Memory
sizeLimit: 8Gi
nodeSelector:
nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
---
apiVersion: v1
kind: Service
metadata:
name: multimodal-llm-service
spec:
selector:
app: multimodal-llm
ports:
- port: 80
targetPort: 8000
type: ClusterIP
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: multimodal-llm-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: multimodal-llm
minReplicas: 2
maxReplicas: 8
metrics:
- type: Resource
resource:
name: nvidia.com/gpu
target:
type: Utilization
averageUtilization: 70
- type: Pods
pods:
metric:
name: request_queue_length
target:
type: AverageValue
averageValue: "10"
总结
本章介绍了多模态推理的核心技术:
- 多模态架构: LLaVA、Flamingo、Q-Former 等主流架构
- 视觉编码优化: Token 剪枝、特征压缩、动态分辨率
- 视频理解: 时序建模、帧采样策略
- 推理服务: 异构批处理、Continuous Batching
- 性能优化: 特征缓存、预取、量化
关键优化点:
- 使用 Perceiver/Q-Former 压缩视觉 token 数量
- 动态分辨率处理适应不同图像
- 图像特征缓存提升重复图像处理效率
- 混合精度量化平衡性能和精度