Skip to content

性能优化

RAG系统的性能直接影响用户体验和系统的实用性。通过合理的性能优化,可以显著提高RAG系统的响应速度、准确性和可靠性。本章节将详细介绍RAG系统的性能优化策略和最佳实践。

1. 性能瓶颈分析

常见瓶颈

  • 检索延迟:向量检索的计算开销
  • 生成延迟:LLM生成回答的时间
  • 数据处理:文档加载和分块的时间
  • 内存使用:内存不足导致的性能下降
  • 网络延迟:API调用的网络开销

分析工具

  • 性能分析器:cProfile、py-spy等
  • 监控系统:Prometheus、Grafana等
  • 日志分析:ELK Stack、Splunk等
  • 负载测试:JMeter、Locust等

2. 检索优化

向量检索优化

  • 索引优化:选择合适的索引类型和参数
  • 批量检索:一次处理多个查询
  • 缓存机制:缓存常见查询的结果
  • 并行处理:使用多线程或异步IO

实现示例

python
# 批量检索示例
from concurrent.futures import ThreadPoolExecutor

class BatchRetriever:
    def __init__(self, retriever, max_workers=4):
        self.retriever = retriever
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
    
    def retrieve_batch(self, queries, k=3):
        tasks = [self.executor.submit(self.retriever.retrieve, query, k=k) for query in queries]
        results = [task.result() for task in tasks]
        return results

# 使用缓存
from functools import lru_cache

class CachedRetriever:
    def __init__(self, retriever, cache_size=1000):
        self.retriever = retriever
        self.cache = {}
        self.cache_size = cache_size
    
    def retrieve(self, query, k=3):
        cache_key = f"{query}_{k}"
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        results = self.retriever.retrieve(query, k=k)
        
        if len(self.cache) < self.cache_size:
            self.cache[cache_key] = results
        
        return results

近似最近邻优化

python
import faiss

class OptimizedVectorStore:
    def __init__(self, dimension, nlist=100):
        self.dimension = dimension
        
        # 使用IVF索引加速大规模检索
        quantizer = faiss.IndexFlatIP(dimension)
        self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
        self.index.nprobe = 10  # 调整搜索的聚类数量
    
    def train(self, vectors):
        """训练索引"""
        self.index.train(vectors)
    
    def add(self, vectors):
        """添加向量"""
        self.index.add(vectors)
    
    def search(self, query, k=5):
        """搜索"""
        return self.index.search(query, k)

3. 生成优化

流式生成

python
class StreamingGenerator:
    def __init__(self, llm):
        self.llm = llm
    
    def generate_stream(self, prompt):
        """流式生成回答"""
        response = self.llm.generate(
            prompt,
            stream=True  # 启用流式输出
        )
        
        for chunk in response:
            yield chunk

# 使用示例
generator = StreamingGenerator(llm)
for chunk in generator.generate_stream(prompt):
    print(chunk, end='', flush=True)

提示优化

python
class PromptOptimizer:
    def __init__(self):
        self.prompt_templates = {}
    
    def optimize_prompt(self, query, context):
        """优化提示以减少token使用"""
        # 压缩上下文
        compressed_context = self.compress_context(context)
        
        # 使用结构化提示
        prompt = f"""基于以下信息回答问题(保持简洁):

上下文:
{compressed_context}

问题:{query}

回答:"""
        
        return prompt
    
    def compress_context(self, context, max_length=2000):
        """压缩上下文"""
        if len(context) <= max_length:
            return context
        
        # 提取关键信息
        sentences = context.split('.')
        compressed = []
        current_length = 0
        
        for sentence in sentences:
            if current_length + len(sentence) <= max_length:
                compressed.append(sentence)
                current_length += len(sentence)
            else:
                break
        
        return '.'.join(compressed) + '...'

4. 缓存策略

多级缓存

python
import redis
import pickle
from functools import wraps

class MultiLevelCache:
    def __init__(self):
        # L1缓存:内存
        self.l1_cache = {}
        # L2缓存:Redis
        self.l2_cache = redis.Redis(host='localhost', port=6379, db=0)
    
    def get(self, key):
        """获取缓存"""
        # 先查L1
        if key in self.l1_cache:
            return self.l1_cache[key]
        
        # 再查L2
        value = self.l2_cache.get(key)
        if value:
            value = pickle.loads(value)
            # 回填L1
            self.l1_cache[key] = value
            return value
        
        return None
    
    def set(self, key, value, l1_ttl=300, l2_ttl=3600):
        """设置缓存"""
        # 设置L1
        self.l1_cache[key] = value
        
        # 设置L2
        self.l2_cache.setex(
            key,
            l2_ttl,
            pickle.dumps(value)
        )

# 装饰器
def cached(cache, ttl=3600):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 生成缓存键
            cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
            
            # 检查缓存
            result = cache.get(cache_key)
            if result is not None:
                return result
            
            # 执行函数
            result = func(*args, **kwargs)
            
            # 缓存结果
            cache.set(cache_key, result, ttl=ttl)
            
            return result
        return wrapper
    return decorator

5. 异步处理

python
import asyncio
import aiohttp

class AsyncRAG:
    def __init__(self):
        self.session = None
    
    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.session.close()
    
    async def retrieve_async(self, query):
        """异步检索"""
        # 并行检索多个源
        tasks = [
            self.search_vector_store(query),
            self.search_knowledge_graph(query),
            self.search_web(query)
        ]
        
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 合并结果
        merged_results = []
        for result in results:
            if not isinstance(result, Exception):
                merged_results.extend(result)
        
        return merged_results
    
    async def generate_async(self, prompt):
        """异步生成"""
        async with self.session.post(
            "https://api.openai.com/v1/chat/completions",
            headers={"Authorization": f"Bearer {api_key}"},
            json={
                "model": "gpt-3.5-turbo",
                "messages": [{"role": "user", "content": prompt}]
            }
        ) as response:
            data = await response.json()
            return data['choices'][0]['message']['content']
    
    async def query(self, question):
        """异步查询"""
        # 异步检索
        context = await self.retrieve_async(question)
        
        # 构建提示
        prompt = self.build_prompt(question, context)
        
        # 异步生成
        answer = await self.generate_async(prompt)
        
        return answer

6. 负载均衡

python
import random

class LoadBalancer:
    def __init__(self, instances):
        self.instances = instances
        self.current_index = 0
    
    def get_instance_round_robin(self):
        """轮询选择实例"""
        instance = self.instances[self.current_index]
        self.current_index = (self.current_index + 1) % len(self.instances)
        return instance
    
    def get_instance_random(self):
        """随机选择实例"""
        return random.choice(self.instances)
    
    def get_instance_least_connections(self, connection_counts):
        """最少连接数选择"""
        min_connections = min(connection_counts.values())
        candidates = [
            inst for inst, count in connection_counts.items()
            if count == min_connections
        ]
        return random.choice(candidates)

7. 资源管理

连接池

python
from contextlib import contextmanager
from queue import Queue

class ConnectionPool:
    def __init__(self, factory, max_size=10):
        self.factory = factory
        self.max_size = max_size
        self.pool = Queue(maxsize=max_size)
        self._fill_pool()
    
    def _fill_pool(self):
        """填充连接池"""
        for _ in range(self.max_size):
            conn = self.factory()
            self.pool.put(conn)
    
    @contextmanager
    def acquire(self):
        """获取连接"""
        conn = self.pool.get()
        try:
            yield conn
        finally:
            self.pool.put(conn)

# 使用示例
def create_db_connection():
    return psycopg2.connect(database="rag_db")

pool = ConnectionPool(create_db_connection, max_size=5)

with pool.acquire() as conn:
    # 使用连接
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM documents")

8. 监控与告警

python
import time
from dataclasses import dataclass
from typing import Dict, List
import statistics

@dataclass
class PerformanceMetrics:
    timestamp: float
    query_latency: float
    retrieval_latency: float
    generation_latency: float
    cache_hit_rate: float

class PerformanceMonitor:
    def __init__(self):
        self.metrics: List[PerformanceMetrics] = []
        self.alert_thresholds = {
            'query_latency': 2.0,  # 2秒
            'error_rate': 0.05     # 5%
        }
    
    def record(self, metrics: PerformanceMetrics):
        """记录性能指标"""
        self.metrics.append(metrics)
        
        # 检查是否需要告警
        self._check_alerts(metrics)
    
    def _check_alerts(self, metrics: PerformanceMetrics):
        """检查告警条件"""
        if metrics.query_latency > self.alert_thresholds['query_latency']:
            self._send_alert(
                f"查询延迟过高: {metrics.query_latency:.2f}s"
            )
    
    def get_statistics(self, window_minutes=5) -> Dict:
        """获取统计信息"""
        cutoff = time.time() - window_minutes * 60
        recent_metrics = [m for m in self.metrics if m.timestamp > cutoff]
        
        if not recent_metrics:
            return {}
        
        latencies = [m.query_latency for m in recent_metrics]
        
        return {
            'count': len(recent_metrics),
            'avg_latency': statistics.mean(latencies),
            'p50_latency': statistics.median(latencies),
            'p95_latency': sorted(latencies)[int(len(latencies) * 0.95)],
            'p99_latency': sorted(latencies)[int(len(latencies) * 0.99)]
        }
    
    def _send_alert(self, message: str):
        """发送告警"""
        # 集成告警系统(如PagerDuty、Slack等)
        print(f"ALERT: {message}")

9. 性能测试

python
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor

class PerformanceTester:
    def __init__(self, rag_system):
        self.rag = rag_system
    
    async def run_load_test(
        self,
        queries: List[str],
        concurrency: int = 10,
        duration_seconds: int = 60
    ):
        """运行负载测试"""
        results = []
        start_time = time.time()
        
        async def worker():
            while time.time() - start_time < duration_seconds:
                query = random.choice(queries)
                start = time.time()
                try:
                    await self.rag.query(query)
                    latency = time.time() - start
                    results.append({'success': True, 'latency': latency})
                except Exception as e:
                    results.append({'success': False, 'error': str(e)})
        
        # 启动工作线程
        tasks = [worker() for _ in range(concurrency)]
        await asyncio.gather(*tasks)
        
        # 计算指标
        successful = [r for r in results if r['success']]
        failed = [r for r in results if not r['success']]
        
        latencies = [r['latency'] for r in successful]
        
        return {
            'total_requests': len(results),
            'successful_requests': len(successful),
            'failed_requests': len(failed),
            'error_rate': len(failed) / len(results),
            'avg_latency': statistics.mean(latencies) if latencies else 0,
            'p95_latency': sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
            'throughput': len(results) / duration_seconds
        }