HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

多模态推理

概述

多模态大模型(Multimodal Large Language Models, MLLMs)能够理解和生成多种模态的数据,包括文本、图像、视频、音频等。本章深入探讨多模态推理的架构设计、优化技术和工程实践。

多模态模型架构

主流架构对比

┌─────────────────────────────────────────────────────────────────────────────┐
│                        多模态模型架构演进                                    │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  1. 双塔架构 (CLIP)           2. 融合架构 (Flamingo)                        │
│  ┌─────────┐ ┌─────────┐     ┌─────────────────────────────┐               │
│  │  Image  │ │  Text   │     │      Language Model         │               │
│  │ Encoder │ │ Encoder │     │  ┌─────┐ ┌─────┐ ┌─────┐   │               │
│  └────┬────┘ └────┬────┘     │  │Cross│ │Self │ │Cross│   │               │
│       │           │          │  │Attn │ │Attn │ │Attn │   │               │
│       └─────┬─────┘          │  └──┬──┘ └─────┘ └──┬──┘   │               │
│             │                │     │               │       │               │
│      Contrastive             │  ┌──┴───────────────┴──┐   │               │
│        Loss                  │  │   Vision Encoder    │   │               │
│                              │  └─────────────────────┘   │               │
│                              └─────────────────────────────┘               │
│                                                                             │
│  3. 统一序列架构 (LLaVA)      4. 原生多模态 (Gemini)                        │
│  ┌─────────────────────────┐ ┌─────────────────────────────┐               │
│  │    Language Model       │ │   Native Multimodal Model   │               │
│  │ [IMG][IMG]...[TXT][TXT] │ │                             │               │
│  └───────────┬─────────────┘ │  Image ──┐                  │               │
│              │               │  Video ──┼──► Unified ──► Output           │
│    ┌─────────┴─────────┐     │  Audio ──┤    Encoder                      │
│    │  Vision Encoder   │     │  Text ───┘                  │               │
│    │   + Projector     │     └─────────────────────────────┘               │
│    └───────────────────┘                                                    │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

LLaVA 架构实现

"""
LLaVA 风格的多模态模型实现
"""
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
from typing import Optional, List, Tuple
import torch.nn.functional as F


class VisionProjector(nn.Module):
    """视觉特征投影器 - 将视觉特征映射到语言模型空间"""

    def __init__(
        self,
        vision_hidden_size: int = 1024,
        text_hidden_size: int = 4096,
        projector_type: str = "mlp"  # mlp, linear, qformer
    ):
        super().__init__()
        self.projector_type = projector_type

        if projector_type == "linear":
            # 简单线性投影
            self.projector = nn.Linear(vision_hidden_size, text_hidden_size)

        elif projector_type == "mlp":
            # 两层 MLP (LLaVA-1.5 默认)
            self.projector = nn.Sequential(
                nn.Linear(vision_hidden_size, text_hidden_size),
                nn.GELU(),
                nn.Linear(text_hidden_size, text_hidden_size)
            )

        elif projector_type == "mlp_deep":
            # 更深的 MLP
            self.projector = nn.Sequential(
                nn.Linear(vision_hidden_size, text_hidden_size),
                nn.GELU(),
                nn.Linear(text_hidden_size, text_hidden_size),
                nn.GELU(),
                nn.Linear(text_hidden_size, text_hidden_size)
            )

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: [batch, num_patches, vision_hidden_size]
        Returns:
            projected: [batch, num_patches, text_hidden_size]
        """
        return self.projector(vision_features)


class SpatialPoolingProjector(nn.Module):
    """空间池化投影器 - 减少视觉 token 数量"""

    def __init__(
        self,
        vision_hidden_size: int = 1024,
        text_hidden_size: int = 4096,
        pool_size: int = 2  # 2x2 池化
    ):
        super().__init__()
        self.pool_size = pool_size

        # 池化后的特征维度
        pooled_dim = vision_hidden_size * pool_size * pool_size

        self.projector = nn.Sequential(
            nn.Linear(pooled_dim, text_hidden_size),
            nn.GELU(),
            nn.Linear(text_hidden_size, text_hidden_size)
        )

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: [batch, num_patches, vision_hidden_size]
                            num_patches = H * W (e.g., 24 * 24 = 576)
        Returns:
            projected: [batch, reduced_patches, text_hidden_size]
        """
        batch_size = vision_features.shape[0]

        # 假设是方形 patch grid
        num_patches = vision_features.shape[1]
        h = w = int(num_patches ** 0.5)

        # Reshape: [batch, H, W, hidden]
        x = vision_features.view(batch_size, h, w, -1)

        # 空间池化: [batch, H//pool, W//pool, hidden * pool^2]
        new_h = h // self.pool_size
        new_w = w // self.pool_size

        x = x.view(batch_size, new_h, self.pool_size, new_w, self.pool_size, -1)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(batch_size, new_h * new_w, -1)

        return self.projector(x)


class MultimodalLLM(nn.Module):
    """多模态大语言模型"""

    def __init__(
        self,
        vision_tower: str = "openai/clip-vit-large-patch14-336",
        language_model: str = "meta-llama/Llama-2-7b-hf",
        projector_type: str = "mlp",
        freeze_vision: bool = True,
        freeze_llm: bool = False
    ):
        super().__init__()

        # 加载视觉编码器
        self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
        self.vision_hidden_size = self.vision_tower.config.hidden_size

        # 加载语言模型
        self.language_model = LlamaForCausalLM.from_pretrained(
            language_model,
            torch_dtype=torch.float16
        )
        self.text_hidden_size = self.language_model.config.hidden_size

        # 投影器
        self.projector = VisionProjector(
            vision_hidden_size=self.vision_hidden_size,
            text_hidden_size=self.text_hidden_size,
            projector_type=projector_type
        )

        # 冻结参数
        if freeze_vision:
            for param in self.vision_tower.parameters():
                param.requires_grad = False

        if freeze_llm:
            for param in self.language_model.parameters():
                param.requires_grad = False

        # 特殊 token
        self.image_token_id = -200  # 占位符

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        编码图像为视觉特征

        Args:
            images: [batch, 3, H, W] 预处理后的图像
        Returns:
            image_features: [batch, num_patches, text_hidden_size]
        """
        # 视觉编码
        with torch.no_grad() if not self.vision_tower.training else torch.enable_grad():
            vision_outputs = self.vision_tower(images, output_hidden_states=True)
            # 使用倒数第二层特征 (更丰富的语义)
            image_features = vision_outputs.hidden_states[-2]

        # 投影到语言空间
        image_features = self.projector(image_features)

        return image_features

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        images: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        前向传播

        Args:
            input_ids: [batch, seq_len] 文本 token ids
            attention_mask: [batch, seq_len]
            images: [batch, 3, H, W] 图像 (可选)
            labels: [batch, seq_len] 训练标签
        """
        # 获取文本 embeddings
        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)

        if images is not None:
            # 编码图像
            image_features = self.encode_images(images)  # [batch, num_patches, hidden]

            # 找到图像占位符位置并替换
            batch_size = input_ids.shape[0]
            for i in range(batch_size):
                # 找到 image token 的位置
                image_token_mask = input_ids[i] == self.image_token_id
                num_image_tokens = image_token_mask.sum()

                if num_image_tokens > 0:
                    # 替换为图像特征
                    image_token_indices = torch.where(image_token_mask)[0]
                    inputs_embeds[i, image_token_indices] = image_features[i, :num_image_tokens]

        # 语言模型前向
        outputs = self.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )

        return outputs


class DynamicResolutionProcessor:
    """动态分辨率处理器 - 支持任意分辨率图像"""

    def __init__(
        self,
        base_resolution: int = 336,
        max_tiles: int = 6,
        patch_size: int = 14
    ):
        self.base_resolution = base_resolution
        self.max_tiles = max_tiles
        self.patch_size = patch_size

    def get_optimal_grid(self, width: int, height: int) -> Tuple[int, int]:
        """计算最优切分网格"""
        aspect_ratio = width / height

        best_grid = (1, 1)
        best_score = float('inf')

        for h_tiles in range(1, self.max_tiles + 1):
            for w_tiles in range(1, self.max_tiles + 1):
                if h_tiles * w_tiles > self.max_tiles:
                    continue

                grid_ratio = w_tiles / h_tiles
                ratio_diff = abs(aspect_ratio - grid_ratio)

                # 考虑分辨率利用率
                target_h = h_tiles * self.base_resolution
                target_w = w_tiles * self.base_resolution

                scale_h = target_h / height
                scale_w = target_w / width
                scale = min(scale_h, scale_w)

                actual_h = int(height * scale)
                actual_w = int(width * scale)

                utilization = (actual_h * actual_w) / (target_h * target_w)

                # 综合评分
                score = ratio_diff * 2 + (1 - utilization)

                if score < best_score:
                    best_score = score
                    best_grid = (h_tiles, w_tiles)

        return best_grid

    def process_image(self, image: torch.Tensor) -> List[torch.Tensor]:
        """
        处理图像为多个 tiles

        Args:
            image: [3, H, W] 原始图像
        Returns:
            tiles: List of [3, base_res, base_res] 图像块
        """
        _, h, w = image.shape
        h_tiles, w_tiles = self.get_optimal_grid(w, h)

        # 计算目标尺寸
        target_h = h_tiles * self.base_resolution
        target_w = w_tiles * self.base_resolution

        # 缩放图像
        image = F.interpolate(
            image.unsqueeze(0),
            size=(target_h, target_w),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)

        # 切分为 tiles
        tiles = []
        for i in range(h_tiles):
            for j in range(w_tiles):
                tile = image[
                    :,
                    i * self.base_resolution:(i + 1) * self.base_resolution,
                    j * self.base_resolution:(j + 1) * self.base_resolution
                ]
                tiles.append(tile)

        # 添加全局缩略图
        thumbnail = F.interpolate(
            image.unsqueeze(0),
            size=(self.base_resolution, self.base_resolution),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)
        tiles.insert(0, thumbnail)

        return tiles

视觉编码优化

高效视觉编码器

"""
高效视觉编码优化技术
"""
import torch
import torch.nn as nn
from typing import Optional
import math


class EfficientViTEncoder(nn.Module):
    """高效视觉编码器"""

    def __init__(
        self,
        image_size: int = 336,
        patch_size: int = 14,
        hidden_size: int = 1024,
        num_layers: int = 24,
        num_heads: int = 16,
        use_flash_attention: bool = True
    ):
        super().__init__()

        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = nn.Conv2d(
            3, hidden_size,
            kernel_size=patch_size,
            stride=patch_size
        )

        # Position embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, hidden_size)
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

        # Transformer layers
        self.layers = nn.ModuleList([
            EfficientViTBlock(
                hidden_size=hidden_size,
                num_heads=num_heads,
                use_flash_attention=use_flash_attention
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_size)

    def forward(
        self,
        images: torch.Tensor,
        output_hidden_states: bool = False
    ) -> torch.Tensor:
        batch_size = images.shape[0]

        # Patch embedding: [B, 3, H, W] -> [B, hidden, H/P, W/P] -> [B, N, hidden]
        x = self.patch_embed(images)
        x = x.flatten(2).transpose(1, 2)

        # 添加 CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Position embedding
        x = x + self.pos_embed

        # Transformer layers
        hidden_states = []
        for layer in self.layers:
            x = layer(x)
            if output_hidden_states:
                hidden_states.append(x)

        x = self.norm(x)

        if output_hidden_states:
            return x, hidden_states
        return x


class EfficientViTBlock(nn.Module):
    """高效 ViT Block"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        use_flash_attention: bool = True
    ):
        super().__init__()

        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn = EfficientAttention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            use_flash_attention=use_flash_attention
        )

        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, int(hidden_size * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(hidden_size * mlp_ratio), hidden_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class EfficientAttention(nn.Module):
    """高效注意力 (支持 Flash Attention)"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        use_flash_attention: bool = True
    ):
        super().__init__()

        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.use_flash_attention = use_flash_attention

        self.qkv = nn.Linear(hidden_size, hidden_size * 3)
        self.proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        # QKV projection
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        if self.use_flash_attention and hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
            # 使用 PyTorch 2.0 的 SDPA (自动选择 Flash Attention)
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=0.0,
                is_causal=False
            )
        else:
            # 标准注意力
            scale = self.head_dim ** -0.5
            attn = (q @ k.transpose(-2, -1)) * scale
            attn = attn.softmax(dim=-1)
            attn_output = attn @ v

        # 输出投影
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
        return self.proj(attn_output)


class TokenPruningViT(nn.Module):
    """Token 剪枝 ViT - 动态减少视觉 token"""

    def __init__(
        self,
        base_encoder: nn.Module,
        prune_layers: list = [6, 12, 18],  # 在哪些层进行剪枝
        keep_ratios: list = [0.7, 0.5, 0.3]  # 保留比例
    ):
        super().__init__()
        self.base_encoder = base_encoder
        self.prune_layers = prune_layers
        self.keep_ratios = keep_ratios

        # 重要性评分网络
        hidden_size = base_encoder.layers[0].attn.qkv.in_features
        self.importance_predictors = nn.ModuleList([
            nn.Linear(hidden_size, 1)
            for _ in prune_layers
        ])

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        batch_size = images.shape[0]

        # Patch embedding
        x = self.base_encoder.patch_embed(images)
        x = x.flatten(2).transpose(1, 2)

        # 添加 CLS token
        cls_tokens = self.base_encoder.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.base_encoder.pos_embed

        prune_idx = 0

        for layer_idx, layer in enumerate(self.base_encoder.layers):
            x = layer(x)

            # 在指定层进行剪枝
            if prune_idx < len(self.prune_layers) and layer_idx == self.prune_layers[prune_idx]:
                # 计算重要性分数 (不包括 CLS token)
                importance = self.importance_predictors[prune_idx](x[:, 1:])  # [B, N-1, 1]
                importance = importance.squeeze(-1)  # [B, N-1]

                # 保留 top-k
                num_patches = x.shape[1] - 1
                keep_num = int(num_patches * self.keep_ratios[prune_idx])

                # 选择最重要的 token
                _, indices = importance.topk(keep_num, dim=1)
                indices = indices.sort(dim=1).values  # 保持空间顺序

                # 收集保留的 token
                batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, keep_num)
                kept_tokens = x[:, 1:][batch_indices, indices]  # [B, keep_num, hidden]

                # 拼接 CLS token
                x = torch.cat([x[:, :1], kept_tokens], dim=1)

                prune_idx += 1

        x = self.base_encoder.norm(x)
        return x

视觉特征压缩

"""
视觉特征压缩技术
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


class PerceiverResampler(nn.Module):
    """
    Perceiver Resampler - 将可变数量的视觉 token 压缩为固定数量
    参考: Flamingo 论文
    """

    def __init__(
        self,
        hidden_size: int = 1024,
        num_queries: int = 64,  # 输出固定数量的 token
        num_layers: int = 6,
        num_heads: int = 16
    ):
        super().__init__()

        # 可学习的查询 token
        self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02)

        # Cross-attention layers
        self.layers = nn.ModuleList([
            PerceiverBlock(
                hidden_size=hidden_size,
                num_heads=num_heads
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: [batch, num_patches, hidden_size]
        Returns:
            compressed: [batch, num_queries, hidden_size]
        """
        batch_size = vision_features.shape[0]

        # 扩展查询
        queries = self.queries.unsqueeze(0).expand(batch_size, -1, -1)

        # Cross-attention
        for layer in self.layers:
            queries = layer(queries, vision_features)

        return self.norm(queries)


class PerceiverBlock(nn.Module):
    """Perceiver Block with cross-attention"""

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()

        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.norm3 = nn.LayerNorm(hidden_size)

        # Cross-attention
        self.cross_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True
        )

        # Self-attention
        self.self_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True
        )

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )

    def forward(
        self,
        queries: torch.Tensor,
        context: torch.Tensor
    ) -> torch.Tensor:
        # Cross-attention to context
        queries = queries + self.cross_attn(
            self.norm1(queries),
            context,
            context
        )[0]

        # Self-attention
        queries = queries + self.self_attn(
            self.norm2(queries),
            self.norm2(queries),
            self.norm2(queries)
        )[0]

        # FFN
        queries = queries + self.ffn(self.norm3(queries))

        return queries


class QFormerCompressor(nn.Module):
    """
    Q-Former 风格的特征压缩
    参考: BLIP-2, InstructBLIP
    """

    def __init__(
        self,
        vision_hidden_size: int = 1024,
        text_hidden_size: int = 768,
        num_queries: int = 32,
        num_layers: int = 6,
        num_heads: int = 12
    ):
        super().__init__()

        self.num_queries = num_queries

        # 可学习查询
        self.query_tokens = nn.Parameter(
            torch.zeros(1, num_queries, text_hidden_size)
        )

        # Cross-attention 层
        self.cross_attention_layers = nn.ModuleList([
            QFormerCrossAttentionLayer(
                text_hidden_size=text_hidden_size,
                vision_hidden_size=vision_hidden_size,
                num_heads=num_heads
            )
            for _ in range(num_layers)
        ])

        # Self-attention 层
        self.self_attention_layers = nn.ModuleList([
            QFormerSelfAttentionLayer(
                hidden_size=text_hidden_size,
                num_heads=num_heads
            )
            for _ in range(num_layers)
        ])

    def forward(
        self,
        vision_features: torch.Tensor,
        text_embeds: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            vision_features: [batch, num_patches, vision_hidden]
            text_embeds: [batch, text_len, text_hidden] (可选,用于条件压缩)
        """
        batch_size = vision_features.shape[0]

        # 初始化查询
        query_output = self.query_tokens.expand(batch_size, -1, -1)

        for cross_layer, self_layer in zip(
            self.cross_attention_layers,
            self.self_attention_layers
        ):
            # Cross-attention with vision
            query_output = cross_layer(query_output, vision_features)

            # Self-attention (可选择性地加入文本条件)
            if text_embeds is not None:
                combined = torch.cat([query_output, text_embeds], dim=1)
                query_output = self_layer(combined)[:, :self.num_queries]
            else:
                query_output = self_layer(query_output)

        return query_output


class QFormerCrossAttentionLayer(nn.Module):
    def __init__(self, text_hidden_size: int, vision_hidden_size: int, num_heads: int):
        super().__init__()
        self.norm = nn.LayerNorm(text_hidden_size)
        self.cross_attn = nn.MultiheadAttention(
            text_hidden_size, num_heads,
            kdim=vision_hidden_size, vdim=vision_hidden_size,
            batch_first=True
        )

    def forward(self, query: torch.Tensor, vision: torch.Tensor) -> torch.Tensor:
        return query + self.cross_attn(self.norm(query), vision, vision)[0]


class QFormerSelfAttentionLayer(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x

多模态推理服务

高效推理引擎

"""
多模态推理服务实现
"""
import torch
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass
from PIL import Image
import asyncio
from concurrent.futures import ThreadPoolExecutor
import numpy as np


@dataclass
class MultimodalRequest:
    """多模态请求"""
    request_id: str
    text: str
    images: List[Union[str, Image.Image]]  # 图像路径或 PIL Image
    max_tokens: int = 512
    temperature: float = 0.7
    stream: bool = False


@dataclass
class MultimodalResponse:
    """多模态响应"""
    request_id: str
    text: str
    usage: Dict[str, int]


class MultimodalInferenceEngine:
    """多模态推理引擎"""

    def __init__(
        self,
        model_path: str,
        vision_processor_path: str,
        device: str = "cuda",
        max_batch_size: int = 8,
        max_image_tiles: int = 6
    ):
        self.device = device
        self.max_batch_size = max_batch_size
        self.max_image_tiles = max_image_tiles

        # 加载模型
        self.model = self._load_model(model_path)
        self.vision_processor = self._load_vision_processor(vision_processor_path)
        self.tokenizer = self._load_tokenizer(model_path)

        # 图像预处理线程池
        self.image_executor = ThreadPoolExecutor(max_workers=4)

        # 请求队列
        self.request_queue = asyncio.Queue()

        # KV Cache 管理
        self.kv_cache_manager = MultimodalKVCacheManager(
            max_batch_size=max_batch_size,
            max_seq_len=4096,
            hidden_size=self.model.config.hidden_size,
            num_layers=self.model.config.num_hidden_layers,
            num_heads=self.model.config.num_attention_heads
        )

    def _load_model(self, model_path: str):
        """加载多模态模型"""
        # 实际实现会加载具体的模型
        pass

    def _load_vision_processor(self, path: str):
        """加载视觉处理器"""
        pass

    def _load_tokenizer(self, path: str):
        """加载分词器"""
        pass

    async def process_images(
        self,
        images: List[Union[str, Image.Image]]
    ) -> torch.Tensor:
        """异步处理图像"""
        loop = asyncio.get_event_loop()

        # 并行加载和预处理图像
        tasks = []
        for img in images:
            task = loop.run_in_executor(
                self.image_executor,
                self._preprocess_single_image,
                img
            )
            tasks.append(task)

        processed = await asyncio.gather(*tasks)

        # 合并为 batch
        return torch.stack(processed).to(self.device)

    def _preprocess_single_image(
        self,
        image: Union[str, Image.Image]
    ) -> torch.Tensor:
        """预处理单张图像"""
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')

        # 使用视觉处理器
        processed = self.vision_processor(
            images=image,
            return_tensors="pt"
        )

        return processed.pixel_values.squeeze(0)

    @torch.inference_mode()
    async def generate(
        self,
        request: MultimodalRequest
    ) -> MultimodalResponse:
        """生成响应"""
        # 1. 预处理图像
        if request.images:
            image_features = await self.process_images(request.images)
            image_embeds = self.model.encode_images(image_features)
        else:
            image_embeds = None

        # 2. 编码文本
        input_ids = self.tokenizer.encode(
            request.text,
            return_tensors="pt"
        ).to(self.device)

        # 3. 准备输入
        if image_embeds is not None:
            inputs_embeds = self._merge_image_text(input_ids, image_embeds)
        else:
            inputs_embeds = self.model.get_input_embeddings()(input_ids)

        # 4. 分配 KV Cache
        cache_id = self.kv_cache_manager.allocate(
            seq_len=inputs_embeds.shape[1] + request.max_tokens
        )

        # 5. 生成
        try:
            generated_ids = await self._generate_tokens(
                inputs_embeds=inputs_embeds,
                cache_id=cache_id,
                max_tokens=request.max_tokens,
                temperature=request.temperature,
                stream=request.stream
            )
        finally:
            # 释放 cache
            self.kv_cache_manager.free(cache_id)

        # 6. 解码
        generated_text = self.tokenizer.decode(
            generated_ids,
            skip_special_tokens=True
        )

        return MultimodalResponse(
            request_id=request.request_id,
            text=generated_text,
            usage={
                "prompt_tokens": inputs_embeds.shape[1],
                "completion_tokens": len(generated_ids),
                "total_tokens": inputs_embeds.shape[1] + len(generated_ids)
            }
        )

    def _merge_image_text(
        self,
        input_ids: torch.Tensor,
        image_embeds: torch.Tensor
    ) -> torch.Tensor:
        """合并图像和文本 embedding"""
        text_embeds = self.model.get_input_embeddings()(input_ids)

        # 找到 <image> 占位符位置
        image_token_id = self.tokenizer.convert_tokens_to_ids("<image>")
        image_positions = (input_ids == image_token_id).nonzero(as_tuple=True)

        # 构建完整的 embedding 序列
        # [text before image] + [image embeds] + [text after image]
        batch_size = input_ids.shape[0]

        merged_embeds = []
        for i in range(batch_size):
            pos = image_positions[1][image_positions[0] == i]
            if len(pos) > 0:
                pos = pos[0].item()
                before = text_embeds[i, :pos]
                after = text_embeds[i, pos+1:]
                merged = torch.cat([before, image_embeds[i], after], dim=0)
            else:
                merged = text_embeds[i]
            merged_embeds.append(merged)

        # Pad to same length
        max_len = max(e.shape[0] for e in merged_embeds)
        padded = torch.zeros(batch_size, max_len, text_embeds.shape[-1], device=self.device)
        for i, emb in enumerate(merged_embeds):
            padded[i, :emb.shape[0]] = emb

        return padded

    async def _generate_tokens(
        self,
        inputs_embeds: torch.Tensor,
        cache_id: int,
        max_tokens: int,
        temperature: float,
        stream: bool
    ) -> torch.Tensor:
        """Token 生成循环"""
        generated = []

        # Prefill
        logits = self.model.forward_prefill(
            inputs_embeds=inputs_embeds,
            cache_id=cache_id
        )

        for _ in range(max_tokens):
            # Sample
            next_token = self._sample(logits[:, -1:], temperature)
            generated.append(next_token)

            # 检查是否结束
            if next_token.item() == self.tokenizer.eos_token_id:
                break

            # Decode step
            logits = self.model.forward_decode(
                input_ids=next_token,
                cache_id=cache_id
            )

            # 让出控制权给其他协程
            await asyncio.sleep(0)

        return torch.cat(generated, dim=1)

    def _sample(
        self,
        logits: torch.Tensor,
        temperature: float
    ) -> torch.Tensor:
        """采样下一个 token"""
        if temperature == 0:
            return logits.argmax(dim=-1)

        probs = torch.softmax(logits / temperature, dim=-1)
        return torch.multinomial(probs.squeeze(1), num_samples=1)


class MultimodalKVCacheManager:
    """多模态 KV Cache 管理器"""

    def __init__(
        self,
        max_batch_size: int,
        max_seq_len: int,
        hidden_size: int,
        num_layers: int,
        num_heads: int
    ):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        head_dim = hidden_size // num_heads

        # 预分配 KV Cache 池
        self.k_cache_pool = torch.zeros(
            max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
            dtype=torch.float16,
            device="cuda"
        )
        self.v_cache_pool = torch.zeros(
            max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
            dtype=torch.float16,
            device="cuda"
        )

        # 空闲 slot 追踪
        self.free_slots = list(range(max_batch_size))
        self.slot_lengths = {}  # slot_id -> current_length

    def allocate(self, seq_len: int) -> int:
        """分配一个 cache slot"""
        if not self.free_slots:
            raise RuntimeError("No free cache slots available")

        slot_id = self.free_slots.pop(0)
        self.slot_lengths[slot_id] = 0

        return slot_id

    def free(self, slot_id: int):
        """释放 cache slot"""
        self.free_slots.append(slot_id)
        del self.slot_lengths[slot_id]

    def get_cache(self, slot_id: int) -> tuple:
        """获取指定 slot 的 KV cache"""
        length = self.slot_lengths[slot_id]
        return (
            self.k_cache_pool[slot_id, :, :length],
            self.v_cache_pool[slot_id, :, :length]
        )

    def update_cache(
        self,
        slot_id: int,
        new_k: torch.Tensor,
        new_v: torch.Tensor
    ):
        """更新 KV cache"""
        length = self.slot_lengths[slot_id]
        new_length = new_k.shape[2]

        self.k_cache_pool[slot_id, :, length:length+new_length] = new_k
        self.v_cache_pool[slot_id, :, length:length+new_length] = new_v

        self.slot_lengths[slot_id] = length + new_length

视频理解

视频编码器

"""
视频理解模块
"""
import torch
import torch.nn as nn
from typing import List, Optional, Tuple
import torch.nn.functional as F


class VideoEncoder(nn.Module):
    """视频编码器"""

    def __init__(
        self,
        image_encoder: nn.Module,
        num_frames: int = 8,
        temporal_aggregation: str = "avg",  # avg, attention, perceiver
        hidden_size: int = 1024
    ):
        super().__init__()

        self.image_encoder = image_encoder
        self.num_frames = num_frames
        self.temporal_aggregation = temporal_aggregation

        if temporal_aggregation == "attention":
            self.temporal_attention = TemporalAttention(hidden_size)
        elif temporal_aggregation == "perceiver":
            self.temporal_perceiver = TemporalPerceiver(
                hidden_size=hidden_size,
                num_queries=64
            )

    def forward(
        self,
        video: torch.Tensor,  # [batch, num_frames, 3, H, W]
        return_all_frames: bool = False
    ) -> torch.Tensor:
        batch_size, num_frames = video.shape[:2]

        # 编码每一帧
        video_flat = video.view(-1, *video.shape[2:])  # [B*T, 3, H, W]
        frame_features = self.image_encoder(video_flat)  # [B*T, num_patches, hidden]

        # Reshape
        num_patches = frame_features.shape[1]
        hidden_size = frame_features.shape[2]
        frame_features = frame_features.view(
            batch_size, num_frames, num_patches, hidden_size
        )

        # 时序聚合
        if self.temporal_aggregation == "avg":
            video_features = frame_features.mean(dim=1)  # [B, patches, hidden]

        elif self.temporal_aggregation == "attention":
            # 使用时序注意力
            video_features = self.temporal_attention(frame_features)

        elif self.temporal_aggregation == "perceiver":
            # 使用 Perceiver 压缩时序
            video_features = self.temporal_perceiver(frame_features)

        if return_all_frames:
            return video_features, frame_features
        return video_features


class TemporalAttention(nn.Module):
    """时序注意力"""

    def __init__(self, hidden_size: int, num_heads: int = 8):
        super().__init__()

        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

        # 可学习的时间位置编码
        self.temporal_pos_embed = nn.Parameter(
            torch.zeros(1, 32, 1, hidden_size)  # 最多 32 帧
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, num_frames, num_patches, hidden]
        Returns:
            output: [batch, num_patches, hidden]
        """
        batch_size, num_frames, num_patches, hidden = x.shape

        # 添加时间位置编码
        x = x + self.temporal_pos_embed[:, :num_frames]

        # 转置为 [batch, patches, frames, hidden]
        x = x.permute(0, 2, 1, 3)

        # 计算注意力
        q = self.q_proj(x).view(batch_size, num_patches, num_frames, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, num_patches, num_frames, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, num_patches, num_frames, self.num_heads, self.head_dim)

        # [B, patches, heads, frames, head_dim]
        q = q.permute(0, 1, 3, 2, 4)
        k = k.permute(0, 1, 3, 2, 4)
        v = v.permute(0, 1, 3, 2, 4)

        # Attention
        scale = self.head_dim ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)

        out = torch.matmul(attn, v)  # [B, patches, heads, frames, head_dim]

        # 聚合时间维度 (使用第一帧的 attention 作为 query)
        out = out.mean(dim=-2)  # [B, patches, heads, head_dim]
        out = out.view(batch_size, num_patches, hidden)

        return self.out_proj(out)


class TemporalPerceiver(nn.Module):
    """时序 Perceiver - 将视频压缩为固定长度"""

    def __init__(
        self,
        hidden_size: int,
        num_queries: int = 64,
        num_layers: int = 2
    ):
        super().__init__()

        self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02)

        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=hidden_size,
                nhead=8,
                dim_feedforward=hidden_size * 4,
                batch_first=True
            )
            for _ in range(num_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, num_frames, num_patches, hidden]
        Returns:
            output: [batch, num_queries, hidden]
        """
        batch_size, num_frames, num_patches, hidden = x.shape

        # Flatten temporal and spatial dimensions
        x = x.view(batch_size, num_frames * num_patches, hidden)

        # Expand queries
        queries = self.queries.unsqueeze(0).expand(batch_size, -1, -1)

        # Cross-attention
        for layer in self.layers:
            queries = layer(queries, x)

        return queries


class FrameSampler:
    """视频帧采样器"""

    def __init__(
        self,
        num_frames: int = 8,
        sampling_strategy: str = "uniform"  # uniform, keyframe, dynamic
    ):
        self.num_frames = num_frames
        self.sampling_strategy = sampling_strategy

    def sample(
        self,
        video: torch.Tensor,  # [total_frames, 3, H, W]
        fps: float = 30.0,
        duration: float = None
    ) -> Tuple[torch.Tensor, List[int]]:
        """
        采样视频帧

        Returns:
            sampled_frames: [num_frames, 3, H, W]
            frame_indices: 采样的帧索引
        """
        total_frames = video.shape[0]

        if self.sampling_strategy == "uniform":
            # 均匀采样
            indices = torch.linspace(0, total_frames - 1, self.num_frames).long()

        elif self.sampling_strategy == "keyframe":
            # 基于关键帧的采样 (需要预计算关键帧)
            indices = self._keyframe_sampling(video)

        elif self.sampling_strategy == "dynamic":
            # 动态采样 - 基于帧间差异
            indices = self._dynamic_sampling(video)

        return video[indices], indices.tolist()

    def _keyframe_sampling(self, video: torch.Tensor) -> torch.Tensor:
        """基于关键帧的采样"""
        # 计算帧间差异
        diffs = (video[1:] - video[:-1]).abs().mean(dim=[1, 2, 3])

        # 找到变化最大的帧
        _, top_indices = diffs.topk(min(self.num_frames - 2, len(diffs)))

        # 添加首尾帧
        indices = torch.cat([
            torch.tensor([0]),
            top_indices.sort().values + 1,
            torch.tensor([len(video) - 1])
        ])

        # 如果不够,均匀补充
        if len(indices) < self.num_frames:
            all_indices = set(indices.tolist())
            remaining = self.num_frames - len(indices)
            step = len(video) // (remaining + 1)
            for i in range(1, remaining + 1):
                idx = i * step
                if idx not in all_indices:
                    indices = torch.cat([indices, torch.tensor([idx])])

        return indices[:self.num_frames].sort().values

    def _dynamic_sampling(self, video: torch.Tensor) -> torch.Tensor:
        """动态采样 - 在变化大的区域采样更多帧"""
        # 计算帧间差异
        diffs = (video[1:] - video[:-1]).abs().mean(dim=[1, 2, 3])
        diffs = torch.cat([torch.tensor([0.0]), diffs])

        # 归一化为概率分布
        probs = diffs / diffs.sum()

        # 加入均匀先验
        uniform = torch.ones_like(probs) / len(probs)
        probs = 0.5 * probs + 0.5 * uniform

        # 采样
        indices = torch.multinomial(probs, self.num_frames, replacement=False)

        return indices.sort().values

多模态批处理

异构批处理

"""
多模态异构批处理
处理不同数量图像的请求
"""
import torch
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import asyncio


@dataclass
class MultimodalBatchItem:
    """批处理项"""
    request_id: str
    text_tokens: torch.Tensor
    image_features: Optional[torch.Tensor]  # [num_images, num_patches, hidden]
    num_images: int
    max_tokens: int


class MultimodalBatcher:
    """多模态批处理器"""

    def __init__(
        self,
        max_batch_size: int = 8,
        max_total_images: int = 16,  # 批次最大图像数
        max_total_tokens: int = 8192,  # 批次最大 token 数
        timeout_ms: int = 50
    ):
        self.max_batch_size = max_batch_size
        self.max_total_images = max_total_images
        self.max_total_tokens = max_total_tokens
        self.timeout_ms = timeout_ms

        self.pending_items: List[MultimodalBatchItem] = []
        self.lock = asyncio.Lock()

    async def add_request(self, item: MultimodalBatchItem) -> bool:
        """添加请求到批处理队列"""
        async with self.lock:
            # 检查是否可以加入当前批次
            current_images = sum(i.num_images for i in self.pending_items)
            current_tokens = sum(len(i.text_tokens) for i in self.pending_items)

            can_add = (
                len(self.pending_items) < self.max_batch_size and
                current_images + item.num_images <= self.max_total_images and
                current_tokens + len(item.text_tokens) <= self.max_total_tokens
            )

            if can_add:
                self.pending_items.append(item)
                return True
            return False

    async def get_batch(self) -> List[MultimodalBatchItem]:
        """获取当前批次"""
        async with self.lock:
            batch = self.pending_items
            self.pending_items = []
            return batch

    def prepare_batch(
        self,
        items: List[MultimodalBatchItem]
    ) -> Dict[str, torch.Tensor]:
        """准备批次数据"""
        if not items:
            return {}

        # 1. Pad 文本 token
        max_text_len = max(len(item.text_tokens) for item in items)
        text_tokens = torch.zeros(len(items), max_text_len, dtype=torch.long)
        text_mask = torch.zeros(len(items), max_text_len, dtype=torch.bool)

        for i, item in enumerate(items):
            seq_len = len(item.text_tokens)
            text_tokens[i, :seq_len] = item.text_tokens
            text_mask[i, :seq_len] = True

        # 2. 处理图像特征 (variable number of images)
        # 使用 packed representation
        all_image_features = []
        image_offsets = [0]

        for item in items:
            if item.image_features is not None:
                all_image_features.append(item.image_features)
                image_offsets.append(image_offsets[-1] + item.num_images)
            else:
                image_offsets.append(image_offsets[-1])

        if all_image_features:
            packed_images = torch.cat(all_image_features, dim=0)
        else:
            packed_images = None

        return {
            "text_tokens": text_tokens,
            "text_mask": text_mask,
            "packed_images": packed_images,
            "image_offsets": torch.tensor(image_offsets),
            "request_ids": [item.request_id for item in items]
        }


class MultimodalContinuousBatching:
    """多模态 Continuous Batching"""

    def __init__(
        self,
        model: nn.Module,
        max_batch_size: int = 8,
        max_seq_len: int = 4096
    ):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        # 运行中的请求
        self.running_requests: Dict[str, RunningRequest] = {}

        # KV Cache
        self.kv_cache = MultimodalKVCache(
            max_batch_size=max_batch_size,
            max_seq_len=max_seq_len,
            model_config=model.config
        )

    async def step(self) -> List[Dict[str, Any]]:
        """执行一步推理"""
        if not self.running_requests:
            return []

        # 收集当前批次的输入
        batch_input_ids = []
        batch_positions = []
        batch_cache_ids = []

        for req_id, req in self.running_requests.items():
            batch_input_ids.append(req.current_token)
            batch_positions.append(req.position)
            batch_cache_ids.append(req.cache_id)

        # 批量推理
        input_ids = torch.stack(batch_input_ids)
        positions = torch.tensor(batch_positions)

        logits = self.model.forward_decode(
            input_ids=input_ids,
            positions=positions,
            cache_ids=batch_cache_ids,
            kv_cache=self.kv_cache
        )

        # 采样和更新
        results = []
        completed = []

        for i, (req_id, req) in enumerate(self.running_requests.items()):
            # 采样
            next_token = self._sample(logits[i], req.temperature)
            req.generated_tokens.append(next_token.item())
            req.current_token = next_token
            req.position += 1

            # 检查完成条件
            if self._should_stop(req, next_token):
                results.append({
                    "request_id": req_id,
                    "tokens": req.generated_tokens,
                    "finished": True
                })
                completed.append(req_id)
            else:
                results.append({
                    "request_id": req_id,
                    "token": next_token.item(),
                    "finished": False
                })

        # 清理完成的请求
        for req_id in completed:
            self._free_request(req_id)

        return results

    def _sample(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
        if temperature == 0:
            return logits.argmax()
        probs = torch.softmax(logits / temperature, dim=-1)
        return torch.multinomial(probs, 1)

    def _should_stop(self, req, token: torch.Tensor) -> bool:
        return (
            token.item() == self.model.config.eos_token_id or
            len(req.generated_tokens) >= req.max_tokens
        )

    def _free_request(self, req_id: str):
        req = self.running_requests.pop(req_id)
        self.kv_cache.free(req.cache_id)


@dataclass
class RunningRequest:
    """运行中的请求"""
    request_id: str
    current_token: torch.Tensor
    position: int
    cache_id: int
    max_tokens: int
    temperature: float
    generated_tokens: List[int]

性能优化

图像缓存与复用

"""
图像特征缓存系统
"""
import torch
import hashlib
from typing import Dict, Optional, Tuple
from collections import OrderedDict
import threading


class ImageFeatureCache:
    """图像特征缓存"""

    def __init__(
        self,
        max_cache_size: int = 1000,  # 最大缓存图像数
        cache_device: str = "cpu"  # 缓存设备
    ):
        self.max_cache_size = max_cache_size
        self.cache_device = cache_device

        # LRU 缓存
        self.cache: OrderedDict[str, torch.Tensor] = OrderedDict()
        self.lock = threading.Lock()

        # 统计
        self.hits = 0
        self.misses = 0

    def get_cache_key(self, image_bytes: bytes) -> str:
        """计算图像的缓存 key"""
        return hashlib.md5(image_bytes).hexdigest()

    def get(self, key: str) -> Optional[torch.Tensor]:
        """获取缓存的特征"""
        with self.lock:
            if key in self.cache:
                # 移到末尾 (最近使用)
                self.cache.move_to_end(key)
                self.hits += 1
                return self.cache[key]
            self.misses += 1
            return None

    def put(self, key: str, features: torch.Tensor):
        """缓存特征"""
        with self.lock:
            if key in self.cache:
                self.cache.move_to_end(key)
                return

            # 移到 cache device
            features = features.to(self.cache_device)

            # 检查容量
            while len(self.cache) >= self.max_cache_size:
                self.cache.popitem(last=False)

            self.cache[key] = features

    def get_stats(self) -> Dict[str, float]:
        """获取缓存统计"""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": hit_rate,
            "cache_size": len(self.cache)
        }


class PrefetchingImageLoader:
    """预取图像加载器"""

    def __init__(
        self,
        vision_encoder: nn.Module,
        cache: ImageFeatureCache,
        num_workers: int = 4
    ):
        self.vision_encoder = vision_encoder
        self.cache = cache

        # 预取队列和线程池
        self.prefetch_queue = asyncio.Queue()
        self.workers = []

        for _ in range(num_workers):
            worker = asyncio.create_task(self._prefetch_worker())
            self.workers.append(worker)

    async def _prefetch_worker(self):
        """预取工作线程"""
        while True:
            image_bytes, future = await self.prefetch_queue.get()

            try:
                key = self.cache.get_cache_key(image_bytes)

                # 检查缓存
                features = self.cache.get(key)

                if features is None:
                    # 加载和编码
                    image = self._load_image(image_bytes)
                    with torch.no_grad():
                        features = self.vision_encoder(image)
                    self.cache.put(key, features)

                future.set_result(features.cuda())

            except Exception as e:
                future.set_exception(e)

    def _load_image(self, image_bytes: bytes) -> torch.Tensor:
        """加载图像"""
        from PIL import Image
        import io

        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        # 预处理...
        return processed_image

    async def load(self, image_bytes: bytes) -> torch.Tensor:
        """加载图像 (带缓存)"""
        key = self.cache.get_cache_key(image_bytes)

        # 先检查缓存
        features = self.cache.get(key)
        if features is not None:
            return features.cuda()

        # 加入预取队列
        future = asyncio.Future()
        await self.prefetch_queue.put((image_bytes, future))

        return await future

    async def prefetch(self, image_bytes_list: List[bytes]):
        """批量预取"""
        futures = []
        for image_bytes in image_bytes_list:
            key = self.cache.get_cache_key(image_bytes)
            if self.cache.get(key) is None:
                future = asyncio.Future()
                await self.prefetch_queue.put((image_bytes, future))
                futures.append(future)

        # 等待所有预取完成
        if futures:
            await asyncio.gather(*futures, return_exceptions=True)

多模态量化

Vision-Language 量化策略

"""
多模态模型量化
"""
import torch
import torch.nn as nn
from typing import Dict, List, Optional


class MultimodalQuantizer:
    """多模态模型量化器"""

    def __init__(
        self,
        vision_bits: int = 8,  # 视觉编码器量化位数
        language_bits: int = 4,  # 语言模型量化位数
        projector_bits: int = 8  # 投影器量化位数
    ):
        self.vision_bits = vision_bits
        self.language_bits = language_bits
        self.projector_bits = projector_bits

    def quantize_model(self, model: nn.Module) -> nn.Module:
        """量化整个模型"""
        # 1. 量化视觉编码器 (通常用 INT8)
        model.vision_tower = self._quantize_vision(
            model.vision_tower,
            bits=self.vision_bits
        )

        # 2. 量化语言模型 (可以用 INT4)
        model.language_model = self._quantize_language(
            model.language_model,
            bits=self.language_bits
        )

        # 3. 量化投影器
        model.projector = self._quantize_projector(
            model.projector,
            bits=self.projector_bits
        )

        return model

    def _quantize_vision(
        self,
        vision_model: nn.Module,
        bits: int
    ) -> nn.Module:
        """量化视觉编码器"""
        if bits == 8:
            # 动态 INT8 量化
            return torch.quantization.quantize_dynamic(
                vision_model,
                {nn.Linear},
                dtype=torch.qint8
            )
        elif bits == 4:
            # 使用 bitsandbytes 4-bit
            from bitsandbytes.nn import Linear4bit
            return self._replace_linear_with_4bit(vision_model)
        return vision_model

    def _quantize_language(
        self,
        language_model: nn.Module,
        bits: int
    ) -> nn.Module:
        """量化语言模型"""
        if bits == 4:
            # GPTQ/AWQ 风格的 4-bit 量化
            return self._apply_gptq_quantization(language_model)
        elif bits == 8:
            return torch.quantization.quantize_dynamic(
                language_model,
                {nn.Linear},
                dtype=torch.qint8
            )
        return language_model

    def _quantize_projector(
        self,
        projector: nn.Module,
        bits: int
    ) -> nn.Module:
        """量化投影器 - 通常保持较高精度"""
        if bits == 8:
            return torch.quantization.quantize_dynamic(
                projector,
                {nn.Linear},
                dtype=torch.qint8
            )
        return projector

    def _apply_gptq_quantization(self, model: nn.Module) -> nn.Module:
        """应用 GPTQ 量化"""
        # 需要校准数据
        # 这里是简化的实现
        from auto_gptq import AutoGPTQForCausalLM

        # 实际使用需要更完整的实现
        return model


class MixedPrecisionConfig:
    """混合精度配置"""

    def __init__(self):
        # 不同层的精度配置
        self.layer_config = {
            # 视觉编码器 - 使用 FP16
            "vision_tower": torch.float16,

            # 投影器 - 保持 FP16 以保证图像-文本对齐
            "projector": torch.float16,

            # 语言模型 Embedding - FP16
            "embed_tokens": torch.float16,

            # 语言模型注意力 - 可以用 INT8
            "self_attn": torch.int8,

            # 语言模型 FFN - 可以用 INT4
            "mlp": "int4",

            # 输出层 - FP16
            "lm_head": torch.float16
        }

    def apply(self, model: nn.Module) -> nn.Module:
        """应用混合精度配置"""
        for name, module in model.named_modules():
            for pattern, dtype in self.layer_config.items():
                if pattern in name:
                    if dtype == torch.float16:
                        module.half()
                    elif dtype == torch.int8:
                        # 应用 INT8 量化
                        pass
                    elif dtype == "int4":
                        # 应用 INT4 量化
                        pass
        return model

部署配置示例

Kubernetes 部署

# multimodal-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: multimodal-llm
  labels:
    app: multimodal-llm
spec:
  replicas: 2
  selector:
    matchLabels:
      app: multimodal-llm
  template:
    metadata:
      labels:
        app: multimodal-llm
    spec:
      containers:
      - name: multimodal-server
        image: multimodal-llm:latest
        ports:
        - containerPort: 8000
        resources:
          limits:
            nvidia.com/gpu: 1
            memory: "32Gi"
            cpu: "8"
          requests:
            nvidia.com/gpu: 1
            memory: "24Gi"
            cpu: "4"
        env:
        - name: MODEL_PATH
          value: "/models/llava-1.5-7b"
        - name: MAX_BATCH_SIZE
          value: "8"
        - name: MAX_IMAGE_TILES
          value: "6"
        - name: VISION_CACHE_SIZE
          value: "1000"
        volumeMounts:
        - name: model-storage
          mountPath: /models
        - name: cache-storage
          mountPath: /cache
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 120
          periodSeconds: 30
        readinessProbe:
          httpGet:
            path: /ready
            port: 8000
          initialDelaySeconds: 60
          periodSeconds: 10
      volumes:
      - name: model-storage
        persistentVolumeClaim:
          claimName: model-pvc
      - name: cache-storage
        emptyDir:
          medium: Memory
          sizeLimit: 8Gi
      nodeSelector:
        nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
---
apiVersion: v1
kind: Service
metadata:
  name: multimodal-llm-service
spec:
  selector:
    app: multimodal-llm
  ports:
  - port: 80
    targetPort: 8000
  type: ClusterIP
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: multimodal-llm-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: multimodal-llm
  minReplicas: 2
  maxReplicas: 8
  metrics:
  - type: Resource
    resource:
      name: nvidia.com/gpu
      target:
        type: Utilization
        averageUtilization: 70
  - type: Pods
    pods:
      metric:
        name: request_queue_length
      target:
        type: AverageValue
        averageValue: "10"

总结

本章介绍了多模态推理的核心技术:

  1. 多模态架构: LLaVA、Flamingo、Q-Former 等主流架构
  2. 视觉编码优化: Token 剪枝、特征压缩、动态分辨率
  3. 视频理解: 时序建模、帧采样策略
  4. 推理服务: 异构批处理、Continuous Batching
  5. 性能优化: 特征缓存、预取、量化

关键优化点:

  • 使用 Perceiver/Q-Former 压缩视觉 token 数量
  • 动态分辨率处理适应不同图像
  • 图像特征缓存提升重复图像处理效率
  • 混合精度量化平衡性能和精度