HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

03-系统设计面试题

概述

本文聚焦 AI 基础设施相关的系统设计面试题,涵盖分布式训练平台、模型服务平台、特征平台、向量数据库等核心系统的设计与实现。

系统设计方法论

设计流程框架

┌─────────────────────────────────────────────────────────────────────────┐
│                      系统设计四步法                                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Step 1: 需求澄清 (5分钟)                                                │
│  ├── 功能需求:核心功能、用户场景                                         │
│  ├── 非功能需求:性能、可用性、一致性                                      │
│  ├── 规模估算:QPS、数据量、用户数                                        │
│  └── 约束条件:技术栈、成本、时间                                         │
│                                                                         │
│  Step 2: 高层设计 (10分钟)                                               │
│  ├── 核心组件:识别主要模块                                               │
│  ├── 数据流:请求如何流转                                                 │
│  ├── 接口设计:API 定义                                                  │
│  └── 技术选型:存储、计算、通信                                           │
│                                                                         │
│  Step 3: 详细设计 (15分钟)                                               │
│  ├── 核心模块深入                                                        │
│  ├── 数据模型设计                                                        │
│  ├── 关键算法                                                           │
│  └── 扩展性设计                                                         │
│                                                                         │
│  Step 4: 优化讨论 (10分钟)                                               │
│  ├── 瓶颈分析                                                           │
│  ├── 容错设计                                                           │
│  ├── 监控告警                                                           │
│  └── 演进路径                                                           │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

题目一:设计分布式模型训练平台

题目描述

设计一个支持大规模分布式训练的平台,需要支持:

  • 多种并行策略(数据并行、模型并行、流水线并行)
  • 弹性训练(节点故障恢复、动态扩缩容)
  • 多租户资源隔离
  • 训练任务生命周期管理

需求分析

功能需求:

  • 提交训练任务(代码、数据、配置)
  • 资源调度(GPU、内存、存储)
  • 分布式训练协调
  • 训练监控与日志
  • Checkpoint 管理
  • 任务生命周期管理

非功能需求:

  • 支持 1000+ GPU 同时训练
  • 任务启动延迟 < 5 分钟
  • 单点故障自动恢复 < 10 分钟
  • 资源利用率 > 80%

高层架构设计

┌─────────────────────────────────────────────────────────────────────────┐
│                         分布式训练平台架构                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                        用户接口层                                 │   │
│  │   Web Console │ CLI Tool │ SDK │ REST API │ Notebook             │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        任务管理层                                 │   │
│  │                                                                  │   │
│  │   ┌──────────┐   ┌──────────┐   ┌──────────┐   ┌──────────┐    │   │
│  │   │ 任务提交  │   │ 任务调度  │   │ 生命周期  │   │ 队列管理  │    │   │
│  │   │ Service  │   │ Service  │   │ Manager  │   │ Service  │    │   │
│  │   └──────────┘   └──────────┘   └──────────┘   └──────────┘    │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        资源管理层                                 │   │
│  │                                                                  │   │
│  │   ┌──────────┐   ┌──────────┐   ┌──────────┐   ┌──────────┐    │   │
│  │   │ 资源池    │   │ 配额管理  │   │ 资源调度  │   │ 弹性伸缩  │    │   │
│  │   │ Manager  │   │ Service  │   │ Service  │   │ Service  │    │   │
│  │   └──────────┘   └──────────┘   └──────────┘   └──────────┘    │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        训练运行层                                 │   │
│  │                                                                  │   │
│  │   ┌──────────┐   ┌──────────┐   ┌──────────┐   ┌──────────┐    │   │
│  │   │ 训练控制  │   │ 通信协调  │   │ Checkpoint│   │ 故障恢复  │    │   │
│  │   │ Controller│  │ Service  │   │ Manager  │   │ Service  │    │   │
│  │   └──────────┘   └──────────┘   └──────────┘   └──────────┘    │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        基础设施层                                 │   │
│  │                                                                  │   │
│  │   Kubernetes │ GPU Operator │ 分布式存储 │ 高速网络(RDMA)         │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

核心组件设计

1. 任务提交与调度

from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
import asyncio
from datetime import datetime
import uuid


class ParallelStrategy(Enum):
    """并行策略"""
    DATA_PARALLEL = "data_parallel"
    MODEL_PARALLEL = "model_parallel"
    PIPELINE_PARALLEL = "pipeline_parallel"
    HYBRID = "hybrid"


class JobStatus(Enum):
    """任务状态"""
    PENDING = "pending"
    QUEUED = "queued"
    SCHEDULING = "scheduling"
    RUNNING = "running"
    SUCCEEDED = "succeeded"
    FAILED = "failed"
    CANCELLED = "cancelled"


@dataclass
class ResourceRequirements:
    """资源需求"""
    gpu_count: int
    gpu_type: str = "A100"
    gpu_memory_gb: int = 80
    cpu_count: int = 8
    memory_gb: int = 64
    storage_gb: int = 100
    network_bandwidth_gbps: int = 100


@dataclass
class TrainingConfig:
    """训练配置"""
    framework: str  # pytorch, tensorflow
    parallel_strategy: ParallelStrategy
    world_size: int
    batch_size: int
    gradient_accumulation_steps: int = 1
    mixed_precision: bool = True
    checkpoint_interval: int = 1000
    max_steps: Optional[int] = None
    max_epochs: Optional[int] = None


@dataclass
class TrainingJob:
    """训练任务"""
    job_id: str
    name: str
    user_id: str
    namespace: str

    # 代码和数据
    code_path: str
    data_path: str
    output_path: str

    # 配置
    config: TrainingConfig
    resources: ResourceRequirements

    # 状态
    status: JobStatus = JobStatus.PENDING
    created_at: datetime = field(default_factory=datetime.now)
    started_at: Optional[datetime] = None
    finished_at: Optional[datetime] = None

    # 运行时信息
    worker_pods: List[str] = field(default_factory=list)
    current_step: int = 0
    current_epoch: int = 0

    # 元数据
    priority: int = 0
    preemptible: bool = True
    tags: Dict[str, str] = field(default_factory=dict)


class JobScheduler:
    """任务调度器"""

    def __init__(self, resource_manager: 'ResourceManager'):
        self.resource_manager = resource_manager
        self.pending_queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
        self.running_jobs: Dict[str, TrainingJob] = {}

    async def submit_job(self, job: TrainingJob) -> str:
        """提交任务"""
        # 验证任务配置
        self._validate_job(job)

        # 检查配额
        if not await self.resource_manager.check_quota(
            job.namespace, job.resources
        ):
            raise QuotaExceededException(f"Quota exceeded for {job.namespace}")

        # 加入队列
        # 优先级:(priority, submit_time),数值小优先
        priority_key = (-job.priority, job.created_at.timestamp())
        await self.pending_queue.put((priority_key, job))

        job.status = JobStatus.QUEUED
        return job.job_id

    async def schedule_loop(self):
        """调度循环"""
        while True:
            try:
                # 获取待调度任务
                priority_key, job = await asyncio.wait_for(
                    self.pending_queue.get(),
                    timeout=1.0
                )

                # 尝试调度
                job.status = JobStatus.SCHEDULING
                scheduled = await self._try_schedule(job)

                if not scheduled:
                    # 放回队列
                    await self.pending_queue.put((priority_key, job))
                    job.status = JobStatus.QUEUED
                    await asyncio.sleep(5)

            except asyncio.TimeoutError:
                continue
            except Exception as e:
                print(f"Schedule error: {e}")

    async def _try_schedule(self, job: TrainingJob) -> bool:
        """尝试调度任务"""
        # 获取可用资源
        available = await self.resource_manager.get_available_resources(
            job.resources.gpu_type
        )

        required_gpus = job.config.world_size

        if available.gpu_count < required_gpus:
            # 检查是否可以抢占
            if job.priority > 0:
                preempted = await self._try_preempt(job, required_gpus)
                if not preempted:
                    return False
            else:
                return False

        # 分配资源
        allocation = await self.resource_manager.allocate(
            job.job_id,
            job.resources,
            job.config.world_size
        )

        if not allocation:
            return False

        # 启动训练
        await self._start_training(job, allocation)

        job.status = JobStatus.RUNNING
        job.started_at = datetime.now()
        self.running_jobs[job.job_id] = job

        return True

    async def _try_preempt(self, job: TrainingJob, required_gpus: int) -> bool:
        """尝试抢占低优先级任务"""
        preemptible_jobs = [
            j for j in self.running_jobs.values()
            if j.preemptible and j.priority < job.priority
        ]

        # 按优先级排序,优先抢占低优先级
        preemptible_jobs.sort(key=lambda j: j.priority)

        freed_gpus = 0
        jobs_to_preempt = []

        for j in preemptible_jobs:
            jobs_to_preempt.append(j)
            freed_gpus += j.config.world_size
            if freed_gpus >= required_gpus:
                break

        if freed_gpus < required_gpus:
            return False

        # 执行抢占
        for j in jobs_to_preempt:
            await self._preempt_job(j)

        return True

    async def _preempt_job(self, job: TrainingJob):
        """抢占任务"""
        # 保存 checkpoint
        await self._save_checkpoint(job)

        # 释放资源
        await self.resource_manager.release(job.job_id)

        # 重新入队
        del self.running_jobs[job.job_id]
        job.status = JobStatus.QUEUED
        priority_key = (-job.priority, job.created_at.timestamp())
        await self.pending_queue.put((priority_key, job))

    async def _start_training(
        self,
        job: TrainingJob,
        allocation: 'ResourceAllocation'
    ):
        """启动训练"""
        # 生成训练配置
        training_spec = self._generate_training_spec(job, allocation)

        # 创建 K8s 资源
        await self._create_k8s_resources(training_spec)

    def _validate_job(self, job: TrainingJob):
        """验证任务配置"""
        if job.config.world_size <= 0:
            raise ValueError("world_size must be positive")

        if job.resources.gpu_count <= 0:
            raise ValueError("gpu_count must be positive")

2. 弹性训练支持

from abc import ABC, abstractmethod
import aiohttp


class ElasticTrainingController:
    """弹性训练控制器"""

    def __init__(self, etcd_client, k8s_client):
        self.etcd = etcd_client
        self.k8s = k8s_client
        self.rendezvous_handlers: Dict[str, RendezvousHandler] = {}

    async def register_job(self, job: TrainingJob):
        """注册弹性训练任务"""
        handler = RendezvousHandler(
            job_id=job.job_id,
            min_workers=max(1, job.config.world_size // 2),
            max_workers=job.config.world_size * 2,
            etcd=self.etcd
        )
        self.rendezvous_handlers[job.job_id] = handler

    async def handle_worker_failure(self, job_id: str, worker_id: str):
        """处理 Worker 故障"""
        handler = self.rendezvous_handlers.get(job_id)
        if not handler:
            return

        # 标记 Worker 失败
        await handler.mark_worker_failed(worker_id)

        # 检查是否需要重新 rendezvous
        active_workers = await handler.get_active_workers()

        if len(active_workers) < handler.min_workers:
            # 等待新 Worker 加入或启动新 Worker
            await self._scale_up_workers(job_id)
        else:
            # 触发重新 rendezvous
            await handler.trigger_rendezvous()

    async def handle_worker_join(self, job_id: str, worker_id: str):
        """处理 Worker 加入"""
        handler = self.rendezvous_handlers.get(job_id)
        if not handler:
            return

        await handler.register_worker(worker_id)

        # 检查是否达到启动条件
        active_workers = await handler.get_active_workers()

        if len(active_workers) >= handler.min_workers:
            await handler.trigger_rendezvous()

    async def _scale_up_workers(self, job_id: str):
        """扩容 Worker"""
        job = await self._get_job(job_id)
        handler = self.rendezvous_handlers[job_id]

        current_workers = len(await handler.get_active_workers())
        target_workers = min(
            handler.max_workers,
            current_workers + 1
        )

        # 创建新的 Worker Pod
        await self._create_worker_pod(job, target_workers - current_workers)


class RendezvousHandler:
    """Rendezvous 处理器"""

    def __init__(
        self,
        job_id: str,
        min_workers: int,
        max_workers: int,
        etcd
    ):
        self.job_id = job_id
        self.min_workers = min_workers
        self.max_workers = max_workers
        self.etcd = etcd
        self.rendezvous_version = 0

    async def register_worker(self, worker_id: str):
        """注册 Worker"""
        key = f"/rendezvous/{self.job_id}/workers/{worker_id}"
        value = {
            "worker_id": worker_id,
            "status": "active",
            "registered_at": datetime.now().isoformat()
        }
        await self.etcd.put(key, json.dumps(value))

    async def mark_worker_failed(self, worker_id: str):
        """标记 Worker 失败"""
        key = f"/rendezvous/{self.job_id}/workers/{worker_id}"
        await self.etcd.delete(key)

    async def get_active_workers(self) -> List[str]:
        """获取活跃 Worker 列表"""
        prefix = f"/rendezvous/{self.job_id}/workers/"
        result = await self.etcd.get_prefix(prefix)

        workers = []
        for key, value in result:
            worker_info = json.loads(value)
            if worker_info.get("status") == "active":
                workers.append(worker_info["worker_id"])

        return workers

    async def trigger_rendezvous(self):
        """触发 Rendezvous"""
        self.rendezvous_version += 1

        workers = await self.get_active_workers()

        # 更新 rendezvous 信息
        rendezvous_info = {
            "version": self.rendezvous_version,
            "workers": workers,
            "world_size": len(workers),
            "timestamp": datetime.now().isoformat()
        }

        key = f"/rendezvous/{self.job_id}/current"
        await self.etcd.put(key, json.dumps(rendezvous_info))

        # 通知所有 Worker
        await self._notify_workers(workers, rendezvous_info)

    async def _notify_workers(
        self,
        workers: List[str],
        rendezvous_info: dict
    ):
        """通知 Worker 重新配置"""
        for i, worker_id in enumerate(workers):
            worker_config = {
                **rendezvous_info,
                "rank": i,
                "local_rank": 0,  # 简化处理
            }

            key = f"/rendezvous/{self.job_id}/workers/{worker_id}/config"
            await self.etcd.put(key, json.dumps(worker_config))

3. Checkpoint 管理

import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import hashlib


class CheckpointManager:
    """Checkpoint 管理器"""

    def __init__(
        self,
        storage_backend: 'StorageBackend',
        cache_dir: str = "/tmp/ckpt_cache"
    ):
        self.storage = storage_backend
        self.cache_dir = cache_dir
        self.executor = ThreadPoolExecutor(max_workers=4)

    async def save_checkpoint(
        self,
        job_id: str,
        step: int,
        state_dict: dict,
        metadata: Optional[dict] = None
    ) -> str:
        """保存 Checkpoint"""
        # 生成 checkpoint 路径
        ckpt_path = f"{job_id}/step_{step}"

        # 异步保存
        await asyncio.get_event_loop().run_in_executor(
            self.executor,
            self._save_to_storage,
            ckpt_path,
            state_dict,
            metadata
        )

        # 记录 checkpoint 信息
        await self._record_checkpoint(job_id, step, ckpt_path)

        # 清理旧 checkpoint
        await self._cleanup_old_checkpoints(job_id)

        return ckpt_path

    def _save_to_storage(
        self,
        path: str,
        state_dict: dict,
        metadata: Optional[dict]
    ):
        """保存到存储"""
        import torch

        # 保存模型状态
        model_path = f"{path}/model.pt"
        torch.save(state_dict.get("model"), model_path)

        # 保存优化器状态
        optimizer_path = f"{path}/optimizer.pt"
        torch.save(state_dict.get("optimizer"), optimizer_path)

        # 保存调度器状态
        if "scheduler" in state_dict:
            scheduler_path = f"{path}/scheduler.pt"
            torch.save(state_dict["scheduler"], scheduler_path)

        # 保存元数据
        if metadata:
            meta_path = f"{path}/metadata.json"
            with open(meta_path, "w") as f:
                json.dump(metadata, f)

        # 上传到分布式存储
        self.storage.upload_directory(path)

    async def load_checkpoint(
        self,
        job_id: str,
        step: Optional[int] = None
    ) -> Optional[dict]:
        """加载 Checkpoint"""
        # 获取最新的 checkpoint
        if step is None:
            step = await self._get_latest_checkpoint_step(job_id)

        if step is None:
            return None

        ckpt_path = f"{job_id}/step_{step}"

        # 下载到本地缓存
        local_path = os.path.join(self.cache_dir, ckpt_path)
        await asyncio.get_event_loop().run_in_executor(
            self.executor,
            self.storage.download_directory,
            ckpt_path,
            local_path
        )

        # 加载状态
        return await asyncio.get_event_loop().run_in_executor(
            self.executor,
            self._load_from_local,
            local_path
        )

    def _load_from_local(self, local_path: str) -> dict:
        """从本地加载"""
        import torch

        state_dict = {}

        model_path = os.path.join(local_path, "model.pt")
        if os.path.exists(model_path):
            state_dict["model"] = torch.load(model_path)

        optimizer_path = os.path.join(local_path, "optimizer.pt")
        if os.path.exists(optimizer_path):
            state_dict["optimizer"] = torch.load(optimizer_path)

        scheduler_path = os.path.join(local_path, "scheduler.pt")
        if os.path.exists(scheduler_path):
            state_dict["scheduler"] = torch.load(scheduler_path)

        meta_path = os.path.join(local_path, "metadata.json")
        if os.path.exists(meta_path):
            with open(meta_path, "r") as f:
                state_dict["metadata"] = json.load(f)

        return state_dict

    async def _cleanup_old_checkpoints(
        self,
        job_id: str,
        keep_last: int = 3
    ):
        """清理旧 Checkpoint"""
        checkpoints = await self._list_checkpoints(job_id)

        if len(checkpoints) <= keep_last:
            return

        # 删除旧的
        to_delete = checkpoints[:-keep_last]
        for ckpt in to_delete:
            await asyncio.get_event_loop().run_in_executor(
                self.executor,
                self.storage.delete_directory,
                ckpt["path"]
            )

题目二:设计模型推理服务平台

题目描述

设计一个高性能、高可用的模型推理服务平台,需要支持:

  • 多模型部署与版本管理
  • 动态 Batching
  • 自动扩缩容
  • A/B 测试与灰度发布
  • 模型热更新

需求分析

功能需求:

  • 模型部署与下线
  • 多版本管理
  • 流量分发
  • 推理请求处理
  • 监控与日志

非功能需求:

  • P99 延迟 < 100ms
  • 可用性 > 99.99%
  • 支持 10000+ QPS
  • 模型切换无感知

架构设计

┌─────────────────────────────────────────────────────────────────────────┐
│                         模型推理服务平台                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                        接入层                                    │   │
│  │   负载均衡 (L4/L7) │ API Gateway │ 认证鉴权 │ 限流熔断            │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        路由层                                    │   │
│  │                                                                  │   │
│  │   ┌──────────┐   ┌──────────┐   ┌──────────┐   ┌──────────┐    │   │
│  │   │ 服务发现  │   │ 流量路由  │   │ A/B测试   │   │ 灰度发布  │    │   │
│  │   │          │   │          │   │ Controller│   │ Controller│    │   │
│  │   └──────────┘   └──────────┘   └──────────┘   └──────────┘    │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        推理层                                    │   │
│  │                                                                  │   │
│  │   ┌──────────────────────────────────────────────────────────┐  │   │
│  │   │                   推理实例集群                             │  │   │
│  │   │                                                          │  │   │
│  │   │  ┌─────────────────────┐   ┌─────────────────────┐      │  │   │
│  │   │  │    模型 A v1.0      │   │    模型 A v1.1      │      │  │   │
│  │   │  │  ┌───┐ ┌───┐ ┌───┐ │   │  ┌───┐ ┌───┐       │      │  │   │
│  │   │  │  │Pod│ │Pod│ │Pod│ │   │  │Pod│ │Pod│       │      │  │   │
│  │   │  │  └───┘ └───┘ └───┘ │   │  └───┘ └───┘       │      │  │   │
│  │   │  └─────────────────────┘   └─────────────────────┘      │  │   │
│  │   │                                                          │  │   │
│  │   │  ┌─────────────────────┐   ┌─────────────────────┐      │  │   │
│  │   │  │    模型 B v2.0      │   │    模型 C v1.0      │      │  │   │
│  │   │  │  ┌───┐ ┌───┐       │   │  ┌───┐              │      │  │   │
│  │   │  │  │Pod│ │Pod│       │   │  │Pod│              │      │  │   │
│  │   │  │  └───┘ └───┘       │   │  └───┘              │      │  │   │
│  │   │  └─────────────────────┘   └─────────────────────┘      │  │   │
│  │   │                                                          │  │   │
│  │   └──────────────────────────────────────────────────────────┘  │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        管控层                                    │   │
│  │                                                                  │   │
│  │   模型仓库 │ 部署管理 │ 版本管理 │ 自动扩缩 │ 监控告警             │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

核心组件设计

1. 动态 Batching 服务

import asyncio
from collections import defaultdict
import time


@dataclass
class InferenceRequest:
    """推理请求"""
    request_id: str
    model_name: str
    model_version: str
    inputs: Dict[str, np.ndarray]
    created_at: float = field(default_factory=time.time)
    future: asyncio.Future = field(default_factory=asyncio.Future)


@dataclass
class InferenceBatch:
    """推理批次"""
    requests: List[InferenceRequest]
    model_name: str
    model_version: str
    created_at: float = field(default_factory=time.time)


class DynamicBatcher:
    """动态 Batching 服务"""

    def __init__(
        self,
        max_batch_size: int = 32,
        max_wait_time_ms: float = 10.0,
        preferred_batch_sizes: List[int] = None
    ):
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time_ms / 1000.0
        self.preferred_batch_sizes = preferred_batch_sizes or [1, 2, 4, 8, 16, 32]

        # 请求队列:model_key -> queue
        self.request_queues: Dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)

        # 推理引擎
        self.inference_engines: Dict[str, 'InferenceEngine'] = {}

    async def submit_request(self, request: InferenceRequest) -> dict:
        """提交推理请求"""
        model_key = f"{request.model_name}:{request.model_version}"

        # 加入队列
        await self.request_queues[model_key].put(request)

        # 等待结果
        return await request.future

    async def start_batching_loop(self, model_key: str):
        """启动 Batching 循环"""
        queue = self.request_queues[model_key]
        engine = self.inference_engines[model_key]

        while True:
            batch = await self._collect_batch(queue)

            if batch:
                # 执行推理
                try:
                    results = await engine.infer_batch(batch)

                    # 分发结果
                    for request, result in zip(batch.requests, results):
                        request.future.set_result(result)

                except Exception as e:
                    # 设置异常
                    for request in batch.requests:
                        request.future.set_exception(e)

    async def _collect_batch(self, queue: asyncio.Queue) -> Optional[InferenceBatch]:
        """收集批次"""
        requests = []

        # 等待第一个请求
        try:
            first_request = await asyncio.wait_for(
                queue.get(),
                timeout=1.0
            )
            requests.append(first_request)
        except asyncio.TimeoutError:
            return None

        batch_start_time = time.time()

        # 收集更多请求
        while len(requests) < self.max_batch_size:
            remaining_time = self.max_wait_time - (time.time() - batch_start_time)

            if remaining_time <= 0:
                break

            # 检查是否达到首选批次大小
            if len(requests) in self.preferred_batch_sizes:
                # 快速检查是否有更多请求
                try:
                    request = await asyncio.wait_for(
                        queue.get(),
                        timeout=0.001  # 1ms 快速检查
                    )
                    requests.append(request)
                except asyncio.TimeoutError:
                    break
            else:
                try:
                    request = await asyncio.wait_for(
                        queue.get(),
                        timeout=remaining_time
                    )
                    requests.append(request)
                except asyncio.TimeoutError:
                    break

        return InferenceBatch(
            requests=requests,
            model_name=requests[0].model_name,
            model_version=requests[0].model_version
        )


class InferenceEngine:
    """推理引擎"""

    def __init__(
        self,
        model_path: str,
        device: str = "cuda:0",
        optimization_level: int = 3
    ):
        self.model_path = model_path
        self.device = device
        self.model = None
        self.optimization_level = optimization_level

    async def load_model(self):
        """加载模型"""
        import torch

        # 加载模型
        self.model = torch.jit.load(self.model_path)
        self.model.to(self.device)
        self.model.eval()

        # 优化模型
        if self.optimization_level >= 2:
            self.model = torch.jit.optimize_for_inference(self.model)

    async def infer_batch(self, batch: InferenceBatch) -> List[dict]:
        """批量推理"""
        import torch

        # 合并输入
        batched_inputs = self._batch_inputs(batch.requests)

        # 推理
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                outputs = self.model(**batched_inputs)

        # 拆分输出
        return self._unbatch_outputs(outputs, len(batch.requests))

    def _batch_inputs(self, requests: List[InferenceRequest]) -> dict:
        """合并输入"""
        import torch

        batched = {}
        for key in requests[0].inputs.keys():
            arrays = [r.inputs[key] for r in requests]
            batched[key] = torch.from_numpy(
                np.stack(arrays)
            ).to(self.device)

        return batched

    def _unbatch_outputs(self, outputs, batch_size: int) -> List[dict]:
        """拆分输出"""
        results = []

        for i in range(batch_size):
            result = {}
            for key, tensor in outputs.items():
                result[key] = tensor[i].cpu().numpy()
            results.append(result)

        return results

2. 流量路由与 A/B 测试

from abc import ABC, abstractmethod
import random
import hashlib


@dataclass
class TrafficRule:
    """流量规则"""
    rule_id: str
    model_name: str
    conditions: List['Condition']
    actions: List['Action']
    priority: int = 0


@dataclass
class Condition:
    """条件"""
    field: str  # header.user_id, query.region, etc.
    operator: str  # eq, ne, in, contains, regex
    value: Any


@dataclass
class Action:
    """动作"""
    type: str  # route, ab_test, canary
    config: dict


class TrafficRouter:
    """流量路由器"""

    def __init__(self):
        self.rules: Dict[str, List[TrafficRule]] = defaultdict(list)
        self.model_versions: Dict[str, List['ModelVersion']] = {}

    def add_rule(self, rule: TrafficRule):
        """添加路由规则"""
        self.rules[rule.model_name].append(rule)
        # 按优先级排序
        self.rules[rule.model_name].sort(key=lambda r: -r.priority)

    def route(self, request: 'InferenceRequest', context: dict) -> str:
        """路由请求到目标版本"""
        rules = self.rules.get(request.model_name, [])

        for rule in rules:
            if self._match_conditions(rule.conditions, context):
                return self._apply_actions(rule.actions, request, context)

        # 默认路由到稳定版本
        return self._get_stable_version(request.model_name)

    def _match_conditions(
        self,
        conditions: List[Condition],
        context: dict
    ) -> bool:
        """匹配条件"""
        for condition in conditions:
            value = self._get_field_value(condition.field, context)

            if condition.operator == "eq":
                if value != condition.value:
                    return False
            elif condition.operator == "ne":
                if value == condition.value:
                    return False
            elif condition.operator == "in":
                if value not in condition.value:
                    return False
            elif condition.operator == "contains":
                if condition.value not in str(value):
                    return False

        return True

    def _apply_actions(
        self,
        actions: List[Action],
        request: 'InferenceRequest',
        context: dict
    ) -> str:
        """应用动作"""
        for action in actions:
            if action.type == "route":
                return action.config["version"]

            elif action.type == "ab_test":
                return self._ab_test_route(action.config, context)

            elif action.type == "canary":
                return self._canary_route(action.config, context)

        return self._get_stable_version(request.model_name)

    def _ab_test_route(self, config: dict, context: dict) -> str:
        """A/B 测试路由"""
        # 基于用户 ID 的一致性哈希
        user_id = context.get("user_id", str(uuid.uuid4()))
        hash_value = int(hashlib.md5(user_id.encode()).hexdigest(), 16)

        # 计算分桶
        bucket = hash_value % 100

        cumulative = 0
        for variant in config["variants"]:
            cumulative += variant["weight"]
            if bucket < cumulative:
                return variant["version"]

        return config["variants"][0]["version"]

    def _canary_route(self, config: dict, context: dict) -> str:
        """金丝雀发布路由"""
        canary_weight = config.get("weight", 10)  # 默认 10%

        if random.randint(1, 100) <= canary_weight:
            return config["canary_version"]

        return config["stable_version"]


class ABTestManager:
    """A/B 测试管理器"""

    def __init__(self, metrics_store: 'MetricsStore'):
        self.metrics_store = metrics_store
        self.experiments: Dict[str, 'Experiment'] = {}

    async def create_experiment(
        self,
        name: str,
        model_name: str,
        variants: List[dict],
        metrics: List[str],
        traffic_allocation: int = 100
    ) -> str:
        """创建实验"""
        experiment = Experiment(
            experiment_id=str(uuid.uuid4()),
            name=name,
            model_name=model_name,
            variants=variants,
            metrics=metrics,
            traffic_allocation=traffic_allocation,
            status="running",
            created_at=datetime.now()
        )

        self.experiments[experiment.experiment_id] = experiment

        # 创建流量规则
        rule = self._create_traffic_rule(experiment)
        # ... 注册规则

        return experiment.experiment_id

    async def get_experiment_results(
        self,
        experiment_id: str
    ) -> 'ExperimentResults':
        """获取实验结果"""
        experiment = self.experiments[experiment_id]

        results = ExperimentResults(
            experiment_id=experiment_id,
            variants=[]
        )

        for variant in experiment.variants:
            variant_metrics = {}

            for metric in experiment.metrics:
                # 获取指标数据
                data = await self.metrics_store.get_metric(
                    model_name=experiment.model_name,
                    version=variant["version"],
                    metric=metric
                )

                variant_metrics[metric] = {
                    "mean": np.mean(data),
                    "std": np.std(data),
                    "p50": np.percentile(data, 50),
                    "p99": np.percentile(data, 99),
                    "sample_size": len(data)
                }

            results.variants.append({
                "version": variant["version"],
                "metrics": variant_metrics
            })

        # 计算统计显著性
        results.significance = self._calculate_significance(results)

        return results

    def _calculate_significance(
        self,
        results: 'ExperimentResults'
    ) -> dict:
        """计算统计显著性"""
        from scipy import stats

        significance = {}

        if len(results.variants) < 2:
            return significance

        control = results.variants[0]

        for variant in results.variants[1:]:
            for metric in control["metrics"].keys():
                # T 检验
                t_stat, p_value = stats.ttest_ind_from_stats(
                    control["metrics"][metric]["mean"],
                    control["metrics"][metric]["std"],
                    control["metrics"][metric]["sample_size"],
                    variant["metrics"][metric]["mean"],
                    variant["metrics"][metric]["std"],
                    variant["metrics"][metric]["sample_size"]
                )

                significance[f"{variant['version']}_{metric}"] = {
                    "t_statistic": t_stat,
                    "p_value": p_value,
                    "significant": p_value < 0.05
                }

        return significance

3. 模型热更新

class ModelHotSwapper:
    """模型热更新器"""

    def __init__(
        self,
        model_registry: 'ModelRegistry',
        inference_engines: Dict[str, InferenceEngine]
    ):
        self.model_registry = model_registry
        self.inference_engines = inference_engines
        self.update_lock = asyncio.Lock()

    async def hot_swap(
        self,
        model_name: str,
        from_version: str,
        to_version: str,
        strategy: str = "rolling"
    ):
        """热更新模型"""
        if strategy == "rolling":
            await self._rolling_update(model_name, from_version, to_version)
        elif strategy == "blue_green":
            await self._blue_green_update(model_name, from_version, to_version)
        elif strategy == "canary":
            await self._canary_update(model_name, from_version, to_version)

    async def _rolling_update(
        self,
        model_name: str,
        from_version: str,
        to_version: str
    ):
        """滚动更新"""
        model_key = f"{model_name}:{from_version}"
        new_model_key = f"{model_name}:{to_version}"

        # 加载新模型
        new_engine = InferenceEngine(
            model_path=await self.model_registry.get_model_path(
                model_name, to_version
            )
        )
        await new_engine.load_model()

        # 预热新模型
        await self._warmup_model(new_engine)

        async with self.update_lock:
            # 注册新引擎
            self.inference_engines[new_model_key] = new_engine

            # 逐步切换流量
            for ratio in [10, 30, 50, 70, 90, 100]:
                await self._update_traffic_ratio(
                    model_name, from_version, to_version, ratio
                )

                # 监控新版本性能
                if not await self._check_health(new_model_key):
                    # 回滚
                    await self._rollback(model_name, from_version, to_version)
                    return

                await asyncio.sleep(30)  # 观察期

            # 卸载旧模型
            if model_key in self.inference_engines:
                del self.inference_engines[model_key]

    async def _blue_green_update(
        self,
        model_name: str,
        from_version: str,
        to_version: str
    ):
        """蓝绿部署"""
        new_model_key = f"{model_name}:{to_version}"
        old_model_key = f"{model_name}:{from_version}"

        # 加载新模型(绿色环境)
        new_engine = InferenceEngine(
            model_path=await self.model_registry.get_model_path(
                model_name, to_version
            )
        )
        await new_engine.load_model()
        await self._warmup_model(new_engine)

        # 健康检查
        self.inference_engines[new_model_key] = new_engine

        if not await self._check_health(new_model_key):
            del self.inference_engines[new_model_key]
            raise Exception(f"Health check failed for {new_model_key}")

        async with self.update_lock:
            # 瞬间切换流量
            await self._switch_traffic(model_name, from_version, to_version)

            # 保留旧版本一段时间用于回滚
            await asyncio.sleep(300)  # 5 分钟观察期

            if await self._check_health(new_model_key):
                # 清理旧版本
                if old_model_key in self.inference_engines:
                    del self.inference_engines[old_model_key]
            else:
                # 回滚
                await self._rollback(model_name, from_version, to_version)

    async def _warmup_model(self, engine: InferenceEngine):
        """预热模型"""
        # 发送预热请求
        dummy_inputs = self._generate_dummy_inputs(engine)

        for _ in range(100):
            await engine.infer_batch(
                InferenceBatch(
                    requests=[InferenceRequest(
                        request_id=str(uuid.uuid4()),
                        model_name="warmup",
                        model_version="warmup",
                        inputs=dummy_inputs
                    )],
                    model_name="warmup",
                    model_version="warmup"
                )
            )

    async def _check_health(self, model_key: str) -> bool:
        """健康检查"""
        engine = self.inference_engines.get(model_key)
        if not engine:
            return False

        try:
            # 发送测试请求
            dummy_inputs = self._generate_dummy_inputs(engine)
            result = await asyncio.wait_for(
                engine.infer_batch(
                    InferenceBatch(
                        requests=[InferenceRequest(
                            request_id=str(uuid.uuid4()),
                            model_name="health",
                            model_version="health",
                            inputs=dummy_inputs
                        )],
                        model_name="health",
                        model_version="health"
                    )
                ),
                timeout=5.0
            )
            return result is not None
        except Exception:
            return False

题目三:设计特征平台

题目描述

设计一个统一的特征平台,需要支持:

  • 特征注册与元数据管理
  • 离线特征计算与存储
  • 在线特征服务
  • 特征一致性保证
  • 特征监控与数据质量

架构设计

┌─────────────────────────────────────────────────────────────────────────┐
│                           特征平台架构                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                      特征定义层                                   │   │
│  │                                                                  │   │
│  │   Feature Registry │ Schema Management │ Feature Discovery       │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                      特征计算层                                   │   │
│  │                                                                  │   │
│  │   ┌───────────────────────┐   ┌───────────────────────┐         │   │
│  │   │    离线计算引擎        │   │    实时计算引擎        │         │   │
│  │   │                       │   │                       │         │   │
│  │   │  Spark │ Flink Batch  │   │  Flink │ Kafka Streams│         │   │
│  │   │                       │   │                       │         │   │
│  │   └───────────────────────┘   └───────────────────────┘         │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                      特征存储层                                   │   │
│  │                                                                  │   │
│  │   ┌───────────────────────┐   ┌───────────────────────┐         │   │
│  │   │    离线存储            │   │    在线存储            │         │   │
│  │   │                       │   │                       │         │   │
│  │   │  Hive │ Delta Lake    │   │  Redis │ Cassandra    │         │   │
│  │   │                       │   │                       │         │   │
│  │   └───────────────────────┘   └───────────────────────┘         │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                      特征服务层                                   │   │
│  │                                                                  │   │
│  │   Online Serving │ Batch Serving │ Feature Vector │ Point-in-Time│   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

核心组件设计

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
import pandas as pd
from datetime import datetime, timedelta


class FeatureType(Enum):
    """特征类型"""
    INT64 = "int64"
    FLOAT32 = "float32"
    FLOAT64 = "float64"
    STRING = "string"
    BOOL = "bool"
    ARRAY = "array"
    EMBEDDING = "embedding"


class AggregationType(Enum):
    """聚合类型"""
    SUM = "sum"
    AVG = "avg"
    COUNT = "count"
    MAX = "max"
    MIN = "min"
    LAST = "last"
    FIRST = "first"


@dataclass
class Entity:
    """实体定义"""
    name: str
    join_keys: List[str]
    description: str = ""


@dataclass
class FeatureView:
    """特征视图"""
    name: str
    entities: List[Entity]
    features: List['Feature']
    source: 'DataSource'
    ttl: Optional[timedelta] = None
    online: bool = True
    offline: bool = True
    description: str = ""


@dataclass
class Feature:
    """特征定义"""
    name: str
    dtype: FeatureType
    description: str = ""
    tags: Dict[str, str] = field(default_factory=dict)

    # 统计信息
    mean: Optional[float] = None
    std: Optional[float] = None
    min_value: Optional[float] = None
    max_value: Optional[float] = None


@dataclass
class DataSource:
    """数据源"""
    name: str
    type: str  # batch, stream
    config: dict
    timestamp_field: Optional[str] = None
    created_timestamp_field: Optional[str] = None


class FeatureRegistry:
    """特征注册中心"""

    def __init__(self, storage: 'RegistryStorage'):
        self.storage = storage
        self.entities: Dict[str, Entity] = {}
        self.feature_views: Dict[str, FeatureView] = {}

    def register_entity(self, entity: Entity):
        """注册实体"""
        self.entities[entity.name] = entity
        self.storage.save_entity(entity)

    def register_feature_view(self, feature_view: FeatureView):
        """注册特征视图"""
        # 验证特征视图
        self._validate_feature_view(feature_view)

        self.feature_views[feature_view.name] = feature_view
        self.storage.save_feature_view(feature_view)

        # 创建物化任务
        if feature_view.online or feature_view.offline:
            self._create_materialization_job(feature_view)

    def get_feature_view(self, name: str) -> Optional[FeatureView]:
        """获取特征视图"""
        return self.feature_views.get(name)

    def list_features(
        self,
        entity: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ) -> List[Feature]:
        """列出特征"""
        features = []

        for fv in self.feature_views.values():
            if entity and entity not in [e.name for e in fv.entities]:
                continue

            for feature in fv.features:
                if tags:
                    if not all(feature.tags.get(k) == v for k, v in tags.items()):
                        continue
                features.append(feature)

        return features

    def _validate_feature_view(self, feature_view: FeatureView):
        """验证特征视图"""
        # 检查实体是否存在
        for entity in feature_view.entities:
            if entity.name not in self.entities:
                raise ValueError(f"Entity {entity.name} not registered")

        # 检查特征名称唯一性
        feature_names = [f.name for f in feature_view.features]
        if len(feature_names) != len(set(feature_names)):
            raise ValueError("Duplicate feature names in feature view")


class OnlineFeatureStore:
    """在线特征存储"""

    def __init__(self, redis_client):
        self.redis = redis_client
        self.cache_ttl = 3600  # 1 hour

    async def get_features(
        self,
        feature_view: str,
        entity_keys: Dict[str, Any],
        features: List[str]
    ) -> Dict[str, Any]:
        """获取在线特征"""
        # 构建 key
        key = self._build_key(feature_view, entity_keys)

        # 从 Redis 获取
        if features:
            result = await self.redis.hmget(key, features)
            return {f: self._deserialize(v) for f, v in zip(features, result)}
        else:
            result = await self.redis.hgetall(key)
            return {k.decode(): self._deserialize(v) for k, v in result.items()}

    async def write_features(
        self,
        feature_view: str,
        entity_keys: Dict[str, Any],
        features: Dict[str, Any],
        timestamp: Optional[datetime] = None
    ):
        """写入在线特征"""
        key = self._build_key(feature_view, entity_keys)

        # 序列化并写入
        serialized = {k: self._serialize(v) for k, v in features.items()}
        await self.redis.hset(key, mapping=serialized)

        # 设置 TTL
        await self.redis.expire(key, self.cache_ttl)

    async def batch_get_features(
        self,
        feature_view: str,
        entity_keys_list: List[Dict[str, Any]],
        features: List[str]
    ) -> List[Dict[str, Any]]:
        """批量获取特征"""
        pipe = self.redis.pipeline()

        for entity_keys in entity_keys_list:
            key = self._build_key(feature_view, entity_keys)
            if features:
                pipe.hmget(key, features)
            else:
                pipe.hgetall(key)

        results = await pipe.execute()

        feature_list = []
        for result in results:
            if isinstance(result, list):
                feature_list.append({
                    f: self._deserialize(v)
                    for f, v in zip(features, result)
                })
            else:
                feature_list.append({
                    k.decode(): self._deserialize(v)
                    for k, v in result.items()
                })

        return feature_list

    def _build_key(self, feature_view: str, entity_keys: Dict[str, Any]) -> str:
        """构建存储 key"""
        sorted_keys = sorted(entity_keys.items())
        key_parts = [f"{k}:{v}" for k, v in sorted_keys]
        return f"feature:{feature_view}:{':'.join(key_parts)}"

    def _serialize(self, value: Any) -> bytes:
        """序列化值"""
        import pickle
        return pickle.dumps(value)

    def _deserialize(self, value: bytes) -> Any:
        """反序列化值"""
        import pickle
        if value is None:
            return None
        return pickle.loads(value)


class FeatureServer:
    """特征服务"""

    def __init__(
        self,
        registry: FeatureRegistry,
        online_store: OnlineFeatureStore,
        offline_store: 'OfflineFeatureStore'
    ):
        self.registry = registry
        self.online_store = online_store
        self.offline_store = offline_store

    async def get_online_features(
        self,
        feature_refs: List[str],  # "feature_view:feature_name"
        entity_rows: List[Dict[str, Any]]
    ) -> pd.DataFrame:
        """获取在线特征"""
        # 解析特征引用
        parsed_refs = self._parse_feature_refs(feature_refs)

        results = []

        for entity_row in entity_rows:
            row_features = dict(entity_row)

            for feature_view, features in parsed_refs.items():
                fv = self.registry.get_feature_view(feature_view)

                # 提取实体键
                entity_keys = {
                    key: entity_row[key]
                    for entity in fv.entities
                    for key in entity.join_keys
                }

                # 获取特征
                feature_values = await self.online_store.get_features(
                    feature_view,
                    entity_keys,
                    features
                )

                for feature, value in feature_values.items():
                    row_features[f"{feature_view}__{feature}"] = value

            results.append(row_features)

        return pd.DataFrame(results)

    async def get_historical_features(
        self,
        feature_refs: List[str],
        entity_df: pd.DataFrame,
        timestamp_column: str = "event_timestamp"
    ) -> pd.DataFrame:
        """获取历史特征(Point-in-Time Join)"""
        parsed_refs = self._parse_feature_refs(feature_refs)

        result_df = entity_df.copy()

        for feature_view, features in parsed_refs.items():
            fv = self.registry.get_feature_view(feature_view)

            # 获取离线特征
            feature_df = await self.offline_store.get_features(
                feature_view,
                features
            )

            # Point-in-Time Join
            result_df = self._point_in_time_join(
                result_df,
                feature_df,
                entity_df,
                fv,
                timestamp_column
            )

        return result_df

    def _point_in_time_join(
        self,
        entity_df: pd.DataFrame,
        feature_df: pd.DataFrame,
        original_entity_df: pd.DataFrame,
        feature_view: FeatureView,
        timestamp_column: str
    ) -> pd.DataFrame:
        """Point-in-Time Join"""
        # 获取 join keys
        join_keys = []
        for entity in feature_view.entities:
            join_keys.extend(entity.join_keys)

        # 确保时间戳列存在
        entity_df = entity_df.sort_values(timestamp_column)
        feature_df = feature_df.sort_values("feature_timestamp")

        # 执行 as-of join
        result = pd.merge_asof(
            entity_df,
            feature_df,
            left_on=timestamp_column,
            right_on="feature_timestamp",
            by=join_keys,
            direction="backward"
        )

        # 应用 TTL 过滤
        if feature_view.ttl:
            ttl_seconds = feature_view.ttl.total_seconds()
            time_diff = (
                result[timestamp_column] - result["feature_timestamp"]
            ).dt.total_seconds()

            # 超过 TTL 的特征设为 null
            expired_mask = time_diff > ttl_seconds
            for feature in feature_view.features:
                result.loc[expired_mask, feature.name] = None

        return result

    def _parse_feature_refs(
        self,
        feature_refs: List[str]
    ) -> Dict[str, List[str]]:
        """解析特征引用"""
        parsed = {}

        for ref in feature_refs:
            if ":" in ref:
                fv_name, feature_name = ref.split(":", 1)
            else:
                raise ValueError(f"Invalid feature ref: {ref}")

            if fv_name not in parsed:
                parsed[fv_name] = []
            parsed[fv_name].append(feature_name)

        return parsed

题目四:设计向量数据库

题目描述

设计一个高性能向量数据库,需要支持:

  • 大规模向量存储(十亿级别)
  • 高效相似性搜索(ANN)
  • 元数据过滤
  • 分布式部署
  • 实时更新

架构设计

┌─────────────────────────────────────────────────────────────────────────┐
│                         向量数据库架构                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                        客户端层                                   │   │
│  │   Python SDK │ REST API │ gRPC │ 管理控制台                       │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        协调层                                    │   │
│  │                                                                  │   │
│  │   ┌──────────┐   ┌──────────┐   ┌──────────┐   ┌──────────┐    │   │
│  │   │ 查询路由  │   │ 负载均衡  │   │ 分片管理  │   │ 元数据   │    │   │
│  │   │ Router   │   │ Balancer │   │ Sharding │   │ Service  │    │   │
│  │   └──────────┘   └──────────┘   └──────────┘   └──────────┘    │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        查询层                                    │   │
│  │                                                                  │   │
│  │   ┌───────────────────────────────────────────────────────────┐ │   │
│  │   │                    Query Node 集群                         │ │   │
│  │   │                                                           │ │   │
│  │   │  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐         │ │   │
│  │   │  │ Query Node 1│ │ Query Node 2│ │ Query Node N│         │ │   │
│  │   │  │             │ │             │ │             │         │ │   │
│  │   │  │ ┌─────────┐ │ │ ┌─────────┐ │ │ ┌─────────┐ │         │ │   │
│  │   │  │ │Index Seg│ │ │ │Index Seg│ │ │ │Index Seg│ │         │ │   │
│  │   │  │ └─────────┘ │ │ └─────────┘ │ │ └─────────┘ │         │ │   │
│  │   │  │ ┌─────────┐ │ │ ┌─────────┐ │ │ ┌─────────┐ │         │ │   │
│  │   │  │ │Index Seg│ │ │ │Index Seg│ │ │ │Index Seg│ │         │ │   │
│  │   │  │ └─────────┘ │ │ └─────────┘ │ │ └─────────┘ │         │ │   │
│  │   │  └─────────────┘ └─────────────┘ └─────────────┘         │ │   │
│  │   │                                                           │ │   │
│  │   └───────────────────────────────────────────────────────────┘ │   │
│  │                                                                  │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                        存储层                                    │   │
│  │                                                                  │   │
│  │   ┌───────────────────────┐   ┌───────────────────────┐         │   │
│  │   │    索引存储            │   │    原始数据存储        │         │   │
│  │   │                       │   │                       │         │   │
│  │   │  本地 SSD │ 共享存储   │   │  对象存储 (S3/MinIO)  │         │   │
│  │   │                       │   │                       │         │   │
│  │   └───────────────────────┘   └───────────────────────┘         │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

核心组件设计

import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from enum import Enum
import heapq
from abc import ABC, abstractmethod


class IndexType(Enum):
    """索引类型"""
    FLAT = "flat"
    IVF_FLAT = "ivf_flat"
    IVF_PQ = "ivf_pq"
    HNSW = "hnsw"
    DISKANN = "diskann"


class MetricType(Enum):
    """距离度量类型"""
    L2 = "l2"
    IP = "ip"  # Inner Product
    COSINE = "cosine"


@dataclass
class VectorRecord:
    """向量记录"""
    id: str
    vector: np.ndarray
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class SearchResult:
    """搜索结果"""
    id: str
    score: float
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class SearchRequest:
    """搜索请求"""
    collection: str
    query_vector: np.ndarray
    top_k: int = 10
    filter: Optional[Dict[str, Any]] = None
    search_params: Optional[Dict[str, Any]] = None


class VectorIndex(ABC):
    """向量索引基类"""

    @abstractmethod
    def build(self, vectors: np.ndarray, ids: List[str]):
        """构建索引"""
        pass

    @abstractmethod
    def search(
        self,
        query: np.ndarray,
        k: int,
        params: Optional[dict] = None
    ) -> Tuple[List[str], List[float]]:
        """搜索"""
        pass

    @abstractmethod
    def add(self, vectors: np.ndarray, ids: List[str]):
        """添加向量"""
        pass

    @abstractmethod
    def delete(self, ids: List[str]):
        """删除向量"""
        pass


class HNSWIndex(VectorIndex):
    """HNSW 索引实现"""

    def __init__(
        self,
        dim: int,
        metric: MetricType = MetricType.L2,
        M: int = 16,
        ef_construction: int = 200,
        ef_search: int = 50
    ):
        self.dim = dim
        self.metric = metric
        self.M = M
        self.ef_construction = ef_construction
        self.ef_search = ef_search

        # 图结构
        self.graphs: List[Dict[str, List[str]]] = []  # 每层的邻接表
        self.vectors: Dict[str, np.ndarray] = {}
        self.entry_point: Optional[str] = None
        self.max_level = 0
        self.node_levels: Dict[str, int] = {}

        # 概率参数
        self.ml = 1.0 / np.log(M)

    def build(self, vectors: np.ndarray, ids: List[str]):
        """构建索引"""
        for vector, id in zip(vectors, ids):
            self.add(np.array([vector]), [id])

    def add(self, vectors: np.ndarray, ids: List[str]):
        """添加向量"""
        for vector, id in zip(vectors, ids):
            self._insert(id, vector)

    def _insert(self, id: str, vector: np.ndarray):
        """插入单个向量"""
        self.vectors[id] = vector

        # 随机确定层数
        level = self._random_level()
        self.node_levels[id] = level

        # 扩展图结构
        while len(self.graphs) <= level:
            self.graphs.append({})

        if self.entry_point is None:
            self.entry_point = id
            self.max_level = level
            for l in range(level + 1):
                self.graphs[l][id] = []
            return

        # 搜索插入位置
        curr = self.entry_point

        # 从最高层向下搜索
        for l in range(self.max_level, level, -1):
            curr = self._search_layer(vector, curr, 1, l)[0]

        # 在 level 及以下层插入
        for l in range(min(level, self.max_level), -1, -1):
            neighbors = self._search_layer(vector, curr, self.ef_construction, l)

            # 选择邻居
            selected = self._select_neighbors(vector, neighbors, self.M)

            # 建立双向连接
            self.graphs[l][id] = selected

            for neighbor in selected:
                if neighbor in self.graphs[l]:
                    self.graphs[l][neighbor].append(id)

                    # 限制邻居数量
                    if len(self.graphs[l][neighbor]) > self.M:
                        neighbor_vec = self.vectors[neighbor]
                        self.graphs[l][neighbor] = self._select_neighbors(
                            neighbor_vec,
                            self.graphs[l][neighbor],
                            self.M
                        )

            if neighbors:
                curr = neighbors[0]

        # 更新入口点
        if level > self.max_level:
            self.entry_point = id
            self.max_level = level

    def search(
        self,
        query: np.ndarray,
        k: int,
        params: Optional[dict] = None
    ) -> Tuple[List[str], List[float]]:
        """搜索最近邻"""
        ef = params.get("ef", self.ef_search) if params else self.ef_search

        if self.entry_point is None:
            return [], []

        curr = self.entry_point

        # 从最高层向下搜索到第 1 层
        for l in range(self.max_level, 0, -1):
            curr = self._search_layer(query, curr, 1, l)[0]

        # 在第 0 层搜索
        candidates = self._search_layer(query, curr, ef, 0)

        # 返回 top-k
        results = []
        for candidate in candidates[:k]:
            score = self._distance(query, self.vectors[candidate])
            results.append((candidate, score))

        ids, scores = zip(*results) if results else ([], [])
        return list(ids), list(scores)

    def _search_layer(
        self,
        query: np.ndarray,
        entry: str,
        ef: int,
        level: int
    ) -> List[str]:
        """在单层搜索"""
        visited = {entry}
        candidates = []  # min heap
        results = []  # max heap (negative distance)

        entry_dist = self._distance(query, self.vectors[entry])
        heapq.heappush(candidates, (entry_dist, entry))
        heapq.heappush(results, (-entry_dist, entry))

        while candidates:
            curr_dist, curr = heapq.heappop(candidates)

            # 最远结果比当前候选近,停止
            if -results[0][0] < curr_dist:
                break

            # 遍历邻居
            neighbors = self.graphs[level].get(curr, [])
            for neighbor in neighbors:
                if neighbor in visited:
                    continue
                visited.add(neighbor)

                neighbor_dist = self._distance(query, self.vectors[neighbor])

                if len(results) < ef or neighbor_dist < -results[0][0]:
                    heapq.heappush(candidates, (neighbor_dist, neighbor))
                    heapq.heappush(results, (-neighbor_dist, neighbor))

                    if len(results) > ef:
                        heapq.heappop(results)

        # 按距离排序返回
        sorted_results = sorted([(-d, id) for d, id in results])
        return [id for _, id in sorted_results]

    def _select_neighbors(
        self,
        query: np.ndarray,
        candidates: List[str],
        M: int
    ) -> List[str]:
        """选择邻居(启发式)"""
        if len(candidates) <= M:
            return candidates

        # 按距离排序
        scored = [(self._distance(query, self.vectors[c]), c) for c in candidates]
        scored.sort()

        # 简单截断(可改为更复杂的启发式)
        return [c for _, c in scored[:M]]

    def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
        """计算距离"""
        if self.metric == MetricType.L2:
            return np.sum((a - b) ** 2)
        elif self.metric == MetricType.IP:
            return -np.dot(a, b)
        elif self.metric == MetricType.COSINE:
            return 1 - np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

    def _random_level(self) -> int:
        """随机生成层数"""
        level = 0
        while np.random.random() < 0.5 and level < 16:
            level += 1
        return level

    def delete(self, ids: List[str]):
        """删除向量"""
        for id in ids:
            if id not in self.vectors:
                continue

            level = self.node_levels.get(id, 0)

            # 从每层移除
            for l in range(level + 1):
                if id in self.graphs[l]:
                    # 移除指向该节点的边
                    for neighbor in self.graphs[l][id]:
                        if neighbor in self.graphs[l]:
                            self.graphs[l][neighbor] = [
                                n for n in self.graphs[l][neighbor] if n != id
                            ]
                    del self.graphs[l][id]

            del self.vectors[id]
            del self.node_levels[id]

            # 更新入口点
            if id == self.entry_point:
                self._update_entry_point()

    def _update_entry_point(self):
        """更新入口点"""
        if not self.vectors:
            self.entry_point = None
            self.max_level = 0
            return

        # 找到最高层的节点
        max_level = 0
        entry = None

        for id, level in self.node_levels.items():
            if level > max_level:
                max_level = level
                entry = id

        self.entry_point = entry
        self.max_level = max_level


class VectorCollection:
    """向量集合"""

    def __init__(
        self,
        name: str,
        dim: int,
        index_type: IndexType = IndexType.HNSW,
        metric: MetricType = MetricType.L2,
        index_params: Optional[dict] = None
    ):
        self.name = name
        self.dim = dim
        self.index_type = index_type
        self.metric = metric

        # 创建索引
        self.index = self._create_index(index_params or {})

        # 元数据存储
        self.metadata: Dict[str, Dict[str, Any]] = {}

    def _create_index(self, params: dict) -> VectorIndex:
        """创建索引"""
        if self.index_type == IndexType.HNSW:
            return HNSWIndex(
                dim=self.dim,
                metric=self.metric,
                M=params.get("M", 16),
                ef_construction=params.get("ef_construction", 200)
            )
        # 其他索引类型...
        raise ValueError(f"Unsupported index type: {self.index_type}")

    def insert(self, records: List[VectorRecord]):
        """插入记录"""
        vectors = np.array([r.vector for r in records])
        ids = [r.id for r in records]

        # 添加到索引
        self.index.add(vectors, ids)

        # 存储元数据
        for record in records:
            self.metadata[record.id] = record.metadata

    def search(
        self,
        query_vector: np.ndarray,
        top_k: int = 10,
        filter: Optional[Dict[str, Any]] = None,
        search_params: Optional[dict] = None
    ) -> List[SearchResult]:
        """搜索"""
        # 搜索向量索引
        ids, scores = self.index.search(query_vector, top_k * 10, search_params)

        results = []
        for id, score in zip(ids, scores):
            metadata = self.metadata.get(id, {})

            # 应用过滤器
            if filter and not self._match_filter(metadata, filter):
                continue

            results.append(SearchResult(
                id=id,
                score=score,
                metadata=metadata
            ))

            if len(results) >= top_k:
                break

        return results

    def _match_filter(self, metadata: dict, filter: dict) -> bool:
        """匹配过滤条件"""
        for key, condition in filter.items():
            if key not in metadata:
                return False

            value = metadata[key]

            if isinstance(condition, dict):
                # 复杂条件
                for op, op_value in condition.items():
                    if op == "$eq" and value != op_value:
                        return False
                    elif op == "$ne" and value == op_value:
                        return False
                    elif op == "$gt" and value <= op_value:
                        return False
                    elif op == "$gte" and value < op_value:
                        return False
                    elif op == "$lt" and value >= op_value:
                        return False
                    elif op == "$lte" and value > op_value:
                        return False
                    elif op == "$in" and value not in op_value:
                        return False
            else:
                # 简单等值条件
                if value != condition:
                    return False

        return True

    def delete(self, ids: List[str]):
        """删除记录"""
        self.index.delete(ids)
        for id in ids:
            self.metadata.pop(id, None)


class DistributedVectorDB:
    """分布式向量数据库"""

    def __init__(
        self,
        num_shards: int = 8,
        replication_factor: int = 2
    ):
        self.num_shards = num_shards
        self.replication_factor = replication_factor
        self.collections: Dict[str, 'ShardedCollection'] = {}

    def create_collection(
        self,
        name: str,
        dim: int,
        index_type: IndexType = IndexType.HNSW,
        metric: MetricType = MetricType.L2,
        index_params: Optional[dict] = None
    ):
        """创建集合"""
        shards = []
        for i in range(self.num_shards):
            shard = VectorCollection(
                name=f"{name}_shard_{i}",
                dim=dim,
                index_type=index_type,
                metric=metric,
                index_params=index_params
            )
            shards.append(shard)

        self.collections[name] = ShardedCollection(
            name=name,
            shards=shards,
            num_shards=self.num_shards
        )

    def insert(self, collection: str, records: List[VectorRecord]):
        """插入记录"""
        coll = self.collections[collection]

        # 按 shard 分组
        shard_records: Dict[int, List[VectorRecord]] = {}
        for record in records:
            shard_id = self._get_shard(record.id)
            if shard_id not in shard_records:
                shard_records[shard_id] = []
            shard_records[shard_id].append(record)

        # 并行插入各 shard
        for shard_id, shard_recs in shard_records.items():
            coll.shards[shard_id].insert(shard_recs)

    def search(
        self,
        collection: str,
        query_vector: np.ndarray,
        top_k: int = 10,
        filter: Optional[Dict[str, Any]] = None,
        search_params: Optional[dict] = None
    ) -> List[SearchResult]:
        """分布式搜索"""
        coll = self.collections[collection]

        # 并行搜索所有 shard
        all_results = []
        for shard in coll.shards:
            shard_results = shard.search(
                query_vector,
                top_k,
                filter,
                search_params
            )
            all_results.extend(shard_results)

        # 合并结果
        all_results.sort(key=lambda r: r.score)
        return all_results[:top_k]

    def _get_shard(self, id: str) -> int:
        """计算 shard ID"""
        return hash(id) % self.num_shards


@dataclass
class ShardedCollection:
    """分片集合"""
    name: str
    shards: List[VectorCollection]
    num_shards: int

总结

系统设计核心要点

系统核心挑战关键技术
训练平台大规模资源调度、故障恢复弹性训练、Checkpoint、抢占调度
推理平台低延迟、高吞吐动态 Batching、模型优化、流量管理
特征平台一致性、实时性Point-in-Time Join、特征物化
向量数据库高效 ANN、分布式HNSW、分片、元数据过滤

设计方法论

  1. 需求先行:明确功能和非功能需求
  2. 分层设计:清晰的职责划分
  3. 权衡取舍:理解各方案的 trade-off
  4. 演进思维:考虑系统的可扩展性
  5. 细节到位:关键算法和数据结构
Prev
02-大模型面试题