HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

模型服务框架

概述

模型服务框架是将训练好的模型部署为生产级 API 服务的关键基础设施。对于大语言模型,服务框架需要处理高并发请求、优化推理性能、管理 GPU 资源。本章深入讲解主流模型服务框架的架构设计、核心特性和最佳实践。

主流框架对比

框架选型矩阵

┌─────────────────────────────────────────────────────────────────┐
│                    LLM 服务框架对比                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  框架         │ 特点                    │ 适用场景              │
│  ─────────────┼─────────────────────────┼───────────────────── │
│  vLLM         │ PagedAttention          │ 高吞吐量生产服务      │
│               │ Continuous Batching     │ 资源受限环境          │
│               │ 高内存效率              │                       │
│  ─────────────┼─────────────────────────┼───────────────────── │
│  TGI          │ 生产就绪                │ HuggingFace 生态      │
│  (Text       │ 多模型支持              │ 企业级部署            │
│  Generation  │ 分布式推理              │                       │
│  Inference)  │                         │                       │
│  ─────────────┼─────────────────────────┼───────────────────── │
│  TensorRT-LLM │ NVIDIA 优化             │ 最高性能需求          │
│               │ 最佳 GPU 利用率         │ NVIDIA GPU 专用       │
│               │ 量化支持完善            │                       │
│  ─────────────┼─────────────────────────┼───────────────────── │
│  Triton       │ 多框架支持              │ 异构模型服务          │
│  Inference    │ 动态批处理              │ 多模型编排            │
│  Server       │ 模型管理                │                       │
│  ─────────────┼─────────────────────────┼───────────────────── │
│  llama.cpp    │ CPU 推理                │ 边缘设备              │
│               │ 低资源消耗              │ 本地部署              │
│               │ 量化支持                │                       │
│                                                                  │
│  性能对比 (LLaMA-7B, A100, batch=1):                            │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ 框架            │ Tokens/s  │ TTFT (ms) │ 内存占用 (GB)  │   │
│  │─────────────────┼───────────┼───────────┼───────────────│   │
│  │ vLLM            │ ~80       │ ~40       │ ~14           │   │
│  │ TGI             │ ~75       │ ~45       │ ~15           │   │
│  │ TensorRT-LLM    │ ~100      │ ~30       │ ~13           │   │
│  │ PyTorch (基准)  │ ~50       │ ~80       │ ~16           │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

vLLM 深度解析

vLLM 架构

┌─────────────────────────────────────────────────────────────────┐
│                      vLLM 架构                                   │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    API Server                            │    │
│  │              (OpenAI Compatible API)                     │    │
│  └─────────────────────────┬───────────────────────────────┘    │
│                            │                                     │
│  ┌─────────────────────────▼───────────────────────────────┐    │
│  │                    LLM Engine                            │    │
│  │  ┌─────────────────────────────────────────────────┐    │    │
│  │  │              Scheduler                           │    │    │
│  │  │  • Continuous Batching                          │    │    │
│  │  │  • Preemption (swap/recompute)                  │    │    │
│  │  │  • Priority Scheduling                          │    │    │
│  │  └─────────────────────────────────────────────────┘    │    │
│  │                         │                                │    │
│  │  ┌─────────────────────▼─────────────────────────┐      │    │
│  │  │           Block Manager                        │      │    │
│  │  │  • Physical Block Allocation                   │      │    │
│  │  │  • Block Table Management                      │      │    │
│  │  │  • Copy-on-Write for Beam Search              │      │    │
│  │  └─────────────────────────────────────────────────┘    │    │
│  │                         │                                │    │
│  │  ┌─────────────────────▼─────────────────────────┐      │    │
│  │  │           Model Executor                       │      │    │
│  │  │  • PagedAttention Kernel                       │      │    │
│  │  │  • Flash Attention                             │      │    │
│  │  │  • Tensor Parallelism                          │      │    │
│  │  └─────────────────────────────────────────────────┘    │    │
│  └─────────────────────────────────────────────────────────┘    │
│                            │                                     │
│  ┌─────────────────────────▼───────────────────────────────┐    │
│  │                   GPU Workers                            │    │
│  │  ┌─────────┐  ┌─────────┐  ┌─────────┐  ┌─────────┐    │    │
│  │  │ GPU 0   │  │ GPU 1   │  │ GPU 2   │  │ GPU 3   │    │    │
│  │  │ Model   │  │ Model   │  │ Model   │  │ Model   │    │    │
│  │  │ Shard   │  │ Shard   │  │ Shard   │  │ Shard   │    │    │
│  │  └─────────┘  └─────────┘  └─────────┘  └─────────┘    │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

vLLM 部署配置

# vllm_server.py
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.api_server import (
    create_chat_completion,
    create_completion,
)
import asyncio
from fastapi import FastAPI, Request
from typing import Optional, List, Dict, Any
import uvicorn

# vLLM 配置类
class VLLMConfig:
    def __init__(
        self,
        model: str,
        # 硬件配置
        tensor_parallel_size: int = 1,
        pipeline_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        # 批处理配置
        max_num_batched_tokens: int = 8192,
        max_num_seqs: int = 256,
        max_model_len: Optional[int] = None,
        # KV Cache 配置
        block_size: int = 16,
        swap_space: int = 4,  # GB
        # 量化配置
        quantization: Optional[str] = None,  # awq, gptq, squeezellm
        # 调度配置
        scheduler_delay_factor: float = 0.0,
        enable_chunked_prefill: bool = False,
        # 推测解码
        speculative_model: Optional[str] = None,
        num_speculative_tokens: int = 5,
    ):
        self.model = model
        self.tensor_parallel_size = tensor_parallel_size
        self.pipeline_parallel_size = pipeline_parallel_size
        self.gpu_memory_utilization = gpu_memory_utilization
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_num_seqs = max_num_seqs
        self.max_model_len = max_model_len
        self.block_size = block_size
        self.swap_space = swap_space
        self.quantization = quantization
        self.scheduler_delay_factor = scheduler_delay_factor
        self.enable_chunked_prefill = enable_chunked_prefill
        self.speculative_model = speculative_model
        self.num_speculative_tokens = num_speculative_tokens

    def to_engine_args(self) -> AsyncEngineArgs:
        return AsyncEngineArgs(
            model=self.model,
            tensor_parallel_size=self.tensor_parallel_size,
            pipeline_parallel_size=self.pipeline_parallel_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=self.max_model_len,
            block_size=self.block_size,
            swap_space=self.swap_space,
            quantization=self.quantization,
            scheduler_delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
            trust_remote_code=True,
            dtype="auto",
        )


# 自定义 vLLM 服务
class VLLMServer:
    def __init__(self, config: VLLMConfig):
        self.config = config
        self.engine: Optional[AsyncLLMEngine] = None
        self.app = FastAPI(title="vLLM API Server")
        self._setup_routes()

    async def initialize(self):
        """初始化引擎"""
        engine_args = self.config.to_engine_args()
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

    def _setup_routes(self):
        """设置 API 路由"""

        @self.app.post("/v1/completions")
        async def completions(request: Request):
            """OpenAI Completions API"""
            body = await request.json()
            return await self._handle_completion(body)

        @self.app.post("/v1/chat/completions")
        async def chat_completions(request: Request):
            """OpenAI Chat Completions API"""
            body = await request.json()
            return await self._handle_chat_completion(body)

        @self.app.get("/health")
        async def health():
            """健康检查"""
            return {"status": "healthy"}

        @self.app.get("/v1/models")
        async def list_models():
            """列出模型"""
            return {
                "data": [
                    {
                        "id": self.config.model,
                        "object": "model",
                        "owned_by": "vllm",
                    }
                ]
            }

    async def _handle_completion(self, body: Dict[str, Any]):
        """处理 Completion 请求"""
        prompt = body.get("prompt", "")
        sampling_params = SamplingParams(
            temperature=body.get("temperature", 1.0),
            top_p=body.get("top_p", 1.0),
            top_k=body.get("top_k", -1),
            max_tokens=body.get("max_tokens", 16),
            stop=body.get("stop"),
            presence_penalty=body.get("presence_penalty", 0.0),
            frequency_penalty=body.get("frequency_penalty", 0.0),
        )

        request_id = f"cmpl-{self._generate_request_id()}"

        # 生成
        results = []
        async for output in self.engine.generate(prompt, sampling_params, request_id):
            results.append(output)

        final_output = results[-1]

        return {
            "id": request_id,
            "object": "text_completion",
            "model": self.config.model,
            "choices": [
                {
                    "text": output.text,
                    "index": i,
                    "finish_reason": output.finish_reason,
                }
                for i, output in enumerate(final_output.outputs)
            ],
            "usage": {
                "prompt_tokens": len(final_output.prompt_token_ids),
                "completion_tokens": sum(
                    len(o.token_ids) for o in final_output.outputs
                ),
                "total_tokens": len(final_output.prompt_token_ids)
                + sum(len(o.token_ids) for o in final_output.outputs),
            },
        }

    async def _handle_chat_completion(self, body: Dict[str, Any]):
        """处理 Chat Completion 请求"""
        messages = body.get("messages", [])

        # 转换消息为 prompt
        prompt = self._messages_to_prompt(messages)

        # 复用 completion 逻辑
        body["prompt"] = prompt
        return await self._handle_completion(body)

    def _messages_to_prompt(self, messages: List[Dict]) -> str:
        """将消息列表转换为 prompt"""
        # 根据模型类型使用不同的模板
        # 这里使用 ChatML 格式作为示例
        prompt_parts = []
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
            elif role == "user":
                prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
            elif role == "assistant":
                prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")

        prompt_parts.append("<|im_start|>assistant\n")
        return "\n".join(prompt_parts)

    def _generate_request_id(self) -> str:
        import uuid
        return str(uuid.uuid4())[:8]

    def run(self, host: str = "0.0.0.0", port: int = 8000):
        """运行服务"""
        asyncio.run(self.initialize())
        uvicorn.run(self.app, host=host, port=port)


# 启动脚本
if __name__ == "__main__":
    config = VLLMConfig(
        model="meta-llama/Llama-2-7b-chat-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.9,
        max_num_seqs=256,
        max_model_len=4096,
    )

    server = VLLMServer(config)
    server.run()

vLLM Kubernetes 部署

# vllm-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: vllm-server
  namespace: ai-inference
spec:
  replicas: 1
  selector:
    matchLabels:
      app: vllm-server
  template:
    metadata:
      labels:
        app: vllm-server
    spec:
      containers:
      - name: vllm
        image: vllm/vllm-openai:latest
        command:
        - python
        - -m
        - vllm.entrypoints.openai.api_server
        args:
        - --model=/models/llama-2-7b-chat
        - --tensor-parallel-size=1
        - --gpu-memory-utilization=0.9
        - --max-num-seqs=256
        - --max-model-len=4096
        - --host=0.0.0.0
        - --port=8000
        ports:
        - containerPort: 8000
          name: http
        resources:
          limits:
            nvidia.com/gpu: 1
            memory: "32Gi"
          requests:
            cpu: "4"
            memory: "16Gi"
        volumeMounts:
        - name: model-storage
          mountPath: /models
        - name: shm
          mountPath: /dev/shm
        env:
        - name: CUDA_VISIBLE_DEVICES
          value: "0"
        - name: NCCL_DEBUG
          value: "WARN"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 60
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 5
      volumes:
      - name: model-storage
        persistentVolumeClaim:
          claimName: model-pvc
      - name: shm
        emptyDir:
          medium: Memory
          sizeLimit: 16Gi
      nodeSelector:
        nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"

---
apiVersion: v1
kind: Service
metadata:
  name: vllm-server
  namespace: ai-inference
spec:
  selector:
    app: vllm-server
  ports:
  - port: 8000
    targetPort: 8000
    name: http
  type: ClusterIP

---
# HPA 配置
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: vllm-server-hpa
  namespace: ai-inference
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: vllm-server
  minReplicas: 1
  maxReplicas: 4
  metrics:
  - type: Pods
    pods:
      metric:
        name: vllm_num_requests_running
      target:
        type: AverageValue
        averageValue: "200"
  - type: Resource
    resource:
      name: nvidia.com/gpu
      target:
        type: Utilization
        averageUtilization: 80

Text Generation Inference (TGI)

TGI 架构

┌─────────────────────────────────────────────────────────────────┐
│                      TGI 架构                                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                   Router (Rust)                          │    │
│  │  • 请求接收和验证                                         │    │
│  │  • Token 预估                                            │    │
│  │  • 负载均衡                                              │    │
│  └─────────────────────────┬───────────────────────────────┘    │
│                            │ gRPC                                │
│  ┌─────────────────────────▼───────────────────────────────┐    │
│  │              Inference Shard (Python)                    │    │
│  │                                                          │    │
│  │  ┌────────────────────────────────────────────────┐     │    │
│  │  │              Batcher                            │     │    │
│  │  │  • Continuous Batching                         │     │    │
│  │  │  • Token Budget Management                     │     │    │
│  │  │  • Request Queuing                             │     │    │
│  │  └────────────────────────────────────────────────┘     │    │
│  │                         │                                │    │
│  │  ┌────────────────────▼─────────────────────────┐       │    │
│  │  │              Model                            │       │    │
│  │  │  • Flash Attention 2                          │       │    │
│  │  │  • Tensor Parallelism                         │       │    │
│  │  │  • Quantization (GPTQ, AWQ, EETQ)            │       │    │
│  │  │  • Paged Attention                            │       │    │
│  │  └────────────────────────────────────────────────┘     │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
│  支持的模型:                                                     │
│  • LLaMA, Mistral, Mixtral                                      │
│  • Falcon, GPT-NeoX, StarCoder                                  │
│  • BLOOM, T5, BERT                                              │
│  • Gemma, Phi, Qwen                                             │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

TGI 配置和部署

# tgi-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: tgi-server
  namespace: ai-inference
spec:
  replicas: 1
  selector:
    matchLabels:
      app: tgi-server
  template:
    metadata:
      labels:
        app: tgi-server
      annotations:
        prometheus.io/scrape: "true"
        prometheus.io/port: "80"
    spec:
      containers:
      - name: tgi
        image: ghcr.io/huggingface/text-generation-inference:2.0
        args:
        - --model-id=meta-llama/Llama-2-7b-chat-hf
        - --num-shard=1
        - --max-input-length=4096
        - --max-total-tokens=8192
        - --max-batch-prefill-tokens=8192
        - --max-batch-total-tokens=16384
        - --max-concurrent-requests=256
        - --quantize=awq
        - --dtype=float16
        - --trust-remote-code
        - --hostname=0.0.0.0
        - --port=80
        ports:
        - containerPort: 80
          name: http
        resources:
          limits:
            nvidia.com/gpu: 1
            memory: "32Gi"
          requests:
            cpu: "4"
            memory: "16Gi"
        volumeMounts:
        - name: model-cache
          mountPath: /data
        - name: shm
          mountPath: /dev/shm
        env:
        - name: HUGGING_FACE_HUB_TOKEN
          valueFrom:
            secretKeyRef:
              name: hf-token
              key: token
        - name: CUDA_MEMORY_FRACTION
          value: "0.9"
        livenessProbe:
          httpGet:
            path: /health
            port: 80
          initialDelaySeconds: 120
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 80
          initialDelaySeconds: 60
          periodSeconds: 5
      volumes:
      - name: model-cache
        persistentVolumeClaim:
          claimName: tgi-model-cache
      - name: shm
        emptyDir:
          medium: Memory
          sizeLimit: 16Gi
      nodeSelector:
        nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"

---
# TGI 服务配置
apiVersion: v1
kind: Service
metadata:
  name: tgi-server
  namespace: ai-inference
spec:
  selector:
    app: tgi-server
  ports:
  - port: 80
    targetPort: 80
    name: http
  type: ClusterIP

TGI 客户端使用

# tgi_client.py
import asyncio
from typing import AsyncIterator, Optional, List, Dict, Any
import aiohttp
import json

class TGIClient:
    """TGI 异步客户端"""

    def __init__(self, base_url: str, timeout: float = 60.0):
        self.base_url = base_url.rstrip("/")
        self.timeout = aiohttp.ClientTimeout(total=timeout)

    async def generate(
        self,
        prompt: str,
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: Optional[int] = None,
        repetition_penalty: float = 1.0,
        stop_sequences: Optional[List[str]] = None,
        stream: bool = False,
    ) -> Dict[str, Any]:
        """生成文本"""
        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "repetition_penalty": repetition_penalty,
                "do_sample": temperature > 0,
            },
        }

        if top_k is not None:
            payload["parameters"]["top_k"] = top_k
        if stop_sequences:
            payload["parameters"]["stop"] = stop_sequences

        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            if stream:
                return await self._stream_generate(session, payload)
            else:
                return await self._generate(session, payload)

    async def _generate(
        self,
        session: aiohttp.ClientSession,
        payload: Dict[str, Any],
    ) -> Dict[str, Any]:
        """非流式生成"""
        async with session.post(
            f"{self.base_url}/generate",
            json=payload,
        ) as response:
            response.raise_for_status()
            return await response.json()

    async def _stream_generate(
        self,
        session: aiohttp.ClientSession,
        payload: Dict[str, Any],
    ) -> AsyncIterator[Dict[str, Any]]:
        """流式生成"""
        payload["stream"] = True

        async with session.post(
            f"{self.base_url}/generate_stream",
            json=payload,
        ) as response:
            response.raise_for_status()

            async for line in response.content:
                line = line.decode("utf-8").strip()
                if line.startswith("data:"):
                    data = json.loads(line[5:])
                    yield data

    async def generate_stream(
        self,
        prompt: str,
        **kwargs,
    ) -> AsyncIterator[str]:
        """流式生成,返回文本流"""
        async for chunk in await self.generate(prompt, stream=True, **kwargs):
            if "token" in chunk:
                yield chunk["token"]["text"]

    async def health(self) -> Dict[str, Any]:
        """健康检查"""
        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            async with session.get(f"{self.base_url}/health") as response:
                return await response.json()

    async def info(self) -> Dict[str, Any]:
        """获取模型信息"""
        async with aiohttp.ClientSession(timeout=self.timeout) as session:
            async with session.get(f"{self.base_url}/info") as response:
                return await response.json()


# 使用示例
async def main():
    client = TGIClient("http://tgi-server:80")

    # 检查健康状态
    health = await client.health()
    print(f"Health: {health}")

    # 获取模型信息
    info = await client.info()
    print(f"Model: {info['model_id']}")

    # 非流式生成
    response = await client.generate(
        prompt="What is the capital of France?",
        max_new_tokens=50,
        temperature=0.7,
    )
    print(f"Response: {response['generated_text']}")

    # 流式生成
    print("Streaming response: ", end="")
    async for token in client.generate_stream(
        prompt="Tell me a short story about a robot.",
        max_new_tokens=200,
        temperature=0.8,
    ):
        print(token, end="", flush=True)
    print()


if __name__ == "__main__":
    asyncio.run(main())

TensorRT-LLM

TensorRT-LLM 架构

┌─────────────────────────────────────────────────────────────────┐
│                   TensorRT-LLM 架构                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  构建阶段 (Offline):                                            │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    Model Builder                         │    │
│  │  ┌─────────┐    ┌──────────┐    ┌──────────────┐       │    │
│  │  │ HF      │───▶│ TRT-LLM  │───▶│ TensorRT     │       │    │
│  │  │ Model   │    │ Network  │    │ Engine       │       │    │
│  │  └─────────┘    └──────────┘    └──────────────┘       │    │
│  │                                                          │    │
│  │  优化过程:                                               │    │
│  │  • 图优化 (层融合、常量折叠)                            │    │
│  │  • 量化校准 (INT8/FP8)                                  │    │
│  │  • Kernel 选择 (Auto-tuning)                            │    │
│  │  • 内存优化                                              │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
│  运行阶段 (Online):                                             │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                   Triton Backend                         │    │
│  │  ┌─────────────────────────────────────────────────┐    │    │
│  │  │              Inflight Batching                   │    │    │
│  │  │  • Continuous Batching                          │    │    │
│  │  │  • Chunked Context                              │    │    │
│  │  │  • KV Cache Management                          │    │    │
│  │  └─────────────────────────────────────────────────┘    │    │
│  │                         │                                │    │
│  │  ┌─────────────────────▼─────────────────────────┐      │    │
│  │  │           TensorRT Runtime                     │      │    │
│  │  │  • Optimized CUDA Kernels                      │      │    │
│  │  │  • Multi-GPU Support (TP/PP)                   │      │    │
│  │  │  • FP8/INT8/INT4 Inference                     │      │    │
│  │  └─────────────────────────────────────────────────┘    │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

TensorRT-LLM 模型构建

# build_trt_llm.py
import tensorrt_llm
from tensorrt_llm import BuildConfig, Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.quantization import QuantMode
import torch

def build_trt_llm_engine(
    model_dir: str,
    output_dir: str,
    dtype: str = "float16",
    tp_size: int = 1,
    pp_size: int = 1,
    max_batch_size: int = 256,
    max_input_len: int = 2048,
    max_output_len: int = 2048,
    quantization: str = None,  # int8, int4_awq, fp8
):
    """构建 TensorRT-LLM 引擎"""

    # 设置量化模式
    quant_mode = QuantMode(0)
    if quantization == "int8":
        quant_mode = QuantMode.use_weight_only(use_int4_weights=False)
    elif quantization == "int4_awq":
        quant_mode = QuantMode.use_weight_only(use_int4_weights=True)
    elif quantization == "fp8":
        quant_mode = QuantMode.use_fp8_qdq()

    # 创建映射配置(张量并行/流水线并行)
    mapping = Mapping(
        world_size=tp_size * pp_size,
        tp_size=tp_size,
        pp_size=pp_size,
    )

    # 构建配置
    build_config = BuildConfig(
        max_batch_size=max_batch_size,
        max_input_len=max_input_len,
        max_output_len=max_output_len,
        max_beam_width=1,
        max_num_tokens=max_batch_size * max_input_len,
        strongly_typed=True,
        builder_opt=5,  # 优化级别
    )

    # 加载模型
    if dtype == "float16":
        torch_dtype = torch.float16
    elif dtype == "bfloat16":
        torch_dtype = torch.bfloat16
    else:
        torch_dtype = torch.float32

    # 从 HuggingFace 格式转换
    model = LLaMAForCausalLM.from_hugging_face(
        model_dir,
        dtype=dtype,
        mapping=mapping,
        quant_mode=quant_mode,
    )

    # 构建引擎
    engine = tensorrt_llm.build(
        model,
        build_config,
    )

    # 保存引擎
    engine.save(output_dir)

    print(f"Engine saved to {output_dir}")
    return output_dir


# 多 GPU 并行构建脚本
def build_multi_gpu_engine():
    """构建多 GPU 引擎"""
    import subprocess
    import os

    # 使用 mpirun 启动多进程构建
    cmd = [
        "mpirun",
        "-n", "4",  # 4 GPU
        "--allow-run-as-root",
        "python", "build_trt_llm.py",
        "--model_dir", "/models/llama-70b",
        "--output_dir", "/engines/llama-70b-tp4",
        "--tp_size", "4",
        "--dtype", "float16",
        "--max_batch_size", "64",
    ]

    subprocess.run(cmd, check=True)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--tp_size", type=int, default=1)
    parser.add_argument("--dtype", default="float16")
    parser.add_argument("--max_batch_size", type=int, default=256)
    parser.add_argument("--quantization", default=None)

    args = parser.parse_args()

    build_trt_llm_engine(
        model_dir=args.model_dir,
        output_dir=args.output_dir,
        dtype=args.dtype,
        tp_size=args.tp_size,
        max_batch_size=args.max_batch_size,
        quantization=args.quantization,
    )

Triton + TensorRT-LLM 部署

# triton_trt_llm_config.py
# Triton 模型配置生成

import json
import os

def generate_triton_config(
    model_name: str,
    engine_dir: str,
    output_dir: str,
    max_batch_size: int = 256,
    tp_size: int = 1,
    pp_size: int = 1,
):
    """生成 Triton 配置文件"""

    config = {
        "name": model_name,
        "backend": "tensorrtllm",
        "max_batch_size": max_batch_size,

        "model_transaction_policy": {
            "decoupled": True  # 流式输出
        },

        "input": [
            {
                "name": "input_ids",
                "data_type": "TYPE_INT32",
                "dims": [-1]
            },
            {
                "name": "input_lengths",
                "data_type": "TYPE_INT32",
                "dims": [1]
            },
            {
                "name": "request_output_len",
                "data_type": "TYPE_INT32",
                "dims": [1]
            },
            {
                "name": "end_id",
                "data_type": "TYPE_INT32",
                "dims": [1],
                "optional": True
            },
            {
                "name": "pad_id",
                "data_type": "TYPE_INT32",
                "dims": [1],
                "optional": True
            },
            {
                "name": "beam_width",
                "data_type": "TYPE_INT32",
                "dims": [1],
                "optional": True
            },
            {
                "name": "temperature",
                "data_type": "TYPE_FP32",
                "dims": [1],
                "optional": True
            },
            {
                "name": "top_k",
                "data_type": "TYPE_INT32",
                "dims": [1],
                "optional": True
            },
            {
                "name": "top_p",
                "data_type": "TYPE_FP32",
                "dims": [1],
                "optional": True
            },
            {
                "name": "streaming",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True
            }
        ],

        "output": [
            {
                "name": "output_ids",
                "data_type": "TYPE_INT32",
                "dims": [-1, -1]
            },
            {
                "name": "sequence_length",
                "data_type": "TYPE_INT32",
                "dims": [-1]
            }
        ],

        "instance_group": [
            {
                "count": 1,
                "kind": "KIND_GPU",
                "gpus": list(range(tp_size * pp_size))
            }
        ],

        "parameters": {
            "gpt_model_type": {"string_value": "llama"},
            "gpt_model_path": {"string_value": engine_dir},
            "batch_scheduler_policy": {"string_value": "inflight_fused_batching"},
            "decoupled_mode": {"string_value": "true"},
            "max_tokens_in_paged_kv_cache": {"string_value": ""},
            "kv_cache_free_gpu_mem_fraction": {"string_value": "0.9"},
            "enable_chunked_context": {"string_value": "true"},
        }
    }

    # 创建目录
    model_dir = os.path.join(output_dir, model_name, "1")
    os.makedirs(model_dir, exist_ok=True)

    # 写入配置
    config_path = os.path.join(output_dir, model_name, "config.pbtxt")
    with open(config_path, "w") as f:
        f.write(dict_to_pbtxt(config))

    print(f"Triton config saved to {config_path}")
    return config_path


def dict_to_pbtxt(d: dict, indent: int = 0) -> str:
    """将字典转换为 pbtxt 格式"""
    lines = []
    prefix = "  " * indent

    for key, value in d.items():
        if isinstance(value, dict):
            lines.append(f"{prefix}{key} {{")
            lines.append(dict_to_pbtxt(value, indent + 1))
            lines.append(f"{prefix}}}")
        elif isinstance(value, list):
            for item in value:
                if isinstance(item, dict):
                    lines.append(f"{prefix}{key} {{")
                    lines.append(dict_to_pbtxt(item, indent + 1))
                    lines.append(f"{prefix}}}")
                else:
                    lines.append(f'{prefix}{key}: {json.dumps(item)}')
        elif isinstance(value, bool):
            lines.append(f'{prefix}{key}: {str(value).lower()}')
        elif isinstance(value, str):
            lines.append(f'{prefix}{key}: "{value}"')
        else:
            lines.append(f'{prefix}{key}: {value}')

    return "\n".join(lines)

NVIDIA Triton Inference Server

Triton 多模型编排

# triton-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: triton-server
  namespace: ai-inference
spec:
  replicas: 1
  selector:
    matchLabels:
      app: triton-server
  template:
    metadata:
      labels:
        app: triton-server
      annotations:
        prometheus.io/scrape: "true"
        prometheus.io/port: "8002"
    spec:
      containers:
      - name: triton
        image: nvcr.io/nvidia/tritonserver:24.01-trtllm-python-py3
        args:
        - tritonserver
        - --model-repository=/models
        - --model-control-mode=explicit
        - --load-model=*
        - --http-port=8000
        - --grpc-port=8001
        - --metrics-port=8002
        - --log-verbose=1
        ports:
        - containerPort: 8000
          name: http
        - containerPort: 8001
          name: grpc
        - containerPort: 8002
          name: metrics
        resources:
          limits:
            nvidia.com/gpu: 4
            memory: "128Gi"
          requests:
            cpu: "16"
            memory: "64Gi"
        volumeMounts:
        - name: model-repository
          mountPath: /models
        - name: shm
          mountPath: /dev/shm
        env:
        - name: CUDA_VISIBLE_DEVICES
          value: "0,1,2,3"
        livenessProbe:
          httpGet:
            path: /v2/health/live
            port: 8000
          initialDelaySeconds: 120
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /v2/health/ready
            port: 8000
          initialDelaySeconds: 60
          periodSeconds: 5
      volumes:
      - name: model-repository
        persistentVolumeClaim:
          claimName: triton-models
      - name: shm
        emptyDir:
          medium: Memory
          sizeLimit: 32Gi

---
# Triton 服务
apiVersion: v1
kind: Service
metadata:
  name: triton-server
  namespace: ai-inference
spec:
  selector:
    app: triton-server
  ports:
  - port: 8000
    targetPort: 8000
    name: http
  - port: 8001
    targetPort: 8001
    name: grpc
  - port: 8002
    targetPort: 8002
    name: metrics

Triton 模型集成

// triton_client.go
package inference

import (
    "context"
    "fmt"
    "time"

    "google.golang.org/grpc"
    pb "github.com/triton-inference-server/client/src/grpc_generated/go/grpc-client"
)

// TritonClient Triton gRPC 客户端
type TritonClient struct {
    conn   *grpc.ClientConn
    client pb.GRPCInferenceServiceClient
}

// NewTritonClient 创建客户端
func NewTritonClient(address string) (*TritonClient, error) {
    conn, err := grpc.Dial(address, grpc.WithInsecure())
    if err != nil {
        return nil, fmt.Errorf("connect to triton: %w", err)
    }

    return &TritonClient{
        conn:   conn,
        client: pb.NewGRPCInferenceServiceClient(conn),
    }, nil
}

// Close 关闭连接
func (c *TritonClient) Close() error {
    return c.conn.Close()
}

// ServerLive 检查服务是否存活
func (c *TritonClient) ServerLive(ctx context.Context) (bool, error) {
    resp, err := c.client.ServerLive(ctx, &pb.ServerLiveRequest{})
    if err != nil {
        return false, err
    }
    return resp.Live, nil
}

// ServerReady 检查服务是否就绪
func (c *TritonClient) ServerReady(ctx context.Context) (bool, error) {
    resp, err := c.client.ServerReady(ctx, &pb.ServerReadyRequest{})
    if err != nil {
        return false, err
    }
    return resp.Ready, nil
}

// ModelReady 检查模型是否就绪
func (c *TritonClient) ModelReady(ctx context.Context, modelName string) (bool, error) {
    resp, err := c.client.ModelReady(ctx, &pb.ModelReadyRequest{
        Name: modelName,
    })
    if err != nil {
        return false, err
    }
    return resp.Ready, nil
}

// Infer 推理请求
func (c *TritonClient) Infer(
    ctx context.Context,
    modelName string,
    inputs []*pb.ModelInferRequest_InferInputTensor,
    outputs []*pb.ModelInferRequest_InferRequestedOutputTensor,
) (*pb.ModelInferResponse, error) {

    request := &pb.ModelInferRequest{
        ModelName: modelName,
        Inputs:    inputs,
        Outputs:   outputs,
    }

    return c.client.ModelInfer(ctx, request)
}

// InferLLM 推理 LLM
func (c *TritonClient) InferLLM(
    ctx context.Context,
    modelName string,
    inputIDs []int32,
    maxNewTokens int32,
    temperature float32,
    topP float32,
    stream bool,
) ([]int32, error) {

    // 构建输入
    inputs := []*pb.ModelInferRequest_InferInputTensor{
        {
            Name:     "input_ids",
            Datatype: "INT32",
            Shape:    []int64{1, int64(len(inputIDs))},
            Contents: &pb.InferTensorContents{
                IntContents: inputIDs,
            },
        },
        {
            Name:     "input_lengths",
            Datatype: "INT32",
            Shape:    []int64{1},
            Contents: &pb.InferTensorContents{
                IntContents: []int32{int32(len(inputIDs))},
            },
        },
        {
            Name:     "request_output_len",
            Datatype: "INT32",
            Shape:    []int64{1},
            Contents: &pb.InferTensorContents{
                IntContents: []int32{maxNewTokens},
            },
        },
        {
            Name:     "temperature",
            Datatype: "FP32",
            Shape:    []int64{1},
            Contents: &pb.InferTensorContents{
                Fp32Contents: []float32{temperature},
            },
        },
        {
            Name:     "top_p",
            Datatype: "FP32",
            Shape:    []int64{1},
            Contents: &pb.InferTensorContents{
                Fp32Contents: []float32{topP},
            },
        },
    }

    // 构建输出
    outputs := []*pb.ModelInferRequest_InferRequestedOutputTensor{
        {Name: "output_ids"},
        {Name: "sequence_length"},
    }

    // 执行推理
    resp, err := c.Infer(ctx, modelName, inputs, outputs)
    if err != nil {
        return nil, err
    }

    // 解析输出
    for _, output := range resp.Outputs {
        if output.Name == "output_ids" {
            return output.Contents.IntContents, nil
        }
    }

    return nil, fmt.Errorf("output_ids not found in response")
}

// StreamInfer 流式推理
func (c *TritonClient) StreamInfer(
    ctx context.Context,
    modelName string,
    inputs []*pb.ModelInferRequest_InferInputTensor,
) (<-chan *pb.ModelStreamInferResponse, error) {

    stream, err := c.client.ModelStreamInfer(ctx)
    if err != nil {
        return nil, err
    }

    // 发送请求
    request := &pb.ModelInferRequest{
        ModelName: modelName,
        Inputs:    inputs,
    }

    if err := stream.Send(request); err != nil {
        return nil, err
    }

    // 接收响应
    resultCh := make(chan *pb.ModelStreamInferResponse, 100)

    go func() {
        defer close(resultCh)
        for {
            resp, err := stream.Recv()
            if err != nil {
                return
            }
            resultCh <- resp
        }
    }()

    return resultCh, nil
}

// GetModelMetadata 获取模型元数据
func (c *TritonClient) GetModelMetadata(ctx context.Context, modelName string) (*pb.ModelMetadataResponse, error) {
    return c.client.ModelMetadata(ctx, &pb.ModelMetadataRequest{
        Name: modelName,
    })
}

// LoadModel 加载模型
func (c *TritonClient) LoadModel(ctx context.Context, modelName string) error {
    _, err := c.client.RepositoryModelLoad(ctx, &pb.RepositoryModelLoadRequest{
        ModelName: modelName,
    })
    return err
}

// UnloadModel 卸载模型
func (c *TritonClient) UnloadModel(ctx context.Context, modelName string) error {
    _, err := c.client.RepositoryModelUnload(ctx, &pb.RepositoryModelUnloadRequest{
        ModelName: modelName,
    })
    return err
}

统一服务网关

模型路由器

// model_router.go
package gateway

import (
    "context"
    "fmt"
    "net/http"
    "sync"
    "time"
)

// ModelRouter 模型路由器
type ModelRouter struct {
    // 后端配置
    backends map[string]*Backend

    // 路由规则
    routes map[string]*Route

    // 负载均衡器
    loadBalancer LoadBalancer

    // 熔断器
    circuitBreaker *CircuitBreaker

    // 限流器
    rateLimiter *RateLimiter

    mu sync.RWMutex
}

// Backend 后端服务
type Backend struct {
    Name     string
    Type     string // vllm, tgi, triton
    Address  string
    Weight   int
    Healthy  bool
    LastCheck time.Time

    // 连接池
    client   interface{}
}

// Route 路由规则
type Route struct {
    ModelName    string
    Backends     []string
    Strategy     string // round_robin, least_conn, weighted
    Timeout      time.Duration
    RetryCount   int
    FallbackModel string
}

// NewModelRouter 创建路由器
func NewModelRouter() *ModelRouter {
    router := &ModelRouter{
        backends:       make(map[string]*Backend),
        routes:         make(map[string]*Route),
        loadBalancer:   NewWeightedRoundRobin(),
        circuitBreaker: NewCircuitBreaker(),
        rateLimiter:    NewRateLimiter(1000), // 1000 QPS
    }

    // 启动健康检查
    go router.healthCheckLoop()

    return router
}

// RegisterBackend 注册后端
func (r *ModelRouter) RegisterBackend(backend *Backend) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.backends[backend.Name] = backend
}

// RegisterRoute 注册路由
func (r *ModelRouter) RegisterRoute(route *Route) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.routes[route.ModelName] = route
}

// Route 路由请求
func (r *ModelRouter) Route(ctx context.Context, req *InferenceRequest) (*InferenceResponse, error) {
    // 限流检查
    if !r.rateLimiter.Allow(req.ModelName) {
        return nil, fmt.Errorf("rate limit exceeded")
    }

    // 获取路由规则
    r.mu.RLock()
    route, ok := r.routes[req.ModelName]
    r.mu.RUnlock()

    if !ok {
        return nil, fmt.Errorf("model not found: %s", req.ModelName)
    }

    // 选择后端
    backend, err := r.selectBackend(route)
    if err != nil {
        return nil, err
    }

    // 熔断检查
    if !r.circuitBreaker.Allow(backend.Name) {
        // 尝试回退模型
        if route.FallbackModel != "" {
            req.ModelName = route.FallbackModel
            return r.Route(ctx, req)
        }
        return nil, fmt.Errorf("circuit breaker open for backend: %s", backend.Name)
    }

    // 设置超时
    ctx, cancel := context.WithTimeout(ctx, route.Timeout)
    defer cancel()

    // 执行请求(带重试)
    var resp *InferenceResponse
    var lastErr error

    for i := 0; i <= route.RetryCount; i++ {
        resp, lastErr = r.executeRequest(ctx, backend, req)
        if lastErr == nil {
            r.circuitBreaker.RecordSuccess(backend.Name)
            return resp, nil
        }

        r.circuitBreaker.RecordFailure(backend.Name)

        // 选择新后端重试
        backend, err = r.selectBackend(route)
        if err != nil {
            break
        }
    }

    return nil, lastErr
}

// selectBackend 选择后端
func (r *ModelRouter) selectBackend(route *Route) (*Backend, error) {
    r.mu.RLock()
    defer r.mu.RUnlock()

    // 过滤健康的后端
    var healthyBackends []*Backend
    for _, name := range route.Backends {
        if backend, ok := r.backends[name]; ok && backend.Healthy {
            healthyBackends = append(healthyBackends, backend)
        }
    }

    if len(healthyBackends) == 0 {
        return nil, fmt.Errorf("no healthy backends available")
    }

    // 负载均衡
    return r.loadBalancer.Select(healthyBackends), nil
}

// executeRequest 执行请求
func (r *ModelRouter) executeRequest(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
    switch backend.Type {
    case "vllm":
        return r.callVLLM(ctx, backend, req)
    case "tgi":
        return r.callTGI(ctx, backend, req)
    case "triton":
        return r.callTriton(ctx, backend, req)
    default:
        return nil, fmt.Errorf("unknown backend type: %s", backend.Type)
    }
}

func (r *ModelRouter) callVLLM(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
    // 调用 vLLM OpenAI 兼容 API
    return nil, nil
}

func (r *ModelRouter) callTGI(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
    // 调用 TGI API
    return nil, nil
}

func (r *ModelRouter) callTriton(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
    // 调用 Triton gRPC
    return nil, nil
}

// healthCheckLoop 健康检查循环
func (r *ModelRouter) healthCheckLoop() {
    ticker := time.NewTicker(10 * time.Second)
    defer ticker.Stop()

    for range ticker.C {
        r.checkAllBackends()
    }
}

func (r *ModelRouter) checkAllBackends() {
    r.mu.RLock()
    backends := make([]*Backend, 0, len(r.backends))
    for _, b := range r.backends {
        backends = append(backends, b)
    }
    r.mu.RUnlock()

    for _, backend := range backends {
        healthy := r.checkBackendHealth(backend)

        r.mu.Lock()
        backend.Healthy = healthy
        backend.LastCheck = time.Now()
        r.mu.Unlock()
    }
}

func (r *ModelRouter) checkBackendHealth(backend *Backend) bool {
    // 执行健康检查
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    switch backend.Type {
    case "vllm", "tgi":
        // HTTP 健康检查
        req, _ := http.NewRequestWithContext(ctx, "GET", backend.Address+"/health", nil)
        resp, err := http.DefaultClient.Do(req)
        if err != nil {
            return false
        }
        defer resp.Body.Close()
        return resp.StatusCode == http.StatusOK

    case "triton":
        // gRPC 健康检查
        // 使用 Triton client
        return true
    }

    return false
}

// InferenceRequest 推理请求
type InferenceRequest struct {
    ModelName      string
    Prompt         string
    MaxTokens      int
    Temperature    float32
    TopP           float32
    Stream         bool
}

// InferenceResponse 推理响应
type InferenceResponse struct {
    Text         string
    TokensUsed   int
    FinishReason string
    Latency      time.Duration
}

// LoadBalancer 负载均衡器接口
type LoadBalancer interface {
    Select(backends []*Backend) *Backend
}

// WeightedRoundRobin 加权轮询
type WeightedRoundRobin struct {
    current int
    mu      sync.Mutex
}

func NewWeightedRoundRobin() *WeightedRoundRobin {
    return &WeightedRoundRobin{}
}

func (w *WeightedRoundRobin) Select(backends []*Backend) *Backend {
    w.mu.Lock()
    defer w.mu.Unlock()

    // 计算总权重
    totalWeight := 0
    for _, b := range backends {
        totalWeight += b.Weight
    }

    if totalWeight == 0 {
        w.current = (w.current + 1) % len(backends)
        return backends[w.current]
    }

    // 加权选择
    w.current = (w.current + 1) % totalWeight
    cumWeight := 0
    for _, b := range backends {
        cumWeight += b.Weight
        if w.current < cumWeight {
            return b
        }
    }

    return backends[0]
}

// CircuitBreaker 熔断器
type CircuitBreaker struct {
    states map[string]*CircuitState
    mu     sync.RWMutex
}

type CircuitState struct {
    Failures      int
    Successes     int
    LastFailure   time.Time
    State         string // closed, open, half-open
}

func NewCircuitBreaker() *CircuitBreaker {
    return &CircuitBreaker{
        states: make(map[string]*CircuitState),
    }
}

func (cb *CircuitBreaker) Allow(backend string) bool {
    cb.mu.RLock()
    state, ok := cb.states[backend]
    cb.mu.RUnlock()

    if !ok {
        return true
    }

    switch state.State {
    case "open":
        // 检查是否可以进入 half-open
        if time.Since(state.LastFailure) > 30*time.Second {
            cb.mu.Lock()
            state.State = "half-open"
            cb.mu.Unlock()
            return true
        }
        return false
    default:
        return true
    }
}

func (cb *CircuitBreaker) RecordSuccess(backend string) {
    cb.mu.Lock()
    defer cb.mu.Unlock()

    state, ok := cb.states[backend]
    if !ok {
        cb.states[backend] = &CircuitState{State: "closed"}
        return
    }

    state.Successes++
    if state.State == "half-open" && state.Successes >= 3 {
        state.State = "closed"
        state.Failures = 0
    }
}

func (cb *CircuitBreaker) RecordFailure(backend string) {
    cb.mu.Lock()
    defer cb.mu.Unlock()

    state, ok := cb.states[backend]
    if !ok {
        cb.states[backend] = &CircuitState{
            Failures:    1,
            LastFailure: time.Now(),
            State:       "closed",
        }
        return
    }

    state.Failures++
    state.LastFailure = time.Now()

    if state.Failures >= 5 {
        state.State = "open"
    }
}

// RateLimiter 限流器
type RateLimiter struct {
    limits map[string]*TokenBucket
    mu     sync.RWMutex
}

type TokenBucket struct {
    tokens     float64
    maxTokens  float64
    refillRate float64
    lastRefill time.Time
}

func NewRateLimiter(defaultQPS int) *RateLimiter {
    return &RateLimiter{
        limits: make(map[string]*TokenBucket),
    }
}

func (rl *RateLimiter) Allow(key string) bool {
    rl.mu.Lock()
    defer rl.mu.Unlock()

    bucket, ok := rl.limits[key]
    if !ok {
        bucket = &TokenBucket{
            tokens:     1000,
            maxTokens:  1000,
            refillRate: 1000, // tokens per second
            lastRefill: time.Now(),
        }
        rl.limits[key] = bucket
    }

    // 补充 tokens
    now := time.Now()
    elapsed := now.Sub(bucket.lastRefill).Seconds()
    bucket.tokens += elapsed * bucket.refillRate
    if bucket.tokens > bucket.maxTokens {
        bucket.tokens = bucket.maxTokens
    }
    bucket.lastRefill = now

    // 检查并消耗 token
    if bucket.tokens >= 1 {
        bucket.tokens--
        return true
    }

    return false
}

最佳实践

框架选型指南

┌─────────────────────────────────────────────────────────────────┐
│                    框架选型决策树                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  1. 确定优先级                                                   │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ 最高吞吐量     → TensorRT-LLM + Triton                    │   │
│  │ 快速部署       → vLLM / TGI                               │   │
│  │ HF 生态集成    → TGI                                      │   │
│  │ 内存效率       → vLLM (PagedAttention)                    │   │
│  │ 多模型服务     → Triton                                   │   │
│  │ 边缘部署       → llama.cpp / ONNX Runtime                 │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  2. 硬件考虑                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ NVIDIA GPU     → TensorRT-LLM (最佳性能)                  │   │
│  │                  vLLM / TGI (易用性)                      │   │
│  │ AMD GPU        → vLLM (ROCm 支持)                         │   │
│  │ CPU           → llama.cpp / ONNX Runtime                 │   │
│  │ 多 GPU        → 所有框架都支持,TRT-LLM 最优              │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
│  3. 生产建议                                                     │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ • 小规模(< 10 QPS): vLLM 单卡部署                       │   │
│  │ • 中规模(10-100 QPS): TGI/vLLM + HPA                   │   │
│  │ • 大规模(> 100 QPS): TRT-LLM + Triton 集群             │   │
│  │ • 多模型: Triton 统一管理                                 │   │
│  └──────────────────────────────────────────────────────────┘   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

小结

本章详细介绍了主流模型服务框架的架构和使用:

  1. 框架对比:vLLM、TGI、TensorRT-LLM、Triton 的特点和适用场景
  2. vLLM:PagedAttention、Continuous Batching 的实现和部署
  3. TGI:HuggingFace 生态的生产级推理服务
  4. TensorRT-LLM:NVIDIA 优化的高性能推理引擎
  5. Triton:多模型编排和统一服务
  6. 统一网关:模型路由、负载均衡、熔断限流

下一章我们将探讨 动态批处理,讲解如何通过批处理优化提升推理吞吐量。

Prev
推理引擎原理
Next
动态批处理