03-系统设计面试题
概述
本文聚焦 AI 基础设施相关的系统设计面试题,涵盖分布式训练平台、模型服务平台、特征平台、向量数据库等核心系统的设计与实现。
系统设计方法论
设计流程框架
┌─────────────────────────────────────────────────────────────────────────┐
│ 系统设计四步法 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Step 1: 需求澄清 (5分钟) │
│ ├── 功能需求:核心功能、用户场景 │
│ ├── 非功能需求:性能、可用性、一致性 │
│ ├── 规模估算:QPS、数据量、用户数 │
│ └── 约束条件:技术栈、成本、时间 │
│ │
│ Step 2: 高层设计 (10分钟) │
│ ├── 核心组件:识别主要模块 │
│ ├── 数据流:请求如何流转 │
│ ├── 接口设计:API 定义 │
│ └── 技术选型:存储、计算、通信 │
│ │
│ Step 3: 详细设计 (15分钟) │
│ ├── 核心模块深入 │
│ ├── 数据模型设计 │
│ ├── 关键算法 │
│ └── 扩展性设计 │
│ │
│ Step 4: 优化讨论 (10分钟) │
│ ├── 瓶颈分析 │
│ ├── 容错设计 │
│ ├── 监控告警 │
│ └── 演进路径 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
题目一:设计分布式模型训练平台
题目描述
设计一个支持大规模分布式训练的平台,需要支持:
- 多种并行策略(数据并行、模型并行、流水线并行)
- 弹性训练(节点故障恢复、动态扩缩容)
- 多租户资源隔离
- 训练任务生命周期管理
需求分析
功能需求:
- 提交训练任务(代码、数据、配置)
- 资源调度(GPU、内存、存储)
- 分布式训练协调
- 训练监控与日志
- Checkpoint 管理
- 任务生命周期管理
非功能需求:
- 支持 1000+ GPU 同时训练
- 任务启动延迟 < 5 分钟
- 单点故障自动恢复 < 10 分钟
- 资源利用率 > 80%
高层架构设计
┌─────────────────────────────────────────────────────────────────────────┐
│ 分布式训练平台架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 用户接口层 │ │
│ │ Web Console │ CLI Tool │ SDK │ REST API │ Notebook │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 任务管理层 │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 任务提交 │ │ 任务调度 │ │ 生命周期 │ │ 队列管理 │ │ │
│ │ │ Service │ │ Service │ │ Manager │ │ Service │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 资源管理层 │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 资源池 │ │ 配额管理 │ │ 资源调度 │ │ 弹性伸缩 │ │ │
│ │ │ Manager │ │ Service │ │ Service │ │ Service │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 训练运行层 │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 训练控制 │ │ 通信协调 │ │ Checkpoint│ │ 故障恢复 │ │ │
│ │ │ Controller│ │ Service │ │ Manager │ │ Service │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 基础设施层 │ │
│ │ │ │
│ │ Kubernetes │ GPU Operator │ 分布式存储 │ 高速网络(RDMA) │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
核心组件设计
1. 任务提交与调度
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
import asyncio
from datetime import datetime
import uuid
class ParallelStrategy(Enum):
"""并行策略"""
DATA_PARALLEL = "data_parallel"
MODEL_PARALLEL = "model_parallel"
PIPELINE_PARALLEL = "pipeline_parallel"
HYBRID = "hybrid"
class JobStatus(Enum):
"""任务状态"""
PENDING = "pending"
QUEUED = "queued"
SCHEDULING = "scheduling"
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class ResourceRequirements:
"""资源需求"""
gpu_count: int
gpu_type: str = "A100"
gpu_memory_gb: int = 80
cpu_count: int = 8
memory_gb: int = 64
storage_gb: int = 100
network_bandwidth_gbps: int = 100
@dataclass
class TrainingConfig:
"""训练配置"""
framework: str # pytorch, tensorflow
parallel_strategy: ParallelStrategy
world_size: int
batch_size: int
gradient_accumulation_steps: int = 1
mixed_precision: bool = True
checkpoint_interval: int = 1000
max_steps: Optional[int] = None
max_epochs: Optional[int] = None
@dataclass
class TrainingJob:
"""训练任务"""
job_id: str
name: str
user_id: str
namespace: str
# 代码和数据
code_path: str
data_path: str
output_path: str
# 配置
config: TrainingConfig
resources: ResourceRequirements
# 状态
status: JobStatus = JobStatus.PENDING
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
# 运行时信息
worker_pods: List[str] = field(default_factory=list)
current_step: int = 0
current_epoch: int = 0
# 元数据
priority: int = 0
preemptible: bool = True
tags: Dict[str, str] = field(default_factory=dict)
class JobScheduler:
"""任务调度器"""
def __init__(self, resource_manager: 'ResourceManager'):
self.resource_manager = resource_manager
self.pending_queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
self.running_jobs: Dict[str, TrainingJob] = {}
async def submit_job(self, job: TrainingJob) -> str:
"""提交任务"""
# 验证任务配置
self._validate_job(job)
# 检查配额
if not await self.resource_manager.check_quota(
job.namespace, job.resources
):
raise QuotaExceededException(f"Quota exceeded for {job.namespace}")
# 加入队列
# 优先级:(priority, submit_time),数值小优先
priority_key = (-job.priority, job.created_at.timestamp())
await self.pending_queue.put((priority_key, job))
job.status = JobStatus.QUEUED
return job.job_id
async def schedule_loop(self):
"""调度循环"""
while True:
try:
# 获取待调度任务
priority_key, job = await asyncio.wait_for(
self.pending_queue.get(),
timeout=1.0
)
# 尝试调度
job.status = JobStatus.SCHEDULING
scheduled = await self._try_schedule(job)
if not scheduled:
# 放回队列
await self.pending_queue.put((priority_key, job))
job.status = JobStatus.QUEUED
await asyncio.sleep(5)
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"Schedule error: {e}")
async def _try_schedule(self, job: TrainingJob) -> bool:
"""尝试调度任务"""
# 获取可用资源
available = await self.resource_manager.get_available_resources(
job.resources.gpu_type
)
required_gpus = job.config.world_size
if available.gpu_count < required_gpus:
# 检查是否可以抢占
if job.priority > 0:
preempted = await self._try_preempt(job, required_gpus)
if not preempted:
return False
else:
return False
# 分配资源
allocation = await self.resource_manager.allocate(
job.job_id,
job.resources,
job.config.world_size
)
if not allocation:
return False
# 启动训练
await self._start_training(job, allocation)
job.status = JobStatus.RUNNING
job.started_at = datetime.now()
self.running_jobs[job.job_id] = job
return True
async def _try_preempt(self, job: TrainingJob, required_gpus: int) -> bool:
"""尝试抢占低优先级任务"""
preemptible_jobs = [
j for j in self.running_jobs.values()
if j.preemptible and j.priority < job.priority
]
# 按优先级排序,优先抢占低优先级
preemptible_jobs.sort(key=lambda j: j.priority)
freed_gpus = 0
jobs_to_preempt = []
for j in preemptible_jobs:
jobs_to_preempt.append(j)
freed_gpus += j.config.world_size
if freed_gpus >= required_gpus:
break
if freed_gpus < required_gpus:
return False
# 执行抢占
for j in jobs_to_preempt:
await self._preempt_job(j)
return True
async def _preempt_job(self, job: TrainingJob):
"""抢占任务"""
# 保存 checkpoint
await self._save_checkpoint(job)
# 释放资源
await self.resource_manager.release(job.job_id)
# 重新入队
del self.running_jobs[job.job_id]
job.status = JobStatus.QUEUED
priority_key = (-job.priority, job.created_at.timestamp())
await self.pending_queue.put((priority_key, job))
async def _start_training(
self,
job: TrainingJob,
allocation: 'ResourceAllocation'
):
"""启动训练"""
# 生成训练配置
training_spec = self._generate_training_spec(job, allocation)
# 创建 K8s 资源
await self._create_k8s_resources(training_spec)
def _validate_job(self, job: TrainingJob):
"""验证任务配置"""
if job.config.world_size <= 0:
raise ValueError("world_size must be positive")
if job.resources.gpu_count <= 0:
raise ValueError("gpu_count must be positive")
2. 弹性训练支持
from abc import ABC, abstractmethod
import aiohttp
class ElasticTrainingController:
"""弹性训练控制器"""
def __init__(self, etcd_client, k8s_client):
self.etcd = etcd_client
self.k8s = k8s_client
self.rendezvous_handlers: Dict[str, RendezvousHandler] = {}
async def register_job(self, job: TrainingJob):
"""注册弹性训练任务"""
handler = RendezvousHandler(
job_id=job.job_id,
min_workers=max(1, job.config.world_size // 2),
max_workers=job.config.world_size * 2,
etcd=self.etcd
)
self.rendezvous_handlers[job.job_id] = handler
async def handle_worker_failure(self, job_id: str, worker_id: str):
"""处理 Worker 故障"""
handler = self.rendezvous_handlers.get(job_id)
if not handler:
return
# 标记 Worker 失败
await handler.mark_worker_failed(worker_id)
# 检查是否需要重新 rendezvous
active_workers = await handler.get_active_workers()
if len(active_workers) < handler.min_workers:
# 等待新 Worker 加入或启动新 Worker
await self._scale_up_workers(job_id)
else:
# 触发重新 rendezvous
await handler.trigger_rendezvous()
async def handle_worker_join(self, job_id: str, worker_id: str):
"""处理 Worker 加入"""
handler = self.rendezvous_handlers.get(job_id)
if not handler:
return
await handler.register_worker(worker_id)
# 检查是否达到启动条件
active_workers = await handler.get_active_workers()
if len(active_workers) >= handler.min_workers:
await handler.trigger_rendezvous()
async def _scale_up_workers(self, job_id: str):
"""扩容 Worker"""
job = await self._get_job(job_id)
handler = self.rendezvous_handlers[job_id]
current_workers = len(await handler.get_active_workers())
target_workers = min(
handler.max_workers,
current_workers + 1
)
# 创建新的 Worker Pod
await self._create_worker_pod(job, target_workers - current_workers)
class RendezvousHandler:
"""Rendezvous 处理器"""
def __init__(
self,
job_id: str,
min_workers: int,
max_workers: int,
etcd
):
self.job_id = job_id
self.min_workers = min_workers
self.max_workers = max_workers
self.etcd = etcd
self.rendezvous_version = 0
async def register_worker(self, worker_id: str):
"""注册 Worker"""
key = f"/rendezvous/{self.job_id}/workers/{worker_id}"
value = {
"worker_id": worker_id,
"status": "active",
"registered_at": datetime.now().isoformat()
}
await self.etcd.put(key, json.dumps(value))
async def mark_worker_failed(self, worker_id: str):
"""标记 Worker 失败"""
key = f"/rendezvous/{self.job_id}/workers/{worker_id}"
await self.etcd.delete(key)
async def get_active_workers(self) -> List[str]:
"""获取活跃 Worker 列表"""
prefix = f"/rendezvous/{self.job_id}/workers/"
result = await self.etcd.get_prefix(prefix)
workers = []
for key, value in result:
worker_info = json.loads(value)
if worker_info.get("status") == "active":
workers.append(worker_info["worker_id"])
return workers
async def trigger_rendezvous(self):
"""触发 Rendezvous"""
self.rendezvous_version += 1
workers = await self.get_active_workers()
# 更新 rendezvous 信息
rendezvous_info = {
"version": self.rendezvous_version,
"workers": workers,
"world_size": len(workers),
"timestamp": datetime.now().isoformat()
}
key = f"/rendezvous/{self.job_id}/current"
await self.etcd.put(key, json.dumps(rendezvous_info))
# 通知所有 Worker
await self._notify_workers(workers, rendezvous_info)
async def _notify_workers(
self,
workers: List[str],
rendezvous_info: dict
):
"""通知 Worker 重新配置"""
for i, worker_id in enumerate(workers):
worker_config = {
**rendezvous_info,
"rank": i,
"local_rank": 0, # 简化处理
}
key = f"/rendezvous/{self.job_id}/workers/{worker_id}/config"
await self.etcd.put(key, json.dumps(worker_config))
3. Checkpoint 管理
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import hashlib
class CheckpointManager:
"""Checkpoint 管理器"""
def __init__(
self,
storage_backend: 'StorageBackend',
cache_dir: str = "/tmp/ckpt_cache"
):
self.storage = storage_backend
self.cache_dir = cache_dir
self.executor = ThreadPoolExecutor(max_workers=4)
async def save_checkpoint(
self,
job_id: str,
step: int,
state_dict: dict,
metadata: Optional[dict] = None
) -> str:
"""保存 Checkpoint"""
# 生成 checkpoint 路径
ckpt_path = f"{job_id}/step_{step}"
# 异步保存
await asyncio.get_event_loop().run_in_executor(
self.executor,
self._save_to_storage,
ckpt_path,
state_dict,
metadata
)
# 记录 checkpoint 信息
await self._record_checkpoint(job_id, step, ckpt_path)
# 清理旧 checkpoint
await self._cleanup_old_checkpoints(job_id)
return ckpt_path
def _save_to_storage(
self,
path: str,
state_dict: dict,
metadata: Optional[dict]
):
"""保存到存储"""
import torch
# 保存模型状态
model_path = f"{path}/model.pt"
torch.save(state_dict.get("model"), model_path)
# 保存优化器状态
optimizer_path = f"{path}/optimizer.pt"
torch.save(state_dict.get("optimizer"), optimizer_path)
# 保存调度器状态
if "scheduler" in state_dict:
scheduler_path = f"{path}/scheduler.pt"
torch.save(state_dict["scheduler"], scheduler_path)
# 保存元数据
if metadata:
meta_path = f"{path}/metadata.json"
with open(meta_path, "w") as f:
json.dump(metadata, f)
# 上传到分布式存储
self.storage.upload_directory(path)
async def load_checkpoint(
self,
job_id: str,
step: Optional[int] = None
) -> Optional[dict]:
"""加载 Checkpoint"""
# 获取最新的 checkpoint
if step is None:
step = await self._get_latest_checkpoint_step(job_id)
if step is None:
return None
ckpt_path = f"{job_id}/step_{step}"
# 下载到本地缓存
local_path = os.path.join(self.cache_dir, ckpt_path)
await asyncio.get_event_loop().run_in_executor(
self.executor,
self.storage.download_directory,
ckpt_path,
local_path
)
# 加载状态
return await asyncio.get_event_loop().run_in_executor(
self.executor,
self._load_from_local,
local_path
)
def _load_from_local(self, local_path: str) -> dict:
"""从本地加载"""
import torch
state_dict = {}
model_path = os.path.join(local_path, "model.pt")
if os.path.exists(model_path):
state_dict["model"] = torch.load(model_path)
optimizer_path = os.path.join(local_path, "optimizer.pt")
if os.path.exists(optimizer_path):
state_dict["optimizer"] = torch.load(optimizer_path)
scheduler_path = os.path.join(local_path, "scheduler.pt")
if os.path.exists(scheduler_path):
state_dict["scheduler"] = torch.load(scheduler_path)
meta_path = os.path.join(local_path, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path, "r") as f:
state_dict["metadata"] = json.load(f)
return state_dict
async def _cleanup_old_checkpoints(
self,
job_id: str,
keep_last: int = 3
):
"""清理旧 Checkpoint"""
checkpoints = await self._list_checkpoints(job_id)
if len(checkpoints) <= keep_last:
return
# 删除旧的
to_delete = checkpoints[:-keep_last]
for ckpt in to_delete:
await asyncio.get_event_loop().run_in_executor(
self.executor,
self.storage.delete_directory,
ckpt["path"]
)
题目二:设计模型推理服务平台
题目描述
设计一个高性能、高可用的模型推理服务平台,需要支持:
- 多模型部署与版本管理
- 动态 Batching
- 自动扩缩容
- A/B 测试与灰度发布
- 模型热更新
需求分析
功能需求:
- 模型部署与下线
- 多版本管理
- 流量分发
- 推理请求处理
- 监控与日志
非功能需求:
- P99 延迟 < 100ms
- 可用性 > 99.99%
- 支持 10000+ QPS
- 模型切换无感知
架构设计
┌─────────────────────────────────────────────────────────────────────────┐
│ 模型推理服务平台 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 接入层 │ │
│ │ 负载均衡 (L4/L7) │ API Gateway │ 认证鉴权 │ 限流熔断 │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 路由层 │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 服务发现 │ │ 流量路由 │ │ A/B测试 │ │ 灰度发布 │ │ │
│ │ │ │ │ │ │ Controller│ │ Controller│ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 推理层 │ │
│ │ │ │
│ │ ┌──────────────────────────────────────────────────────────┐ │ │
│ │ │ 推理实例集群 │ │ │
│ │ │ │ │ │
│ │ │ ┌─────────────────────┐ ┌─────────────────────┐ │ │ │
│ │ │ │ 模型 A v1.0 │ │ 模型 A v1.1 │ │ │ │
│ │ │ │ ┌───┐ ┌───┐ ┌───┐ │ │ ┌───┐ ┌───┐ │ │ │ │
│ │ │ │ │Pod│ │Pod│ │Pod│ │ │ │Pod│ │Pod│ │ │ │ │
│ │ │ │ └───┘ └───┘ └───┘ │ │ └───┘ └───┘ │ │ │ │
│ │ │ └─────────────────────┘ └─────────────────────┘ │ │ │
│ │ │ │ │ │
│ │ │ ┌─────────────────────┐ ┌─────────────────────┐ │ │ │
│ │ │ │ 模型 B v2.0 │ │ 模型 C v1.0 │ │ │ │
│ │ │ │ ┌───┐ ┌───┐ │ │ ┌───┐ │ │ │ │
│ │ │ │ │Pod│ │Pod│ │ │ │Pod│ │ │ │ │
│ │ │ │ └───┘ └───┘ │ │ └───┘ │ │ │ │
│ │ │ └─────────────────────┘ └─────────────────────┘ │ │ │
│ │ │ │ │ │
│ │ └──────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 管控层 │ │
│ │ │ │
│ │ 模型仓库 │ 部署管理 │ 版本管理 │ 自动扩缩 │ 监控告警 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
核心组件设计
1. 动态 Batching 服务
import asyncio
from collections import defaultdict
import time
@dataclass
class InferenceRequest:
"""推理请求"""
request_id: str
model_name: str
model_version: str
inputs: Dict[str, np.ndarray]
created_at: float = field(default_factory=time.time)
future: asyncio.Future = field(default_factory=asyncio.Future)
@dataclass
class InferenceBatch:
"""推理批次"""
requests: List[InferenceRequest]
model_name: str
model_version: str
created_at: float = field(default_factory=time.time)
class DynamicBatcher:
"""动态 Batching 服务"""
def __init__(
self,
max_batch_size: int = 32,
max_wait_time_ms: float = 10.0,
preferred_batch_sizes: List[int] = None
):
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time_ms / 1000.0
self.preferred_batch_sizes = preferred_batch_sizes or [1, 2, 4, 8, 16, 32]
# 请求队列:model_key -> queue
self.request_queues: Dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)
# 推理引擎
self.inference_engines: Dict[str, 'InferenceEngine'] = {}
async def submit_request(self, request: InferenceRequest) -> dict:
"""提交推理请求"""
model_key = f"{request.model_name}:{request.model_version}"
# 加入队列
await self.request_queues[model_key].put(request)
# 等待结果
return await request.future
async def start_batching_loop(self, model_key: str):
"""启动 Batching 循环"""
queue = self.request_queues[model_key]
engine = self.inference_engines[model_key]
while True:
batch = await self._collect_batch(queue)
if batch:
# 执行推理
try:
results = await engine.infer_batch(batch)
# 分发结果
for request, result in zip(batch.requests, results):
request.future.set_result(result)
except Exception as e:
# 设置异常
for request in batch.requests:
request.future.set_exception(e)
async def _collect_batch(self, queue: asyncio.Queue) -> Optional[InferenceBatch]:
"""收集批次"""
requests = []
# 等待第一个请求
try:
first_request = await asyncio.wait_for(
queue.get(),
timeout=1.0
)
requests.append(first_request)
except asyncio.TimeoutError:
return None
batch_start_time = time.time()
# 收集更多请求
while len(requests) < self.max_batch_size:
remaining_time = self.max_wait_time - (time.time() - batch_start_time)
if remaining_time <= 0:
break
# 检查是否达到首选批次大小
if len(requests) in self.preferred_batch_sizes:
# 快速检查是否有更多请求
try:
request = await asyncio.wait_for(
queue.get(),
timeout=0.001 # 1ms 快速检查
)
requests.append(request)
except asyncio.TimeoutError:
break
else:
try:
request = await asyncio.wait_for(
queue.get(),
timeout=remaining_time
)
requests.append(request)
except asyncio.TimeoutError:
break
return InferenceBatch(
requests=requests,
model_name=requests[0].model_name,
model_version=requests[0].model_version
)
class InferenceEngine:
"""推理引擎"""
def __init__(
self,
model_path: str,
device: str = "cuda:0",
optimization_level: int = 3
):
self.model_path = model_path
self.device = device
self.model = None
self.optimization_level = optimization_level
async def load_model(self):
"""加载模型"""
import torch
# 加载模型
self.model = torch.jit.load(self.model_path)
self.model.to(self.device)
self.model.eval()
# 优化模型
if self.optimization_level >= 2:
self.model = torch.jit.optimize_for_inference(self.model)
async def infer_batch(self, batch: InferenceBatch) -> List[dict]:
"""批量推理"""
import torch
# 合并输入
batched_inputs = self._batch_inputs(batch.requests)
# 推理
with torch.no_grad():
with torch.cuda.amp.autocast():
outputs = self.model(**batched_inputs)
# 拆分输出
return self._unbatch_outputs(outputs, len(batch.requests))
def _batch_inputs(self, requests: List[InferenceRequest]) -> dict:
"""合并输入"""
import torch
batched = {}
for key in requests[0].inputs.keys():
arrays = [r.inputs[key] for r in requests]
batched[key] = torch.from_numpy(
np.stack(arrays)
).to(self.device)
return batched
def _unbatch_outputs(self, outputs, batch_size: int) -> List[dict]:
"""拆分输出"""
results = []
for i in range(batch_size):
result = {}
for key, tensor in outputs.items():
result[key] = tensor[i].cpu().numpy()
results.append(result)
return results
2. 流量路由与 A/B 测试
from abc import ABC, abstractmethod
import random
import hashlib
@dataclass
class TrafficRule:
"""流量规则"""
rule_id: str
model_name: str
conditions: List['Condition']
actions: List['Action']
priority: int = 0
@dataclass
class Condition:
"""条件"""
field: str # header.user_id, query.region, etc.
operator: str # eq, ne, in, contains, regex
value: Any
@dataclass
class Action:
"""动作"""
type: str # route, ab_test, canary
config: dict
class TrafficRouter:
"""流量路由器"""
def __init__(self):
self.rules: Dict[str, List[TrafficRule]] = defaultdict(list)
self.model_versions: Dict[str, List['ModelVersion']] = {}
def add_rule(self, rule: TrafficRule):
"""添加路由规则"""
self.rules[rule.model_name].append(rule)
# 按优先级排序
self.rules[rule.model_name].sort(key=lambda r: -r.priority)
def route(self, request: 'InferenceRequest', context: dict) -> str:
"""路由请求到目标版本"""
rules = self.rules.get(request.model_name, [])
for rule in rules:
if self._match_conditions(rule.conditions, context):
return self._apply_actions(rule.actions, request, context)
# 默认路由到稳定版本
return self._get_stable_version(request.model_name)
def _match_conditions(
self,
conditions: List[Condition],
context: dict
) -> bool:
"""匹配条件"""
for condition in conditions:
value = self._get_field_value(condition.field, context)
if condition.operator == "eq":
if value != condition.value:
return False
elif condition.operator == "ne":
if value == condition.value:
return False
elif condition.operator == "in":
if value not in condition.value:
return False
elif condition.operator == "contains":
if condition.value not in str(value):
return False
return True
def _apply_actions(
self,
actions: List[Action],
request: 'InferenceRequest',
context: dict
) -> str:
"""应用动作"""
for action in actions:
if action.type == "route":
return action.config["version"]
elif action.type == "ab_test":
return self._ab_test_route(action.config, context)
elif action.type == "canary":
return self._canary_route(action.config, context)
return self._get_stable_version(request.model_name)
def _ab_test_route(self, config: dict, context: dict) -> str:
"""A/B 测试路由"""
# 基于用户 ID 的一致性哈希
user_id = context.get("user_id", str(uuid.uuid4()))
hash_value = int(hashlib.md5(user_id.encode()).hexdigest(), 16)
# 计算分桶
bucket = hash_value % 100
cumulative = 0
for variant in config["variants"]:
cumulative += variant["weight"]
if bucket < cumulative:
return variant["version"]
return config["variants"][0]["version"]
def _canary_route(self, config: dict, context: dict) -> str:
"""金丝雀发布路由"""
canary_weight = config.get("weight", 10) # 默认 10%
if random.randint(1, 100) <= canary_weight:
return config["canary_version"]
return config["stable_version"]
class ABTestManager:
"""A/B 测试管理器"""
def __init__(self, metrics_store: 'MetricsStore'):
self.metrics_store = metrics_store
self.experiments: Dict[str, 'Experiment'] = {}
async def create_experiment(
self,
name: str,
model_name: str,
variants: List[dict],
metrics: List[str],
traffic_allocation: int = 100
) -> str:
"""创建实验"""
experiment = Experiment(
experiment_id=str(uuid.uuid4()),
name=name,
model_name=model_name,
variants=variants,
metrics=metrics,
traffic_allocation=traffic_allocation,
status="running",
created_at=datetime.now()
)
self.experiments[experiment.experiment_id] = experiment
# 创建流量规则
rule = self._create_traffic_rule(experiment)
# ... 注册规则
return experiment.experiment_id
async def get_experiment_results(
self,
experiment_id: str
) -> 'ExperimentResults':
"""获取实验结果"""
experiment = self.experiments[experiment_id]
results = ExperimentResults(
experiment_id=experiment_id,
variants=[]
)
for variant in experiment.variants:
variant_metrics = {}
for metric in experiment.metrics:
# 获取指标数据
data = await self.metrics_store.get_metric(
model_name=experiment.model_name,
version=variant["version"],
metric=metric
)
variant_metrics[metric] = {
"mean": np.mean(data),
"std": np.std(data),
"p50": np.percentile(data, 50),
"p99": np.percentile(data, 99),
"sample_size": len(data)
}
results.variants.append({
"version": variant["version"],
"metrics": variant_metrics
})
# 计算统计显著性
results.significance = self._calculate_significance(results)
return results
def _calculate_significance(
self,
results: 'ExperimentResults'
) -> dict:
"""计算统计显著性"""
from scipy import stats
significance = {}
if len(results.variants) < 2:
return significance
control = results.variants[0]
for variant in results.variants[1:]:
for metric in control["metrics"].keys():
# T 检验
t_stat, p_value = stats.ttest_ind_from_stats(
control["metrics"][metric]["mean"],
control["metrics"][metric]["std"],
control["metrics"][metric]["sample_size"],
variant["metrics"][metric]["mean"],
variant["metrics"][metric]["std"],
variant["metrics"][metric]["sample_size"]
)
significance[f"{variant['version']}_{metric}"] = {
"t_statistic": t_stat,
"p_value": p_value,
"significant": p_value < 0.05
}
return significance
3. 模型热更新
class ModelHotSwapper:
"""模型热更新器"""
def __init__(
self,
model_registry: 'ModelRegistry',
inference_engines: Dict[str, InferenceEngine]
):
self.model_registry = model_registry
self.inference_engines = inference_engines
self.update_lock = asyncio.Lock()
async def hot_swap(
self,
model_name: str,
from_version: str,
to_version: str,
strategy: str = "rolling"
):
"""热更新模型"""
if strategy == "rolling":
await self._rolling_update(model_name, from_version, to_version)
elif strategy == "blue_green":
await self._blue_green_update(model_name, from_version, to_version)
elif strategy == "canary":
await self._canary_update(model_name, from_version, to_version)
async def _rolling_update(
self,
model_name: str,
from_version: str,
to_version: str
):
"""滚动更新"""
model_key = f"{model_name}:{from_version}"
new_model_key = f"{model_name}:{to_version}"
# 加载新模型
new_engine = InferenceEngine(
model_path=await self.model_registry.get_model_path(
model_name, to_version
)
)
await new_engine.load_model()
# 预热新模型
await self._warmup_model(new_engine)
async with self.update_lock:
# 注册新引擎
self.inference_engines[new_model_key] = new_engine
# 逐步切换流量
for ratio in [10, 30, 50, 70, 90, 100]:
await self._update_traffic_ratio(
model_name, from_version, to_version, ratio
)
# 监控新版本性能
if not await self._check_health(new_model_key):
# 回滚
await self._rollback(model_name, from_version, to_version)
return
await asyncio.sleep(30) # 观察期
# 卸载旧模型
if model_key in self.inference_engines:
del self.inference_engines[model_key]
async def _blue_green_update(
self,
model_name: str,
from_version: str,
to_version: str
):
"""蓝绿部署"""
new_model_key = f"{model_name}:{to_version}"
old_model_key = f"{model_name}:{from_version}"
# 加载新模型(绿色环境)
new_engine = InferenceEngine(
model_path=await self.model_registry.get_model_path(
model_name, to_version
)
)
await new_engine.load_model()
await self._warmup_model(new_engine)
# 健康检查
self.inference_engines[new_model_key] = new_engine
if not await self._check_health(new_model_key):
del self.inference_engines[new_model_key]
raise Exception(f"Health check failed for {new_model_key}")
async with self.update_lock:
# 瞬间切换流量
await self._switch_traffic(model_name, from_version, to_version)
# 保留旧版本一段时间用于回滚
await asyncio.sleep(300) # 5 分钟观察期
if await self._check_health(new_model_key):
# 清理旧版本
if old_model_key in self.inference_engines:
del self.inference_engines[old_model_key]
else:
# 回滚
await self._rollback(model_name, from_version, to_version)
async def _warmup_model(self, engine: InferenceEngine):
"""预热模型"""
# 发送预热请求
dummy_inputs = self._generate_dummy_inputs(engine)
for _ in range(100):
await engine.infer_batch(
InferenceBatch(
requests=[InferenceRequest(
request_id=str(uuid.uuid4()),
model_name="warmup",
model_version="warmup",
inputs=dummy_inputs
)],
model_name="warmup",
model_version="warmup"
)
)
async def _check_health(self, model_key: str) -> bool:
"""健康检查"""
engine = self.inference_engines.get(model_key)
if not engine:
return False
try:
# 发送测试请求
dummy_inputs = self._generate_dummy_inputs(engine)
result = await asyncio.wait_for(
engine.infer_batch(
InferenceBatch(
requests=[InferenceRequest(
request_id=str(uuid.uuid4()),
model_name="health",
model_version="health",
inputs=dummy_inputs
)],
model_name="health",
model_version="health"
)
),
timeout=5.0
)
return result is not None
except Exception:
return False
题目三:设计特征平台
题目描述
设计一个统一的特征平台,需要支持:
- 特征注册与元数据管理
- 离线特征计算与存储
- 在线特征服务
- 特征一致性保证
- 特征监控与数据质量
架构设计
┌─────────────────────────────────────────────────────────────────────────┐
│ 特征平台架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 特征定义层 │ │
│ │ │ │
│ │ Feature Registry │ Schema Management │ Feature Discovery │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 特征计算层 │ │
│ │ │ │
│ │ ┌───────────────────────┐ ┌───────────────────────┐ │ │
│ │ │ 离线计算引擎 │ │ 实时计算引擎 │ │ │
│ │ │ │ │ │ │ │
│ │ │ Spark │ Flink Batch │ │ Flink │ Kafka Streams│ │ │
│ │ │ │ │ │ │ │
│ │ └───────────────────────┘ └───────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 特征存储层 │ │
│ │ │ │
│ │ ┌───────────────────────┐ ┌───────────────────────┐ │ │
│ │ │ 离线存储 │ │ 在线存储 │ │ │
│ │ │ │ │ │ │ │
│ │ │ Hive │ Delta Lake │ │ Redis │ Cassandra │ │ │
│ │ │ │ │ │ │ │
│ │ └───────────────────────┘ └───────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 特征服务层 │ │
│ │ │ │
│ │ Online Serving │ Batch Serving │ Feature Vector │ Point-in-Time│ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
核心组件设计
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
import pandas as pd
from datetime import datetime, timedelta
class FeatureType(Enum):
"""特征类型"""
INT64 = "int64"
FLOAT32 = "float32"
FLOAT64 = "float64"
STRING = "string"
BOOL = "bool"
ARRAY = "array"
EMBEDDING = "embedding"
class AggregationType(Enum):
"""聚合类型"""
SUM = "sum"
AVG = "avg"
COUNT = "count"
MAX = "max"
MIN = "min"
LAST = "last"
FIRST = "first"
@dataclass
class Entity:
"""实体定义"""
name: str
join_keys: List[str]
description: str = ""
@dataclass
class FeatureView:
"""特征视图"""
name: str
entities: List[Entity]
features: List['Feature']
source: 'DataSource'
ttl: Optional[timedelta] = None
online: bool = True
offline: bool = True
description: str = ""
@dataclass
class Feature:
"""特征定义"""
name: str
dtype: FeatureType
description: str = ""
tags: Dict[str, str] = field(default_factory=dict)
# 统计信息
mean: Optional[float] = None
std: Optional[float] = None
min_value: Optional[float] = None
max_value: Optional[float] = None
@dataclass
class DataSource:
"""数据源"""
name: str
type: str # batch, stream
config: dict
timestamp_field: Optional[str] = None
created_timestamp_field: Optional[str] = None
class FeatureRegistry:
"""特征注册中心"""
def __init__(self, storage: 'RegistryStorage'):
self.storage = storage
self.entities: Dict[str, Entity] = {}
self.feature_views: Dict[str, FeatureView] = {}
def register_entity(self, entity: Entity):
"""注册实体"""
self.entities[entity.name] = entity
self.storage.save_entity(entity)
def register_feature_view(self, feature_view: FeatureView):
"""注册特征视图"""
# 验证特征视图
self._validate_feature_view(feature_view)
self.feature_views[feature_view.name] = feature_view
self.storage.save_feature_view(feature_view)
# 创建物化任务
if feature_view.online or feature_view.offline:
self._create_materialization_job(feature_view)
def get_feature_view(self, name: str) -> Optional[FeatureView]:
"""获取特征视图"""
return self.feature_views.get(name)
def list_features(
self,
entity: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> List[Feature]:
"""列出特征"""
features = []
for fv in self.feature_views.values():
if entity and entity not in [e.name for e in fv.entities]:
continue
for feature in fv.features:
if tags:
if not all(feature.tags.get(k) == v for k, v in tags.items()):
continue
features.append(feature)
return features
def _validate_feature_view(self, feature_view: FeatureView):
"""验证特征视图"""
# 检查实体是否存在
for entity in feature_view.entities:
if entity.name not in self.entities:
raise ValueError(f"Entity {entity.name} not registered")
# 检查特征名称唯一性
feature_names = [f.name for f in feature_view.features]
if len(feature_names) != len(set(feature_names)):
raise ValueError("Duplicate feature names in feature view")
class OnlineFeatureStore:
"""在线特征存储"""
def __init__(self, redis_client):
self.redis = redis_client
self.cache_ttl = 3600 # 1 hour
async def get_features(
self,
feature_view: str,
entity_keys: Dict[str, Any],
features: List[str]
) -> Dict[str, Any]:
"""获取在线特征"""
# 构建 key
key = self._build_key(feature_view, entity_keys)
# 从 Redis 获取
if features:
result = await self.redis.hmget(key, features)
return {f: self._deserialize(v) for f, v in zip(features, result)}
else:
result = await self.redis.hgetall(key)
return {k.decode(): self._deserialize(v) for k, v in result.items()}
async def write_features(
self,
feature_view: str,
entity_keys: Dict[str, Any],
features: Dict[str, Any],
timestamp: Optional[datetime] = None
):
"""写入在线特征"""
key = self._build_key(feature_view, entity_keys)
# 序列化并写入
serialized = {k: self._serialize(v) for k, v in features.items()}
await self.redis.hset(key, mapping=serialized)
# 设置 TTL
await self.redis.expire(key, self.cache_ttl)
async def batch_get_features(
self,
feature_view: str,
entity_keys_list: List[Dict[str, Any]],
features: List[str]
) -> List[Dict[str, Any]]:
"""批量获取特征"""
pipe = self.redis.pipeline()
for entity_keys in entity_keys_list:
key = self._build_key(feature_view, entity_keys)
if features:
pipe.hmget(key, features)
else:
pipe.hgetall(key)
results = await pipe.execute()
feature_list = []
for result in results:
if isinstance(result, list):
feature_list.append({
f: self._deserialize(v)
for f, v in zip(features, result)
})
else:
feature_list.append({
k.decode(): self._deserialize(v)
for k, v in result.items()
})
return feature_list
def _build_key(self, feature_view: str, entity_keys: Dict[str, Any]) -> str:
"""构建存储 key"""
sorted_keys = sorted(entity_keys.items())
key_parts = [f"{k}:{v}" for k, v in sorted_keys]
return f"feature:{feature_view}:{':'.join(key_parts)}"
def _serialize(self, value: Any) -> bytes:
"""序列化值"""
import pickle
return pickle.dumps(value)
def _deserialize(self, value: bytes) -> Any:
"""反序列化值"""
import pickle
if value is None:
return None
return pickle.loads(value)
class FeatureServer:
"""特征服务"""
def __init__(
self,
registry: FeatureRegistry,
online_store: OnlineFeatureStore,
offline_store: 'OfflineFeatureStore'
):
self.registry = registry
self.online_store = online_store
self.offline_store = offline_store
async def get_online_features(
self,
feature_refs: List[str], # "feature_view:feature_name"
entity_rows: List[Dict[str, Any]]
) -> pd.DataFrame:
"""获取在线特征"""
# 解析特征引用
parsed_refs = self._parse_feature_refs(feature_refs)
results = []
for entity_row in entity_rows:
row_features = dict(entity_row)
for feature_view, features in parsed_refs.items():
fv = self.registry.get_feature_view(feature_view)
# 提取实体键
entity_keys = {
key: entity_row[key]
for entity in fv.entities
for key in entity.join_keys
}
# 获取特征
feature_values = await self.online_store.get_features(
feature_view,
entity_keys,
features
)
for feature, value in feature_values.items():
row_features[f"{feature_view}__{feature}"] = value
results.append(row_features)
return pd.DataFrame(results)
async def get_historical_features(
self,
feature_refs: List[str],
entity_df: pd.DataFrame,
timestamp_column: str = "event_timestamp"
) -> pd.DataFrame:
"""获取历史特征(Point-in-Time Join)"""
parsed_refs = self._parse_feature_refs(feature_refs)
result_df = entity_df.copy()
for feature_view, features in parsed_refs.items():
fv = self.registry.get_feature_view(feature_view)
# 获取离线特征
feature_df = await self.offline_store.get_features(
feature_view,
features
)
# Point-in-Time Join
result_df = self._point_in_time_join(
result_df,
feature_df,
entity_df,
fv,
timestamp_column
)
return result_df
def _point_in_time_join(
self,
entity_df: pd.DataFrame,
feature_df: pd.DataFrame,
original_entity_df: pd.DataFrame,
feature_view: FeatureView,
timestamp_column: str
) -> pd.DataFrame:
"""Point-in-Time Join"""
# 获取 join keys
join_keys = []
for entity in feature_view.entities:
join_keys.extend(entity.join_keys)
# 确保时间戳列存在
entity_df = entity_df.sort_values(timestamp_column)
feature_df = feature_df.sort_values("feature_timestamp")
# 执行 as-of join
result = pd.merge_asof(
entity_df,
feature_df,
left_on=timestamp_column,
right_on="feature_timestamp",
by=join_keys,
direction="backward"
)
# 应用 TTL 过滤
if feature_view.ttl:
ttl_seconds = feature_view.ttl.total_seconds()
time_diff = (
result[timestamp_column] - result["feature_timestamp"]
).dt.total_seconds()
# 超过 TTL 的特征设为 null
expired_mask = time_diff > ttl_seconds
for feature in feature_view.features:
result.loc[expired_mask, feature.name] = None
return result
def _parse_feature_refs(
self,
feature_refs: List[str]
) -> Dict[str, List[str]]:
"""解析特征引用"""
parsed = {}
for ref in feature_refs:
if ":" in ref:
fv_name, feature_name = ref.split(":", 1)
else:
raise ValueError(f"Invalid feature ref: {ref}")
if fv_name not in parsed:
parsed[fv_name] = []
parsed[fv_name].append(feature_name)
return parsed
题目四:设计向量数据库
题目描述
设计一个高性能向量数据库,需要支持:
- 大规模向量存储(十亿级别)
- 高效相似性搜索(ANN)
- 元数据过滤
- 分布式部署
- 实时更新
架构设计
┌─────────────────────────────────────────────────────────────────────────┐
│ 向量数据库架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 客户端层 │ │
│ │ Python SDK │ REST API │ gRPC │ 管理控制台 │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 协调层 │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 查询路由 │ │ 负载均衡 │ │ 分片管理 │ │ 元数据 │ │ │
│ │ │ Router │ │ Balancer │ │ Sharding │ │ Service │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 查询层 │ │
│ │ │ │
│ │ ┌───────────────────────────────────────────────────────────┐ │ │
│ │ │ Query Node 集群 │ │ │
│ │ │ │ │ │
│ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ │
│ │ │ │ Query Node 1│ │ Query Node 2│ │ Query Node N│ │ │ │
│ │ │ │ │ │ │ │ │ │ │ │
│ │ │ │ ┌─────────┐ │ │ ┌─────────┐ │ │ ┌─────────┐ │ │ │ │
│ │ │ │ │Index Seg│ │ │ │Index Seg│ │ │ │Index Seg│ │ │ │ │
│ │ │ │ └─────────┘ │ │ └─────────┘ │ │ └─────────┘ │ │ │ │
│ │ │ │ ┌─────────┐ │ │ ┌─────────┐ │ │ ┌─────────┐ │ │ │ │
│ │ │ │ │Index Seg│ │ │ │Index Seg│ │ │ │Index Seg│ │ │ │ │
│ │ │ │ └─────────┘ │ │ └─────────┘ │ │ └─────────┘ │ │ │ │
│ │ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ │
│ │ │ │ │ │
│ │ └───────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ 存储层 │ │
│ │ │ │
│ │ ┌───────────────────────┐ ┌───────────────────────┐ │ │
│ │ │ 索引存储 │ │ 原始数据存储 │ │ │
│ │ │ │ │ │ │ │
│ │ │ 本地 SSD │ 共享存储 │ │ 对象存储 (S3/MinIO) │ │ │
│ │ │ │ │ │ │ │
│ │ └───────────────────────┘ └───────────────────────┘ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
核心组件设计
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from enum import Enum
import heapq
from abc import ABC, abstractmethod
class IndexType(Enum):
"""索引类型"""
FLAT = "flat"
IVF_FLAT = "ivf_flat"
IVF_PQ = "ivf_pq"
HNSW = "hnsw"
DISKANN = "diskann"
class MetricType(Enum):
"""距离度量类型"""
L2 = "l2"
IP = "ip" # Inner Product
COSINE = "cosine"
@dataclass
class VectorRecord:
"""向量记录"""
id: str
vector: np.ndarray
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SearchResult:
"""搜索结果"""
id: str
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SearchRequest:
"""搜索请求"""
collection: str
query_vector: np.ndarray
top_k: int = 10
filter: Optional[Dict[str, Any]] = None
search_params: Optional[Dict[str, Any]] = None
class VectorIndex(ABC):
"""向量索引基类"""
@abstractmethod
def build(self, vectors: np.ndarray, ids: List[str]):
"""构建索引"""
pass
@abstractmethod
def search(
self,
query: np.ndarray,
k: int,
params: Optional[dict] = None
) -> Tuple[List[str], List[float]]:
"""搜索"""
pass
@abstractmethod
def add(self, vectors: np.ndarray, ids: List[str]):
"""添加向量"""
pass
@abstractmethod
def delete(self, ids: List[str]):
"""删除向量"""
pass
class HNSWIndex(VectorIndex):
"""HNSW 索引实现"""
def __init__(
self,
dim: int,
metric: MetricType = MetricType.L2,
M: int = 16,
ef_construction: int = 200,
ef_search: int = 50
):
self.dim = dim
self.metric = metric
self.M = M
self.ef_construction = ef_construction
self.ef_search = ef_search
# 图结构
self.graphs: List[Dict[str, List[str]]] = [] # 每层的邻接表
self.vectors: Dict[str, np.ndarray] = {}
self.entry_point: Optional[str] = None
self.max_level = 0
self.node_levels: Dict[str, int] = {}
# 概率参数
self.ml = 1.0 / np.log(M)
def build(self, vectors: np.ndarray, ids: List[str]):
"""构建索引"""
for vector, id in zip(vectors, ids):
self.add(np.array([vector]), [id])
def add(self, vectors: np.ndarray, ids: List[str]):
"""添加向量"""
for vector, id in zip(vectors, ids):
self._insert(id, vector)
def _insert(self, id: str, vector: np.ndarray):
"""插入单个向量"""
self.vectors[id] = vector
# 随机确定层数
level = self._random_level()
self.node_levels[id] = level
# 扩展图结构
while len(self.graphs) <= level:
self.graphs.append({})
if self.entry_point is None:
self.entry_point = id
self.max_level = level
for l in range(level + 1):
self.graphs[l][id] = []
return
# 搜索插入位置
curr = self.entry_point
# 从最高层向下搜索
for l in range(self.max_level, level, -1):
curr = self._search_layer(vector, curr, 1, l)[0]
# 在 level 及以下层插入
for l in range(min(level, self.max_level), -1, -1):
neighbors = self._search_layer(vector, curr, self.ef_construction, l)
# 选择邻居
selected = self._select_neighbors(vector, neighbors, self.M)
# 建立双向连接
self.graphs[l][id] = selected
for neighbor in selected:
if neighbor in self.graphs[l]:
self.graphs[l][neighbor].append(id)
# 限制邻居数量
if len(self.graphs[l][neighbor]) > self.M:
neighbor_vec = self.vectors[neighbor]
self.graphs[l][neighbor] = self._select_neighbors(
neighbor_vec,
self.graphs[l][neighbor],
self.M
)
if neighbors:
curr = neighbors[0]
# 更新入口点
if level > self.max_level:
self.entry_point = id
self.max_level = level
def search(
self,
query: np.ndarray,
k: int,
params: Optional[dict] = None
) -> Tuple[List[str], List[float]]:
"""搜索最近邻"""
ef = params.get("ef", self.ef_search) if params else self.ef_search
if self.entry_point is None:
return [], []
curr = self.entry_point
# 从最高层向下搜索到第 1 层
for l in range(self.max_level, 0, -1):
curr = self._search_layer(query, curr, 1, l)[0]
# 在第 0 层搜索
candidates = self._search_layer(query, curr, ef, 0)
# 返回 top-k
results = []
for candidate in candidates[:k]:
score = self._distance(query, self.vectors[candidate])
results.append((candidate, score))
ids, scores = zip(*results) if results else ([], [])
return list(ids), list(scores)
def _search_layer(
self,
query: np.ndarray,
entry: str,
ef: int,
level: int
) -> List[str]:
"""在单层搜索"""
visited = {entry}
candidates = [] # min heap
results = [] # max heap (negative distance)
entry_dist = self._distance(query, self.vectors[entry])
heapq.heappush(candidates, (entry_dist, entry))
heapq.heappush(results, (-entry_dist, entry))
while candidates:
curr_dist, curr = heapq.heappop(candidates)
# 最远结果比当前候选近,停止
if -results[0][0] < curr_dist:
break
# 遍历邻居
neighbors = self.graphs[level].get(curr, [])
for neighbor in neighbors:
if neighbor in visited:
continue
visited.add(neighbor)
neighbor_dist = self._distance(query, self.vectors[neighbor])
if len(results) < ef or neighbor_dist < -results[0][0]:
heapq.heappush(candidates, (neighbor_dist, neighbor))
heapq.heappush(results, (-neighbor_dist, neighbor))
if len(results) > ef:
heapq.heappop(results)
# 按距离排序返回
sorted_results = sorted([(-d, id) for d, id in results])
return [id for _, id in sorted_results]
def _select_neighbors(
self,
query: np.ndarray,
candidates: List[str],
M: int
) -> List[str]:
"""选择邻居(启发式)"""
if len(candidates) <= M:
return candidates
# 按距离排序
scored = [(self._distance(query, self.vectors[c]), c) for c in candidates]
scored.sort()
# 简单截断(可改为更复杂的启发式)
return [c for _, c in scored[:M]]
def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算距离"""
if self.metric == MetricType.L2:
return np.sum((a - b) ** 2)
elif self.metric == MetricType.IP:
return -np.dot(a, b)
elif self.metric == MetricType.COSINE:
return 1 - np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def _random_level(self) -> int:
"""随机生成层数"""
level = 0
while np.random.random() < 0.5 and level < 16:
level += 1
return level
def delete(self, ids: List[str]):
"""删除向量"""
for id in ids:
if id not in self.vectors:
continue
level = self.node_levels.get(id, 0)
# 从每层移除
for l in range(level + 1):
if id in self.graphs[l]:
# 移除指向该节点的边
for neighbor in self.graphs[l][id]:
if neighbor in self.graphs[l]:
self.graphs[l][neighbor] = [
n for n in self.graphs[l][neighbor] if n != id
]
del self.graphs[l][id]
del self.vectors[id]
del self.node_levels[id]
# 更新入口点
if id == self.entry_point:
self._update_entry_point()
def _update_entry_point(self):
"""更新入口点"""
if not self.vectors:
self.entry_point = None
self.max_level = 0
return
# 找到最高层的节点
max_level = 0
entry = None
for id, level in self.node_levels.items():
if level > max_level:
max_level = level
entry = id
self.entry_point = entry
self.max_level = max_level
class VectorCollection:
"""向量集合"""
def __init__(
self,
name: str,
dim: int,
index_type: IndexType = IndexType.HNSW,
metric: MetricType = MetricType.L2,
index_params: Optional[dict] = None
):
self.name = name
self.dim = dim
self.index_type = index_type
self.metric = metric
# 创建索引
self.index = self._create_index(index_params or {})
# 元数据存储
self.metadata: Dict[str, Dict[str, Any]] = {}
def _create_index(self, params: dict) -> VectorIndex:
"""创建索引"""
if self.index_type == IndexType.HNSW:
return HNSWIndex(
dim=self.dim,
metric=self.metric,
M=params.get("M", 16),
ef_construction=params.get("ef_construction", 200)
)
# 其他索引类型...
raise ValueError(f"Unsupported index type: {self.index_type}")
def insert(self, records: List[VectorRecord]):
"""插入记录"""
vectors = np.array([r.vector for r in records])
ids = [r.id for r in records]
# 添加到索引
self.index.add(vectors, ids)
# 存储元数据
for record in records:
self.metadata[record.id] = record.metadata
def search(
self,
query_vector: np.ndarray,
top_k: int = 10,
filter: Optional[Dict[str, Any]] = None,
search_params: Optional[dict] = None
) -> List[SearchResult]:
"""搜索"""
# 搜索向量索引
ids, scores = self.index.search(query_vector, top_k * 10, search_params)
results = []
for id, score in zip(ids, scores):
metadata = self.metadata.get(id, {})
# 应用过滤器
if filter and not self._match_filter(metadata, filter):
continue
results.append(SearchResult(
id=id,
score=score,
metadata=metadata
))
if len(results) >= top_k:
break
return results
def _match_filter(self, metadata: dict, filter: dict) -> bool:
"""匹配过滤条件"""
for key, condition in filter.items():
if key not in metadata:
return False
value = metadata[key]
if isinstance(condition, dict):
# 复杂条件
for op, op_value in condition.items():
if op == "$eq" and value != op_value:
return False
elif op == "$ne" and value == op_value:
return False
elif op == "$gt" and value <= op_value:
return False
elif op == "$gte" and value < op_value:
return False
elif op == "$lt" and value >= op_value:
return False
elif op == "$lte" and value > op_value:
return False
elif op == "$in" and value not in op_value:
return False
else:
# 简单等值条件
if value != condition:
return False
return True
def delete(self, ids: List[str]):
"""删除记录"""
self.index.delete(ids)
for id in ids:
self.metadata.pop(id, None)
class DistributedVectorDB:
"""分布式向量数据库"""
def __init__(
self,
num_shards: int = 8,
replication_factor: int = 2
):
self.num_shards = num_shards
self.replication_factor = replication_factor
self.collections: Dict[str, 'ShardedCollection'] = {}
def create_collection(
self,
name: str,
dim: int,
index_type: IndexType = IndexType.HNSW,
metric: MetricType = MetricType.L2,
index_params: Optional[dict] = None
):
"""创建集合"""
shards = []
for i in range(self.num_shards):
shard = VectorCollection(
name=f"{name}_shard_{i}",
dim=dim,
index_type=index_type,
metric=metric,
index_params=index_params
)
shards.append(shard)
self.collections[name] = ShardedCollection(
name=name,
shards=shards,
num_shards=self.num_shards
)
def insert(self, collection: str, records: List[VectorRecord]):
"""插入记录"""
coll = self.collections[collection]
# 按 shard 分组
shard_records: Dict[int, List[VectorRecord]] = {}
for record in records:
shard_id = self._get_shard(record.id)
if shard_id not in shard_records:
shard_records[shard_id] = []
shard_records[shard_id].append(record)
# 并行插入各 shard
for shard_id, shard_recs in shard_records.items():
coll.shards[shard_id].insert(shard_recs)
def search(
self,
collection: str,
query_vector: np.ndarray,
top_k: int = 10,
filter: Optional[Dict[str, Any]] = None,
search_params: Optional[dict] = None
) -> List[SearchResult]:
"""分布式搜索"""
coll = self.collections[collection]
# 并行搜索所有 shard
all_results = []
for shard in coll.shards:
shard_results = shard.search(
query_vector,
top_k,
filter,
search_params
)
all_results.extend(shard_results)
# 合并结果
all_results.sort(key=lambda r: r.score)
return all_results[:top_k]
def _get_shard(self, id: str) -> int:
"""计算 shard ID"""
return hash(id) % self.num_shards
@dataclass
class ShardedCollection:
"""分片集合"""
name: str
shards: List[VectorCollection]
num_shards: int
总结
系统设计核心要点
| 系统 | 核心挑战 | 关键技术 |
|---|---|---|
| 训练平台 | 大规模资源调度、故障恢复 | 弹性训练、Checkpoint、抢占调度 |
| 推理平台 | 低延迟、高吞吐 | 动态 Batching、模型优化、流量管理 |
| 特征平台 | 一致性、实时性 | Point-in-Time Join、特征物化 |
| 向量数据库 | 高效 ANN、分布式 | HNSW、分片、元数据过滤 |
设计方法论
- 需求先行:明确功能和非功能需求
- 分层设计:清晰的职责划分
- 权衡取舍:理解各方案的 trade-off
- 演进思维:考虑系统的可扩展性
- 细节到位:关键算法和数据结构