Appearance
安全最佳实践
安全性是RAG系统部署和维护的重要考虑因素。通过实施安全最佳实践,可以保护系统免受各种安全威胁,确保数据和用户信息的安全。本章节将详细介绍RAG系统的安全策略、措施和最佳实践。
1. 安全威胁分析
常见安全威胁
- 数据泄露:敏感数据被未授权访问或泄露
- API滥用:API被恶意使用或滥用
- 注入攻击:SQL注入、命令注入等
- 跨站脚本:XSS攻击
- 跨站请求伪造:CSRF攻击
- 拒绝服务:DoS/DDoS攻击
- 权限提升:未授权用户获取更高权限
- 供应链攻击:依赖库或组件被攻击
威胁评估
- 风险识别:识别系统面临的安全风险
- 风险分析:分析风险的可能性和影响
- 风险评估:评估风险的严重程度
- 风险缓解:制定风险缓解策略
2. 数据安全
数据加密
- 传输加密:使用HTTPS加密传输数据
- 存储加密:加密存储敏感数据
- 端到端加密:确保数据在整个传输过程中加密
数据脱敏
- 敏感信息脱敏:对敏感信息进行脱敏处理
- 数据访问控制:限制对敏感数据的访问
- 数据生命周期管理:管理数据的创建、使用和销毁
数据备份与恢复
- 定期备份:定期备份数据
- 异地备份:在不同地点备份数据
- 备份验证:验证备份的有效性
- 恢复测试:定期测试数据恢复流程
3. 访问控制
身份认证
- 多因素认证:使用多因素认证提高安全性
python
# auth.py
from functools import wraps
from flask import request, jsonify, current_app
import jwt
import bcrypt
from datetime import datetime, timedelta
class AuthManager:
def __init__(self, secret_key):
self.secret_key = secret_key
self.users = {} # 实际应该使用数据库
def hash_password(self, password):
"""密码哈希"""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt())
def verify_password(self, password, hashed):
"""验证密码"""
return bcrypt.checkpw(password.encode(), hashed)
def generate_token(self, user_id, expires_in=3600):
"""生成JWT令牌"""
payload = {
'user_id': user_id,
'exp': datetime.utcnow() + timedelta(seconds=expires_in),
'iat': datetime.utcnow()
}
return jwt.encode(payload, self.secret_key, algorithm='HS256')
def verify_token(self, token):
"""验证JWT令牌"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=['HS256'])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def login_required(self, f):
"""登录验证装饰器"""
@wraps(f)
def decorated_function(*args, **kwargs):
token = None
if 'Authorization' in request.headers:
auth_header = request.headers['Authorization']
try:
token = auth_header.split(" ")[1]
except IndexError:
return jsonify({'error': 'Invalid token format'}), 401
if not token:
return jsonify({'error': 'Token is missing'}), 401
payload = self.verify_token(token)
if not payload:
return jsonify({'error': 'Token is invalid or expired'}), 401
request.user_id = payload['user_id']
return f(*args, **kwargs)
return decorated_function
# 使用示例
auth = AuthManager(secret_key="your-secret-key")
@app.route('/api/login', methods=['POST'])
def login():
data = request.json
username = data.get('username')
password = data.get('password')
# 验证用户凭据
user = auth.users.get(username)
if user and auth.verify_password(password, user['password']):
token = auth.generate_token(user['id'])
return jsonify({'token': token})
return jsonify({'error': 'Invalid credentials'}), 401
@app.route('/api/protected')
@auth.login_required
def protected():
return jsonify({'message': 'Access granted', 'user_id': request.user_id})权限控制
python
# rbac.py
from functools import wraps
from flask import request, jsonify
class RBAC:
def __init__(self):
self.roles = {
'admin': ['read', 'write', 'delete', 'manage'],
'user': ['read', 'write'],
'guest': ['read']
}
self.user_roles = {}
def assign_role(self, user_id, role):
"""分配角色"""
if role not in self.roles:
raise ValueError(f"Invalid role: {role}")
self.user_roles[user_id] = role
def check_permission(self, user_id, permission):
"""检查权限"""
role = self.user_roles.get(user_id)
if not role:
return False
return permission in self.roles.get(role, [])
def require_permission(self, permission):
"""权限检查装饰器"""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
user_id = getattr(request, 'user_id', None)
if not user_id:
return jsonify({'error': 'Authentication required'}), 401
if not self.check_permission(user_id, permission):
return jsonify({'error': 'Permission denied'}), 403
return f(*args, **kwargs)
return wrapper
return decorator
# 使用示例
rbac = RBAC()
@app.route('/api/admin-only')
@auth.login_required
@rbac.require_permission('manage')
def admin_only():
return jsonify({'message': 'Admin access granted'})4. API安全
速率限制
python
# rate_limiter.py
import time
from functools import wraps
from flask import request, jsonify
class RateLimiter:
def __init__(self, max_requests=100, window=3600):
self.max_requests = max_requests
self.window = window
self.requests = {}
def is_allowed(self, key):
"""检查是否允许请求"""
now = time.time()
if key not in self.requests:
self.requests[key] = []
# 清理过期请求
self.requests[key] = [
req_time for req_time in self.requests[key]
if now - req_time < self.window
]
# 检查是否超过限制
if len(self.requests[key]) >= self.max_requests:
return False
# 记录请求
self.requests[key].append(now)
return True
def limit(self, key_func=None):
"""速率限制装饰器"""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
# 获取限制键
if key_func:
key = key_func()
else:
key = request.remote_addr
if not self.is_allowed(key):
return jsonify({
'error': 'Rate limit exceeded',
'retry_after': self.window
}), 429
return f(*args, **kwargs)
return wrapper
return decorator
# 使用示例
limiter = RateLimiter(max_requests=100, window=3600) # 每小时100请求
@app.route('/api/query')
@limiter.limit(key_func=lambda: request.headers.get('X-API-Key', request.remote_addr))
def query():
# 处理查询
pass输入验证
python
# input_validation.py
import re
from flask import request, jsonify
from marshmallow import Schema, fields, validate, ValidationError
class QuerySchema(Schema):
"""查询请求验证模式"""
question = fields.String(
required=True,
validate=validate.Length(min=1, max=1000),
metadata={'description': '查询问题'}
)
context = fields.String(
validate=validate.Length(max=5000),
load_default='',
metadata={'description': '上下文信息'}
)
max_results = fields.Integer(
validate=validate.Range(min=1, max=10),
load_default=5,
metadata={'description': '最大结果数'}
)
class InputValidator:
@staticmethod
def sanitize_input(text):
"""清理输入"""
# 移除潜在的恶意字符
text = re.sub(r'[<>&\"\']', '', text)
return text.strip()
@staticmethod
def validate_query(data):
"""验证查询请求"""
schema = QuerySchema()
try:
result = schema.load(data)
# 清理输入
result['question'] = InputValidator.sanitize_input(result['question'])
result['context'] = InputValidator.sanitize_input(result['context'])
return result, None
except ValidationError as err:
return None, err.messages
# 使用示例
@app.route('/api/query', methods=['POST'])
def query():
data = request.json
validated_data, errors = InputValidator.validate_query(data)
if errors:
return jsonify({'error': 'Invalid input', 'details': errors}), 400
# 处理验证后的数据
result = rag_system.query(validated_data['question'])
return jsonify(result)CORS配置
python
# cors_config.py
from flask_cors import CORS
# 基本CORS配置
cors = CORS(resources={
r"/api/*": {
"origins": ["https://yourdomain.com", "https://app.yourdomain.com"],
"methods": ["GET", "POST", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization"],
"supports_credentials": True,
"max_age": 3600
}
})
# 应用到应用
cors.init_app(app)5. 模型安全
提示注入防护
python
# prompt_security.py
import re
class PromptSecurity:
def __init__(self):
self.forbidden_patterns = [
r'ignore\s+previous\s+instructions',
r'disregard\s+all\s+prior',
r'system\s+prompt',
r'you\s+are\s+now',
r'forget\s+everything',
r'act\s+as\s+if',
]
def check_prompt(self, prompt):
"""检查提示是否包含注入攻击"""
prompt_lower = prompt.lower()
for pattern in self.forbidden_patterns:
if re.search(pattern, prompt_lower):
return False, f"Detected potential injection: {pattern}"
return True, None
def sanitize_prompt(self, prompt):
"""清理提示"""
# 移除控制字符
prompt = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f]', '', prompt)
# 限制长度
max_length = 4000
if len(prompt) > max_length:
prompt = prompt[:max_length]
return prompt
# 使用示例
security = PromptSecurity()
@app.route('/api/query', methods=['POST'])
def query():
data = request.json
question = data.get('question', '')
# 安全检查
is_safe, error = security.check_prompt(question)
if not is_safe:
return jsonify({'error': 'Invalid input', 'message': error}), 400
# 清理提示
clean_question = security.sanitize_prompt(question)
# 处理查询
result = rag_system.query(clean_question)
return jsonify(result)输出过滤
python
# output_filter.py
import re
class OutputFilter:
def __init__(self):
self.sensitive_patterns = [
r'\b\d{16}\b', # 信用卡号
r'\b\d{3}-\d{2}-\d{4}\b', # SSN
r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', # 邮箱
r'\b\d{11}\b', # 手机号
]
def filter_output(self, text):
"""过滤敏感信息"""
filtered = text
for pattern in self.sensitive_patterns:
filtered = re.sub(pattern, '[REDACTED]', filtered)
return filtered
def check_harmful_content(self, text):
"""检查有害内容"""
# 这里应该使用更复杂的内容审核API
harmful_keywords = ['暴力', '色情', '歧视', '仇恨']
for keyword in harmful_keywords:
if keyword in text.lower():
return False, f"Harmful content detected: {keyword}"
return True, None
# 使用示例
output_filter = OutputFilter()
@app.route('/api/query', methods=['POST'])
def query():
# ... 处理查询 ...
result = rag_system.query(question)
# 过滤输出
filtered_answer = output_filter.filter_output(result['answer'])
# 检查有害内容
is_safe, error = output_filter.check_harmful_content(filtered_answer)
if not is_safe:
return jsonify({'error': 'Content filtered', 'message': error}), 400
return jsonify({'answer': filtered_answer})6. 日志与审计
python
# audit_logging.py
import logging
import json
from datetime import datetime
from functools import wraps
class AuditLogger:
def __init__(self):
self.logger = logging.getLogger('audit')
handler = logging.FileHandler('logs/audit.log')
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(message)s'
))
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def log_access(self, user_id, action, resource, result):
"""记录访问日志"""
log_entry = {
'timestamp': datetime.utcnow().isoformat(),
'user_id': user_id,
'action': action,
'resource': resource,
'result': result,
'ip_address': request.remote_addr,
'user_agent': request.user_agent.string
}
self.logger.info(json.dumps(log_entry, ensure_ascii=False))
def audit_trail(self, action):
"""审计追踪装饰器"""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
user_id = getattr(request, 'user_id', 'anonymous')
resource = request.path
try:
result = f(*args, **kwargs)
self.log_access(user_id, action, resource, 'success')
return result
except Exception as e:
self.log_access(user_id, action, resource, f'error: {str(e)}')
raise
return wrapper
return decorator
# 使用示例
audit_logger = AuditLogger()
@app.route('/api/sensitive-data')
@auth.login_required
@audit_logger.audit_trail('access_sensitive_data')
def sensitive_data():
# 处理敏感数据
pass7. 安全测试
python
# security_tests.py
import unittest
from app import app
class SecurityTests(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def test_sql_injection(self):
"""测试SQL注入防护"""
malicious_input = "'; DROP TABLE users; --"
response = self.app.post('/api/query', json={
'question': malicious_input
})
# 应该正常处理或返回错误,而不是执行恶意代码
self.assertIn(response.status_code, [200, 400])
def test_xss_protection(self):
"""测试XSS防护"""
xss_payload = "<script>alert('xss')</script>"
response = self.app.post('/api/query', json={
'question': xss_payload
})
# 响应中不应该包含未过滤的脚本
self.assertNotIn('<script>', response.data.decode())
def test_rate_limiting(self):
"""测试速率限制"""
# 发送大量请求
for _ in range(110): # 超过限制
response = self.app.get('/api/query')
# 应该被限制
self.assertEqual(response.status_code, 429)
def test_authentication_required(self):
"""测试认证要求"""
response = self.app.get('/api/protected')
self.assertEqual(response.status_code, 401)
def test_permission_control(self):
"""测试权限控制"""
# 使用普通用户令牌访问管理员接口
user_token = self.get_user_token()
response = self.app.get(
'/api/admin-only',
headers={'Authorization': f'Bearer {user_token}'}
)
self.assertEqual(response.status_code, 403)
if __name__ == '__main__':
unittest.main()