模型服务框架
概述
模型服务框架是将训练好的模型部署为生产级 API 服务的关键基础设施。对于大语言模型,服务框架需要处理高并发请求、优化推理性能、管理 GPU 资源。本章深入讲解主流模型服务框架的架构设计、核心特性和最佳实践。
主流框架对比
框架选型矩阵
┌─────────────────────────────────────────────────────────────────┐
│ LLM 服务框架对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 框架 │ 特点 │ 适用场景 │
│ ─────────────┼─────────────────────────┼───────────────────── │
│ vLLM │ PagedAttention │ 高吞吐量生产服务 │
│ │ Continuous Batching │ 资源受限环境 │
│ │ 高内存效率 │ │
│ ─────────────┼─────────────────────────┼───────────────────── │
│ TGI │ 生产就绪 │ HuggingFace 生态 │
│ (Text │ 多模型支持 │ 企业级部署 │
│ Generation │ 分布式推理 │ │
│ Inference) │ │ │
│ ─────────────┼─────────────────────────┼───────────────────── │
│ TensorRT-LLM │ NVIDIA 优化 │ 最高性能需求 │
│ │ 最佳 GPU 利用率 │ NVIDIA GPU 专用 │
│ │ 量化支持完善 │ │
│ ─────────────┼─────────────────────────┼───────────────────── │
│ Triton │ 多框架支持 │ 异构模型服务 │
│ Inference │ 动态批处理 │ 多模型编排 │
│ Server │ 模型管理 │ │
│ ─────────────┼─────────────────────────┼───────────────────── │
│ llama.cpp │ CPU 推理 │ 边缘设备 │
│ │ 低资源消耗 │ 本地部署 │
│ │ 量化支持 │ │
│ │
│ 性能对比 (LLaMA-7B, A100, batch=1): │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 框架 │ Tokens/s │ TTFT (ms) │ 内存占用 (GB) │ │
│ │─────────────────┼───────────┼───────────┼───────────────│ │
│ │ vLLM │ ~80 │ ~40 │ ~14 │ │
│ │ TGI │ ~75 │ ~45 │ ~15 │ │
│ │ TensorRT-LLM │ ~100 │ ~30 │ ~13 │ │
│ │ PyTorch (基准) │ ~50 │ ~80 │ ~16 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
vLLM 深度解析
vLLM 架构
┌─────────────────────────────────────────────────────────────────┐
│ vLLM 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ API Server │ │
│ │ (OpenAI Compatible API) │ │
│ └─────────────────────────┬───────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────▼───────────────────────────────┐ │
│ │ LLM Engine │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ Scheduler │ │ │
│ │ │ • Continuous Batching │ │ │
│ │ │ • Preemption (swap/recompute) │ │ │
│ │ │ • Priority Scheduling │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ┌─────────────────────▼─────────────────────────┐ │ │
│ │ │ Block Manager │ │ │
│ │ │ • Physical Block Allocation │ │ │
│ │ │ • Block Table Management │ │ │
│ │ │ • Copy-on-Write for Beam Search │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ┌─────────────────────▼─────────────────────────┐ │ │
│ │ │ Model Executor │ │ │
│ │ │ • PagedAttention Kernel │ │ │
│ │ │ • Flash Attention │ │ │
│ │ │ • Tensor Parallelism │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────▼───────────────────────────────┐ │
│ │ GPU Workers │ │
│ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │
│ │ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ GPU 3 │ │ │
│ │ │ Model │ │ Model │ │ Model │ │ Model │ │ │
│ │ │ Shard │ │ Shard │ │ Shard │ │ Shard │ │ │
│ │ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
vLLM 部署配置
# vllm_server.py
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.api_server import (
create_chat_completion,
create_completion,
)
import asyncio
from fastapi import FastAPI, Request
from typing import Optional, List, Dict, Any
import uvicorn
# vLLM 配置类
class VLLMConfig:
def __init__(
self,
model: str,
# 硬件配置
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
gpu_memory_utilization: float = 0.9,
# 批处理配置
max_num_batched_tokens: int = 8192,
max_num_seqs: int = 256,
max_model_len: Optional[int] = None,
# KV Cache 配置
block_size: int = 16,
swap_space: int = 4, # GB
# 量化配置
quantization: Optional[str] = None, # awq, gptq, squeezellm
# 调度配置
scheduler_delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
# 推测解码
speculative_model: Optional[str] = None,
num_speculative_tokens: int = 5,
):
self.model = model
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
self.gpu_memory_utilization = gpu_memory_utilization
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.block_size = block_size
self.swap_space = swap_space
self.quantization = quantization
self.scheduler_delay_factor = scheduler_delay_factor
self.enable_chunked_prefill = enable_chunked_prefill
self.speculative_model = speculative_model
self.num_speculative_tokens = num_speculative_tokens
def to_engine_args(self) -> AsyncEngineArgs:
return AsyncEngineArgs(
model=self.model,
tensor_parallel_size=self.tensor_parallel_size,
pipeline_parallel_size=self.pipeline_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=self.max_model_len,
block_size=self.block_size,
swap_space=self.swap_space,
quantization=self.quantization,
scheduler_delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
trust_remote_code=True,
dtype="auto",
)
# 自定义 vLLM 服务
class VLLMServer:
def __init__(self, config: VLLMConfig):
self.config = config
self.engine: Optional[AsyncLLMEngine] = None
self.app = FastAPI(title="vLLM API Server")
self._setup_routes()
async def initialize(self):
"""初始化引擎"""
engine_args = self.config.to_engine_args()
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
def _setup_routes(self):
"""设置 API 路由"""
@self.app.post("/v1/completions")
async def completions(request: Request):
"""OpenAI Completions API"""
body = await request.json()
return await self._handle_completion(body)
@self.app.post("/v1/chat/completions")
async def chat_completions(request: Request):
"""OpenAI Chat Completions API"""
body = await request.json()
return await self._handle_chat_completion(body)
@self.app.get("/health")
async def health():
"""健康检查"""
return {"status": "healthy"}
@self.app.get("/v1/models")
async def list_models():
"""列出模型"""
return {
"data": [
{
"id": self.config.model,
"object": "model",
"owned_by": "vllm",
}
]
}
async def _handle_completion(self, body: Dict[str, Any]):
"""处理 Completion 请求"""
prompt = body.get("prompt", "")
sampling_params = SamplingParams(
temperature=body.get("temperature", 1.0),
top_p=body.get("top_p", 1.0),
top_k=body.get("top_k", -1),
max_tokens=body.get("max_tokens", 16),
stop=body.get("stop"),
presence_penalty=body.get("presence_penalty", 0.0),
frequency_penalty=body.get("frequency_penalty", 0.0),
)
request_id = f"cmpl-{self._generate_request_id()}"
# 生成
results = []
async for output in self.engine.generate(prompt, sampling_params, request_id):
results.append(output)
final_output = results[-1]
return {
"id": request_id,
"object": "text_completion",
"model": self.config.model,
"choices": [
{
"text": output.text,
"index": i,
"finish_reason": output.finish_reason,
}
for i, output in enumerate(final_output.outputs)
],
"usage": {
"prompt_tokens": len(final_output.prompt_token_ids),
"completion_tokens": sum(
len(o.token_ids) for o in final_output.outputs
),
"total_tokens": len(final_output.prompt_token_ids)
+ sum(len(o.token_ids) for o in final_output.outputs),
},
}
async def _handle_chat_completion(self, body: Dict[str, Any]):
"""处理 Chat Completion 请求"""
messages = body.get("messages", [])
# 转换消息为 prompt
prompt = self._messages_to_prompt(messages)
# 复用 completion 逻辑
body["prompt"] = prompt
return await self._handle_completion(body)
def _messages_to_prompt(self, messages: List[Dict]) -> str:
"""将消息列表转换为 prompt"""
# 根据模型类型使用不同的模板
# 这里使用 ChatML 格式作为示例
prompt_parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
prompt_parts.append("<|im_start|>assistant\n")
return "\n".join(prompt_parts)
def _generate_request_id(self) -> str:
import uuid
return str(uuid.uuid4())[:8]
def run(self, host: str = "0.0.0.0", port: int = 8000):
"""运行服务"""
asyncio.run(self.initialize())
uvicorn.run(self.app, host=host, port=port)
# 启动脚本
if __name__ == "__main__":
config = VLLMConfig(
model="meta-llama/Llama-2-7b-chat-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
max_num_seqs=256,
max_model_len=4096,
)
server = VLLMServer(config)
server.run()
vLLM Kubernetes 部署
# vllm-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: vllm-server
namespace: ai-inference
spec:
replicas: 1
selector:
matchLabels:
app: vllm-server
template:
metadata:
labels:
app: vllm-server
spec:
containers:
- name: vllm
image: vllm/vllm-openai:latest
command:
- python
- -m
- vllm.entrypoints.openai.api_server
args:
- --model=/models/llama-2-7b-chat
- --tensor-parallel-size=1
- --gpu-memory-utilization=0.9
- --max-num-seqs=256
- --max-model-len=4096
- --host=0.0.0.0
- --port=8000
ports:
- containerPort: 8000
name: http
resources:
limits:
nvidia.com/gpu: 1
memory: "32Gi"
requests:
cpu: "4"
memory: "16Gi"
volumeMounts:
- name: model-storage
mountPath: /models
- name: shm
mountPath: /dev/shm
env:
- name: CUDA_VISIBLE_DEVICES
value: "0"
- name: NCCL_DEBUG
value: "WARN"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 60
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 5
volumes:
- name: model-storage
persistentVolumeClaim:
claimName: model-pvc
- name: shm
emptyDir:
medium: Memory
sizeLimit: 16Gi
nodeSelector:
nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"
---
apiVersion: v1
kind: Service
metadata:
name: vllm-server
namespace: ai-inference
spec:
selector:
app: vllm-server
ports:
- port: 8000
targetPort: 8000
name: http
type: ClusterIP
---
# HPA 配置
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: vllm-server-hpa
namespace: ai-inference
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: vllm-server
minReplicas: 1
maxReplicas: 4
metrics:
- type: Pods
pods:
metric:
name: vllm_num_requests_running
target:
type: AverageValue
averageValue: "200"
- type: Resource
resource:
name: nvidia.com/gpu
target:
type: Utilization
averageUtilization: 80
Text Generation Inference (TGI)
TGI 架构
┌─────────────────────────────────────────────────────────────────┐
│ TGI 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Router (Rust) │ │
│ │ • 请求接收和验证 │ │
│ │ • Token 预估 │ │
│ │ • 负载均衡 │ │
│ └─────────────────────────┬───────────────────────────────┘ │
│ │ gRPC │
│ ┌─────────────────────────▼───────────────────────────────┐ │
│ │ Inference Shard (Python) │ │
│ │ │ │
│ │ ┌────────────────────────────────────────────────┐ │ │
│ │ │ Batcher │ │ │
│ │ │ • Continuous Batching │ │ │
│ │ │ • Token Budget Management │ │ │
│ │ │ • Request Queuing │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ┌────────────────────▼─────────────────────────┐ │ │
│ │ │ Model │ │ │
│ │ │ • Flash Attention 2 │ │ │
│ │ │ • Tensor Parallelism │ │ │
│ │ │ • Quantization (GPTQ, AWQ, EETQ) │ │ │
│ │ │ • Paged Attention │ │ │
│ │ └────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 支持的模型: │
│ • LLaMA, Mistral, Mixtral │
│ • Falcon, GPT-NeoX, StarCoder │
│ • BLOOM, T5, BERT │
│ • Gemma, Phi, Qwen │
│ │
└─────────────────────────────────────────────────────────────────┘
TGI 配置和部署
# tgi-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: tgi-server
namespace: ai-inference
spec:
replicas: 1
selector:
matchLabels:
app: tgi-server
template:
metadata:
labels:
app: tgi-server
annotations:
prometheus.io/scrape: "true"
prometheus.io/port: "80"
spec:
containers:
- name: tgi
image: ghcr.io/huggingface/text-generation-inference:2.0
args:
- --model-id=meta-llama/Llama-2-7b-chat-hf
- --num-shard=1
- --max-input-length=4096
- --max-total-tokens=8192
- --max-batch-prefill-tokens=8192
- --max-batch-total-tokens=16384
- --max-concurrent-requests=256
- --quantize=awq
- --dtype=float16
- --trust-remote-code
- --hostname=0.0.0.0
- --port=80
ports:
- containerPort: 80
name: http
resources:
limits:
nvidia.com/gpu: 1
memory: "32Gi"
requests:
cpu: "4"
memory: "16Gi"
volumeMounts:
- name: model-cache
mountPath: /data
- name: shm
mountPath: /dev/shm
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-token
key: token
- name: CUDA_MEMORY_FRACTION
value: "0.9"
livenessProbe:
httpGet:
path: /health
port: 80
initialDelaySeconds: 120
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 80
initialDelaySeconds: 60
periodSeconds: 5
volumes:
- name: model-cache
persistentVolumeClaim:
claimName: tgi-model-cache
- name: shm
emptyDir:
medium: Memory
sizeLimit: 16Gi
nodeSelector:
nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"
---
# TGI 服务配置
apiVersion: v1
kind: Service
metadata:
name: tgi-server
namespace: ai-inference
spec:
selector:
app: tgi-server
ports:
- port: 80
targetPort: 80
name: http
type: ClusterIP
TGI 客户端使用
# tgi_client.py
import asyncio
from typing import AsyncIterator, Optional, List, Dict, Any
import aiohttp
import json
class TGIClient:
"""TGI 异步客户端"""
def __init__(self, base_url: str, timeout: float = 60.0):
self.base_url = base_url.rstrip("/")
self.timeout = aiohttp.ClientTimeout(total=timeout)
async def generate(
self,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: Optional[int] = None,
repetition_penalty: float = 1.0,
stop_sequences: Optional[List[str]] = None,
stream: bool = False,
) -> Dict[str, Any]:
"""生成文本"""
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": temperature > 0,
},
}
if top_k is not None:
payload["parameters"]["top_k"] = top_k
if stop_sequences:
payload["parameters"]["stop"] = stop_sequences
async with aiohttp.ClientSession(timeout=self.timeout) as session:
if stream:
return await self._stream_generate(session, payload)
else:
return await self._generate(session, payload)
async def _generate(
self,
session: aiohttp.ClientSession,
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""非流式生成"""
async with session.post(
f"{self.base_url}/generate",
json=payload,
) as response:
response.raise_for_status()
return await response.json()
async def _stream_generate(
self,
session: aiohttp.ClientSession,
payload: Dict[str, Any],
) -> AsyncIterator[Dict[str, Any]]:
"""流式生成"""
payload["stream"] = True
async with session.post(
f"{self.base_url}/generate_stream",
json=payload,
) as response:
response.raise_for_status()
async for line in response.content:
line = line.decode("utf-8").strip()
if line.startswith("data:"):
data = json.loads(line[5:])
yield data
async def generate_stream(
self,
prompt: str,
**kwargs,
) -> AsyncIterator[str]:
"""流式生成,返回文本流"""
async for chunk in await self.generate(prompt, stream=True, **kwargs):
if "token" in chunk:
yield chunk["token"]["text"]
async def health(self) -> Dict[str, Any]:
"""健康检查"""
async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.get(f"{self.base_url}/health") as response:
return await response.json()
async def info(self) -> Dict[str, Any]:
"""获取模型信息"""
async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.get(f"{self.base_url}/info") as response:
return await response.json()
# 使用示例
async def main():
client = TGIClient("http://tgi-server:80")
# 检查健康状态
health = await client.health()
print(f"Health: {health}")
# 获取模型信息
info = await client.info()
print(f"Model: {info['model_id']}")
# 非流式生成
response = await client.generate(
prompt="What is the capital of France?",
max_new_tokens=50,
temperature=0.7,
)
print(f"Response: {response['generated_text']}")
# 流式生成
print("Streaming response: ", end="")
async for token in client.generate_stream(
prompt="Tell me a short story about a robot.",
max_new_tokens=200,
temperature=0.8,
):
print(token, end="", flush=True)
print()
if __name__ == "__main__":
asyncio.run(main())
TensorRT-LLM
TensorRT-LLM 架构
┌─────────────────────────────────────────────────────────────────┐
│ TensorRT-LLM 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 构建阶段 (Offline): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Model Builder │ │
│ │ ┌─────────┐ ┌──────────┐ ┌──────────────┐ │ │
│ │ │ HF │───▶│ TRT-LLM │───▶│ TensorRT │ │ │
│ │ │ Model │ │ Network │ │ Engine │ │ │
│ │ └─────────┘ └──────────┘ └──────────────┘ │ │
│ │ │ │
│ │ 优化过程: │ │
│ │ • 图优化 (层融合、常量折叠) │ │
│ │ • 量化校准 (INT8/FP8) │ │
│ │ • Kernel 选择 (Auto-tuning) │ │
│ │ • 内存优化 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 运行阶段 (Online): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Triton Backend │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ Inflight Batching │ │ │
│ │ │ • Continuous Batching │ │ │
│ │ │ • Chunked Context │ │ │
│ │ │ • KV Cache Management │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ┌─────────────────────▼─────────────────────────┐ │ │
│ │ │ TensorRT Runtime │ │ │
│ │ │ • Optimized CUDA Kernels │ │ │
│ │ │ • Multi-GPU Support (TP/PP) │ │ │
│ │ │ • FP8/INT8/INT4 Inference │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
TensorRT-LLM 模型构建
# build_trt_llm.py
import tensorrt_llm
from tensorrt_llm import BuildConfig, Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.quantization import QuantMode
import torch
def build_trt_llm_engine(
model_dir: str,
output_dir: str,
dtype: str = "float16",
tp_size: int = 1,
pp_size: int = 1,
max_batch_size: int = 256,
max_input_len: int = 2048,
max_output_len: int = 2048,
quantization: str = None, # int8, int4_awq, fp8
):
"""构建 TensorRT-LLM 引擎"""
# 设置量化模式
quant_mode = QuantMode(0)
if quantization == "int8":
quant_mode = QuantMode.use_weight_only(use_int4_weights=False)
elif quantization == "int4_awq":
quant_mode = QuantMode.use_weight_only(use_int4_weights=True)
elif quantization == "fp8":
quant_mode = QuantMode.use_fp8_qdq()
# 创建映射配置(张量并行/流水线并行)
mapping = Mapping(
world_size=tp_size * pp_size,
tp_size=tp_size,
pp_size=pp_size,
)
# 构建配置
build_config = BuildConfig(
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_beam_width=1,
max_num_tokens=max_batch_size * max_input_len,
strongly_typed=True,
builder_opt=5, # 优化级别
)
# 加载模型
if dtype == "float16":
torch_dtype = torch.float16
elif dtype == "bfloat16":
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
# 从 HuggingFace 格式转换
model = LLaMAForCausalLM.from_hugging_face(
model_dir,
dtype=dtype,
mapping=mapping,
quant_mode=quant_mode,
)
# 构建引擎
engine = tensorrt_llm.build(
model,
build_config,
)
# 保存引擎
engine.save(output_dir)
print(f"Engine saved to {output_dir}")
return output_dir
# 多 GPU 并行构建脚本
def build_multi_gpu_engine():
"""构建多 GPU 引擎"""
import subprocess
import os
# 使用 mpirun 启动多进程构建
cmd = [
"mpirun",
"-n", "4", # 4 GPU
"--allow-run-as-root",
"python", "build_trt_llm.py",
"--model_dir", "/models/llama-70b",
"--output_dir", "/engines/llama-70b-tp4",
"--tp_size", "4",
"--dtype", "float16",
"--max_batch_size", "64",
]
subprocess.run(cmd, check=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", required=True)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--dtype", default="float16")
parser.add_argument("--max_batch_size", type=int, default=256)
parser.add_argument("--quantization", default=None)
args = parser.parse_args()
build_trt_llm_engine(
model_dir=args.model_dir,
output_dir=args.output_dir,
dtype=args.dtype,
tp_size=args.tp_size,
max_batch_size=args.max_batch_size,
quantization=args.quantization,
)
Triton + TensorRT-LLM 部署
# triton_trt_llm_config.py
# Triton 模型配置生成
import json
import os
def generate_triton_config(
model_name: str,
engine_dir: str,
output_dir: str,
max_batch_size: int = 256,
tp_size: int = 1,
pp_size: int = 1,
):
"""生成 Triton 配置文件"""
config = {
"name": model_name,
"backend": "tensorrtllm",
"max_batch_size": max_batch_size,
"model_transaction_policy": {
"decoupled": True # 流式输出
},
"input": [
{
"name": "input_ids",
"data_type": "TYPE_INT32",
"dims": [-1]
},
{
"name": "input_lengths",
"data_type": "TYPE_INT32",
"dims": [1]
},
{
"name": "request_output_len",
"data_type": "TYPE_INT32",
"dims": [1]
},
{
"name": "end_id",
"data_type": "TYPE_INT32",
"dims": [1],
"optional": True
},
{
"name": "pad_id",
"data_type": "TYPE_INT32",
"dims": [1],
"optional": True
},
{
"name": "beam_width",
"data_type": "TYPE_INT32",
"dims": [1],
"optional": True
},
{
"name": "temperature",
"data_type": "TYPE_FP32",
"dims": [1],
"optional": True
},
{
"name": "top_k",
"data_type": "TYPE_INT32",
"dims": [1],
"optional": True
},
{
"name": "top_p",
"data_type": "TYPE_FP32",
"dims": [1],
"optional": True
},
{
"name": "streaming",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True
}
],
"output": [
{
"name": "output_ids",
"data_type": "TYPE_INT32",
"dims": [-1, -1]
},
{
"name": "sequence_length",
"data_type": "TYPE_INT32",
"dims": [-1]
}
],
"instance_group": [
{
"count": 1,
"kind": "KIND_GPU",
"gpus": list(range(tp_size * pp_size))
}
],
"parameters": {
"gpt_model_type": {"string_value": "llama"},
"gpt_model_path": {"string_value": engine_dir},
"batch_scheduler_policy": {"string_value": "inflight_fused_batching"},
"decoupled_mode": {"string_value": "true"},
"max_tokens_in_paged_kv_cache": {"string_value": ""},
"kv_cache_free_gpu_mem_fraction": {"string_value": "0.9"},
"enable_chunked_context": {"string_value": "true"},
}
}
# 创建目录
model_dir = os.path.join(output_dir, model_name, "1")
os.makedirs(model_dir, exist_ok=True)
# 写入配置
config_path = os.path.join(output_dir, model_name, "config.pbtxt")
with open(config_path, "w") as f:
f.write(dict_to_pbtxt(config))
print(f"Triton config saved to {config_path}")
return config_path
def dict_to_pbtxt(d: dict, indent: int = 0) -> str:
"""将字典转换为 pbtxt 格式"""
lines = []
prefix = " " * indent
for key, value in d.items():
if isinstance(value, dict):
lines.append(f"{prefix}{key} {{")
lines.append(dict_to_pbtxt(value, indent + 1))
lines.append(f"{prefix}}}")
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
lines.append(f"{prefix}{key} {{")
lines.append(dict_to_pbtxt(item, indent + 1))
lines.append(f"{prefix}}}")
else:
lines.append(f'{prefix}{key}: {json.dumps(item)}')
elif isinstance(value, bool):
lines.append(f'{prefix}{key}: {str(value).lower()}')
elif isinstance(value, str):
lines.append(f'{prefix}{key}: "{value}"')
else:
lines.append(f'{prefix}{key}: {value}')
return "\n".join(lines)
NVIDIA Triton Inference Server
Triton 多模型编排
# triton-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: triton-server
namespace: ai-inference
spec:
replicas: 1
selector:
matchLabels:
app: triton-server
template:
metadata:
labels:
app: triton-server
annotations:
prometheus.io/scrape: "true"
prometheus.io/port: "8002"
spec:
containers:
- name: triton
image: nvcr.io/nvidia/tritonserver:24.01-trtllm-python-py3
args:
- tritonserver
- --model-repository=/models
- --model-control-mode=explicit
- --load-model=*
- --http-port=8000
- --grpc-port=8001
- --metrics-port=8002
- --log-verbose=1
ports:
- containerPort: 8000
name: http
- containerPort: 8001
name: grpc
- containerPort: 8002
name: metrics
resources:
limits:
nvidia.com/gpu: 4
memory: "128Gi"
requests:
cpu: "16"
memory: "64Gi"
volumeMounts:
- name: model-repository
mountPath: /models
- name: shm
mountPath: /dev/shm
env:
- name: CUDA_VISIBLE_DEVICES
value: "0,1,2,3"
livenessProbe:
httpGet:
path: /v2/health/live
port: 8000
initialDelaySeconds: 120
periodSeconds: 10
readinessProbe:
httpGet:
path: /v2/health/ready
port: 8000
initialDelaySeconds: 60
periodSeconds: 5
volumes:
- name: model-repository
persistentVolumeClaim:
claimName: triton-models
- name: shm
emptyDir:
medium: Memory
sizeLimit: 32Gi
---
# Triton 服务
apiVersion: v1
kind: Service
metadata:
name: triton-server
namespace: ai-inference
spec:
selector:
app: triton-server
ports:
- port: 8000
targetPort: 8000
name: http
- port: 8001
targetPort: 8001
name: grpc
- port: 8002
targetPort: 8002
name: metrics
Triton 模型集成
// triton_client.go
package inference
import (
"context"
"fmt"
"time"
"google.golang.org/grpc"
pb "github.com/triton-inference-server/client/src/grpc_generated/go/grpc-client"
)
// TritonClient Triton gRPC 客户端
type TritonClient struct {
conn *grpc.ClientConn
client pb.GRPCInferenceServiceClient
}
// NewTritonClient 创建客户端
func NewTritonClient(address string) (*TritonClient, error) {
conn, err := grpc.Dial(address, grpc.WithInsecure())
if err != nil {
return nil, fmt.Errorf("connect to triton: %w", err)
}
return &TritonClient{
conn: conn,
client: pb.NewGRPCInferenceServiceClient(conn),
}, nil
}
// Close 关闭连接
func (c *TritonClient) Close() error {
return c.conn.Close()
}
// ServerLive 检查服务是否存活
func (c *TritonClient) ServerLive(ctx context.Context) (bool, error) {
resp, err := c.client.ServerLive(ctx, &pb.ServerLiveRequest{})
if err != nil {
return false, err
}
return resp.Live, nil
}
// ServerReady 检查服务是否就绪
func (c *TritonClient) ServerReady(ctx context.Context) (bool, error) {
resp, err := c.client.ServerReady(ctx, &pb.ServerReadyRequest{})
if err != nil {
return false, err
}
return resp.Ready, nil
}
// ModelReady 检查模型是否就绪
func (c *TritonClient) ModelReady(ctx context.Context, modelName string) (bool, error) {
resp, err := c.client.ModelReady(ctx, &pb.ModelReadyRequest{
Name: modelName,
})
if err != nil {
return false, err
}
return resp.Ready, nil
}
// Infer 推理请求
func (c *TritonClient) Infer(
ctx context.Context,
modelName string,
inputs []*pb.ModelInferRequest_InferInputTensor,
outputs []*pb.ModelInferRequest_InferRequestedOutputTensor,
) (*pb.ModelInferResponse, error) {
request := &pb.ModelInferRequest{
ModelName: modelName,
Inputs: inputs,
Outputs: outputs,
}
return c.client.ModelInfer(ctx, request)
}
// InferLLM 推理 LLM
func (c *TritonClient) InferLLM(
ctx context.Context,
modelName string,
inputIDs []int32,
maxNewTokens int32,
temperature float32,
topP float32,
stream bool,
) ([]int32, error) {
// 构建输入
inputs := []*pb.ModelInferRequest_InferInputTensor{
{
Name: "input_ids",
Datatype: "INT32",
Shape: []int64{1, int64(len(inputIDs))},
Contents: &pb.InferTensorContents{
IntContents: inputIDs,
},
},
{
Name: "input_lengths",
Datatype: "INT32",
Shape: []int64{1},
Contents: &pb.InferTensorContents{
IntContents: []int32{int32(len(inputIDs))},
},
},
{
Name: "request_output_len",
Datatype: "INT32",
Shape: []int64{1},
Contents: &pb.InferTensorContents{
IntContents: []int32{maxNewTokens},
},
},
{
Name: "temperature",
Datatype: "FP32",
Shape: []int64{1},
Contents: &pb.InferTensorContents{
Fp32Contents: []float32{temperature},
},
},
{
Name: "top_p",
Datatype: "FP32",
Shape: []int64{1},
Contents: &pb.InferTensorContents{
Fp32Contents: []float32{topP},
},
},
}
// 构建输出
outputs := []*pb.ModelInferRequest_InferRequestedOutputTensor{
{Name: "output_ids"},
{Name: "sequence_length"},
}
// 执行推理
resp, err := c.Infer(ctx, modelName, inputs, outputs)
if err != nil {
return nil, err
}
// 解析输出
for _, output := range resp.Outputs {
if output.Name == "output_ids" {
return output.Contents.IntContents, nil
}
}
return nil, fmt.Errorf("output_ids not found in response")
}
// StreamInfer 流式推理
func (c *TritonClient) StreamInfer(
ctx context.Context,
modelName string,
inputs []*pb.ModelInferRequest_InferInputTensor,
) (<-chan *pb.ModelStreamInferResponse, error) {
stream, err := c.client.ModelStreamInfer(ctx)
if err != nil {
return nil, err
}
// 发送请求
request := &pb.ModelInferRequest{
ModelName: modelName,
Inputs: inputs,
}
if err := stream.Send(request); err != nil {
return nil, err
}
// 接收响应
resultCh := make(chan *pb.ModelStreamInferResponse, 100)
go func() {
defer close(resultCh)
for {
resp, err := stream.Recv()
if err != nil {
return
}
resultCh <- resp
}
}()
return resultCh, nil
}
// GetModelMetadata 获取模型元数据
func (c *TritonClient) GetModelMetadata(ctx context.Context, modelName string) (*pb.ModelMetadataResponse, error) {
return c.client.ModelMetadata(ctx, &pb.ModelMetadataRequest{
Name: modelName,
})
}
// LoadModel 加载模型
func (c *TritonClient) LoadModel(ctx context.Context, modelName string) error {
_, err := c.client.RepositoryModelLoad(ctx, &pb.RepositoryModelLoadRequest{
ModelName: modelName,
})
return err
}
// UnloadModel 卸载模型
func (c *TritonClient) UnloadModel(ctx context.Context, modelName string) error {
_, err := c.client.RepositoryModelUnload(ctx, &pb.RepositoryModelUnloadRequest{
ModelName: modelName,
})
return err
}
统一服务网关
模型路由器
// model_router.go
package gateway
import (
"context"
"fmt"
"net/http"
"sync"
"time"
)
// ModelRouter 模型路由器
type ModelRouter struct {
// 后端配置
backends map[string]*Backend
// 路由规则
routes map[string]*Route
// 负载均衡器
loadBalancer LoadBalancer
// 熔断器
circuitBreaker *CircuitBreaker
// 限流器
rateLimiter *RateLimiter
mu sync.RWMutex
}
// Backend 后端服务
type Backend struct {
Name string
Type string // vllm, tgi, triton
Address string
Weight int
Healthy bool
LastCheck time.Time
// 连接池
client interface{}
}
// Route 路由规则
type Route struct {
ModelName string
Backends []string
Strategy string // round_robin, least_conn, weighted
Timeout time.Duration
RetryCount int
FallbackModel string
}
// NewModelRouter 创建路由器
func NewModelRouter() *ModelRouter {
router := &ModelRouter{
backends: make(map[string]*Backend),
routes: make(map[string]*Route),
loadBalancer: NewWeightedRoundRobin(),
circuitBreaker: NewCircuitBreaker(),
rateLimiter: NewRateLimiter(1000), // 1000 QPS
}
// 启动健康检查
go router.healthCheckLoop()
return router
}
// RegisterBackend 注册后端
func (r *ModelRouter) RegisterBackend(backend *Backend) {
r.mu.Lock()
defer r.mu.Unlock()
r.backends[backend.Name] = backend
}
// RegisterRoute 注册路由
func (r *ModelRouter) RegisterRoute(route *Route) {
r.mu.Lock()
defer r.mu.Unlock()
r.routes[route.ModelName] = route
}
// Route 路由请求
func (r *ModelRouter) Route(ctx context.Context, req *InferenceRequest) (*InferenceResponse, error) {
// 限流检查
if !r.rateLimiter.Allow(req.ModelName) {
return nil, fmt.Errorf("rate limit exceeded")
}
// 获取路由规则
r.mu.RLock()
route, ok := r.routes[req.ModelName]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("model not found: %s", req.ModelName)
}
// 选择后端
backend, err := r.selectBackend(route)
if err != nil {
return nil, err
}
// 熔断检查
if !r.circuitBreaker.Allow(backend.Name) {
// 尝试回退模型
if route.FallbackModel != "" {
req.ModelName = route.FallbackModel
return r.Route(ctx, req)
}
return nil, fmt.Errorf("circuit breaker open for backend: %s", backend.Name)
}
// 设置超时
ctx, cancel := context.WithTimeout(ctx, route.Timeout)
defer cancel()
// 执行请求(带重试)
var resp *InferenceResponse
var lastErr error
for i := 0; i <= route.RetryCount; i++ {
resp, lastErr = r.executeRequest(ctx, backend, req)
if lastErr == nil {
r.circuitBreaker.RecordSuccess(backend.Name)
return resp, nil
}
r.circuitBreaker.RecordFailure(backend.Name)
// 选择新后端重试
backend, err = r.selectBackend(route)
if err != nil {
break
}
}
return nil, lastErr
}
// selectBackend 选择后端
func (r *ModelRouter) selectBackend(route *Route) (*Backend, error) {
r.mu.RLock()
defer r.mu.RUnlock()
// 过滤健康的后端
var healthyBackends []*Backend
for _, name := range route.Backends {
if backend, ok := r.backends[name]; ok && backend.Healthy {
healthyBackends = append(healthyBackends, backend)
}
}
if len(healthyBackends) == 0 {
return nil, fmt.Errorf("no healthy backends available")
}
// 负载均衡
return r.loadBalancer.Select(healthyBackends), nil
}
// executeRequest 执行请求
func (r *ModelRouter) executeRequest(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
switch backend.Type {
case "vllm":
return r.callVLLM(ctx, backend, req)
case "tgi":
return r.callTGI(ctx, backend, req)
case "triton":
return r.callTriton(ctx, backend, req)
default:
return nil, fmt.Errorf("unknown backend type: %s", backend.Type)
}
}
func (r *ModelRouter) callVLLM(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
// 调用 vLLM OpenAI 兼容 API
return nil, nil
}
func (r *ModelRouter) callTGI(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
// 调用 TGI API
return nil, nil
}
func (r *ModelRouter) callTriton(ctx context.Context, backend *Backend, req *InferenceRequest) (*InferenceResponse, error) {
// 调用 Triton gRPC
return nil, nil
}
// healthCheckLoop 健康检查循环
func (r *ModelRouter) healthCheckLoop() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
r.checkAllBackends()
}
}
func (r *ModelRouter) checkAllBackends() {
r.mu.RLock()
backends := make([]*Backend, 0, len(r.backends))
for _, b := range r.backends {
backends = append(backends, b)
}
r.mu.RUnlock()
for _, backend := range backends {
healthy := r.checkBackendHealth(backend)
r.mu.Lock()
backend.Healthy = healthy
backend.LastCheck = time.Now()
r.mu.Unlock()
}
}
func (r *ModelRouter) checkBackendHealth(backend *Backend) bool {
// 执行健康检查
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
switch backend.Type {
case "vllm", "tgi":
// HTTP 健康检查
req, _ := http.NewRequestWithContext(ctx, "GET", backend.Address+"/health", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
case "triton":
// gRPC 健康检查
// 使用 Triton client
return true
}
return false
}
// InferenceRequest 推理请求
type InferenceRequest struct {
ModelName string
Prompt string
MaxTokens int
Temperature float32
TopP float32
Stream bool
}
// InferenceResponse 推理响应
type InferenceResponse struct {
Text string
TokensUsed int
FinishReason string
Latency time.Duration
}
// LoadBalancer 负载均衡器接口
type LoadBalancer interface {
Select(backends []*Backend) *Backend
}
// WeightedRoundRobin 加权轮询
type WeightedRoundRobin struct {
current int
mu sync.Mutex
}
func NewWeightedRoundRobin() *WeightedRoundRobin {
return &WeightedRoundRobin{}
}
func (w *WeightedRoundRobin) Select(backends []*Backend) *Backend {
w.mu.Lock()
defer w.mu.Unlock()
// 计算总权重
totalWeight := 0
for _, b := range backends {
totalWeight += b.Weight
}
if totalWeight == 0 {
w.current = (w.current + 1) % len(backends)
return backends[w.current]
}
// 加权选择
w.current = (w.current + 1) % totalWeight
cumWeight := 0
for _, b := range backends {
cumWeight += b.Weight
if w.current < cumWeight {
return b
}
}
return backends[0]
}
// CircuitBreaker 熔断器
type CircuitBreaker struct {
states map[string]*CircuitState
mu sync.RWMutex
}
type CircuitState struct {
Failures int
Successes int
LastFailure time.Time
State string // closed, open, half-open
}
func NewCircuitBreaker() *CircuitBreaker {
return &CircuitBreaker{
states: make(map[string]*CircuitState),
}
}
func (cb *CircuitBreaker) Allow(backend string) bool {
cb.mu.RLock()
state, ok := cb.states[backend]
cb.mu.RUnlock()
if !ok {
return true
}
switch state.State {
case "open":
// 检查是否可以进入 half-open
if time.Since(state.LastFailure) > 30*time.Second {
cb.mu.Lock()
state.State = "half-open"
cb.mu.Unlock()
return true
}
return false
default:
return true
}
}
func (cb *CircuitBreaker) RecordSuccess(backend string) {
cb.mu.Lock()
defer cb.mu.Unlock()
state, ok := cb.states[backend]
if !ok {
cb.states[backend] = &CircuitState{State: "closed"}
return
}
state.Successes++
if state.State == "half-open" && state.Successes >= 3 {
state.State = "closed"
state.Failures = 0
}
}
func (cb *CircuitBreaker) RecordFailure(backend string) {
cb.mu.Lock()
defer cb.mu.Unlock()
state, ok := cb.states[backend]
if !ok {
cb.states[backend] = &CircuitState{
Failures: 1,
LastFailure: time.Now(),
State: "closed",
}
return
}
state.Failures++
state.LastFailure = time.Now()
if state.Failures >= 5 {
state.State = "open"
}
}
// RateLimiter 限流器
type RateLimiter struct {
limits map[string]*TokenBucket
mu sync.RWMutex
}
type TokenBucket struct {
tokens float64
maxTokens float64
refillRate float64
lastRefill time.Time
}
func NewRateLimiter(defaultQPS int) *RateLimiter {
return &RateLimiter{
limits: make(map[string]*TokenBucket),
}
}
func (rl *RateLimiter) Allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
bucket, ok := rl.limits[key]
if !ok {
bucket = &TokenBucket{
tokens: 1000,
maxTokens: 1000,
refillRate: 1000, // tokens per second
lastRefill: time.Now(),
}
rl.limits[key] = bucket
}
// 补充 tokens
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
bucket.tokens += elapsed * bucket.refillRate
if bucket.tokens > bucket.maxTokens {
bucket.tokens = bucket.maxTokens
}
bucket.lastRefill = now
// 检查并消耗 token
if bucket.tokens >= 1 {
bucket.tokens--
return true
}
return false
}
最佳实践
框架选型指南
┌─────────────────────────────────────────────────────────────────┐
│ 框架选型决策树 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 确定优先级 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 最高吞吐量 → TensorRT-LLM + Triton │ │
│ │ 快速部署 → vLLM / TGI │ │
│ │ HF 生态集成 → TGI │ │
│ │ 内存效率 → vLLM (PagedAttention) │ │
│ │ 多模型服务 → Triton │ │
│ │ 边缘部署 → llama.cpp / ONNX Runtime │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 2. 硬件考虑 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ NVIDIA GPU → TensorRT-LLM (最佳性能) │ │
│ │ vLLM / TGI (易用性) │ │
│ │ AMD GPU → vLLM (ROCm 支持) │ │
│ │ CPU → llama.cpp / ONNX Runtime │ │
│ │ 多 GPU → 所有框架都支持,TRT-LLM 最优 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
│ 3. 生产建议 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ • 小规模(< 10 QPS): vLLM 单卡部署 │ │
│ │ • 中规模(10-100 QPS): TGI/vLLM + HPA │ │
│ │ • 大规模(> 100 QPS): TRT-LLM + Triton 集群 │ │
│ │ • 多模型: Triton 统一管理 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
小结
本章详细介绍了主流模型服务框架的架构和使用:
- 框架对比:vLLM、TGI、TensorRT-LLM、Triton 的特点和适用场景
- vLLM:PagedAttention、Continuous Batching 的实现和部署
- TGI:HuggingFace 生态的生产级推理服务
- TensorRT-LLM:NVIDIA 优化的高性能推理引擎
- Triton:多模型编排和统一服务
- 统一网关:模型路由、负载均衡、熔断限流
下一章我们将探讨 动态批处理,讲解如何通过批处理优化提升推理吞吐量。