RAG系统架构
概述
RAG (Retrieval-Augmented Generation) 是当前大模型应用最主流的架构模式,通过检索外部知识库增强 LLM 的回答准确性和时效性。本章深入讲解 RAG 系统的架构设计、核心组件实现与性能优化。
RAG 架构演进
基础架构
┌─────────────────────────────────────────────────────────────────────────────┐
│ RAG 系统架构 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 用户查询 │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Query Processing │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ Query │ │ Query │ │ Query │ │ │
│ │ │ Rewriting │──│ Expansion │──│ Embedding │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ └──────────────────────────────┬──────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Retrieval Layer │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ Vector │ │ Keyword │ │ Knowledge │ │ │
│ │ │ Search │ │ Search │ │ Graph │ │ │
│ │ │ (Dense) │ │ (Sparse) │ │ (Structured)│ │ │
│ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │
│ │ └────────────────┼────────────────┘ │ │
│ │ ▼ │ │
│ │ ┌─────────────┐ │ │
│ │ │ Reranker │ │ │
│ │ └─────────────┘ │ │
│ └──────────────────────────┬──────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Generation Layer │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ Context │ │ Prompt │ │ LLM │ │ │
│ │ │ Compression │──│ Assembly │──│ Generation │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ └──────────────────────────┬──────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 生成回答 │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
Advanced RAG 架构
┌─────────────────────────────────────────────────────────────────────────────┐
│ Advanced RAG Pipeline │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Pre-Retrieval │ │
│ │ │ │
│ │ Query ──▶ [HyDE] ──▶ [Query Decomposition] ──▶ [Query Routing] │ │
│ │ │ │
│ │ HyDE: 生成假设性文档作为查询 │ │
│ │ Query Decomposition: 复杂问题分解为子问题 │ │
│ │ Query Routing: 选择最佳检索策略 │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Retrieval │ │
│ │ │ │
│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ Multi-Index │ │ Recursive │ │ Self-Query │ │ │
│ │ │ Retrieval │ │ Retrieval │ │ Retrieval │ │ │
│ │ └──────────────┘ └──────────────┘ └──────────────┘ │ │
│ │ │ │ │ │ │
│ │ └───────────────────┼───────────────────┘ │ │
│ │ ▼ │ │
│ │ ┌──────────────┐ │ │
│ │ │ Fusion │ │ │
│ │ └──────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Post-Retrieval │ │
│ │ │ │
│ │ Documents ──▶ [Reranking] ──▶ [Filtering] ──▶ [Compression] │ │
│ │ │ │
│ │ Reranking: Cross-encoder 重排序 │ │
│ │ Filtering: 相关性/质量过滤 │ │
│ │ Compression: 长文档压缩/摘要 │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
核心组件实现
文档处理管道
"""
文档处理管道实现
"""
from dataclasses import dataclass
from typing import List, Dict, Optional, Iterator
from abc import ABC, abstractmethod
import hashlib
import re
@dataclass
class Document:
"""文档数据结构"""
id: str
content: str
metadata: Dict
embedding: Optional[List[float]] = None
@staticmethod
def generate_id(content: str) -> str:
return hashlib.md5(content.encode()).hexdigest()
@dataclass
class Chunk:
"""文档块"""
id: str
doc_id: str
content: str
metadata: Dict
embedding: Optional[List[float]] = None
start_index: int = 0
end_index: int = 0
class TextSplitter(ABC):
"""文本分割器基类"""
@abstractmethod
def split(self, text: str) -> List[str]:
pass
class RecursiveCharacterSplitter(TextSplitter):
"""递归字符分割器"""
def __init__(
self,
chunk_size: int = 1000,
chunk_overlap: int = 200,
separators: List[str] = None
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = separators or ["\n\n", "\n", "。", ".", " ", ""]
def split(self, text: str) -> List[str]:
return self._split_text(text, self.separators)
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""递归分割文本"""
final_chunks = []
# 找到合适的分隔符
separator = separators[-1]
new_separators = []
for i, sep in enumerate(separators):
if sep == "":
separator = sep
break
if sep in text:
separator = sep
new_separators = separators[i + 1:]
break
# 按分隔符分割
splits = text.split(separator) if separator else list(text)
# 合并小块
good_splits = []
current_chunk = ""
for split in splits:
if len(current_chunk) + len(split) + len(separator) <= self.chunk_size:
current_chunk += (separator if current_chunk else "") + split
else:
if current_chunk:
good_splits.append(current_chunk)
# 如果单个 split 太大,递归分割
if len(split) > self.chunk_size and new_separators:
good_splits.extend(self._split_text(split, new_separators))
current_chunk = ""
else:
current_chunk = split
if current_chunk:
good_splits.append(current_chunk)
# 添加重叠
final_chunks = self._add_overlap(good_splits)
return final_chunks
def _add_overlap(self, chunks: List[str]) -> List[str]:
"""添加块间重叠"""
if not chunks or self.chunk_overlap == 0:
return chunks
result = [chunks[0]]
for i in range(1, len(chunks)):
# 从前一个块取重叠部分
prev_chunk = chunks[i - 1]
overlap_text = prev_chunk[-self.chunk_overlap:] if len(prev_chunk) > self.chunk_overlap else prev_chunk
result.append(overlap_text + chunks[i])
return result
class SemanticSplitter(TextSplitter):
"""语义分割器 - 基于句子嵌入的相似度分割"""
def __init__(
self,
embedding_model,
threshold: float = 0.5,
min_chunk_size: int = 100,
max_chunk_size: int = 2000
):
self.embedding_model = embedding_model
self.threshold = threshold
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size
def split(self, text: str) -> List[str]:
# 先按句子分割
sentences = self._split_sentences(text)
if len(sentences) <= 1:
return [text]
# 计算句子嵌入
embeddings = self.embedding_model.encode(sentences)
# 基于相似度分组
chunks = []
current_chunk = [sentences[0]]
current_embedding = embeddings[0]
for i in range(1, len(sentences)):
similarity = self._cosine_similarity(current_embedding, embeddings[i])
current_length = sum(len(s) for s in current_chunk)
# 判断是否应该开始新块
if similarity < self.threshold and current_length >= self.min_chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = [sentences[i]]
current_embedding = embeddings[i]
elif current_length >= self.max_chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = [sentences[i]]
current_embedding = embeddings[i]
else:
current_chunk.append(sentences[i])
# 更新平均嵌入
current_embedding = (current_embedding * len(current_chunk) + embeddings[i]) / (len(current_chunk) + 1)
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def _split_sentences(self, text: str) -> List[str]:
"""分割句子"""
# 支持中英文句子分割
pattern = r'(?<=[。!?.!?])\s*'
sentences = re.split(pattern, text)
return [s.strip() for s in sentences if s.strip()]
def _cosine_similarity(self, a, b) -> float:
"""计算余弦相似度"""
import numpy as np
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
class DocumentProcessor:
"""文档处理器"""
def __init__(
self,
splitter: TextSplitter,
embedding_model,
metadata_extractors: List = None
):
self.splitter = splitter
self.embedding_model = embedding_model
self.metadata_extractors = metadata_extractors or []
def process(self, documents: List[Document]) -> List[Chunk]:
"""处理文档列表"""
all_chunks = []
for doc in documents:
chunks = self._process_document(doc)
all_chunks.extend(chunks)
# 批量计算嵌入
if all_chunks:
contents = [c.content for c in all_chunks]
embeddings = self.embedding_model.encode(contents, batch_size=32)
for chunk, embedding in zip(all_chunks, embeddings):
chunk.embedding = embedding.tolist()
return all_chunks
def _process_document(self, doc: Document) -> List[Chunk]:
"""处理单个文档"""
# 分割文本
texts = self.splitter.split(doc.content)
chunks = []
current_index = 0
for i, text in enumerate(texts):
chunk = Chunk(
id=f"{doc.id}_{i}",
doc_id=doc.id,
content=text,
metadata={
**doc.metadata,
"chunk_index": i,
"total_chunks": len(texts)
},
start_index=current_index,
end_index=current_index + len(text)
)
# 提取额外元数据
for extractor in self.metadata_extractors:
chunk.metadata.update(extractor.extract(text))
chunks.append(chunk)
current_index += len(text)
return chunks
向量检索引擎
"""
向量检索引擎实现
"""
from typing import List, Tuple, Optional, Dict, Any
from abc import ABC, abstractmethod
import numpy as np
from dataclasses import dataclass
import threading
from collections import defaultdict
@dataclass
class SearchResult:
"""检索结果"""
chunk_id: str
content: str
score: float
metadata: Dict[str, Any]
class VectorStore(ABC):
"""向量存储基类"""
@abstractmethod
def add(self, ids: List[str], embeddings: List[List[float]],
contents: List[str], metadatas: List[Dict]) -> None:
pass
@abstractmethod
def search(self, query_embedding: List[float], top_k: int = 10,
filters: Dict = None) -> List[SearchResult]:
pass
@abstractmethod
def delete(self, ids: List[str]) -> None:
pass
class FAISSVectorStore(VectorStore):
"""基于 FAISS 的向量存储"""
def __init__(
self,
dimension: int,
index_type: str = "IVF",
nlist: int = 100,
use_gpu: bool = False
):
import faiss
self.dimension = dimension
self.use_gpu = use_gpu
# 创建索引
if index_type == "Flat":
self.index = faiss.IndexFlatIP(dimension) # 内积(需要归一化向量)
elif index_type == "IVF":
quantizer = faiss.IndexFlatIP(dimension)
self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
elif index_type == "HNSW":
self.index = faiss.IndexHNSWFlat(dimension, 32, faiss.METRIC_INNER_PRODUCT)
else:
raise ValueError(f"Unknown index type: {index_type}")
# GPU 加速
if use_gpu:
res = faiss.StandardGpuResources()
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
# 元数据存储
self.id_to_idx: Dict[str, int] = {}
self.idx_to_data: Dict[int, Dict] = {}
self.current_idx = 0
self.lock = threading.Lock()
def add(self, ids: List[str], embeddings: List[List[float]],
contents: List[str], metadatas: List[Dict]) -> None:
"""添加向量"""
import faiss
vectors = np.array(embeddings, dtype=np.float32)
# 归一化
faiss.normalize_L2(vectors)
with self.lock:
# 训练索引(如果需要)
if hasattr(self.index, 'is_trained') and not self.index.is_trained:
self.index.train(vectors)
# 添加向量
self.index.add(vectors)
# 存储元数据
for i, (id_, content, metadata) in enumerate(zip(ids, contents, metadatas)):
idx = self.current_idx + i
self.id_to_idx[id_] = idx
self.idx_to_data[idx] = {
"id": id_,
"content": content,
"metadata": metadata
}
self.current_idx += len(ids)
def search(self, query_embedding: List[float], top_k: int = 10,
filters: Dict = None) -> List[SearchResult]:
"""搜索相似向量"""
import faiss
query_vector = np.array([query_embedding], dtype=np.float32)
faiss.normalize_L2(query_vector)
# 如果有过滤条件,需要检索更多然后过滤
search_k = top_k * 10 if filters else top_k
# 搜索
scores, indices = self.index.search(query_vector, search_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
data = self.idx_to_data.get(idx)
if not data:
continue
# 应用过滤器
if filters and not self._match_filters(data["metadata"], filters):
continue
results.append(SearchResult(
chunk_id=data["id"],
content=data["content"],
score=float(score),
metadata=data["metadata"]
))
if len(results) >= top_k:
break
return results
def delete(self, ids: List[str]) -> None:
"""删除向量(FAISS 不支持真删除,标记删除)"""
with self.lock:
for id_ in ids:
if id_ in self.id_to_idx:
idx = self.id_to_idx.pop(id_)
self.idx_to_data.pop(idx, None)
def _match_filters(self, metadata: Dict, filters: Dict) -> bool:
"""检查元数据是否匹配过滤条件"""
for key, value in filters.items():
if key not in metadata:
return False
if isinstance(value, list):
if metadata[key] not in value:
return False
elif metadata[key] != value:
return False
return True
class HybridRetriever:
"""混合检索器 - 结合向量检索和关键词检索"""
def __init__(
self,
vector_store: VectorStore,
bm25_index,
embedding_model,
vector_weight: float = 0.7
):
self.vector_store = vector_store
self.bm25_index = bm25_index
self.embedding_model = embedding_model
self.vector_weight = vector_weight
self.keyword_weight = 1 - vector_weight
def search(
self,
query: str,
top_k: int = 10,
filters: Dict = None
) -> List[SearchResult]:
"""混合搜索"""
# 向量检索
query_embedding = self.embedding_model.encode(query)
vector_results = self.vector_store.search(
query_embedding.tolist(),
top_k=top_k * 2,
filters=filters
)
# BM25 检索
bm25_results = self.bm25_index.search(query, top_k=top_k * 2)
# 融合结果 (Reciprocal Rank Fusion)
fused_scores = defaultdict(float)
chunk_data = {}
k = 60 # RRF 参数
for rank, result in enumerate(vector_results):
fused_scores[result.chunk_id] += self.vector_weight / (k + rank + 1)
chunk_data[result.chunk_id] = result
for rank, result in enumerate(bm25_results):
fused_scores[result.chunk_id] += self.keyword_weight / (k + rank + 1)
if result.chunk_id not in chunk_data:
chunk_data[result.chunk_id] = result
# 排序
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
results = []
for chunk_id in sorted_ids[:top_k]:
result = chunk_data[chunk_id]
result.score = fused_scores[chunk_id]
results.append(result)
return results
class BM25Index:
"""BM25 索引"""
def __init__(self, k1: float = 1.5, b: float = 0.75):
self.k1 = k1
self.b = b
self.documents: Dict[str, Dict] = {}
self.doc_lengths: Dict[str, int] = {}
self.avg_doc_length = 0
self.idf_scores: Dict[str, float] = {}
self.inverted_index: Dict[str, List[str]] = defaultdict(list)
def add(self, ids: List[str], contents: List[str], metadatas: List[Dict]):
"""添加文档"""
for id_, content, metadata in zip(ids, contents, metadatas):
tokens = self._tokenize(content)
self.documents[id_] = {
"content": content,
"metadata": metadata,
"tokens": tokens
}
self.doc_lengths[id_] = len(tokens)
# 更新倒排索引
for token in set(tokens):
self.inverted_index[token].append(id_)
# 更新统计信息
self.avg_doc_length = sum(self.doc_lengths.values()) / len(self.doc_lengths)
self._compute_idf()
def search(self, query: str, top_k: int = 10) -> List[SearchResult]:
"""BM25 搜索"""
query_tokens = self._tokenize(query)
# 找到候选文档
candidate_docs = set()
for token in query_tokens:
candidate_docs.update(self.inverted_index.get(token, []))
# 计算 BM25 分数
scores = []
for doc_id in candidate_docs:
score = self._compute_bm25(query_tokens, doc_id)
scores.append((doc_id, score))
# 排序
scores.sort(key=lambda x: x[1], reverse=True)
results = []
for doc_id, score in scores[:top_k]:
doc = self.documents[doc_id]
results.append(SearchResult(
chunk_id=doc_id,
content=doc["content"],
score=score,
metadata=doc["metadata"]
))
return results
def _tokenize(self, text: str) -> List[str]:
"""分词"""
import jieba
# 中英文混合分词
text = text.lower()
tokens = list(jieba.cut(text))
return [t for t in tokens if len(t.strip()) > 0]
def _compute_idf(self):
"""计算 IDF"""
import math
N = len(self.documents)
for token, doc_ids in self.inverted_index.items():
df = len(doc_ids)
self.idf_scores[token] = math.log((N - df + 0.5) / (df + 0.5) + 1)
def _compute_bm25(self, query_tokens: List[str], doc_id: str) -> float:
"""计算 BM25 分数"""
doc_tokens = self.documents[doc_id]["tokens"]
doc_length = self.doc_lengths[doc_id]
score = 0.0
token_counts = defaultdict(int)
for token in doc_tokens:
token_counts[token] += 1
for token in query_tokens:
if token not in self.idf_scores:
continue
tf = token_counts.get(token, 0)
idf = self.idf_scores[token]
numerator = tf * (self.k1 + 1)
denominator = tf + self.k1 * (1 - self.b + self.b * doc_length / self.avg_doc_length)
score += idf * numerator / denominator
return score
重排序器
"""
重排序器实现
"""
from typing import List, Tuple
from abc import ABC, abstractmethod
import numpy as np
class Reranker(ABC):
"""重排序器基类"""
@abstractmethod
def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
"""
重排序文档
返回: [(原始索引, 分数), ...]
"""
pass
class CrossEncoderReranker(Reranker):
"""Cross-Encoder 重排序器"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(model_name)
def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
# 构建查询-文档对
pairs = [[query, doc] for doc in documents]
# 计算分数
scores = self.model.predict(pairs)
# 排序
indexed_scores = list(enumerate(scores))
indexed_scores.sort(key=lambda x: x[1], reverse=True)
if top_k:
indexed_scores = indexed_scores[:top_k]
return indexed_scores
class CohereReranker(Reranker):
"""Cohere Rerank API"""
def __init__(self, api_key: str, model: str = "rerank-english-v2.0"):
import cohere
self.client = cohere.Client(api_key)
self.model = model
def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
response = self.client.rerank(
model=self.model,
query=query,
documents=documents,
top_n=top_k or len(documents)
)
results = []
for result in response.results:
results.append((result.index, result.relevance_score))
return results
class LLMReranker(Reranker):
"""使用 LLM 进行重排序"""
def __init__(self, llm_client, batch_size: int = 5):
self.llm = llm_client
self.batch_size = batch_size
def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
prompt_template = """Given the following query and documents, rate the relevance of each document to the query on a scale of 0-10.
Query: {query}
Documents:
{documents}
For each document, provide a relevance score (0-10). Output format:
Document 1: [score]
Document 2: [score]
...
Only output the scores, nothing else."""
all_scores = []
# 批量处理
for i in range(0, len(documents), self.batch_size):
batch = documents[i:i + self.batch_size]
doc_text = "\n".join([f"Document {j+1}: {doc[:500]}..."
for j, doc in enumerate(batch)])
prompt = prompt_template.format(query=query, documents=doc_text)
response = self.llm.complete(prompt)
batch_scores = self._parse_scores(response, len(batch))
for j, score in enumerate(batch_scores):
all_scores.append((i + j, score / 10.0)) # 归一化到 0-1
# 排序
all_scores.sort(key=lambda x: x[1], reverse=True)
if top_k:
all_scores = all_scores[:top_k]
return all_scores
def _parse_scores(self, response: str, expected_count: int) -> List[float]:
"""解析 LLM 返回的分数"""
import re
scores = []
pattern = r'Document\s*\d+:\s*(\d+(?:\.\d+)?)'
matches = re.findall(pattern, response)
for match in matches[:expected_count]:
try:
scores.append(float(match))
except ValueError:
scores.append(5.0) # 默认分数
# 补齐
while len(scores) < expected_count:
scores.append(5.0)
return scores
RAG 主管道
"""
RAG 主管道实现
"""
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
import asyncio
@dataclass
class RAGResponse:
"""RAG 响应"""
answer: str
sources: List[Dict]
query: str
retrieved_chunks: List[SearchResult]
reranked_chunks: List[SearchResult]
metadata: Dict[str, Any]
class RAGPipeline:
"""RAG 管道"""
def __init__(
self,
retriever: HybridRetriever,
reranker: Reranker,
llm,
embedding_model,
config: Dict = None
):
self.retriever = retriever
self.reranker = reranker
self.llm = llm
self.embedding_model = embedding_model
self.config = config or {}
# 配置
self.top_k_retrieval = self.config.get("top_k_retrieval", 20)
self.top_k_rerank = self.config.get("top_k_rerank", 5)
self.max_context_length = self.config.get("max_context_length", 4000)
self.enable_hyde = self.config.get("enable_hyde", False)
self.enable_query_expansion = self.config.get("enable_query_expansion", False)
async def query(
self,
query: str,
filters: Dict = None,
chat_history: List[Dict] = None
) -> RAGResponse:
"""执行 RAG 查询"""
metadata = {"original_query": query}
# 1. 查询预处理
processed_query = await self._preprocess_query(query, chat_history)
metadata["processed_query"] = processed_query
# 2. 检索
retrieved_chunks = await self._retrieve(processed_query, filters)
# 3. 重排序
reranked_chunks = await self._rerank(query, retrieved_chunks)
# 4. 上下文构建
context = self._build_context(reranked_chunks)
# 5. 生成回答
answer = await self._generate(query, context, chat_history)
# 6. 构建响应
sources = [
{
"chunk_id": chunk.chunk_id,
"content": chunk.content[:200] + "...",
"score": chunk.score,
"metadata": chunk.metadata
}
for chunk in reranked_chunks
]
return RAGResponse(
answer=answer,
sources=sources,
query=query,
retrieved_chunks=retrieved_chunks,
reranked_chunks=reranked_chunks,
metadata=metadata
)
async def _preprocess_query(
self,
query: str,
chat_history: List[Dict] = None
) -> str:
"""查询预处理"""
processed_query = query
# HyDE: 生成假设性文档
if self.enable_hyde:
hyde_prompt = f"""Based on the following question, write a short passage that would answer it.
Question: {query}
Passage:"""
hypothetical_doc = await self.llm.complete(hyde_prompt)
processed_query = hypothetical_doc
# 查询扩展
if self.enable_query_expansion:
expansion_prompt = f"""Generate 3 alternative phrasings for the following question.
Output only the questions, one per line.
Original question: {query}
Alternative questions:"""
expansions = await self.llm.complete(expansion_prompt)
# 合并原始查询和扩展
processed_query = query + " " + expansions.replace("\n", " ")
# 处理对话历史(查询改写)
if chat_history:
rewrite_prompt = f"""Given the conversation history and the latest question,
rewrite the question to be self-contained.
Conversation history:
{self._format_history(chat_history)}
Latest question: {query}
Rewritten question:"""
processed_query = await self.llm.complete(rewrite_prompt)
return processed_query
async def _retrieve(
self,
query: str,
filters: Dict = None
) -> List[SearchResult]:
"""检索相关文档"""
results = self.retriever.search(
query=query,
top_k=self.top_k_retrieval,
filters=filters
)
return results
async def _rerank(
self,
query: str,
chunks: List[SearchResult]
) -> List[SearchResult]:
"""重排序"""
if not chunks:
return []
documents = [chunk.content for chunk in chunks]
reranked = self.reranker.rerank(
query=query,
documents=documents,
top_k=self.top_k_rerank
)
# 根据重排序结果重新组织
result = []
for idx, score in reranked:
chunk = chunks[idx]
chunk.score = score
result.append(chunk)
return result
def _build_context(self, chunks: List[SearchResult]) -> str:
"""构建上下文"""
context_parts = []
total_length = 0
for i, chunk in enumerate(chunks):
chunk_text = f"[Source {i+1}]\n{chunk.content}\n"
if total_length + len(chunk_text) > self.max_context_length:
break
context_parts.append(chunk_text)
total_length += len(chunk_text)
return "\n".join(context_parts)
async def _generate(
self,
query: str,
context: str,
chat_history: List[Dict] = None
) -> str:
"""生成回答"""
system_prompt = """You are a helpful assistant. Answer the user's question based on the provided context.
If the context doesn't contain relevant information, say so.
Always cite your sources using [Source N] format."""
user_prompt = f"""Context:
{context}
Question: {query}
Please provide a comprehensive answer based on the context above."""
messages = [{"role": "system", "content": system_prompt}]
if chat_history:
messages.extend(chat_history)
messages.append({"role": "user", "content": user_prompt})
response = await self.llm.chat(messages)
return response
def _format_history(self, history: List[Dict]) -> str:
"""格式化对话历史"""
formatted = []
for msg in history[-6:]: # 只保留最近 6 条
role = msg.get("role", "user")
content = msg.get("content", "")
formatted.append(f"{role}: {content}")
return "\n".join(formatted)
class MultiQueryRAG(RAGPipeline):
"""多查询 RAG - 并行执行多个查询"""
def __init__(self, *args, num_queries: int = 3, **kwargs):
super().__init__(*args, **kwargs)
self.num_queries = num_queries
async def _retrieve(
self,
query: str,
filters: Dict = None
) -> List[SearchResult]:
"""多查询检索"""
# 生成多个查询变体
queries = await self._generate_query_variants(query)
# 并行检索
tasks = [
self._single_retrieve(q, filters)
for q in queries
]
results_list = await asyncio.gather(*tasks)
# 融合结果
return self._fuse_results(results_list)
async def _generate_query_variants(self, query: str) -> List[str]:
"""生成查询变体"""
prompt = f"""Generate {self.num_queries} different versions of the following question.
Each version should capture a different aspect or phrasing.
Output only the questions, one per line.
Original question: {query}
Versions:"""
response = await self.llm.complete(prompt)
variants = [query] # 包含原始查询
for line in response.strip().split("\n"):
line = line.strip()
if line and not line.startswith(str(len(variants))):
variants.append(line)
return variants[:self.num_queries + 1]
async def _single_retrieve(
self,
query: str,
filters: Dict
) -> List[SearchResult]:
return self.retriever.search(
query=query,
top_k=self.top_k_retrieval,
filters=filters
)
def _fuse_results(
self,
results_list: List[List[SearchResult]]
) -> List[SearchResult]:
"""融合多个检索结果 (RRF)"""
from collections import defaultdict
fused_scores = defaultdict(float)
chunk_data = {}
k = 60
for results in results_list:
for rank, result in enumerate(results):
fused_scores[result.chunk_id] += 1.0 / (k + rank + 1)
chunk_data[result.chunk_id] = result
# 排序并返回
sorted_ids = sorted(
fused_scores.keys(),
key=lambda x: fused_scores[x],
reverse=True
)
final_results = []
for chunk_id in sorted_ids:
result = chunk_data[chunk_id]
result.score = fused_scores[chunk_id]
final_results.append(result)
return final_results
RAG 评估体系
评估指标
"""
RAG 评估系统
"""
from typing import List, Dict, Tuple
from dataclasses import dataclass
import numpy as np
@dataclass
class RAGEvalResult:
"""RAG 评估结果"""
retrieval_precision: float
retrieval_recall: float
retrieval_mrr: float
retrieval_ndcg: float
answer_relevance: float
answer_faithfulness: float
answer_completeness: float
overall_score: float
class RAGEvaluator:
"""RAG 评估器"""
def __init__(self, llm_judge=None, embedding_model=None):
self.llm_judge = llm_judge
self.embedding_model = embedding_model
def evaluate(
self,
queries: List[str],
responses: List[RAGResponse],
ground_truths: List[Dict]
) -> RAGEvalResult:
"""评估 RAG 系统"""
retrieval_metrics = self._evaluate_retrieval(responses, ground_truths)
generation_metrics = self._evaluate_generation(queries, responses, ground_truths)
overall = (
retrieval_metrics["mrr"] * 0.3 +
generation_metrics["relevance"] * 0.3 +
generation_metrics["faithfulness"] * 0.4
)
return RAGEvalResult(
retrieval_precision=retrieval_metrics["precision"],
retrieval_recall=retrieval_metrics["recall"],
retrieval_mrr=retrieval_metrics["mrr"],
retrieval_ndcg=retrieval_metrics["ndcg"],
answer_relevance=generation_metrics["relevance"],
answer_faithfulness=generation_metrics["faithfulness"],
answer_completeness=generation_metrics["completeness"],
overall_score=overall
)
def _evaluate_retrieval(
self,
responses: List[RAGResponse],
ground_truths: List[Dict]
) -> Dict[str, float]:
"""评估检索质量"""
precisions = []
recalls = []
mrrs = []
ndcgs = []
for response, gt in zip(responses, ground_truths):
gt_chunk_ids = set(gt.get("relevant_chunk_ids", []))
retrieved_ids = [c.chunk_id for c in response.reranked_chunks]
if not gt_chunk_ids:
continue
# Precision@K
hits = sum(1 for id_ in retrieved_ids if id_ in gt_chunk_ids)
precision = hits / len(retrieved_ids) if retrieved_ids else 0
precisions.append(precision)
# Recall
recall = hits / len(gt_chunk_ids) if gt_chunk_ids else 0
recalls.append(recall)
# MRR
mrr = 0
for i, id_ in enumerate(retrieved_ids):
if id_ in gt_chunk_ids:
mrr = 1.0 / (i + 1)
break
mrrs.append(mrr)
# NDCG
ndcg = self._compute_ndcg(retrieved_ids, gt_chunk_ids)
ndcgs.append(ndcg)
return {
"precision": np.mean(precisions) if precisions else 0,
"recall": np.mean(recalls) if recalls else 0,
"mrr": np.mean(mrrs) if mrrs else 0,
"ndcg": np.mean(ndcgs) if ndcgs else 0
}
def _evaluate_generation(
self,
queries: List[str],
responses: List[RAGResponse],
ground_truths: List[Dict]
) -> Dict[str, float]:
"""评估生成质量"""
relevance_scores = []
faithfulness_scores = []
completeness_scores = []
for query, response, gt in zip(queries, responses, ground_truths):
gt_answer = gt.get("answer", "")
# 使用 LLM 评估
if self.llm_judge:
relevance = self._judge_relevance(query, response.answer)
faithfulness = self._judge_faithfulness(
response.answer,
[c.content for c in response.reranked_chunks]
)
completeness = self._judge_completeness(
query, response.answer, gt_answer
)
else:
# 使用嵌入相似度作为近似
relevance = self._embedding_similarity(query, response.answer)
faithfulness = self._compute_faithfulness_heuristic(
response.answer,
[c.content for c in response.reranked_chunks]
)
completeness = self._embedding_similarity(gt_answer, response.answer) if gt_answer else 0
relevance_scores.append(relevance)
faithfulness_scores.append(faithfulness)
completeness_scores.append(completeness)
return {
"relevance": np.mean(relevance_scores) if relevance_scores else 0,
"faithfulness": np.mean(faithfulness_scores) if faithfulness_scores else 0,
"completeness": np.mean(completeness_scores) if completeness_scores else 0
}
def _compute_ndcg(self, retrieved: List[str], relevant: set, k: int = 10) -> float:
"""计算 NDCG@K"""
dcg = 0
for i, id_ in enumerate(retrieved[:k]):
rel = 1 if id_ in relevant else 0
dcg += rel / np.log2(i + 2)
# 理想 DCG
idcg = sum(1 / np.log2(i + 2) for i in range(min(len(relevant), k)))
return dcg / idcg if idcg > 0 else 0
def _judge_relevance(self, query: str, answer: str) -> float:
"""使用 LLM 判断相关性"""
prompt = f"""Rate the relevance of the following answer to the question on a scale of 0-10.
Question: {query}
Answer: {answer}
Score (0-10):"""
response = self.llm_judge.complete(prompt)
try:
score = float(response.strip()) / 10.0
return min(max(score, 0), 1)
except:
return 0.5
def _judge_faithfulness(self, answer: str, contexts: List[str]) -> float:
"""评估答案忠实度 - 是否基于检索内容"""
context_text = "\n".join(contexts)
prompt = f"""Evaluate if the answer is faithful to the provided context.
Rate on a scale of 0-10, where:
- 0: Answer contains information not in the context (hallucination)
- 10: Answer is completely faithful to the context
Context:
{context_text[:3000]}
Answer: {answer}
Faithfulness Score (0-10):"""
response = self.llm_judge.complete(prompt)
try:
score = float(response.strip()) / 10.0
return min(max(score, 0), 1)
except:
return 0.5
def _judge_completeness(self, query: str, answer: str, gt_answer: str) -> float:
"""评估答案完整性"""
prompt = f"""Compare the given answer with the reference answer and rate completeness on 0-10.
Question: {query}
Given Answer: {answer}
Reference Answer: {gt_answer}
Completeness Score (0-10):"""
response = self.llm_judge.complete(prompt)
try:
score = float(response.strip()) / 10.0
return min(max(score, 0), 1)
except:
return 0.5
def _embedding_similarity(self, text1: str, text2: str) -> float:
"""计算嵌入相似度"""
if not self.embedding_model:
return 0.5
emb1 = self.embedding_model.encode(text1)
emb2 = self.embedding_model.encode(text2)
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
return float(similarity)
def _compute_faithfulness_heuristic(self, answer: str, contexts: List[str]) -> float:
"""启发式计算忠实度"""
# 检查答案中的句子是否能在上下文中找到支持
answer_sentences = answer.split(". ")
context_text = " ".join(contexts).lower()
supported = 0
for sentence in answer_sentences:
# 检查关键词重叠
words = set(sentence.lower().split())
context_words = set(context_text.split())
overlap = len(words & context_words) / len(words) if words else 0
if overlap > 0.5:
supported += 1
return supported / len(answer_sentences) if answer_sentences else 0
生产部署最佳实践
系统架构
# rag-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: rag-service
spec:
replicas: 3
selector:
matchLabels:
app: rag-service
template:
metadata:
labels:
app: rag-service
spec:
containers:
- name: rag-api
image: rag-service:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "4Gi"
cpu: "2"
limits:
memory: "8Gi"
cpu: "4"
env:
- name: VECTOR_DB_HOST
value: "milvus-service:19530"
- name: LLM_ENDPOINT
valueFrom:
secretKeyRef:
name: llm-credentials
key: endpoint
- name: EMBEDDING_MODEL
value: "sentence-transformers/all-MiniLM-L6-v2"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: rag-service
spec:
selector:
app: rag-service
ports:
- port: 80
targetPort: 8000
type: ClusterIP
---
# 向量数据库
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: milvus
spec:
serviceName: milvus
replicas: 1
selector:
matchLabels:
app: milvus
template:
metadata:
labels:
app: milvus
spec:
containers:
- name: milvus
image: milvusdb/milvus:v2.3.0
ports:
- containerPort: 19530
volumeMounts:
- name: data
mountPath: /var/lib/milvus
resources:
requests:
memory: "8Gi"
cpu: "4"
nvidia.com/gpu: 1
limits:
memory: "16Gi"
nvidia.com/gpu: 1
volumeClaimTemplates:
- metadata:
name: data
spec:
accessModes: ["ReadWriteOnce"]
resources:
requests:
storage: 100Gi
性能优化配置
"""
RAG 性能优化配置
"""
# 缓存配置
CACHE_CONFIG = {
# 查询结果缓存
"query_cache": {
"enabled": True,
"backend": "redis",
"ttl": 3600, # 1小时
"max_size": 10000
},
# 嵌入缓存
"embedding_cache": {
"enabled": True,
"backend": "redis",
"ttl": 86400, # 1天
},
# LLM 响应缓存
"llm_cache": {
"enabled": True,
"backend": "redis",
"ttl": 1800, # 30分钟
"similarity_threshold": 0.95 # 相似查询复用
}
}
# 批处理配置
BATCH_CONFIG = {
"embedding_batch_size": 64,
"rerank_batch_size": 32,
"index_batch_size": 1000
}
# 并发配置
CONCURRENCY_CONFIG = {
"max_concurrent_queries": 100,
"retrieval_timeout": 5.0,
"llm_timeout": 30.0,
"total_timeout": 60.0
}
# 降级策略
FALLBACK_CONFIG = {
"on_retrieval_failure": "use_llm_only",
"on_rerank_failure": "skip_rerank",
"on_llm_failure": "return_top_chunks",
"max_retries": 3
}
小结
本章深入讲解了 RAG 系统的核心架构和实现:
- 架构设计:从基础 RAG 到 Advanced RAG 的演进
- 文档处理:多种分割策略(递归、语义)
- 检索引擎:向量检索、混合检索、BM25
- 重排序:Cross-Encoder、LLM Judge
- 评估体系:检索指标、生成质量评估
- 生产部署:Kubernetes 部署、性能优化
下一章我们将探讨 Agent 架构,讲解如何构建具备规划和工具调用能力的智能代理。