15-RAG系统设计与实现
RAG基础
什么是RAG(检索增强生成)
RAG(Retrieval-Augmented Generation)是一种将信息检索与大语言模型生成能力结合的技术架构。它通过在生成答案前先检索相关文档,为LLM提供外部知识,从而克服了传统LLM的知识截止时间限制和幻觉问题。
RAG的工作流程:
- 文档准备:将知识库文档切分成块(chunks)
- 向量化:使用Embedding模型将文本块转换为向量
- 存储:将向量存入向量数据库
- 查询:用户提问时,将问题也转换为向量
- 检索:在向量数据库中查找最相似的文档块
- 生成:将检索到的上下文与问题一起发送给LLM生成答案
核心优势:
- 无需重新训练模型即可更新知识
- 降低幻觉,提供可追溯的信息来源
- 成本低于微调
- 适合动态更新的知识库
RAG vs 微调的区别
| 维度 | RAG | 微调(Fine-tuning) |
|---|---|---|
| 知识更新 | 实时更新,只需更新文档库 | 需要重新训练模型 |
| 成本 | 低,主要是存储和检索成本 | 高,需要GPU算力 |
| 响应时间 | 稍慢(需检索) | 快 |
| 可解释性 | 高,可查看检索的文档 | 低,黑盒 |
| 专业领域 | 适合知识密集型任务 | 适合行为/风格调整 |
| 幻觉控制 | 较好 | 一般 |
| 实施难度 | 简单 | 复杂 |
选择建议:
- 使用RAG:企业知识库问答、文档助手、实时信息查询
- 使用微调:特定风格输出、领域专用术语、任务格式调整
- 结合使用:先微调模型适应领域,再用RAG提供实时知识
RAG的优势和适用场景
优势:
- 知识可追溯:每个答案都能指向具体的文档来源
- 动态更新:新增文档无需重训模型
- 隐私控制:敏感数据保留在企业内部
- 多源整合:可整合多个数据源
- 降低幻觉:基于真实文档生成
适用场景:
- 企业内部知识库问答
- 法律/医疗文档查询
- 技术文档助手
- 客服机器人
- 研究论文分析
- 产品手册查询
不适用场景:
- 需要复杂推理的数学题
- 创意写作
- 代码生成(除非有大量代码库)
- 实时对话(检索会增加延迟)
向量数据库
Embedding原理
Embedding是将文本转换为高维向量的过程,相似的文本在向量空间中距离更近。
Embedding模型演进:
# 1. Word2Vec时代(2013)- 词级别
# "king" - "man" + "woman" ≈ "queen"
# 2. BERT时代(2018)- 句子级别
from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
text = "RAG is powerful"
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1) # [1, 768]
# 3. 专用Embedding模型(2020+)- 优化相似度
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
embedding = model.encode("RAG is powerful") # [384]
常用Embedding模型对比:
| 模型 | 维度 | 语言 | 性能 | 速度 |
|---|---|---|---|---|
| text-embedding-ada-002 | 1536 | 多语言 | 最高 | 中 |
| all-MiniLM-L6-v2 | 384 | 英文 | 中 | 快 |
| paraphrase-multilingual-MiniLM-L12-v2 | 384 | 多语言 | 中 | 快 |
| bge-large-zh | 1024 | 中文 | 高 | 中 |
| m3e-base | 768 | 中文 | 高 | 快 |
中文Embedding示例:
from sentence_transformers import SentenceTransformer
import numpy as np
# 加载中文模型
model = SentenceTransformer('moka-ai/m3e-base')
# 多个文本
texts = [
"RAG系统是检索增强生成",
"向量数据库存储文档嵌入",
"今天天气很好",
]
# 生成embeddings
embeddings = model.encode(texts)
print(f"Embedding shape: {embeddings.shape}") # (3, 768)
# 计算相似度
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(embeddings)
print("相似度矩阵:")
print(similarity_matrix)
# [[1. 0.65 0.12]
# [0.65 1. 0.10]
# [0.12 0.10 1. ]]
向量相似度计算
1. 余弦相似度(Cosine Similarity)
最常用的相似度度量,范围[-1, 1],不受向量长度影响。
import numpy as np
def cosine_similarity(vec1, vec2):
"""余弦相似度"""
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot_product / (norm1 * norm2)
# 示例
vec1 = np.array([1, 2, 3])
vec2 = np.array([4, 5, 6])
print(f"余弦相似度: {cosine_similarity(vec1, vec2):.4f}") # 0.9746
2. 欧氏距离(Euclidean Distance)
计算两点间的直线距离,距离越小越相似。
def euclidean_distance(vec1, vec2):
"""欧氏距离"""
return np.linalg.norm(vec1 - vec2)
print(f"欧氏距离: {euclidean_distance(vec1, vec2):.4f}") # 5.1962
3. 点积(Dot Product)
如果向量已归一化,点积等同于余弦相似度。
def dot_product_similarity(vec1, vec2):
"""点积相似度(假设已归一化)"""
# 先归一化
vec1_norm = vec1 / np.linalg.norm(vec1)
vec2_norm = vec2 / np.linalg.norm(vec2)
return np.dot(vec1_norm, vec2_norm)
print(f"点积相似度: {dot_product_similarity(vec1, vec2):.4f}") # 0.9746
性能对比:
import time
# 生成大量向量
n_vectors = 10000
dim = 768
vectors = np.random.rand(n_vectors, dim).astype(np.float32)
query = np.random.rand(dim).astype(np.float32)
# 测试余弦相似度
start = time.time()
# 归一化
vectors_norm = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
query_norm = query / np.linalg.norm(query)
similarities = np.dot(vectors_norm, query_norm)
print(f"余弦相似度耗时: {(time.time() - start) * 1000:.2f}ms")
# 测试欧氏距离
start = time.time()
distances = np.linalg.norm(vectors - query, axis=1)
print(f"欧氏距离耗时: {(time.time() - start) * 1000:.2f}ms")
# 测试点积(预归一化)
start = time.time()
dot_products = np.dot(vectors_norm, query_norm)
print(f"点积耗时: {(time.time() - start) * 1000:.2f}ms")
常用向量数据库
1. Milvus - 开源分布式向量数据库
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
# 连接Milvus
connections.connect(host='localhost', port='19530')
# 定义schema
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1000),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768)
]
schema = CollectionSchema(fields, description="RAG documents")
# 创建集合
collection = Collection(name="rag_docs", schema=schema)
# 创建索引
index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
collection.create_index(field_name="embedding", index_params=index_params)
# 插入数据
texts = ["RAG系统", "向量数据库"]
embeddings = model.encode(texts).tolist()
entities = [
texts,
embeddings
]
collection.insert(entities)
collection.flush()
# 搜索
collection.load()
query_embedding = model.encode(["什么是RAG"]).tolist()
results = collection.search(
data=query_embedding,
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": 10}},
limit=5,
output_fields=["text"]
)
for hits in results:
for hit in hits:
print(f"距离: {hit.distance:.4f}, 文本: {hit.entity.get('text')}")
2. Pinecone - 托管向量数据库
import pinecone
# 初始化
pinecone.init(api_key="your-api-key", environment="us-west1-gcp")
# 创建索引
index_name = "rag-index"
if index_name not in pinecone.list_indexes():
pinecone.create_index(
name=index_name,
dimension=768,
metric="cosine"
)
index = pinecone.Index(index_name)
# 插入向量
vectors_to_upsert = [
("doc1", embeddings[0].tolist(), {"text": texts[0]}),
("doc2", embeddings[1].tolist(), {"text": texts[1]})
]
index.upsert(vectors=vectors_to_upsert)
# 查询
query_result = index.query(
vector=query_embedding[0],
top_k=5,
include_metadata=True
)
for match in query_result['matches']:
print(f"得分: {match['score']:.4f}, 文本: {match['metadata']['text']}")
3. Qdrant - Rust编写的高性能向量数据库
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
# 连接
client = QdrantClient(host="localhost", port=6333)
# 创建集合
collection_name = "rag_collection"
client.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=768, distance=Distance.COSINE)
)
# 插入点
points = [
PointStruct(
id=i,
vector=embeddings[i].tolist(),
payload={"text": texts[i]}
)
for i in range(len(texts))
]
client.upsert(collection_name=collection_name, points=points)
# 搜索
search_result = client.search(
collection_name=collection_name,
query_vector=query_embedding[0],
limit=5
)
for scored_point in search_result:
print(f"得分: {scored_point.score:.4f}, 文本: {scored_point.payload['text']}")
4. Chroma - 轻量级嵌入式数据库
import chromadb
from chromadb.config import Settings
# 创建客户端
client = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory="./chroma_db"
))
# 创建集合
collection = client.create_collection(
name="rag_docs",
metadata={"hnsw:space": "cosine"}
)
# 添加文档
collection.add(
embeddings=embeddings.tolist(),
documents=texts,
ids=[f"doc{i}" for i in range(len(texts))]
)
# 查询
results = collection.query(
query_embeddings=query_embedding,
n_results=5
)
print(results['documents'])
print(results['distances'])
向量数据库对比:
| 数据库 | 类型 | 语言 | 特点 | 适用场景 |
|---|---|---|---|---|
| Milvus | 开源 | Go/C++ | 分布式、高性能 | 大规模生产 |
| Pinecone | 托管 | - | 全托管、易用 | 快速原型 |
| Qdrant | 开源 | Rust | 高性能、功能丰富 | 中大规模 |
| Chroma | 开源 | Python | 轻量、易集成 | 开发测试 |
| Weaviate | 开源 | Go | GraphQL、模块化 | 知识图谱 |
FAISS索引原理
FAISS(Facebook AI Similarity Search)是Meta开源的高效相似度搜索库。
FAISS索引类型:
import faiss
import numpy as np
# 准备数据
d = 768 # 维度
nb = 10000 # 数据库大小
nq = 10 # 查询数量
np.random.seed(1234)
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')
# 1. Flat索引 - 暴力搜索,最精确但最慢
index_flat = faiss.IndexFlatL2(d)
index_flat.add(xb)
D, I = index_flat.search(xq, 5) # 搜索top5
print(f"Flat索引结果: {I[0]}")
# 2. IVF索引 - 倒排索引,速度快
nlist = 100 # 聚类中心数
quantizer = faiss.IndexFlatL2(d)
index_ivf = faiss.IndexIVFFlat(quantizer, d, nlist)
index_ivf.train(xb) # 需要训练
index_ivf.add(xb)
index_ivf.nprobe = 10 # 搜索的聚类数
D, I = index_ivf.search(xq, 5)
print(f"IVF索引结果: {I[0]}")
# 3. HNSW索引 - 分层图,高召回率
M = 32 # 每层的连接数
index_hnsw = faiss.IndexHNSWFlat(d, M)
index_hnsw.add(xb)
D, I = index_hnsw.search(xq, 5)
print(f"HNSW索引结果: {I[0]}")
# 4. PQ索引 - 乘积量化,压缩存储
m = 8 # 子向量数
index_pq = faiss.IndexPQ(d, m, 8)
index_pq.train(xb)
index_pq.add(xb)
D, I = index_pq.search(xq, 5)
print(f"PQ索引结果: {I[0]}")
# 5. 组合索引 - IVF + PQ
index_ivfpq = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
index_ivfpq.train(xb)
index_ivfpq.add(xb)
index_ivfpq.nprobe = 10
D, I = index_ivfpq.search(xq, 5)
print(f"IVFPQ索引结果: {I[0]}")
FAISS性能对比:
import time
def benchmark_index(index, xb, xq, name):
"""测试索引性能"""
# 添加数据
start = time.time()
if hasattr(index, 'train'):
index.train(xb)
index.add(xb)
add_time = time.time() - start
# 搜索
start = time.time()
D, I = index.search(xq, 5)
search_time = (time.time() - start) / len(xq) * 1000 # ms per query
print(f"{name:15} - 添加: {add_time:.3f}s, 搜索: {search_time:.3f}ms/query")
return D, I
# 大规模测试
d = 768
nb = 100000
nq = 1000
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')
# 测试各种索引
indexes = {
"Flat": faiss.IndexFlatL2(d),
"IVF100": faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, 100),
"HNSW32": faiss.IndexHNSWFlat(d, 32),
"PQ8": faiss.IndexPQ(d, 8, 8),
}
for name, index in indexes.items():
benchmark_index(index, xb, xq, name)
FAISS持久化:
# 保存索引
faiss.write_index(index_hnsw, "rag_index.faiss")
# 加载索引
index_loaded = faiss.read_index("rag_index.faiss")
# GPU加速
if faiss.get_num_gpus() > 0:
res = faiss.StandardGpuResources()
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_flat)
D, I = index_gpu.search(xq, 5)
RAG Pipeline详解
文档加载和分块(Chunking)
1. 文档加载器
from typing import List
import os
class Document:
"""文档对象"""
def __init__(self, content: str, metadata: dict = None):
self.content = content
self.metadata = metadata or {}
class DocumentLoader:
"""通用文档加载器"""
@staticmethod
def load_txt(file_path: str) -> Document:
"""加载文本文件"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return Document(content, {"source": file_path})
@staticmethod
def load_pdf(file_path: str) -> List[Document]:
"""加载PDF文件"""
from PyPDF2 import PdfReader
reader = PdfReader(file_path)
documents = []
for i, page in enumerate(reader.pages):
text = page.extract_text()
doc = Document(
content=text,
metadata={"source": file_path, "page": i + 1}
)
documents.append(doc)
return documents
@staticmethod
def load_docx(file_path: str) -> Document:
"""加载Word文档"""
from docx import Document as DocxDocument
doc = DocxDocument(file_path)
content = '\n'.join([p.text for p in doc.paragraphs])
return Document(content, {"source": file_path})
@staticmethod
def load_markdown(file_path: str) -> Document:
"""加载Markdown文件"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return Document(content, {"source": file_path, "format": "markdown"})
@staticmethod
def load_directory(directory: str) -> List[Document]:
"""加载目录下所有文档"""
documents = []
for root, dirs, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
ext = os.path.splitext(file)[1].lower()
try:
if ext == '.txt':
documents.append(DocumentLoader.load_txt(file_path))
elif ext == '.pdf':
documents.extend(DocumentLoader.load_pdf(file_path))
elif ext == '.docx':
documents.append(DocumentLoader.load_docx(file_path))
elif ext == '.md':
documents.append(DocumentLoader.load_markdown(file_path))
except Exception as e:
print(f"加载 {file_path} 失败: {e}")
return documents
2. 文本分块策略
from typing import List, Callable
import re
class TextSplitter:
"""文本分块基类"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_text(self, text: str) -> List[str]:
"""分割文本"""
raise NotImplementedError
class CharacterTextSplitter(TextSplitter):
"""按字符数分块"""
def split_text(self, text: str) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + self.chunk_size
chunk = text[start:end]
chunks.append(chunk)
start = end - self.chunk_overlap
return chunks
class RecursiveCharacterTextSplitter(TextSplitter):
"""递归字符分块 - 保持段落完整性"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50,
separators: List[str] = None):
super().__init__(chunk_size, chunk_overlap)
self.separators = separators or ["\n\n", "\n", "。", "!", "?", ";", " ", ""]
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self.separators)
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""递归分割"""
if not separators:
return [text]
separator = separators[0]
chunks = []
if separator:
splits = text.split(separator)
else:
splits = list(text)
current_chunk = []
current_length = 0
for split in splits:
split_length = len(split)
if current_length + split_length > self.chunk_size:
if current_chunk:
chunk_text = separator.join(current_chunk)
if len(chunk_text) > self.chunk_size:
# 继续用下一个分隔符分割
chunks.extend(self._split_text(chunk_text, separators[1:]))
else:
chunks.append(chunk_text)
# 保留重叠
overlap_text = separator.join(current_chunk[-2:]) if len(current_chunk) > 1 else ""
current_chunk = [overlap_text, split] if overlap_text else [split]
current_length = len(overlap_text) + split_length
else:
current_chunk = [split]
current_length = split_length
else:
current_chunk.append(split)
current_length += split_length
if current_chunk:
chunks.append(separator.join(current_chunk))
return chunks
class SemanticTextSplitter(TextSplitter):
"""语义分块 - 基于句子相似度"""
def __init__(self, chunk_size: int = 500, model=None):
super().__init__(chunk_size)
from sentence_transformers import SentenceTransformer
self.model = model or SentenceTransformer('all-MiniLM-L6-v2')
def split_text(self, text: str) -> List[str]:
# 按句子分割
sentences = re.split(r'[。!?!?]', text)
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return []
# 计算句子embeddings
embeddings = self.model.encode(sentences)
# 计算相邻句子相似度
from sklearn.metrics.pairwise import cosine_similarity
similarities = []
for i in range(len(embeddings) - 1):
sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
similarities.append(sim)
# 在相似度低的地方分割
threshold = np.percentile(similarities, 30) # 取30%分位数
chunks = []
current_chunk = [sentences[0]]
current_length = len(sentences[0])
for i, sentence in enumerate(sentences[1:], 1):
if current_length + len(sentence) > self.chunk_size or similarities[i-1] < threshold:
chunks.append(''.join(current_chunk))
current_chunk = [sentence]
current_length = len(sentence)
else:
current_chunk.append(sentence)
current_length += len(sentence)
if current_chunk:
chunks.append(''.join(current_chunk))
return chunks
class MarkdownTextSplitter(TextSplitter):
"""Markdown专用分块 - 保持标题结构"""
def split_text(self, text: str) -> List[str]:
# 按标题分割
sections = re.split(r'\n(#{1,6}\s+.+)\n', text)
chunks = []
current_chunk = ""
current_headers = []
for i, section in enumerate(sections):
if re.match(r'#{1,6}\s+', section):
# 是标题
level = len(section.split()[0])
current_headers = current_headers[:level-1] + [section]
else:
# 是内容
header_text = '\n'.join(current_headers)
full_text = f"{header_text}\n{section}" if header_text else section
if len(current_chunk) + len(full_text) > self.chunk_size:
if current_chunk:
chunks.append(current_chunk)
current_chunk = full_text
else:
current_chunk += "\n" + full_text
if current_chunk:
chunks.append(current_chunk)
return chunks
3. 分块效果对比
# 示例文本
text = """
人工智能的发展历程可以分为几个阶段。
第一阶段是符号主义时期。在这个时期,研究者们相信智能可以通过符号操作实现。他们开发了各种专家系统。
第二阶段是连接主义复兴。神经网络重新获得关注。深度学习开始兴起。
第三阶段是大模型时代。Transformer架构的出现改变了一切。GPT、BERT等模型展现了惊人的能力。
"""
# 测试不同分块策略
splitters = {
"字符分块": CharacterTextSplitter(chunk_size=100, chunk_overlap=20),
"递归分块": RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20),
}
for name, splitter in splitters.items():
chunks = splitter.split_text(text)
print(f"\n{name}:")
for i, chunk in enumerate(chunks, 1):
print(f" 块{i} ({len(chunk)}字): {chunk[:50]}...")
Embedding生成
from typing import List
import numpy as np
class EmbeddingGenerator:
"""Embedding生成器"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""批量生成embeddings"""
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=True,
normalize_embeddings=True # 归一化便于余弦相似度计算
)
return embeddings
def embed_query(self, query: str) -> np.ndarray:
"""生成查询embedding"""
return self.model.encode([query], normalize_embeddings=True)[0]
# 使用OpenAI Embedding
class OpenAIEmbedding:
"""OpenAI Embedding"""
def __init__(self, api_key: str):
import openai
openai.api_key = api_key
self.client = openai
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""批量生成embeddings"""
response = self.client.embeddings.create(
model="text-embedding-ada-002",
input=texts
)
return [item.embedding for item in response.data]
def embed_query(self, query: str) -> List[float]:
"""生成查询embedding"""
return self.embed_texts([query])[0]
向量存储和索引
import pickle
from typing import List, Tuple, Dict
class VectorStore:
"""向量存储基类"""
def add_texts(self, texts: List[str], embeddings: np.ndarray, metadatas: List[Dict] = None):
"""添加文本和向量"""
raise NotImplementedError
def similarity_search(self, query_embedding: np.ndarray, k: int = 5) -> List[Tuple[str, float, Dict]]:
"""相似度搜索"""
raise NotImplementedError
def save(self, path: str):
"""保存索引"""
raise NotImplementedError
def load(self, path: str):
"""加载索引"""
raise NotImplementedError
class FAISSVectorStore(VectorStore):
"""基于FAISS的向量存储"""
def __init__(self, dimension: int, index_type: str = "Flat"):
self.dimension = dimension
self.texts = []
self.metadatas = []
# 创建索引
if index_type == "Flat":
self.index = faiss.IndexFlatL2(dimension)
elif index_type == "IVF":
quantizer = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIVFFlat(quantizer, dimension, 100)
elif index_type == "HNSW":
self.index = faiss.IndexHNSWFlat(dimension, 32)
else:
raise ValueError(f"Unknown index type: {index_type}")
self.is_trained = False
def add_texts(self, texts: List[str], embeddings: np.ndarray, metadatas: List[Dict] = None):
"""添加文本和向量"""
# 确保embeddings是float32
embeddings = embeddings.astype('float32')
# 训练索引(如果需要)
if hasattr(self.index, 'train') and not self.is_trained:
self.index.train(embeddings)
self.is_trained = True
# 添加向量
self.index.add(embeddings)
# 保存文本和元数据
self.texts.extend(texts)
if metadatas:
self.metadatas.extend(metadatas)
else:
self.metadatas.extend([{} for _ in texts])
def similarity_search(self, query_embedding: np.ndarray, k: int = 5) -> List[Tuple[str, float, Dict]]:
"""相似度搜索"""
query_embedding = query_embedding.astype('float32').reshape(1, -1)
# 搜索
distances, indices = self.index.search(query_embedding, k)
# 组装结果
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx < len(self.texts):
results.append((
self.texts[idx],
float(dist),
self.metadatas[idx]
))
return results
def save(self, path: str):
"""保存索引"""
faiss.write_index(self.index, f"{path}.faiss")
with open(f"{path}.pkl", 'wb') as f:
pickle.dump({
'texts': self.texts,
'metadatas': self.metadatas,
'dimension': self.dimension,
'is_trained': self.is_trained
}, f)
def load(self, path: str):
"""加载索引"""
self.index = faiss.read_index(f"{path}.faiss")
with open(f"{path}.pkl", 'rb') as f:
data = pickle.load(f)
self.texts = data['texts']
self.metadatas = data['metadatas']
self.dimension = data['dimension']
self.is_trained = data['is_trained']
相似度检索
class Retriever:
"""检索器"""
def __init__(self, vector_store: VectorStore, embedding_generator: EmbeddingGenerator):
self.vector_store = vector_store
self.embedding_generator = embedding_generator
def retrieve(self, query: str, k: int = 5) -> List[Tuple[str, float, Dict]]:
"""检索相关文档"""
# 生成查询embedding
query_embedding = self.embedding_generator.embed_query(query)
# 搜索
results = self.vector_store.similarity_search(query_embedding, k)
return results
def retrieve_with_threshold(self, query: str, k: int = 5, threshold: float = 0.7) -> List[Tuple[str, float, Dict]]:
"""带阈值的检索"""
results = self.retrieve(query, k)
# 过滤低分结果
filtered_results = [(text, score, meta) for text, score, meta in results if score >= threshold]
return filtered_results
上下文注入
class ContextBuilder:
"""上下文构建器"""
def __init__(self, max_context_length: int = 2000):
self.max_context_length = max_context_length
def build_context(self, query: str, retrieved_docs: List[Tuple[str, float, Dict]]) -> str:
"""构建上下文"""
context_parts = []
current_length = 0
for i, (text, score, metadata) in enumerate(retrieved_docs, 1):
# 添加来源信息
source = metadata.get('source', 'Unknown')
doc_context = f"[文档{i}] (来源: {source}, 相关度: {score:.3f})\n{text}\n"
doc_length = len(doc_context)
if current_length + doc_length > self.max_context_length:
break
context_parts.append(doc_context)
current_length += doc_length
context = "\n".join(context_parts)
return context
def build_prompt(self, query: str, context: str) -> str:
"""构建完整提示"""
prompt = f"""基于以下文档内容回答问题。如果文档中没有相关信息,请明确说明。
文档内容:
{context}
问题: {query}
回答:"""
return prompt
LLM生成
from typing import Optional
class LLMGenerator:
"""LLM生成器"""
def __init__(self, model_name: str = "gpt-3.5-turbo", api_key: Optional[str] = None):
import openai
if api_key:
openai.api_key = api_key
self.client = openai
self.model_name = model_name
def generate(self, prompt: str, temperature: float = 0.7, max_tokens: int = 500) -> str:
"""生成回答"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "你是一个helpful的AI助手,基于提供的文档回答问题。"},
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
def generate_stream(self, prompt: str, temperature: float = 0.7):
"""流式生成"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "你是一个helpful的AI助手,基于提供的文档回答问题。"},
{"role": "user", "content": prompt}
],
temperature=temperature,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
完整RAG实现
LangChain RAG
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
class LangChainRAG:
"""基于LangChain的RAG系统"""
def __init__(self, docs_path: str, openai_api_key: str):
self.docs_path = docs_path
self.openai_api_key = openai_api_key
self.vectorstore = None
self.qa_chain = None
def load_documents(self):
"""加载文档"""
loader = DirectoryLoader(
self.docs_path,
glob="**/*.txt",
loader_cls=TextLoader
)
documents = loader.load()
print(f"加载了 {len(documents)} 个文档")
return documents
def split_documents(self, documents):
"""分割文档"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
)
chunks = text_splitter.split_documents(documents)
print(f"分割成 {len(chunks)} 个块")
return chunks
def create_vectorstore(self, chunks):
"""创建向量存储"""
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
self.vectorstore = FAISS.from_documents(chunks, embeddings)
print("向量存储创建完成")
def setup_qa_chain(self):
"""设置问答链"""
# LLM
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.7,
openai_api_key=self.openai_api_key
)
# 自定义提示模板
prompt_template = """使用以下文档片段来回答问题。如果你不知道答案,就说不知道,不要编造答案。
{context}
问题: {question}
答案:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# 创建QA链
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": PROMPT},
return_source_documents=True
)
def query(self, question: str):
"""查询"""
result = self.qa_chain({"query": question})
print(f"\n问题: {question}")
print(f"答案: {result['result']}")
print("\n来源文档:")
for i, doc in enumerate(result['source_documents'], 1):
print(f" {i}. {doc.metadata.get('source', 'Unknown')}: {doc.page_content[:100]}...")
return result
def save_vectorstore(self, path: str):
"""保存向量存储"""
self.vectorstore.save_local(path)
def load_vectorstore(self, path: str):
"""加载向量存储"""
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
self.vectorstore = FAISS.load_local(path, embeddings)
# 使用示例
if __name__ == "__main__":
rag = LangChainRAG(
docs_path="./documents",
openai_api_key="your-api-key"
)
# 构建索引
docs = rag.load_documents()
chunks = rag.split_documents(docs)
rag.create_vectorstore(chunks)
rag.save_vectorstore("./vectorstore")
# 设置问答
rag.setup_qa_chain()
# 查询
rag.query("什么是RAG?")
LlamaIndex RAG
from llama_index import (
VectorStoreIndex,
SimpleDirectoryReader,
ServiceContext,
StorageContext,
load_index_from_storage
)
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.llms import OpenAI
from llama_index.node_parser import SimpleNodeParser
from llama_index.text_splitter import SentenceSplitter
class LlamaIndexRAG:
"""基于LlamaIndex的RAG系统"""
def __init__(self, docs_path: str, openai_api_key: str):
self.docs_path = docs_path
self.openai_api_key = openai_api_key
self.index = None
def build_index(self):
"""构建索引"""
# 加载文档
documents = SimpleDirectoryReader(self.docs_path).load_data()
print(f"加载了 {len(documents)} 个文档")
# 配置embedding模型
embed_model = HuggingFaceEmbedding(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# 配置LLM
llm = OpenAI(
model="gpt-3.5-turbo",
temperature=0.7,
api_key=self.openai_api_key
)
# 配置文本分割器
text_splitter = SentenceSplitter(
chunk_size=500,
chunk_overlap=50
)
# 配置节点解析器
node_parser = SimpleNodeParser.from_defaults(
text_splitter=text_splitter
)
# 创建服务上下文
service_context = ServiceContext.from_defaults(
embed_model=embed_model,
llm=llm,
node_parser=node_parser
)
# 构建索引
self.index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
show_progress=True
)
print("索引构建完成")
def query(self, question: str, similarity_top_k: int = 3):
"""查询"""
query_engine = self.index.as_query_engine(
similarity_top_k=similarity_top_k
)
response = query_engine.query(question)
print(f"\n问题: {question}")
print(f"答案: {response.response}")
print("\n来源节点:")
for i, node in enumerate(response.source_nodes, 1):
print(f" {i}. (分数: {node.score:.3f}): {node.text[:100]}...")
return response
def save_index(self, persist_dir: str = "./storage"):
"""保存索引"""
self.index.storage_context.persist(persist_dir=persist_dir)
def load_index(self, persist_dir: str = "./storage", openai_api_key: str = None):
"""加载索引"""
# 配置服务上下文
embed_model = HuggingFaceEmbedding(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
llm = OpenAI(
model="gpt-3.5-turbo",
api_key=openai_api_key or self.openai_api_key
)
service_context = ServiceContext.from_defaults(
embed_model=embed_model,
llm=llm
)
# 加载存储上下文
storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
# 加载索引
self.index = load_index_from_storage(
storage_context,
service_context=service_context
)
# 使用示例
if __name__ == "__main__":
rag = LlamaIndexRAG(
docs_path="./documents",
openai_api_key="your-api-key"
)
# 构建索引
rag.build_index()
rag.save_index()
# 查询
rag.query("什么是RAG?")
从零实现RAG系统
from typing import List, Dict, Tuple, Optional
import numpy as np
from dataclasses import dataclass
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class RAGConfig:
"""RAG配置"""
chunk_size: int = 500
chunk_overlap: int = 50
embedding_model: str = "all-MiniLM-L6-v2"
llm_model: str = "gpt-3.5-turbo"
top_k: int = 5
similarity_threshold: float = 0.0
max_context_length: int = 2000
temperature: float = 0.7
class RAGSystem:
"""完整的RAG系统实现"""
def __init__(self, config: RAGConfig = None, openai_api_key: str = None):
self.config = config or RAGConfig()
self.openai_api_key = openai_api_key
# 初始化组件
self._init_components()
logger.info("RAG系统初始化完成")
def _init_components(self):
"""初始化各个组件"""
# 文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.config.chunk_size,
chunk_overlap=self.config.chunk_overlap
)
# Embedding生成器
self.embedding_generator = EmbeddingGenerator(
model_name=self.config.embedding_model
)
# 向量存储
self.vector_store = FAISSVectorStore(
dimension=self.embedding_generator.dimension,
index_type="HNSW"
)
# 检索器
self.retriever = Retriever(
vector_store=self.vector_store,
embedding_generator=self.embedding_generator
)
# 上下文构建器
self.context_builder = ContextBuilder(
max_context_length=self.config.max_context_length
)
# LLM生成器
self.llm_generator = LLMGenerator(
model_name=self.config.llm_model,
api_key=self.openai_api_key
)
def add_documents(self, documents: List[Document]):
"""添加文档到系统"""
logger.info(f"开始处理 {len(documents)} 个文档")
all_chunks = []
all_metadatas = []
for doc in documents:
# 分割文档
chunks = self.text_splitter.split_text(doc.content)
# 添加元数据
for chunk in chunks:
all_chunks.append(chunk)
all_metadatas.append(doc.metadata)
logger.info(f"文档分割完成,共 {len(all_chunks)} 个块")
# 生成embeddings
logger.info("生成embeddings...")
embeddings = self.embedding_generator.embed_texts(all_chunks)
# 存储
logger.info("存储向量...")
self.vector_store.add_texts(all_chunks, embeddings, all_metadatas)
logger.info("文档添加完成")
def query(self, question: str, return_sources: bool = True) -> Dict:
"""查询系统"""
logger.info(f"查询: {question}")
# 检索相关文档
retrieved_docs = self.retriever.retrieve(
question,
k=self.config.top_k
)
# 过滤低分文档
retrieved_docs = [
(text, score, meta)
for text, score, meta in retrieved_docs
if score >= self.config.similarity_threshold
]
if not retrieved_docs:
return {
"answer": "抱歉,我没有找到相关信息来回答您的问题。",
"sources": []
}
logger.info(f"检索到 {len(retrieved_docs)} 个相关文档")
# 构建上下文
context = self.context_builder.build_context(question, retrieved_docs)
# 构建提示
prompt = self.context_builder.build_prompt(question, context)
# 生成答案
logger.info("生成答案...")
answer = self.llm_generator.generate(
prompt,
temperature=self.config.temperature
)
result = {
"answer": answer,
"sources": [
{
"text": text[:200],
"score": score,
"metadata": meta
}
for text, score, meta in retrieved_docs
] if return_sources else []
}
return result
def query_stream(self, question: str):
"""流式查询"""
# 检索
retrieved_docs = self.retriever.retrieve(question, k=self.config.top_k)
# 构建提示
context = self.context_builder.build_context(question, retrieved_docs)
prompt = self.context_builder.build_prompt(question, context)
# 流式生成
for chunk in self.llm_generator.generate_stream(prompt):
yield chunk
def save(self, path: str):
"""保存系统"""
self.vector_store.save(path)
logger.info(f"系统已保存到 {path}")
def load(self, path: str):
"""加载系统"""
self.vector_store.load(path)
logger.info(f"系统已从 {path} 加载")
# 使用示例
def main():
# 配置
config = RAGConfig(
chunk_size=500,
chunk_overlap=50,
top_k=3,
temperature=0.7
)
# 创建系统
rag = RAGSystem(
config=config,
openai_api_key="your-api-key"
)
# 加载文档
docs = DocumentLoader.load_directory("./documents")
# 添加到系统
rag.add_documents(docs)
# 保存
rag.save("./rag_system")
# 查询
result = rag.query("什么是RAG系统?")
print(f"\n答案: {result['answer']}")
print("\n来源:")
for i, source in enumerate(result['sources'], 1):
print(f" {i}. (分数: {source['score']:.3f}): {source['text'][:100]}...")
# 流式查询
print("\n流式回答:")
for chunk in rag.query_stream("RAG有什么优势?"):
print(chunk, end='', flush=True)
print()
if __name__ == "__main__":
main()
RAG优化技巧
Hybrid Search(混合检索)
混合检索结合了稠密向量检索和稀疏关键词检索的优势。
from rank_bm25 import BM25Okapi
import jieba
class HybridRetriever:
"""混合检索器"""
def __init__(self, vector_store: VectorStore, embedding_generator: EmbeddingGenerator):
self.vector_store = vector_store
self.embedding_generator = embedding_generator
self.bm25 = None
self.texts = []
def add_texts(self, texts: List[str]):
"""添加文本"""
self.texts = texts
# 构建BM25索引
tokenized_corpus = [list(jieba.cut(text)) for text in texts]
self.bm25 = BM25Okapi(tokenized_corpus)
def retrieve(self, query: str, k: int = 5, alpha: float = 0.5) -> List[Tuple[str, float]]:
"""
混合检索
alpha: 向量检索权重,1-alpha为BM25权重
"""
# 向量检索
query_embedding = self.embedding_generator.embed_query(query)
vector_results = self.vector_store.similarity_search(query_embedding, k=k*2)
# BM25检索
tokenized_query = list(jieba.cut(query))
bm25_scores = self.bm25.get_scores(tokenized_query)
# 归一化分数
vector_scores = {text: score for text, score, _ in vector_results}
max_vector_score = max(vector_scores.values()) if vector_scores else 1
max_bm25_score = max(bm25_scores) if len(bm25_scores) > 0 else 1
# 计算混合分数
hybrid_scores = {}
for i, text in enumerate(self.texts):
vector_score = vector_scores.get(text, 0) / max_vector_score
bm25_score = bm25_scores[i] / max_bm25_score
hybrid_score = alpha * vector_score + (1 - alpha) * bm25_score
hybrid_scores[text] = hybrid_score
# 排序
sorted_results = sorted(
hybrid_scores.items(),
key=lambda x: x[1],
reverse=True
)[:k]
return sorted_results
# 使用示例
hybrid_retriever = HybridRetriever(vector_store, embedding_generator)
hybrid_retriever.add_texts(texts)
results = hybrid_retriever.retrieve("RAG系统", k=5, alpha=0.7)
Reranking(重排序)
使用交叉编码器对检索结果重新排序,提高相关性。
from sentence_transformers import CrossEncoder
class Reranker:
"""重排序器"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model = CrossEncoder(model_name)
def rerank(self, query: str, documents: List[str], top_k: int = 5) -> List[Tuple[str, float]]:
"""重排序"""
# 构建查询-文档对
pairs = [[query, doc] for doc in documents]
# 计算相关性分数
scores = self.model.predict(pairs)
# 排序
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
return doc_score_pairs[:top_k]
# 集成到RAG
class RAGWithReranking(RAGSystem):
"""带重排序的RAG系统"""
def __init__(self, config: RAGConfig = None, openai_api_key: str = None):
super().__init__(config, openai_api_key)
self.reranker = Reranker()
def query(self, question: str, return_sources: bool = True) -> Dict:
"""查询(带重排序)"""
# 初步检索更多文档
retrieved_docs = self.retriever.retrieve(
question,
k=self.config.top_k * 2 # 检索更多文档
)
# 提取文本
texts = [text for text, _, _ in retrieved_docs]
# 重排序
reranked = self.reranker.rerank(question, texts, top_k=self.config.top_k)
# 重构retrieved_docs
text_to_meta = {text: meta for text, _, meta in retrieved_docs}
retrieved_docs = [
(text, score, text_to_meta[text])
for text, score in reranked
]
# 后续流程同RAGSystem
context = self.context_builder.build_context(question, retrieved_docs)
prompt = self.context_builder.build_prompt(question, context)
answer = self.llm_generator.generate(prompt)
return {
"answer": answer,
"sources": [
{"text": text[:200], "score": score, "metadata": meta}
for text, score, meta in retrieved_docs
] if return_sources else []
}
Query改写
通过改写查询来提高检索效果。
class QueryRewriter:
"""查询改写器"""
def __init__(self, llm_generator: LLMGenerator):
self.llm = llm_generator
def rewrite_multi_query(self, query: str, n: int = 3) -> List[str]:
"""生成多个查询变体"""
prompt = f"""请为以下问题生成{n}个不同的表述方式,每个表述一行:
原问题: {query}
改写后的问题:"""
response = self.llm.generate(prompt, temperature=0.7)
queries = [q.strip() for q in response.strip().split('\n') if q.strip()]
return queries[:n]
def rewrite_step_back(self, query: str) -> str:
"""Step-back改写 - 生成更通用的问题"""
prompt = f"""给定一个具体问题,请生成一个更通用、更高层次的问题。
具体问题: {query}
通用问题:"""
general_query = self.llm.generate(prompt, temperature=0.3)
return general_query.strip()
def rewrite_with_context(self, query: str, chat_history: List[Dict]) -> str:
"""基于对话历史改写查询"""
history_text = "\n".join([
f"{'用户' if msg['role'] == 'user' else 'AI'}: {msg['content']}"
for msg in chat_history
])
prompt = f"""基于对话历史,将用户的最新问题改写为一个独立的问题。
对话历史:
{history_text}
最新问题: {query}
独立问题:"""
standalone_query = self.llm.generate(prompt, temperature=0.3)
return standalone_query.strip()
# 多查询检索
class MultiQueryRetriever:
"""多查询检索器"""
def __init__(self, retriever: Retriever, query_rewriter: QueryRewriter):
self.retriever = retriever
self.query_rewriter = query_rewriter
def retrieve(self, query: str, k: int = 5) -> List[Tuple[str, float, Dict]]:
"""使用多个查询检索"""
# 生成查询变体
queries = [query] + self.query_rewriter.rewrite_multi_query(query, n=2)
# 收集所有结果
all_results = {}
for q in queries:
results = self.retriever.retrieve(q, k=k)
for text, score, meta in results:
if text in all_results:
# 取最高分
all_results[text] = max(all_results[text], (score, meta), key=lambda x: x[0])
else:
all_results[text] = (score, meta)
# 排序并返回
sorted_results = sorted(
[(text, score, meta) for text, (score, meta) in all_results.items()],
key=lambda x: x[1],
reverse=True
)[:k]
return sorted_results
HyDE(假设文档嵌入)
先让LLM生成假设的答案文档,然后用这个文档来检索。
class HyDERetriever:
"""HyDE检索器"""
def __init__(self, vector_store: VectorStore,
embedding_generator: EmbeddingGenerator,
llm_generator: LLMGenerator):
self.vector_store = vector_store
self.embedding_generator = embedding_generator
self.llm = llm_generator
def generate_hypothetical_document(self, query: str) -> str:
"""生成假设文档"""
prompt = f"""请为以下问题生成一个详细的答案(假设你知道答案)。
问题: {query}
答案:"""
hypothetical_doc = self.llm.generate(prompt, temperature=0.7)
return hypothetical_doc
def retrieve(self, query: str, k: int = 5) -> List[Tuple[str, float, Dict]]:
"""使用HyDE检索"""
# 生成假设文档
hypothetical_doc = self.generate_hypothetical_document(query)
# 使用假设文档的embedding来检索
doc_embedding = self.embedding_generator.embed_query(hypothetical_doc)
results = self.vector_store.similarity_search(doc_embedding, k=k)
return results
# 使用示例
hyde_retriever = HyDERetriever(vector_store, embedding_generator, llm_generator)
results = hyde_retriever.retrieve("RAG系统的优势是什么?", k=5)
实战案例
企业知识库问答
class EnterpriseKnowledgeBase:
"""企业知识库系统"""
def __init__(self, openai_api_key: str):
self.rag = RAGWithReranking(
config=RAGConfig(
chunk_size=800,
chunk_overlap=100,
top_k=5
),
openai_api_key=openai_api_key
)
self.query_rewriter = QueryRewriter(self.rag.llm_generator)
self.chat_history = []
def load_knowledge_base(self, paths: List[str]):
"""加载知识库"""
all_docs = []
for path in paths:
if os.path.isdir(path):
docs = DocumentLoader.load_directory(path)
else:
ext = os.path.splitext(path)[1]
if ext == '.pdf':
docs = DocumentLoader.load_pdf(path)
elif ext == '.txt':
docs = [DocumentLoader.load_txt(path)]
else:
continue
all_docs.extend(docs)
self.rag.add_documents(all_docs)
logger.info(f"知识库加载完成,共 {len(all_docs)} 个文档")
def ask(self, question: str) -> Dict:
"""提问"""
# 基于历史改写查询
if self.chat_history:
standalone_query = self.query_rewriter.rewrite_with_context(
question,
self.chat_history
)
else:
standalone_query = question
# 查询
result = self.rag.query(standalone_query)
# 更新历史
self.chat_history.append({"role": "user", "content": question})
self.chat_history.append({"role": "assistant", "content": result['answer']})
# 保持历史长度
if len(self.chat_history) > 10:
self.chat_history = self.chat_history[-10:]
return result
def reset_conversation(self):
"""重置对话"""
self.chat_history = []
# 使用示例
kb = EnterpriseKnowledgeBase(openai_api_key="your-api-key")
kb.load_knowledge_base([
"./company_docs/policies",
"./company_docs/procedures",
"./company_docs/faqs.pdf"
])
result = kb.ask("公司的休假政策是什么?")
print(result['answer'])
result = kb.ask("那产假呢?") # 基于上下文的追问
print(result['answer'])
文档助手
class DocumentAssistant:
"""文档助手 - 针对单个或少量文档的深度问答"""
def __init__(self, openai_api_key: str):
self.openai_api_key = openai_api_key
self.rag = None
self.document_summary = None
def load_document(self, file_path: str):
"""加载文档"""
ext = os.path.splitext(file_path)[1]
if ext == '.pdf':
docs = DocumentLoader.load_pdf(file_path)
elif ext == '.txt':
docs = [DocumentLoader.load_txt(file_path)]
elif ext == '.docx':
docs = [DocumentLoader.load_docx(file_path)]
else:
raise ValueError(f"不支持的文件格式: {ext}")
# 配置 - 使用更小的chunk以获得更精确的检索
config = RAGConfig(
chunk_size=300,
chunk_overlap=50,
top_k=8
)
self.rag = RAGWithReranking(config, self.openai_api_key)
self.rag.add_documents(docs)
# 生成文档摘要
self._generate_summary(docs)
logger.info(f"文档加载完成: {file_path}")
def _generate_summary(self, docs: List[Document]):
"""生成文档摘要"""
# 合并所有文档内容
full_text = "\n\n".join([doc.content for doc in docs])
# 如果文档太长,先分块总结再合并
if len(full_text) > 10000:
chunks = self.rag.text_splitter.split_text(full_text)
chunk_summaries = []
for chunk in chunks[:10]: # 最多总结10个块
prompt = f"请简要总结以下内容(2-3句话):\n\n{chunk}"
summary = self.rag.llm_generator.generate(prompt, temperature=0.3, max_tokens=200)
chunk_summaries.append(summary)
# 合并摘要
combined = "\n".join(chunk_summaries)
prompt = f"请总结以下内容(一段话):\n\n{combined}"
self.document_summary = self.rag.llm_generator.generate(prompt, temperature=0.3)
else:
prompt = f"请总结以下文档内容(一段话):\n\n{full_text[:5000]}"
self.document_summary = self.rag.llm_generator.generate(prompt, temperature=0.3)
def ask(self, question: str, mode: str = "detailed") -> Dict:
"""
提问
mode: 'quick' 快速模式, 'detailed' 详细模式, 'summary' 摘要模式
"""
if mode == "summary":
return {"answer": self.document_summary, "sources": []}
result = self.rag.query(question)
if mode == "detailed":
# 详细模式:提供更多上下文
result['summary'] = self.document_summary
return result
def extract_info(self, fields: List[str]) -> Dict:
"""从文档中提取特定信息"""
results = {}
for field in fields:
question = f"文档中关于{field}的信息是什么?"
result = self.ask(question, mode="quick")
results[field] = result['answer']
return results
# 使用示例
assistant = DocumentAssistant(openai_api_key="your-api-key")
assistant.load_document("./contract.pdf")
# 获取摘要
summary = assistant.ask("", mode="summary")
print(f"文档摘要: {summary['answer']}")
# 提问
result = assistant.ask("合同的有效期是多久?")
print(f"答案: {result['answer']}")
# 批量提取信息
info = assistant.extract_info(["甲方", "乙方", "合同金额", "签订日期"])
for field, value in info.items():
print(f"{field}: {value}")
多模态RAG
from PIL import Image
import base64
from io import BytesIO
class MultimodalRAG:
"""多模态RAG - 支持文本和图像"""
def __init__(self, openai_api_key: str):
self.openai_api_key = openai_api_key
# 文本RAG
self.text_rag = RAGSystem(
config=RAGConfig(chunk_size=500),
openai_api_key=openai_api_key
)
# 图像索引
self.image_index = {} # {image_path: description}
self.image_embeddings = []
self.image_paths = []
from sentence_transformers import SentenceTransformer
self.text_encoder = SentenceTransformer('clip-ViT-B-32-multilingual-v1')
def add_text_documents(self, documents: List[Document]):
"""添加文本文档"""
self.text_rag.add_documents(documents)
def add_images(self, image_folder: str):
"""添加图像"""
import openai
openai.api_key = self.openai_api_key
for filename in os.listdir(image_folder):
if not filename.lower().endswith(('.png', '.jpg', '.jpeg')):
continue
image_path = os.path.join(image_folder, filename)
# 使用GPT-4V生成图像描述
with open(image_path, 'rb') as f:
image_data = base64.b64encode(f.read()).decode('utf-8')
response = openai.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "请详细描述这张图片的内容。"},
{
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{image_data}"
}
]
}
],
max_tokens=300
)
description = response.choices[0].message.content
self.image_index[image_path] = description
# 生成embedding
embedding = self.text_encoder.encode(description)
self.image_embeddings.append(embedding)
self.image_paths.append(image_path)
logger.info(f"添加图像: {filename}")
self.image_embeddings = np.array(self.image_embeddings)
def search_images(self, query: str, k: int = 3) -> List[Tuple[str, str, float]]:
"""搜索相关图像"""
query_embedding = self.text_encoder.encode(query)
# 计算相似度
from sklearn.metrics.pairwise import cosine_similarity
similarities = cosine_similarity([query_embedding], self.image_embeddings)[0]
# 排序
top_indices = np.argsort(similarities)[::-1][:k]
results = [
(self.image_paths[i], self.image_index[self.image_paths[i]], similarities[i])
for i in top_indices
]
return results
def query(self, question: str, include_images: bool = True) -> Dict:
"""多模态查询"""
# 文本检索
text_result = self.text_rag.query(question)
result = {
"answer": text_result['answer'],
"text_sources": text_result['sources']
}
# 图像检索
if include_images:
image_results = self.search_images(question, k=3)
result['image_sources'] = [
{
"path": path,
"description": desc,
"score": score
}
for path, desc, score in image_results
]
return result
# 使用示例
mm_rag = MultimodalRAG(openai_api_key="your-api-key")
# 添加文本
docs = DocumentLoader.load_directory("./product_docs")
mm_rag.add_text_documents(docs)
# 添加图像
mm_rag.add_images("./product_images")
# 查询
result = mm_rag.query("这个产品的外观是什么样的?", include_images=True)
print(f"答案: {result['answer']}")
print("\n相关图像:")
for img in result['image_sources']:
print(f" {img['path']}: {img['description'][:100]}... (分数: {img['score']:.3f})")