HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

RAG系统架构

概述

RAG (Retrieval-Augmented Generation) 是当前大模型应用最主流的架构模式,通过检索外部知识库增强 LLM 的回答准确性和时效性。本章深入讲解 RAG 系统的架构设计、核心组件实现与性能优化。

RAG 架构演进

基础架构

┌─────────────────────────────────────────────────────────────────────────────┐
│                           RAG 系统架构                                        │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│   用户查询                                                                   │
│      │                                                                      │
│      ▼                                                                      │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                        Query Processing                              │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐                  │   │
│  │  │ Query       │  │ Query       │  │ Query       │                  │   │
│  │  │ Rewriting   │──│ Expansion   │──│ Embedding   │                  │   │
│  │  └─────────────┘  └─────────────┘  └─────────────┘                  │   │
│  └──────────────────────────────┬──────────────────────────────────────┘   │
│                                 │                                          │
│                                 ▼                                          │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                         Retrieval Layer                              │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐                  │   │
│  │  │ Vector      │  │ Keyword     │  │ Knowledge   │                  │   │
│  │  │ Search      │  │ Search      │  │ Graph       │                  │   │
│  │  │ (Dense)     │  │ (Sparse)    │  │ (Structured)│                  │   │
│  │  └──────┬──────┘  └──────┬──────┘  └──────┬──────┘                  │   │
│  │         └────────────────┼────────────────┘                         │   │
│  │                          ▼                                          │   │
│  │                  ┌─────────────┐                                    │   │
│  │                  │ Reranker    │                                    │   │
│  │                  └─────────────┘                                    │   │
│  └──────────────────────────┬──────────────────────────────────────────┘   │
│                             │                                              │
│                             ▼                                              │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                      Generation Layer                                │   │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐                  │   │
│  │  │ Context     │  │ Prompt      │  │ LLM         │                  │   │
│  │  │ Compression │──│ Assembly    │──│ Generation  │                  │   │
│  │  └─────────────┘  └─────────────┘  └─────────────┘                  │   │
│  └──────────────────────────┬──────────────────────────────────────────┘   │
│                             │                                              │
│                             ▼                                              │
│                        生成回答                                             │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Advanced RAG 架构

┌─────────────────────────────────────────────────────────────────────────────┐
│                        Advanced RAG Pipeline                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                      Pre-Retrieval                                   │   │
│  │                                                                      │   │
│  │   Query ──▶ [HyDE] ──▶ [Query Decomposition] ──▶ [Query Routing]    │   │
│  │                                                                      │   │
│  │   HyDE: 生成假设性文档作为查询                                         │   │
│  │   Query Decomposition: 复杂问题分解为子问题                            │   │
│  │   Query Routing: 选择最佳检索策略                                      │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
│                                 │                                          │
│                                 ▼                                          │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                      Retrieval                                       │   │
│  │                                                                      │   │
│  │   ┌──────────────┐    ┌──────────────┐    ┌──────────────┐          │   │
│  │   │ Multi-Index  │    │ Recursive    │    │ Self-Query   │          │   │
│  │   │ Retrieval    │    │ Retrieval    │    │ Retrieval    │          │   │
│  │   └──────────────┘    └──────────────┘    └──────────────┘          │   │
│  │           │                   │                   │                  │   │
│  │           └───────────────────┼───────────────────┘                  │   │
│  │                               ▼                                      │   │
│  │                    ┌──────────────┐                                  │   │
│  │                    │   Fusion     │                                  │   │
│  │                    └──────────────┘                                  │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
│                                 │                                          │
│                                 ▼                                          │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │                      Post-Retrieval                                  │   │
│  │                                                                      │   │
│  │   Documents ──▶ [Reranking] ──▶ [Filtering] ──▶ [Compression]       │   │
│  │                                                                      │   │
│  │   Reranking: Cross-encoder 重排序                                    │   │
│  │   Filtering: 相关性/质量过滤                                          │   │
│  │   Compression: 长文档压缩/摘要                                        │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

核心组件实现

文档处理管道

"""
文档处理管道实现
"""
from dataclasses import dataclass
from typing import List, Dict, Optional, Iterator
from abc import ABC, abstractmethod
import hashlib
import re

@dataclass
class Document:
    """文档数据结构"""
    id: str
    content: str
    metadata: Dict
    embedding: Optional[List[float]] = None

    @staticmethod
    def generate_id(content: str) -> str:
        return hashlib.md5(content.encode()).hexdigest()

@dataclass
class Chunk:
    """文档块"""
    id: str
    doc_id: str
    content: str
    metadata: Dict
    embedding: Optional[List[float]] = None
    start_index: int = 0
    end_index: int = 0


class TextSplitter(ABC):
    """文本分割器基类"""

    @abstractmethod
    def split(self, text: str) -> List[str]:
        pass


class RecursiveCharacterSplitter(TextSplitter):
    """递归字符分割器"""

    def __init__(
        self,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        separators: List[str] = None
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators or ["\n\n", "\n", "。", ".", " ", ""]

    def split(self, text: str) -> List[str]:
        return self._split_text(text, self.separators)

    def _split_text(self, text: str, separators: List[str]) -> List[str]:
        """递归分割文本"""
        final_chunks = []

        # 找到合适的分隔符
        separator = separators[-1]
        new_separators = []

        for i, sep in enumerate(separators):
            if sep == "":
                separator = sep
                break
            if sep in text:
                separator = sep
                new_separators = separators[i + 1:]
                break

        # 按分隔符分割
        splits = text.split(separator) if separator else list(text)

        # 合并小块
        good_splits = []
        current_chunk = ""

        for split in splits:
            if len(current_chunk) + len(split) + len(separator) <= self.chunk_size:
                current_chunk += (separator if current_chunk else "") + split
            else:
                if current_chunk:
                    good_splits.append(current_chunk)

                # 如果单个 split 太大,递归分割
                if len(split) > self.chunk_size and new_separators:
                    good_splits.extend(self._split_text(split, new_separators))
                    current_chunk = ""
                else:
                    current_chunk = split

        if current_chunk:
            good_splits.append(current_chunk)

        # 添加重叠
        final_chunks = self._add_overlap(good_splits)

        return final_chunks

    def _add_overlap(self, chunks: List[str]) -> List[str]:
        """添加块间重叠"""
        if not chunks or self.chunk_overlap == 0:
            return chunks

        result = [chunks[0]]

        for i in range(1, len(chunks)):
            # 从前一个块取重叠部分
            prev_chunk = chunks[i - 1]
            overlap_text = prev_chunk[-self.chunk_overlap:] if len(prev_chunk) > self.chunk_overlap else prev_chunk

            result.append(overlap_text + chunks[i])

        return result


class SemanticSplitter(TextSplitter):
    """语义分割器 - 基于句子嵌入的相似度分割"""

    def __init__(
        self,
        embedding_model,
        threshold: float = 0.5,
        min_chunk_size: int = 100,
        max_chunk_size: int = 2000
    ):
        self.embedding_model = embedding_model
        self.threshold = threshold
        self.min_chunk_size = min_chunk_size
        self.max_chunk_size = max_chunk_size

    def split(self, text: str) -> List[str]:
        # 先按句子分割
        sentences = self._split_sentences(text)

        if len(sentences) <= 1:
            return [text]

        # 计算句子嵌入
        embeddings = self.embedding_model.encode(sentences)

        # 基于相似度分组
        chunks = []
        current_chunk = [sentences[0]]
        current_embedding = embeddings[0]

        for i in range(1, len(sentences)):
            similarity = self._cosine_similarity(current_embedding, embeddings[i])
            current_length = sum(len(s) for s in current_chunk)

            # 判断是否应该开始新块
            if similarity < self.threshold and current_length >= self.min_chunk_size:
                chunks.append(" ".join(current_chunk))
                current_chunk = [sentences[i]]
                current_embedding = embeddings[i]
            elif current_length >= self.max_chunk_size:
                chunks.append(" ".join(current_chunk))
                current_chunk = [sentences[i]]
                current_embedding = embeddings[i]
            else:
                current_chunk.append(sentences[i])
                # 更新平均嵌入
                current_embedding = (current_embedding * len(current_chunk) + embeddings[i]) / (len(current_chunk) + 1)

        if current_chunk:
            chunks.append(" ".join(current_chunk))

        return chunks

    def _split_sentences(self, text: str) -> List[str]:
        """分割句子"""
        # 支持中英文句子分割
        pattern = r'(?<=[。!?.!?])\s*'
        sentences = re.split(pattern, text)
        return [s.strip() for s in sentences if s.strip()]

    def _cosine_similarity(self, a, b) -> float:
        """计算余弦相似度"""
        import numpy as np
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


class DocumentProcessor:
    """文档处理器"""

    def __init__(
        self,
        splitter: TextSplitter,
        embedding_model,
        metadata_extractors: List = None
    ):
        self.splitter = splitter
        self.embedding_model = embedding_model
        self.metadata_extractors = metadata_extractors or []

    def process(self, documents: List[Document]) -> List[Chunk]:
        """处理文档列表"""
        all_chunks = []

        for doc in documents:
            chunks = self._process_document(doc)
            all_chunks.extend(chunks)

        # 批量计算嵌入
        if all_chunks:
            contents = [c.content for c in all_chunks]
            embeddings = self.embedding_model.encode(contents, batch_size=32)

            for chunk, embedding in zip(all_chunks, embeddings):
                chunk.embedding = embedding.tolist()

        return all_chunks

    def _process_document(self, doc: Document) -> List[Chunk]:
        """处理单个文档"""
        # 分割文本
        texts = self.splitter.split(doc.content)

        chunks = []
        current_index = 0

        for i, text in enumerate(texts):
            chunk = Chunk(
                id=f"{doc.id}_{i}",
                doc_id=doc.id,
                content=text,
                metadata={
                    **doc.metadata,
                    "chunk_index": i,
                    "total_chunks": len(texts)
                },
                start_index=current_index,
                end_index=current_index + len(text)
            )

            # 提取额外元数据
            for extractor in self.metadata_extractors:
                chunk.metadata.update(extractor.extract(text))

            chunks.append(chunk)
            current_index += len(text)

        return chunks

向量检索引擎

"""
向量检索引擎实现
"""
from typing import List, Tuple, Optional, Dict, Any
from abc import ABC, abstractmethod
import numpy as np
from dataclasses import dataclass
import threading
from collections import defaultdict

@dataclass
class SearchResult:
    """检索结果"""
    chunk_id: str
    content: str
    score: float
    metadata: Dict[str, Any]


class VectorStore(ABC):
    """向量存储基类"""

    @abstractmethod
    def add(self, ids: List[str], embeddings: List[List[float]],
            contents: List[str], metadatas: List[Dict]) -> None:
        pass

    @abstractmethod
    def search(self, query_embedding: List[float], top_k: int = 10,
               filters: Dict = None) -> List[SearchResult]:
        pass

    @abstractmethod
    def delete(self, ids: List[str]) -> None:
        pass


class FAISSVectorStore(VectorStore):
    """基于 FAISS 的向量存储"""

    def __init__(
        self,
        dimension: int,
        index_type: str = "IVF",
        nlist: int = 100,
        use_gpu: bool = False
    ):
        import faiss

        self.dimension = dimension
        self.use_gpu = use_gpu

        # 创建索引
        if index_type == "Flat":
            self.index = faiss.IndexFlatIP(dimension)  # 内积(需要归一化向量)
        elif index_type == "IVF":
            quantizer = faiss.IndexFlatIP(dimension)
            self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
        elif index_type == "HNSW":
            self.index = faiss.IndexHNSWFlat(dimension, 32, faiss.METRIC_INNER_PRODUCT)
        else:
            raise ValueError(f"Unknown index type: {index_type}")

        # GPU 加速
        if use_gpu:
            res = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(res, 0, self.index)

        # 元数据存储
        self.id_to_idx: Dict[str, int] = {}
        self.idx_to_data: Dict[int, Dict] = {}
        self.current_idx = 0
        self.lock = threading.Lock()

    def add(self, ids: List[str], embeddings: List[List[float]],
            contents: List[str], metadatas: List[Dict]) -> None:
        """添加向量"""
        import faiss

        vectors = np.array(embeddings, dtype=np.float32)

        # 归一化
        faiss.normalize_L2(vectors)

        with self.lock:
            # 训练索引(如果需要)
            if hasattr(self.index, 'is_trained') and not self.index.is_trained:
                self.index.train(vectors)

            # 添加向量
            self.index.add(vectors)

            # 存储元数据
            for i, (id_, content, metadata) in enumerate(zip(ids, contents, metadatas)):
                idx = self.current_idx + i
                self.id_to_idx[id_] = idx
                self.idx_to_data[idx] = {
                    "id": id_,
                    "content": content,
                    "metadata": metadata
                }

            self.current_idx += len(ids)

    def search(self, query_embedding: List[float], top_k: int = 10,
               filters: Dict = None) -> List[SearchResult]:
        """搜索相似向量"""
        import faiss

        query_vector = np.array([query_embedding], dtype=np.float32)
        faiss.normalize_L2(query_vector)

        # 如果有过滤条件,需要检索更多然后过滤
        search_k = top_k * 10 if filters else top_k

        # 搜索
        scores, indices = self.index.search(query_vector, search_k)

        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx == -1:
                continue

            data = self.idx_to_data.get(idx)
            if not data:
                continue

            # 应用过滤器
            if filters and not self._match_filters(data["metadata"], filters):
                continue

            results.append(SearchResult(
                chunk_id=data["id"],
                content=data["content"],
                score=float(score),
                metadata=data["metadata"]
            ))

            if len(results) >= top_k:
                break

        return results

    def delete(self, ids: List[str]) -> None:
        """删除向量(FAISS 不支持真删除,标记删除)"""
        with self.lock:
            for id_ in ids:
                if id_ in self.id_to_idx:
                    idx = self.id_to_idx.pop(id_)
                    self.idx_to_data.pop(idx, None)

    def _match_filters(self, metadata: Dict, filters: Dict) -> bool:
        """检查元数据是否匹配过滤条件"""
        for key, value in filters.items():
            if key not in metadata:
                return False
            if isinstance(value, list):
                if metadata[key] not in value:
                    return False
            elif metadata[key] != value:
                return False
        return True


class HybridRetriever:
    """混合检索器 - 结合向量检索和关键词检索"""

    def __init__(
        self,
        vector_store: VectorStore,
        bm25_index,
        embedding_model,
        vector_weight: float = 0.7
    ):
        self.vector_store = vector_store
        self.bm25_index = bm25_index
        self.embedding_model = embedding_model
        self.vector_weight = vector_weight
        self.keyword_weight = 1 - vector_weight

    def search(
        self,
        query: str,
        top_k: int = 10,
        filters: Dict = None
    ) -> List[SearchResult]:
        """混合搜索"""
        # 向量检索
        query_embedding = self.embedding_model.encode(query)
        vector_results = self.vector_store.search(
            query_embedding.tolist(),
            top_k=top_k * 2,
            filters=filters
        )

        # BM25 检索
        bm25_results = self.bm25_index.search(query, top_k=top_k * 2)

        # 融合结果 (Reciprocal Rank Fusion)
        fused_scores = defaultdict(float)
        chunk_data = {}

        k = 60  # RRF 参数

        for rank, result in enumerate(vector_results):
            fused_scores[result.chunk_id] += self.vector_weight / (k + rank + 1)
            chunk_data[result.chunk_id] = result

        for rank, result in enumerate(bm25_results):
            fused_scores[result.chunk_id] += self.keyword_weight / (k + rank + 1)
            if result.chunk_id not in chunk_data:
                chunk_data[result.chunk_id] = result

        # 排序
        sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)

        results = []
        for chunk_id in sorted_ids[:top_k]:
            result = chunk_data[chunk_id]
            result.score = fused_scores[chunk_id]
            results.append(result)

        return results


class BM25Index:
    """BM25 索引"""

    def __init__(self, k1: float = 1.5, b: float = 0.75):
        self.k1 = k1
        self.b = b
        self.documents: Dict[str, Dict] = {}
        self.doc_lengths: Dict[str, int] = {}
        self.avg_doc_length = 0
        self.idf_scores: Dict[str, float] = {}
        self.inverted_index: Dict[str, List[str]] = defaultdict(list)

    def add(self, ids: List[str], contents: List[str], metadatas: List[Dict]):
        """添加文档"""
        for id_, content, metadata in zip(ids, contents, metadatas):
            tokens = self._tokenize(content)
            self.documents[id_] = {
                "content": content,
                "metadata": metadata,
                "tokens": tokens
            }
            self.doc_lengths[id_] = len(tokens)

            # 更新倒排索引
            for token in set(tokens):
                self.inverted_index[token].append(id_)

        # 更新统计信息
        self.avg_doc_length = sum(self.doc_lengths.values()) / len(self.doc_lengths)
        self._compute_idf()

    def search(self, query: str, top_k: int = 10) -> List[SearchResult]:
        """BM25 搜索"""
        query_tokens = self._tokenize(query)

        # 找到候选文档
        candidate_docs = set()
        for token in query_tokens:
            candidate_docs.update(self.inverted_index.get(token, []))

        # 计算 BM25 分数
        scores = []
        for doc_id in candidate_docs:
            score = self._compute_bm25(query_tokens, doc_id)
            scores.append((doc_id, score))

        # 排序
        scores.sort(key=lambda x: x[1], reverse=True)

        results = []
        for doc_id, score in scores[:top_k]:
            doc = self.documents[doc_id]
            results.append(SearchResult(
                chunk_id=doc_id,
                content=doc["content"],
                score=score,
                metadata=doc["metadata"]
            ))

        return results

    def _tokenize(self, text: str) -> List[str]:
        """分词"""
        import jieba
        # 中英文混合分词
        text = text.lower()
        tokens = list(jieba.cut(text))
        return [t for t in tokens if len(t.strip()) > 0]

    def _compute_idf(self):
        """计算 IDF"""
        import math
        N = len(self.documents)

        for token, doc_ids in self.inverted_index.items():
            df = len(doc_ids)
            self.idf_scores[token] = math.log((N - df + 0.5) / (df + 0.5) + 1)

    def _compute_bm25(self, query_tokens: List[str], doc_id: str) -> float:
        """计算 BM25 分数"""
        doc_tokens = self.documents[doc_id]["tokens"]
        doc_length = self.doc_lengths[doc_id]

        score = 0.0
        token_counts = defaultdict(int)
        for token in doc_tokens:
            token_counts[token] += 1

        for token in query_tokens:
            if token not in self.idf_scores:
                continue

            tf = token_counts.get(token, 0)
            idf = self.idf_scores[token]

            numerator = tf * (self.k1 + 1)
            denominator = tf + self.k1 * (1 - self.b + self.b * doc_length / self.avg_doc_length)

            score += idf * numerator / denominator

        return score

重排序器

"""
重排序器实现
"""
from typing import List, Tuple
from abc import ABC, abstractmethod
import numpy as np

class Reranker(ABC):
    """重排序器基类"""

    @abstractmethod
    def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
        """
        重排序文档
        返回: [(原始索引, 分数), ...]
        """
        pass


class CrossEncoderReranker(Reranker):
    """Cross-Encoder 重排序器"""

    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        from sentence_transformers import CrossEncoder
        self.model = CrossEncoder(model_name)

    def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
        # 构建查询-文档对
        pairs = [[query, doc] for doc in documents]

        # 计算分数
        scores = self.model.predict(pairs)

        # 排序
        indexed_scores = list(enumerate(scores))
        indexed_scores.sort(key=lambda x: x[1], reverse=True)

        if top_k:
            indexed_scores = indexed_scores[:top_k]

        return indexed_scores


class CohereReranker(Reranker):
    """Cohere Rerank API"""

    def __init__(self, api_key: str, model: str = "rerank-english-v2.0"):
        import cohere
        self.client = cohere.Client(api_key)
        self.model = model

    def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
        response = self.client.rerank(
            model=self.model,
            query=query,
            documents=documents,
            top_n=top_k or len(documents)
        )

        results = []
        for result in response.results:
            results.append((result.index, result.relevance_score))

        return results


class LLMReranker(Reranker):
    """使用 LLM 进行重排序"""

    def __init__(self, llm_client, batch_size: int = 5):
        self.llm = llm_client
        self.batch_size = batch_size

    def rerank(self, query: str, documents: List[str], top_k: int = None) -> List[Tuple[int, float]]:
        prompt_template = """Given the following query and documents, rate the relevance of each document to the query on a scale of 0-10.

Query: {query}

Documents:
{documents}

For each document, provide a relevance score (0-10). Output format:
Document 1: [score]
Document 2: [score]
...

Only output the scores, nothing else."""

        all_scores = []

        # 批量处理
        for i in range(0, len(documents), self.batch_size):
            batch = documents[i:i + self.batch_size]

            doc_text = "\n".join([f"Document {j+1}: {doc[:500]}..."
                                  for j, doc in enumerate(batch)])

            prompt = prompt_template.format(query=query, documents=doc_text)

            response = self.llm.complete(prompt)
            batch_scores = self._parse_scores(response, len(batch))

            for j, score in enumerate(batch_scores):
                all_scores.append((i + j, score / 10.0))  # 归一化到 0-1

        # 排序
        all_scores.sort(key=lambda x: x[1], reverse=True)

        if top_k:
            all_scores = all_scores[:top_k]

        return all_scores

    def _parse_scores(self, response: str, expected_count: int) -> List[float]:
        """解析 LLM 返回的分数"""
        import re
        scores = []

        pattern = r'Document\s*\d+:\s*(\d+(?:\.\d+)?)'
        matches = re.findall(pattern, response)

        for match in matches[:expected_count]:
            try:
                scores.append(float(match))
            except ValueError:
                scores.append(5.0)  # 默认分数

        # 补齐
        while len(scores) < expected_count:
            scores.append(5.0)

        return scores

RAG 主管道

"""
RAG 主管道实现
"""
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
import asyncio

@dataclass
class RAGResponse:
    """RAG 响应"""
    answer: str
    sources: List[Dict]
    query: str
    retrieved_chunks: List[SearchResult]
    reranked_chunks: List[SearchResult]
    metadata: Dict[str, Any]


class RAGPipeline:
    """RAG 管道"""

    def __init__(
        self,
        retriever: HybridRetriever,
        reranker: Reranker,
        llm,
        embedding_model,
        config: Dict = None
    ):
        self.retriever = retriever
        self.reranker = reranker
        self.llm = llm
        self.embedding_model = embedding_model
        self.config = config or {}

        # 配置
        self.top_k_retrieval = self.config.get("top_k_retrieval", 20)
        self.top_k_rerank = self.config.get("top_k_rerank", 5)
        self.max_context_length = self.config.get("max_context_length", 4000)
        self.enable_hyde = self.config.get("enable_hyde", False)
        self.enable_query_expansion = self.config.get("enable_query_expansion", False)

    async def query(
        self,
        query: str,
        filters: Dict = None,
        chat_history: List[Dict] = None
    ) -> RAGResponse:
        """执行 RAG 查询"""
        metadata = {"original_query": query}

        # 1. 查询预处理
        processed_query = await self._preprocess_query(query, chat_history)
        metadata["processed_query"] = processed_query

        # 2. 检索
        retrieved_chunks = await self._retrieve(processed_query, filters)

        # 3. 重排序
        reranked_chunks = await self._rerank(query, retrieved_chunks)

        # 4. 上下文构建
        context = self._build_context(reranked_chunks)

        # 5. 生成回答
        answer = await self._generate(query, context, chat_history)

        # 6. 构建响应
        sources = [
            {
                "chunk_id": chunk.chunk_id,
                "content": chunk.content[:200] + "...",
                "score": chunk.score,
                "metadata": chunk.metadata
            }
            for chunk in reranked_chunks
        ]

        return RAGResponse(
            answer=answer,
            sources=sources,
            query=query,
            retrieved_chunks=retrieved_chunks,
            reranked_chunks=reranked_chunks,
            metadata=metadata
        )

    async def _preprocess_query(
        self,
        query: str,
        chat_history: List[Dict] = None
    ) -> str:
        """查询预处理"""
        processed_query = query

        # HyDE: 生成假设性文档
        if self.enable_hyde:
            hyde_prompt = f"""Based on the following question, write a short passage that would answer it.

Question: {query}

Passage:"""
            hypothetical_doc = await self.llm.complete(hyde_prompt)
            processed_query = hypothetical_doc

        # 查询扩展
        if self.enable_query_expansion:
            expansion_prompt = f"""Generate 3 alternative phrasings for the following question.
Output only the questions, one per line.

Original question: {query}

Alternative questions:"""
            expansions = await self.llm.complete(expansion_prompt)
            # 合并原始查询和扩展
            processed_query = query + " " + expansions.replace("\n", " ")

        # 处理对话历史(查询改写)
        if chat_history:
            rewrite_prompt = f"""Given the conversation history and the latest question,
rewrite the question to be self-contained.

Conversation history:
{self._format_history(chat_history)}

Latest question: {query}

Rewritten question:"""
            processed_query = await self.llm.complete(rewrite_prompt)

        return processed_query

    async def _retrieve(
        self,
        query: str,
        filters: Dict = None
    ) -> List[SearchResult]:
        """检索相关文档"""
        results = self.retriever.search(
            query=query,
            top_k=self.top_k_retrieval,
            filters=filters
        )
        return results

    async def _rerank(
        self,
        query: str,
        chunks: List[SearchResult]
    ) -> List[SearchResult]:
        """重排序"""
        if not chunks:
            return []

        documents = [chunk.content for chunk in chunks]
        reranked = self.reranker.rerank(
            query=query,
            documents=documents,
            top_k=self.top_k_rerank
        )

        # 根据重排序结果重新组织
        result = []
        for idx, score in reranked:
            chunk = chunks[idx]
            chunk.score = score
            result.append(chunk)

        return result

    def _build_context(self, chunks: List[SearchResult]) -> str:
        """构建上下文"""
        context_parts = []
        total_length = 0

        for i, chunk in enumerate(chunks):
            chunk_text = f"[Source {i+1}]\n{chunk.content}\n"

            if total_length + len(chunk_text) > self.max_context_length:
                break

            context_parts.append(chunk_text)
            total_length += len(chunk_text)

        return "\n".join(context_parts)

    async def _generate(
        self,
        query: str,
        context: str,
        chat_history: List[Dict] = None
    ) -> str:
        """生成回答"""
        system_prompt = """You are a helpful assistant. Answer the user's question based on the provided context.
If the context doesn't contain relevant information, say so.
Always cite your sources using [Source N] format."""

        user_prompt = f"""Context:
{context}

Question: {query}

Please provide a comprehensive answer based on the context above."""

        messages = [{"role": "system", "content": system_prompt}]

        if chat_history:
            messages.extend(chat_history)

        messages.append({"role": "user", "content": user_prompt})

        response = await self.llm.chat(messages)

        return response

    def _format_history(self, history: List[Dict]) -> str:
        """格式化对话历史"""
        formatted = []
        for msg in history[-6:]:  # 只保留最近 6 条
            role = msg.get("role", "user")
            content = msg.get("content", "")
            formatted.append(f"{role}: {content}")
        return "\n".join(formatted)


class MultiQueryRAG(RAGPipeline):
    """多查询 RAG - 并行执行多个查询"""

    def __init__(self, *args, num_queries: int = 3, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_queries = num_queries

    async def _retrieve(
        self,
        query: str,
        filters: Dict = None
    ) -> List[SearchResult]:
        """多查询检索"""
        # 生成多个查询变体
        queries = await self._generate_query_variants(query)

        # 并行检索
        tasks = [
            self._single_retrieve(q, filters)
            for q in queries
        ]
        results_list = await asyncio.gather(*tasks)

        # 融合结果
        return self._fuse_results(results_list)

    async def _generate_query_variants(self, query: str) -> List[str]:
        """生成查询变体"""
        prompt = f"""Generate {self.num_queries} different versions of the following question.
Each version should capture a different aspect or phrasing.
Output only the questions, one per line.

Original question: {query}

Versions:"""

        response = await self.llm.complete(prompt)
        variants = [query]  # 包含原始查询

        for line in response.strip().split("\n"):
            line = line.strip()
            if line and not line.startswith(str(len(variants))):
                variants.append(line)

        return variants[:self.num_queries + 1]

    async def _single_retrieve(
        self,
        query: str,
        filters: Dict
    ) -> List[SearchResult]:
        return self.retriever.search(
            query=query,
            top_k=self.top_k_retrieval,
            filters=filters
        )

    def _fuse_results(
        self,
        results_list: List[List[SearchResult]]
    ) -> List[SearchResult]:
        """融合多个检索结果 (RRF)"""
        from collections import defaultdict

        fused_scores = defaultdict(float)
        chunk_data = {}
        k = 60

        for results in results_list:
            for rank, result in enumerate(results):
                fused_scores[result.chunk_id] += 1.0 / (k + rank + 1)
                chunk_data[result.chunk_id] = result

        # 排序并返回
        sorted_ids = sorted(
            fused_scores.keys(),
            key=lambda x: fused_scores[x],
            reverse=True
        )

        final_results = []
        for chunk_id in sorted_ids:
            result = chunk_data[chunk_id]
            result.score = fused_scores[chunk_id]
            final_results.append(result)

        return final_results

RAG 评估体系

评估指标

"""
RAG 评估系统
"""
from typing import List, Dict, Tuple
from dataclasses import dataclass
import numpy as np

@dataclass
class RAGEvalResult:
    """RAG 评估结果"""
    retrieval_precision: float
    retrieval_recall: float
    retrieval_mrr: float
    retrieval_ndcg: float
    answer_relevance: float
    answer_faithfulness: float
    answer_completeness: float
    overall_score: float


class RAGEvaluator:
    """RAG 评估器"""

    def __init__(self, llm_judge=None, embedding_model=None):
        self.llm_judge = llm_judge
        self.embedding_model = embedding_model

    def evaluate(
        self,
        queries: List[str],
        responses: List[RAGResponse],
        ground_truths: List[Dict]
    ) -> RAGEvalResult:
        """评估 RAG 系统"""

        retrieval_metrics = self._evaluate_retrieval(responses, ground_truths)
        generation_metrics = self._evaluate_generation(queries, responses, ground_truths)

        overall = (
            retrieval_metrics["mrr"] * 0.3 +
            generation_metrics["relevance"] * 0.3 +
            generation_metrics["faithfulness"] * 0.4
        )

        return RAGEvalResult(
            retrieval_precision=retrieval_metrics["precision"],
            retrieval_recall=retrieval_metrics["recall"],
            retrieval_mrr=retrieval_metrics["mrr"],
            retrieval_ndcg=retrieval_metrics["ndcg"],
            answer_relevance=generation_metrics["relevance"],
            answer_faithfulness=generation_metrics["faithfulness"],
            answer_completeness=generation_metrics["completeness"],
            overall_score=overall
        )

    def _evaluate_retrieval(
        self,
        responses: List[RAGResponse],
        ground_truths: List[Dict]
    ) -> Dict[str, float]:
        """评估检索质量"""
        precisions = []
        recalls = []
        mrrs = []
        ndcgs = []

        for response, gt in zip(responses, ground_truths):
            gt_chunk_ids = set(gt.get("relevant_chunk_ids", []))
            retrieved_ids = [c.chunk_id for c in response.reranked_chunks]

            if not gt_chunk_ids:
                continue

            # Precision@K
            hits = sum(1 for id_ in retrieved_ids if id_ in gt_chunk_ids)
            precision = hits / len(retrieved_ids) if retrieved_ids else 0
            precisions.append(precision)

            # Recall
            recall = hits / len(gt_chunk_ids) if gt_chunk_ids else 0
            recalls.append(recall)

            # MRR
            mrr = 0
            for i, id_ in enumerate(retrieved_ids):
                if id_ in gt_chunk_ids:
                    mrr = 1.0 / (i + 1)
                    break
            mrrs.append(mrr)

            # NDCG
            ndcg = self._compute_ndcg(retrieved_ids, gt_chunk_ids)
            ndcgs.append(ndcg)

        return {
            "precision": np.mean(precisions) if precisions else 0,
            "recall": np.mean(recalls) if recalls else 0,
            "mrr": np.mean(mrrs) if mrrs else 0,
            "ndcg": np.mean(ndcgs) if ndcgs else 0
        }

    def _evaluate_generation(
        self,
        queries: List[str],
        responses: List[RAGResponse],
        ground_truths: List[Dict]
    ) -> Dict[str, float]:
        """评估生成质量"""
        relevance_scores = []
        faithfulness_scores = []
        completeness_scores = []

        for query, response, gt in zip(queries, responses, ground_truths):
            gt_answer = gt.get("answer", "")

            # 使用 LLM 评估
            if self.llm_judge:
                relevance = self._judge_relevance(query, response.answer)
                faithfulness = self._judge_faithfulness(
                    response.answer,
                    [c.content for c in response.reranked_chunks]
                )
                completeness = self._judge_completeness(
                    query, response.answer, gt_answer
                )
            else:
                # 使用嵌入相似度作为近似
                relevance = self._embedding_similarity(query, response.answer)
                faithfulness = self._compute_faithfulness_heuristic(
                    response.answer,
                    [c.content for c in response.reranked_chunks]
                )
                completeness = self._embedding_similarity(gt_answer, response.answer) if gt_answer else 0

            relevance_scores.append(relevance)
            faithfulness_scores.append(faithfulness)
            completeness_scores.append(completeness)

        return {
            "relevance": np.mean(relevance_scores) if relevance_scores else 0,
            "faithfulness": np.mean(faithfulness_scores) if faithfulness_scores else 0,
            "completeness": np.mean(completeness_scores) if completeness_scores else 0
        }

    def _compute_ndcg(self, retrieved: List[str], relevant: set, k: int = 10) -> float:
        """计算 NDCG@K"""
        dcg = 0
        for i, id_ in enumerate(retrieved[:k]):
            rel = 1 if id_ in relevant else 0
            dcg += rel / np.log2(i + 2)

        # 理想 DCG
        idcg = sum(1 / np.log2(i + 2) for i in range(min(len(relevant), k)))

        return dcg / idcg if idcg > 0 else 0

    def _judge_relevance(self, query: str, answer: str) -> float:
        """使用 LLM 判断相关性"""
        prompt = f"""Rate the relevance of the following answer to the question on a scale of 0-10.

Question: {query}

Answer: {answer}

Score (0-10):"""

        response = self.llm_judge.complete(prompt)
        try:
            score = float(response.strip()) / 10.0
            return min(max(score, 0), 1)
        except:
            return 0.5

    def _judge_faithfulness(self, answer: str, contexts: List[str]) -> float:
        """评估答案忠实度 - 是否基于检索内容"""
        context_text = "\n".join(contexts)

        prompt = f"""Evaluate if the answer is faithful to the provided context.
Rate on a scale of 0-10, where:
- 0: Answer contains information not in the context (hallucination)
- 10: Answer is completely faithful to the context

Context:
{context_text[:3000]}

Answer: {answer}

Faithfulness Score (0-10):"""

        response = self.llm_judge.complete(prompt)
        try:
            score = float(response.strip()) / 10.0
            return min(max(score, 0), 1)
        except:
            return 0.5

    def _judge_completeness(self, query: str, answer: str, gt_answer: str) -> float:
        """评估答案完整性"""
        prompt = f"""Compare the given answer with the reference answer and rate completeness on 0-10.

Question: {query}

Given Answer: {answer}

Reference Answer: {gt_answer}

Completeness Score (0-10):"""

        response = self.llm_judge.complete(prompt)
        try:
            score = float(response.strip()) / 10.0
            return min(max(score, 0), 1)
        except:
            return 0.5

    def _embedding_similarity(self, text1: str, text2: str) -> float:
        """计算嵌入相似度"""
        if not self.embedding_model:
            return 0.5

        emb1 = self.embedding_model.encode(text1)
        emb2 = self.embedding_model.encode(text2)

        similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
        return float(similarity)

    def _compute_faithfulness_heuristic(self, answer: str, contexts: List[str]) -> float:
        """启发式计算忠实度"""
        # 检查答案中的句子是否能在上下文中找到支持
        answer_sentences = answer.split(". ")
        context_text = " ".join(contexts).lower()

        supported = 0
        for sentence in answer_sentences:
            # 检查关键词重叠
            words = set(sentence.lower().split())
            context_words = set(context_text.split())
            overlap = len(words & context_words) / len(words) if words else 0
            if overlap > 0.5:
                supported += 1

        return supported / len(answer_sentences) if answer_sentences else 0

生产部署最佳实践

系统架构

# rag-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: rag-service
spec:
  replicas: 3
  selector:
    matchLabels:
      app: rag-service
  template:
    metadata:
      labels:
        app: rag-service
    spec:
      containers:
      - name: rag-api
        image: rag-service:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            memory: "4Gi"
            cpu: "2"
          limits:
            memory: "8Gi"
            cpu: "4"
        env:
        - name: VECTOR_DB_HOST
          value: "milvus-service:19530"
        - name: LLM_ENDPOINT
          valueFrom:
            secretKeyRef:
              name: llm-credentials
              key: endpoint
        - name: EMBEDDING_MODEL
          value: "sentence-transformers/all-MiniLM-L6-v2"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /ready
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
  name: rag-service
spec:
  selector:
    app: rag-service
  ports:
  - port: 80
    targetPort: 8000
  type: ClusterIP
---
# 向量数据库
apiVersion: apps/v1
kind: StatefulSet
metadata:
  name: milvus
spec:
  serviceName: milvus
  replicas: 1
  selector:
    matchLabels:
      app: milvus
  template:
    metadata:
      labels:
        app: milvus
    spec:
      containers:
      - name: milvus
        image: milvusdb/milvus:v2.3.0
        ports:
        - containerPort: 19530
        volumeMounts:
        - name: data
          mountPath: /var/lib/milvus
        resources:
          requests:
            memory: "8Gi"
            cpu: "4"
            nvidia.com/gpu: 1
          limits:
            memory: "16Gi"
            nvidia.com/gpu: 1
  volumeClaimTemplates:
  - metadata:
      name: data
    spec:
      accessModes: ["ReadWriteOnce"]
      resources:
        requests:
          storage: 100Gi

性能优化配置

"""
RAG 性能优化配置
"""

# 缓存配置
CACHE_CONFIG = {
    # 查询结果缓存
    "query_cache": {
        "enabled": True,
        "backend": "redis",
        "ttl": 3600,  # 1小时
        "max_size": 10000
    },
    # 嵌入缓存
    "embedding_cache": {
        "enabled": True,
        "backend": "redis",
        "ttl": 86400,  # 1天
    },
    # LLM 响应缓存
    "llm_cache": {
        "enabled": True,
        "backend": "redis",
        "ttl": 1800,  # 30分钟
        "similarity_threshold": 0.95  # 相似查询复用
    }
}

# 批处理配置
BATCH_CONFIG = {
    "embedding_batch_size": 64,
    "rerank_batch_size": 32,
    "index_batch_size": 1000
}

# 并发配置
CONCURRENCY_CONFIG = {
    "max_concurrent_queries": 100,
    "retrieval_timeout": 5.0,
    "llm_timeout": 30.0,
    "total_timeout": 60.0
}

# 降级策略
FALLBACK_CONFIG = {
    "on_retrieval_failure": "use_llm_only",
    "on_rerank_failure": "skip_rerank",
    "on_llm_failure": "return_top_chunks",
    "max_retries": 3
}

小结

本章深入讲解了 RAG 系统的核心架构和实现:

  1. 架构设计:从基础 RAG 到 Advanced RAG 的演进
  2. 文档处理:多种分割策略(递归、语义)
  3. 检索引擎:向量检索、混合检索、BM25
  4. 重排序:Cross-Encoder、LLM Judge
  5. 评估体系:检索指标、生成质量评估
  6. 生产部署:Kubernetes 部署、性能优化

下一章我们将探讨 Agent 架构,讲解如何构建具备规划和工具调用能力的智能代理。