Skip to content

文档问答

文档问答是RAG技术的经典应用场景,通过构建基于RAG的文档问答系统,可以实现对文档内容的智能理解和问答。本章节将详细介绍如何构建文档问答系统。

1. 文档问答概述

核心功能

  • 文档理解:理解文档的内容和结构
  • 智能问答:回答关于文档内容的问题
  • 信息检索:快速检索文档中的相关信息
  • 多文档处理:处理多个文档的问答
  • 上下文理解:理解问题的上下文

应用场景

  • 学术研究:辅助研究人员快速获取文献信息
  • 法律文档:辅助律师和法律工作者理解法律文档
  • 医疗文档:辅助医生和患者理解医疗记录
  • 企业文档:辅助员工理解企业政策和流程
  • 技术文档:辅助开发者理解技术文档

2. 系统架构

整体架构

┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  用户界面   │────>│  文档处理   │────>│  RAG系统    │
└─────────────┘     └─────────────┘     └─────────────┘


┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  评估系统   │<────│  分析系统   │<────│  向量存储   │
└─────────────┘     └─────────────┘     └─────────────┘

技术栈选择

  • 前端:React、Vue、Angular等
  • 后端:Python、Node.js等
  • RAG框架:LangChain、LlamaIndex等
  • 向量数据库:Pinecone、Weaviate、Chroma等
  • LLM:OpenAI GPT、Claude、本地LLM等
  • 文档处理:PyPDF、python-docx等

3. 文档处理

文档类型

  • PDF文档:学术论文、报告、手册等
  • Word文档:合同、协议、政策文件等
  • 文本文件:日志、配置文件、代码等
  • 网页:在线文档、博客文章等

文档解析

python
from langchain.document_loaders import (
    PyPDFLoader,
    Docx2txtLoader,
    TextLoader,
    UnstructuredHTMLLoader
)

class DocumentParser:
    def __init__(self):
        self.loaders = {
            '.pdf': PyPDFLoader,
            '.docx': Docx2txtLoader,
            '.txt': TextLoader,
            '.html': UnstructuredHTMLLoader,
            '.md': TextLoader
        }
    
    def parse(self, file_path):
        """解析文档"""
        ext = file_path.split('.')[-1].lower()
        ext = '.' + ext
        
        if ext not in self.loaders:
            raise ValueError(f"不支持的文件格式: {ext}")
        
        loader = self.loaders[ext](file_path)
        documents = loader.load()
        
        # 添加元数据
        for doc in documents:
            doc.metadata.update({
                'source': file_path,
                'filename': file_path.split('/')[-1],
                'file_type': ext
            })
        
        return documents
    
    def parse_batch(self, file_paths):
        """批量解析文档"""
        all_documents = []
        for file_path in file_paths:
            try:
                documents = self.parse(file_path)
                all_documents.extend(documents)
            except Exception as e:
                print(f"解析文件失败 {file_path}: {e}")
        
        return all_documents

文档分块策略

python
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    CharacterTextSplitter,
    MarkdownHeaderTextSplitter
)

class DocumentChunker:
    def __init__(self):
        self.splitters = {
            'default': RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            ),
            'code': RecursiveCharacterTextSplitter(
                chunk_size=1500,
                chunk_overlap=300,
                separators=["\nclass ", "\ndef ", "\n\n", "\n", " "]
            ),
            'markdown': MarkdownHeaderTextSplitter(
                headers_to_split_on=["#", "##", "###"]
            )
        }
    
    def chunk(self, documents, strategy='default'):
        """分块文档"""
        splitter = self.splitters.get(strategy, self.splitters['default'])
        
        chunks = []
        for doc in documents:
            # 根据文档类型选择策略
            if doc.metadata.get('file_type') == '.md':
                doc_chunks = self.splitters['markdown'].split_text(doc.page_content)
            elif self.is_code_file(doc.metadata.get('source', '')):
                doc_chunks = self.splitters['code'].split_documents([doc])
            else:
                doc_chunks = splitter.split_documents([doc])
            
            # 保留元数据
            for chunk in doc_chunks:
                chunk.metadata.update(doc.metadata)
            
            chunks.extend(doc_chunks)
        
        return chunks
    
    def is_code_file(self, file_path):
        """判断是否为代码文件"""
        code_extensions = ['.py', '.js', '.java', '.cpp', '.c', '.go', '.rs']
        return any(file_path.endswith(ext) for ext in code_extensions)

4. 文档问答系统实现

python
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

class DocumentQASystem:
    def __init__(self, persist_directory="./doc_qa_db"):
        # 初始化嵌入模型
        self.embeddings = OpenAIEmbeddings()
        
        # 初始化向量存储
        self.vectorstore = Chroma(
            persist_directory=persist_directory,
            embedding_function=self.embeddings
        )
        
        # 自定义提示模板
        self.prompt_template = PromptTemplate(
            template="""基于以下文档内容回答问题。如果无法从文档中找到答案,请明确说明"根据提供的文档,我无法找到相关信息"。

文档内容:
{context}

问题:{question}

请提供准确、详细的回答,并尽可能引用文档中的相关内容:""",
            input_variables=["context", "question"]
        )
        
        # 创建RAG链
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=OpenAI(temperature=0),
            chain_type="stuff",
            retriever=self.vectorstore.as_retriever(
                search_kwargs={"k": 5}
            ),
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.prompt_template}
        )
        
        # 文档解析器
        self.parser = DocumentParser()
        self.chunker = DocumentChunker()
    
    def upload_document(self, file_path):
        """上传并索引文档"""
        # 解析文档
        documents = self.parser.parse(file_path)
        
        # 分块
        chunks = self.chunker.chunk(documents)
        
        # 添加到向量存储
        self.vectorstore.add_documents(chunks)
        self.vectorstore.persist()
        
        return len(chunks)
    
    def ask(self, question):
        """问答"""
        result = self.qa_chain({"query": question})
        
        return {
            "answer": result["result"],
            "sources": [
                {
                    "content": doc.page_content[:300],
                    "source": doc.metadata.get("source"),
                    "page": doc.metadata.get("page")
                }
                for doc in result.get("source_documents", [])
            ]
        }
    
    def search(self, query, k=5):
        """搜索相关段落"""
        results = self.vectorstore.similarity_search(query, k=k)
        
        return [
            {
                "content": doc.page_content,
                "source": doc.metadata.get("source"),
                "score": score
            }
            for doc, score in [(r, None) for r in results]
        ]

5. 多文档问答

python
class MultiDocumentQA:
    def __init__(self):
        self.doc_qa = DocumentQASystem()
        self.document_metadata = {}
    
    def add_document_collection(self, name, file_paths):
        """添加文档集合"""
        total_chunks = 0
        for file_path in file_paths:
            chunks = self.doc_qa.upload_document(file_path)
            total_chunks += chunks
            
            # 记录元数据
            self.document_metadata[file_path] = {
                'collection': name,
                'chunks': chunks,
                'uploaded_at': datetime.now()
            }
        
        return total_chunks
    
    def query_collection(self, collection_name, question):
        """查询特定文档集合"""
        # 过滤特定集合的文档
        # 这里简化处理,实际应该使用元数据过滤
        return self.doc_qa.ask(question)
    
    def compare_documents(self, doc_paths, question):
        """对比多个文档的内容"""
        answers = []
        for doc_path in doc_paths:
            # 为每个文档单独查询
            answer = self.query_single_document(doc_path, question)
            answers.append({
                'document': doc_path,
                'answer': answer
            })
        
        # 生成对比总结
        comparison = self.generate_comparison(answers, question)
        return comparison
    
    def query_single_document(self, doc_path, question):
        """查询单个文档"""
        # 使用元数据过滤
        retriever = self.doc_qa.vectorstore.as_retriever(
            search_kwargs={
                "k": 5,
                "filter": {"source": doc_path}
            }
        )
        
        qa_chain = RetrievalQA.from_chain_type(
            llm=OpenAI(temperature=0),
            chain_type="stuff",
            retriever=retriever
        )
        
        result = qa_chain({"query": question})
        return result["result"]
    
    def generate_comparison(self, answers, question):
        """生成对比总结"""
        comparison_prompt = f"""基于以下不同文档对同一问题的回答,生成一个对比总结:

问题:{question}

各文档回答:
"""
        for ans in answers:
            comparison_prompt += f"\n文档 {ans['document']}:\n{ans['answer']}\n"
        
        comparison_prompt += "\n请总结各文档观点的异同:"
        
        return self.doc_qa.qa_chain.llm.generate(comparison_prompt)

6. 高级功能

引用溯源

python
class CitationTracker:
    def __init__(self, qa_system):
        self.qa = qa_system
    
    def answer_with_citations(self, question):
        """带引用的回答"""
        result = self.qa.qa_chain({"query": question})
        
        answer = result["result"]
        sources = result.get("source_documents", [])
        
        # 为回答添加引用标记
        cited_answer = self.add_citations(answer, sources)
        
        return {
            "answer": cited_answer,
            "citations": [
                {
                    "id": i+1,
                    "source": doc.metadata.get("source"),
                    "page": doc.metadata.get("page"),
                    "content": doc.page_content[:200]
                }
                for i, doc in enumerate(sources)
            ]
        }
    
    def add_citations(self, answer, sources):
        """在回答中添加引用标记"""
        # 简化处理:在回答末尾添加引用列表
        cited_answer = answer + "\n\n参考来源:\n"
        for i, source in enumerate(sources):
            cited_answer += f"[{i+1}] {source.metadata.get('source')}"
            if source.metadata.get('page'):
                cited_answer += f", 第{source.metadata.get('page')}页"
            cited_answer += "\n"
        
        return cited_answer

文档摘要

python
class DocumentSummarizer:
    def __init__(self, qa_system):
        self.qa = qa_system
    
    def summarize(self, doc_path, max_length=500):
        """生成文档摘要"""
        # 检索文档的关键段落
        key_sections = self.extract_key_sections(doc_path)
        
        # 生成摘要
        summary_prompt = f"""基于以下内容生成文档摘要:

{key_sections}

请生成一个简洁的摘要(不超过{max_length}字):"""
        
        summary = self.qa.qa_chain.llm.generate(summary_prompt)
        return summary
    
    def extract_key_sections(self, doc_path):
        """提取关键段落"""
        # 检索文档开头的概述和结尾的结论
        sections = []
        
        # 这里简化处理,实际应该根据文档结构提取
        results = self.qa.vectorstore.similarity_search(
            "概述 介绍 总结 结论",
            k=10,
            filter={"source": doc_path}
        )
        
        return "\n\n".join([r.page_content for r in results])
    
    def generate_toc(self, doc_path):
        """生成文档目录"""
        # 检索文档中的标题
        toc_prompt = f"""基于文档内容生成目录结构:

文档路径:{doc_path}

请识别文档的主要章节和子章节,生成目录:"""
        
        toc = self.qa.qa_chain.llm.generate(toc_prompt)
        return toc

7. 评估与优化

python
class QAEvaluator:
    def __init__(self, qa_system):
        self.qa = qa_system
    
    def evaluate_answer(self, question, ground_truth):
        """评估回答质量"""
        result = self.qa.ask(question)
        predicted_answer = result["answer"]
        
        # 计算ROUGE分数
        from rouge import Rouge
        rouge = Rouge()
        scores = rouge.get_scores(predicted_answer, ground_truth)[0]
        
        # 评估引用准确性
        citation_accuracy = self.evaluate_citations(
            result.get("sources", []),
            ground_truth
        )
        
        return {
            "rouge-1": scores["rouge-1"]["f"],
            "rouge-2": scores["rouge-2"]["f"],
            "rouge-l": scores["rouge-l"]["f"],
            "citation_accuracy": citation_accuracy
        }
    
    def evaluate_citations(self, sources, ground_truth):
        """评估引用准确性"""
        # 检查引用的文档是否包含正确答案
        # 简化处理
        return 1.0 if sources else 0.0
    
    def find_failure_cases(self, test_set):
        """找出失败案例"""
        failures = []
        
        for test in test_set:
            result = self.qa.ask(test["question"])
            
            # 检查回答是否包含正确答案的关键信息
            if not self.contains_answer(result["answer"], test["ground_truth"]):
                failures.append({
                    "question": test["question"],
                    "predicted": result["answer"],
                    "ground_truth": test["ground_truth"]
                })
        
        return failures
    
    def contains_answer(self, predicted, ground_truth):
        """检查预测是否包含正确答案"""
        # 使用语义相似度或关键词匹配
        # 简化处理
        return any(kw in predicted for kw in ground_truth.split())