Skip to content

自定义RAG

虽然LangChain等框架提供了便捷的RAG实现,但在某些场景下,我们可能需要构建自定义的RAG系统,以满足特定需求。本章节将详细介绍如何构建自定义RAG系统,包括核心组件的实现和集成。

1. 自定义RAG系统架构

核心组件

  • 数据处理模块:负责文档加载、清洗和分块
  • 嵌入模块:负责将文本转换为向量
  • 检索模块:负责从向量数据库中检索相关文档
  • 生成模块:负责基于检索结果生成回答
  • 评估模块:负责评估系统性能

系统架构图

┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│ 数据处理   │────>│  嵌入模块   │────>│ 向量存储   │
└─────────────┘     └─────────────┘     └─────────────┘


┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  用户查询   │────>│  检索模块   │<────│  向量检索  │
└─────────────┘     └─────────────┘     └─────────────┘


┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  评估模块   │<────│  生成模块   │<────│  上下文构建 │
└─────────────┘     └─────────────┘     └─────────────┘

2. 数据处理模块

文档加载

python
import os
from abc import ABC, abstractmethod

class DocumentLoader(ABC):
    @abstractmethod
    def load(self, path):
        pass

class TextLoader(DocumentLoader):
    def load(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            content = f.read()
        return [{"content": content, "metadata": {"source": path}}]

class PDFLoader(DocumentLoader):
    def load(self, path):
        # 使用PyPDF2或其他库加载PDF
        from PyPDF2 import PdfReader
        reader = PdfReader(path)
        content = ""
        for page in reader.pages:
            content += page.extract_text()
        return [{"content": content, "metadata": {"source": path}}]

class DirectoryLoader:
    def __init__(self, directory, loader_map):
        self.directory = directory
        self.loader_map = loader_map  # {'.txt': TextLoader, '.pdf': PDFLoader}
    
    def load(self):
        documents = []
        for root, dirs, files in os.walk(self.directory):
            for file in files:
                ext = os.path.splitext(file)[1]
                if ext in self.loader_map:
                    loader = self.loader_map[ext]()
                    path = os.path.join(root, file)
                    documents.extend(loader.load(path))
        return documents

文本分割

python
class TextSplitter:
    def __init__(self, chunk_size=1000, chunk_overlap=200):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
    
    def split(self, text):
        """基于字符的分割"""
        chunks = []
        start = 0
        while start < len(text):
            end = start + self.chunk_size
            chunk = text[start:end]
            chunks.append(chunk)
            start = end - self.chunk_overlap
        return chunks
    
    def split_documents(self, documents):
        """分割文档列表"""
        chunks = []
        for doc in documents:
            text_chunks = self.split(doc["content"])
            for i, chunk in enumerate(text_chunks):
                chunks.append({
                    "content": chunk,
                    "metadata": {**doc["metadata"], "chunk_index": i}
                })
        return chunks

3. 嵌入模块

python
import numpy as np
from sentence_transformers import SentenceTransformer

class EmbeddingModel:
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
        self.dimension = self.model.get_sentence_embedding_dimension()
    
    def embed(self, texts):
        """将文本转换为向量"""
        if isinstance(texts, str):
            texts = [texts]
        embeddings = self.model.encode(texts, convert_to_numpy=True)
        return embeddings
    
    def embed_query(self, query):
        """嵌入单个查询"""
        return self.embed([query])[0]

# 使用示例
embedder = EmbeddingModel()
vectors = embedder.embed(["这是第一个文档", "这是第二个文档"])
print(f"向量维度: {vectors.shape}")

4. 检索模块

python
import faiss

class VectorStore:
    def __init__(self, dimension):
        self.dimension = dimension
        self.index = faiss.IndexFlatIP(dimension)  # 内积(余弦相似度)
        self.documents = []
    
    def add(self, documents, embeddings):
        """添加文档和向量"""
        self.documents.extend(documents)
        embeddings = np.array(embeddings).astype('float32')
        self.index.add(embeddings)
    
    def search(self, query_embedding, k=5):
        """搜索最相似的文档"""
        query_embedding = np.array([query_embedding]).astype('float32')
        distances, indices = self.index.search(query_embedding, k)
        
        results = []
        for i, idx in enumerate(indices[0]):
            if idx < len(self.documents):
                results.append({
                    "document": self.documents[idx],
                    "score": float(distances[0][i])
                })
        return results

5. 生成模块

python
import openai

class Generator:
    def __init__(self, api_key=None, model="gpt-3.5-turbo"):
        if api_key:
            openai.api_key = api_key
        self.model = model
    
    def generate(self, query, context_documents):
        """基于上下文生成回答"""
        # 构建上下文
        context = "\n\n".join([
            f"[文档 {i+1}] {doc['document']['content'][:500]}"
            for i, doc in enumerate(context_documents)
        ])
        
        # 构建提示
        prompt = f"""基于以下上下文回答问题。如果无法从上下文中找到答案,请说"我不知道"。

上下文:
{context}

问题:{query}

请提供详细且准确的回答:"""
        
        # 调用LLM
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "你是一个有帮助的助手。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0
        )
        
        return response.choices[0].message.content

6. 集成RAG系统

python
class CustomRAG:
    def __init__(self, embedding_model=None, generator=None):
        self.embedding_model = embedding_model or EmbeddingModel()
        self.vector_store = None
        self.generator = generator or Generator()
    
    def load_documents(self, documents):
        """加载并索引文档"""
        # 分割文档
        splitter = TextSplitter(chunk_size=1000, chunk_overlap=200)
        chunks = splitter.split_documents(documents)
        
        # 生成嵌入
        texts = [chunk["content"] for chunk in chunks]
        embeddings = self.embedding_model.embed(texts)
        
        # 创建向量存储
        self.vector_store = VectorStore(self.embedding_model.dimension)
        self.vector_store.add(chunks, embeddings)
        
        print(f"已索引 {len(chunks)} 个文档块")
    
    def query(self, question, k=3):
        """执行RAG查询"""
        if not self.vector_store:
            raise ValueError("请先加载文档")
        
        # 嵌入查询
        query_embedding = self.embedding_model.embed_query(question)
        
        # 检索相关文档
        retrieved_docs = self.vector_store.search(query_embedding, k=k)
        
        # 生成回答
        answer = self.generator.generate(question, retrieved_docs)
        
        return {
            "answer": answer,
            "sources": retrieved_docs
        }

# 完整使用示例
def main():
    # 初始化RAG系统
    rag = CustomRAG()
    
    # 加载文档
    loader = DirectoryLoader(
        "./docs",
        {".txt": TextLoader, ".pdf": PDFLoader}
    )
    documents = loader.load()
    rag.load_documents(documents)
    
    # 查询
    result = rag.query("什么是RAG技术?")
    print("回答:", result["answer"])
    print("\n来源:")
    for source in result["sources"]:
        print(f"- {source['document']['metadata']['source']} (得分: {source['score']:.4f})")

if __name__ == "__main__":
    main()

7. 扩展功能

混合检索

python
class HybridRetriever:
    def __init__(self, vector_store, sparse_retriever, vector_weight=0.7):
        self.vector_store = vector_store
        self.sparse_retriever = sparse_retriever
        self.vector_weight = vector_weight
    
    def search(self, query, query_embedding, k=5):
        # 密集检索
        dense_results = self.vector_store.search(query_embedding, k=k*2)
        
        # 稀疏检索
        sparse_results = self.sparse_retriever.search(query, k=k*2)
        
        # 融合结果
        fused_results = self._fuse_results(dense_results, sparse_results, k)
        return fused_results
    
    def _fuse_results(self, dense_results, sparse_results, k):
        """融合两种检索结果"""
        scores = {}
        
        # 添加密集检索分数
        for i, result in enumerate(dense_results):
            doc_id = result["document"]["metadata"].get("chunk_index", i)
            scores[doc_id] = scores.get(doc_id, 0) + result["score"] * self.vector_weight
        
        # 添加稀疏检索分数
        for i, result in enumerate(sparse_results):
            doc_id = result["document"]["metadata"].get("chunk_index", i)
            scores[doc_id] = scores.get(doc_id, 0) + result["score"] * (1 - self.vector_weight)
        
        # 排序并返回
        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_results[:k]

缓存机制

python
from functools import lru_cache
import hashlib

class CachedRAG(CustomRAG):
    def __init__(self, *args, cache_size=1000, **kwargs):
        super().__init__(*args, **kwargs)
        self.cache = {}
        self.cache_size = cache_size
    
    def query(self, question, k=3):
        # 生成缓存键
        cache_key = hashlib.md5(f"{question}_{k}".encode()).hexdigest()
        
        # 检查缓存
        if cache_key in self.cache:
            print("返回缓存结果")
            return self.cache[cache_key]
        
        # 执行查询
        result = super().query(question, k)
        
        # 缓存结果
        if len(self.cache) < self.cache_size:
            self.cache[cache_key] = result
        
        return result