Appearance
性能优化
RAG系统的性能直接影响用户体验和系统的实用性。通过合理的性能优化,可以显著提高RAG系统的响应速度、准确性和可靠性。本章节将详细介绍RAG系统的性能优化策略和最佳实践。
1. 性能瓶颈分析
常见瓶颈
- 检索延迟:向量检索的计算开销
- 生成延迟:LLM生成回答的时间
- 数据处理:文档加载和分块的时间
- 内存使用:内存不足导致的性能下降
- 网络延迟:API调用的网络开销
分析工具
- 性能分析器:cProfile、py-spy等
- 监控系统:Prometheus、Grafana等
- 日志分析:ELK Stack、Splunk等
- 负载测试:JMeter、Locust等
2. 检索优化
向量检索优化
- 索引优化:选择合适的索引类型和参数
- 批量检索:一次处理多个查询
- 缓存机制:缓存常见查询的结果
- 并行处理:使用多线程或异步IO
实现示例
python
# 批量检索示例
from concurrent.futures import ThreadPoolExecutor
class BatchRetriever:
def __init__(self, retriever, max_workers=4):
self.retriever = retriever
self.executor = ThreadPoolExecutor(max_workers=max_workers)
def retrieve_batch(self, queries, k=3):
tasks = [self.executor.submit(self.retriever.retrieve, query, k=k) for query in queries]
results = [task.result() for task in tasks]
return results
# 使用缓存
from functools import lru_cache
class CachedRetriever:
def __init__(self, retriever, cache_size=1000):
self.retriever = retriever
self.cache = {}
self.cache_size = cache_size
def retrieve(self, query, k=3):
cache_key = f"{query}_{k}"
if cache_key in self.cache:
return self.cache[cache_key]
results = self.retriever.retrieve(query, k=k)
if len(self.cache) < self.cache_size:
self.cache[cache_key] = results
return results近似最近邻优化
python
import faiss
class OptimizedVectorStore:
def __init__(self, dimension, nlist=100):
self.dimension = dimension
# 使用IVF索引加速大规模检索
quantizer = faiss.IndexFlatIP(dimension)
self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
self.index.nprobe = 10 # 调整搜索的聚类数量
def train(self, vectors):
"""训练索引"""
self.index.train(vectors)
def add(self, vectors):
"""添加向量"""
self.index.add(vectors)
def search(self, query, k=5):
"""搜索"""
return self.index.search(query, k)3. 生成优化
流式生成
python
class StreamingGenerator:
def __init__(self, llm):
self.llm = llm
def generate_stream(self, prompt):
"""流式生成回答"""
response = self.llm.generate(
prompt,
stream=True # 启用流式输出
)
for chunk in response:
yield chunk
# 使用示例
generator = StreamingGenerator(llm)
for chunk in generator.generate_stream(prompt):
print(chunk, end='', flush=True)提示优化
python
class PromptOptimizer:
def __init__(self):
self.prompt_templates = {}
def optimize_prompt(self, query, context):
"""优化提示以减少token使用"""
# 压缩上下文
compressed_context = self.compress_context(context)
# 使用结构化提示
prompt = f"""基于以下信息回答问题(保持简洁):
上下文:
{compressed_context}
问题:{query}
回答:"""
return prompt
def compress_context(self, context, max_length=2000):
"""压缩上下文"""
if len(context) <= max_length:
return context
# 提取关键信息
sentences = context.split('.')
compressed = []
current_length = 0
for sentence in sentences:
if current_length + len(sentence) <= max_length:
compressed.append(sentence)
current_length += len(sentence)
else:
break
return '.'.join(compressed) + '...'4. 缓存策略
多级缓存
python
import redis
import pickle
from functools import wraps
class MultiLevelCache:
def __init__(self):
# L1缓存:内存
self.l1_cache = {}
# L2缓存:Redis
self.l2_cache = redis.Redis(host='localhost', port=6379, db=0)
def get(self, key):
"""获取缓存"""
# 先查L1
if key in self.l1_cache:
return self.l1_cache[key]
# 再查L2
value = self.l2_cache.get(key)
if value:
value = pickle.loads(value)
# 回填L1
self.l1_cache[key] = value
return value
return None
def set(self, key, value, l1_ttl=300, l2_ttl=3600):
"""设置缓存"""
# 设置L1
self.l1_cache[key] = value
# 设置L2
self.l2_cache.setex(
key,
l2_ttl,
pickle.dumps(value)
)
# 装饰器
def cached(cache, ttl=3600):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
# 检查缓存
result = cache.get(cache_key)
if result is not None:
return result
# 执行函数
result = func(*args, **kwargs)
# 缓存结果
cache.set(cache_key, result, ttl=ttl)
return result
return wrapper
return decorator5. 异步处理
python
import asyncio
import aiohttp
class AsyncRAG:
def __init__(self):
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close()
async def retrieve_async(self, query):
"""异步检索"""
# 并行检索多个源
tasks = [
self.search_vector_store(query),
self.search_knowledge_graph(query),
self.search_web(query)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 合并结果
merged_results = []
for result in results:
if not isinstance(result, Exception):
merged_results.extend(result)
return merged_results
async def generate_async(self, prompt):
"""异步生成"""
async with self.session.post(
"https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {api_key}"},
json={
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": prompt}]
}
) as response:
data = await response.json()
return data['choices'][0]['message']['content']
async def query(self, question):
"""异步查询"""
# 异步检索
context = await self.retrieve_async(question)
# 构建提示
prompt = self.build_prompt(question, context)
# 异步生成
answer = await self.generate_async(prompt)
return answer6. 负载均衡
python
import random
class LoadBalancer:
def __init__(self, instances):
self.instances = instances
self.current_index = 0
def get_instance_round_robin(self):
"""轮询选择实例"""
instance = self.instances[self.current_index]
self.current_index = (self.current_index + 1) % len(self.instances)
return instance
def get_instance_random(self):
"""随机选择实例"""
return random.choice(self.instances)
def get_instance_least_connections(self, connection_counts):
"""最少连接数选择"""
min_connections = min(connection_counts.values())
candidates = [
inst for inst, count in connection_counts.items()
if count == min_connections
]
return random.choice(candidates)7. 资源管理
连接池
python
from contextlib import contextmanager
from queue import Queue
class ConnectionPool:
def __init__(self, factory, max_size=10):
self.factory = factory
self.max_size = max_size
self.pool = Queue(maxsize=max_size)
self._fill_pool()
def _fill_pool(self):
"""填充连接池"""
for _ in range(self.max_size):
conn = self.factory()
self.pool.put(conn)
@contextmanager
def acquire(self):
"""获取连接"""
conn = self.pool.get()
try:
yield conn
finally:
self.pool.put(conn)
# 使用示例
def create_db_connection():
return psycopg2.connect(database="rag_db")
pool = ConnectionPool(create_db_connection, max_size=5)
with pool.acquire() as conn:
# 使用连接
cursor = conn.cursor()
cursor.execute("SELECT * FROM documents")8. 监控与告警
python
import time
from dataclasses import dataclass
from typing import Dict, List
import statistics
@dataclass
class PerformanceMetrics:
timestamp: float
query_latency: float
retrieval_latency: float
generation_latency: float
cache_hit_rate: float
class PerformanceMonitor:
def __init__(self):
self.metrics: List[PerformanceMetrics] = []
self.alert_thresholds = {
'query_latency': 2.0, # 2秒
'error_rate': 0.05 # 5%
}
def record(self, metrics: PerformanceMetrics):
"""记录性能指标"""
self.metrics.append(metrics)
# 检查是否需要告警
self._check_alerts(metrics)
def _check_alerts(self, metrics: PerformanceMetrics):
"""检查告警条件"""
if metrics.query_latency > self.alert_thresholds['query_latency']:
self._send_alert(
f"查询延迟过高: {metrics.query_latency:.2f}s"
)
def get_statistics(self, window_minutes=5) -> Dict:
"""获取统计信息"""
cutoff = time.time() - window_minutes * 60
recent_metrics = [m for m in self.metrics if m.timestamp > cutoff]
if not recent_metrics:
return {}
latencies = [m.query_latency for m in recent_metrics]
return {
'count': len(recent_metrics),
'avg_latency': statistics.mean(latencies),
'p50_latency': statistics.median(latencies),
'p95_latency': sorted(latencies)[int(len(latencies) * 0.95)],
'p99_latency': sorted(latencies)[int(len(latencies) * 0.99)]
}
def _send_alert(self, message: str):
"""发送告警"""
# 集成告警系统(如PagerDuty、Slack等)
print(f"ALERT: {message}")9. 性能测试
python
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
class PerformanceTester:
def __init__(self, rag_system):
self.rag = rag_system
async def run_load_test(
self,
queries: List[str],
concurrency: int = 10,
duration_seconds: int = 60
):
"""运行负载测试"""
results = []
start_time = time.time()
async def worker():
while time.time() - start_time < duration_seconds:
query = random.choice(queries)
start = time.time()
try:
await self.rag.query(query)
latency = time.time() - start
results.append({'success': True, 'latency': latency})
except Exception as e:
results.append({'success': False, 'error': str(e)})
# 启动工作线程
tasks = [worker() for _ in range(concurrency)]
await asyncio.gather(*tasks)
# 计算指标
successful = [r for r in results if r['success']]
failed = [r for r in results if not r['success']]
latencies = [r['latency'] for r in successful]
return {
'total_requests': len(results),
'successful_requests': len(successful),
'failed_requests': len(failed),
'error_rate': len(failed) / len(results),
'avg_latency': statistics.mean(latencies) if latencies else 0,
'p95_latency': sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
'throughput': len(results) / duration_seconds
}