Skip to content

安全最佳实践

安全性是RAG系统部署和维护的重要考虑因素。通过实施安全最佳实践,可以保护系统免受各种安全威胁,确保数据和用户信息的安全。本章节将详细介绍RAG系统的安全策略、措施和最佳实践。

1. 安全威胁分析

常见安全威胁

  • 数据泄露:敏感数据被未授权访问或泄露
  • API滥用:API被恶意使用或滥用
  • 注入攻击:SQL注入、命令注入等
  • 跨站脚本:XSS攻击
  • 跨站请求伪造:CSRF攻击
  • 拒绝服务:DoS/DDoS攻击
  • 权限提升:未授权用户获取更高权限
  • 供应链攻击:依赖库或组件被攻击

威胁评估

  • 风险识别:识别系统面临的安全风险
  • 风险分析:分析风险的可能性和影响
  • 风险评估:评估风险的严重程度
  • 风险缓解:制定风险缓解策略

2. 数据安全

数据加密

  • 传输加密:使用HTTPS加密传输数据
  • 存储加密:加密存储敏感数据
  • 端到端加密:确保数据在整个传输过程中加密

数据脱敏

  • 敏感信息脱敏:对敏感信息进行脱敏处理
  • 数据访问控制:限制对敏感数据的访问
  • 数据生命周期管理:管理数据的创建、使用和销毁

数据备份与恢复

  • 定期备份:定期备份数据
  • 异地备份:在不同地点备份数据
  • 备份验证:验证备份的有效性
  • 恢复测试:定期测试数据恢复流程

3. 访问控制

身份认证

  • 多因素认证:使用多因素认证提高安全性
python
# auth.py
from functools import wraps
from flask import request, jsonify, current_app
import jwt
import bcrypt
from datetime import datetime, timedelta

class AuthManager:
    def __init__(self, secret_key):
        self.secret_key = secret_key
        self.users = {}  # 实际应该使用数据库
    
    def hash_password(self, password):
        """密码哈希"""
        return bcrypt.hashpw(password.encode(), bcrypt.gensalt())
    
    def verify_password(self, password, hashed):
        """验证密码"""
        return bcrypt.checkpw(password.encode(), hashed)
    
    def generate_token(self, user_id, expires_in=3600):
        """生成JWT令牌"""
        payload = {
            'user_id': user_id,
            'exp': datetime.utcnow() + timedelta(seconds=expires_in),
            'iat': datetime.utcnow()
        }
        return jwt.encode(payload, self.secret_key, algorithm='HS256')
    
    def verify_token(self, token):
        """验证JWT令牌"""
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=['HS256'])
            return payload
        except jwt.ExpiredSignatureError:
            return None
        except jwt.InvalidTokenError:
            return None
    
    def login_required(self, f):
        """登录验证装饰器"""
        @wraps(f)
        def decorated_function(*args, **kwargs):
            token = None
            if 'Authorization' in request.headers:
                auth_header = request.headers['Authorization']
                try:
                    token = auth_header.split(" ")[1]
                except IndexError:
                    return jsonify({'error': 'Invalid token format'}), 401
            
            if not token:
                return jsonify({'error': 'Token is missing'}), 401
            
            payload = self.verify_token(token)
            if not payload:
                return jsonify({'error': 'Token is invalid or expired'}), 401
            
            request.user_id = payload['user_id']
            return f(*args, **kwargs)
        return decorated_function

# 使用示例
auth = AuthManager(secret_key="your-secret-key")

@app.route('/api/login', methods=['POST'])
def login():
    data = request.json
    username = data.get('username')
    password = data.get('password')
    
    # 验证用户凭据
    user = auth.users.get(username)
    if user and auth.verify_password(password, user['password']):
        token = auth.generate_token(user['id'])
        return jsonify({'token': token})
    
    return jsonify({'error': 'Invalid credentials'}), 401

@app.route('/api/protected')
@auth.login_required
def protected():
    return jsonify({'message': 'Access granted', 'user_id': request.user_id})

权限控制

python
# rbac.py
from functools import wraps
from flask import request, jsonify

class RBAC:
    def __init__(self):
        self.roles = {
            'admin': ['read', 'write', 'delete', 'manage'],
            'user': ['read', 'write'],
            'guest': ['read']
        }
        self.user_roles = {}
    
    def assign_role(self, user_id, role):
        """分配角色"""
        if role not in self.roles:
            raise ValueError(f"Invalid role: {role}")
        self.user_roles[user_id] = role
    
    def check_permission(self, user_id, permission):
        """检查权限"""
        role = self.user_roles.get(user_id)
        if not role:
            return False
        return permission in self.roles.get(role, [])
    
    def require_permission(self, permission):
        """权限检查装饰器"""
        def decorator(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                user_id = getattr(request, 'user_id', None)
                if not user_id:
                    return jsonify({'error': 'Authentication required'}), 401
                
                if not self.check_permission(user_id, permission):
                    return jsonify({'error': 'Permission denied'}), 403
                
                return f(*args, **kwargs)
            return wrapper
        return decorator

# 使用示例
rbac = RBAC()

@app.route('/api/admin-only')
@auth.login_required
@rbac.require_permission('manage')
def admin_only():
    return jsonify({'message': 'Admin access granted'})

4. API安全

速率限制

python
# rate_limiter.py
import time
from functools import wraps
from flask import request, jsonify

class RateLimiter:
    def __init__(self, max_requests=100, window=3600):
        self.max_requests = max_requests
        self.window = window
        self.requests = {}
    
    def is_allowed(self, key):
        """检查是否允许请求"""
        now = time.time()
        
        if key not in self.requests:
            self.requests[key] = []
        
        # 清理过期请求
        self.requests[key] = [
            req_time for req_time in self.requests[key]
            if now - req_time < self.window
        ]
        
        # 检查是否超过限制
        if len(self.requests[key]) >= self.max_requests:
            return False
        
        # 记录请求
        self.requests[key].append(now)
        return True
    
    def limit(self, key_func=None):
        """速率限制装饰器"""
        def decorator(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                # 获取限制键
                if key_func:
                    key = key_func()
                else:
                    key = request.remote_addr
                
                if not self.is_allowed(key):
                    return jsonify({
                        'error': 'Rate limit exceeded',
                        'retry_after': self.window
                    }), 429
                
                return f(*args, **kwargs)
            return wrapper
        return decorator

# 使用示例
limiter = RateLimiter(max_requests=100, window=3600)  # 每小时100请求

@app.route('/api/query')
@limiter.limit(key_func=lambda: request.headers.get('X-API-Key', request.remote_addr))
def query():
    # 处理查询
    pass

输入验证

python
# input_validation.py
import re
from flask import request, jsonify
from marshmallow import Schema, fields, validate, ValidationError

class QuerySchema(Schema):
    """查询请求验证模式"""
    question = fields.String(
        required=True,
        validate=validate.Length(min=1, max=1000),
        metadata={'description': '查询问题'}
    )
    context = fields.String(
        validate=validate.Length(max=5000),
        load_default='',
        metadata={'description': '上下文信息'}
    )
    max_results = fields.Integer(
        validate=validate.Range(min=1, max=10),
        load_default=5,
        metadata={'description': '最大结果数'}
    )

class InputValidator:
    @staticmethod
    def sanitize_input(text):
        """清理输入"""
        # 移除潜在的恶意字符
        text = re.sub(r'[<>&\"\']', '', text)
        return text.strip()
    
    @staticmethod
    def validate_query(data):
        """验证查询请求"""
        schema = QuerySchema()
        try:
            result = schema.load(data)
            # 清理输入
            result['question'] = InputValidator.sanitize_input(result['question'])
            result['context'] = InputValidator.sanitize_input(result['context'])
            return result, None
        except ValidationError as err:
            return None, err.messages

# 使用示例
@app.route('/api/query', methods=['POST'])
def query():
    data = request.json
    validated_data, errors = InputValidator.validate_query(data)
    
    if errors:
        return jsonify({'error': 'Invalid input', 'details': errors}), 400
    
    # 处理验证后的数据
    result = rag_system.query(validated_data['question'])
    return jsonify(result)

CORS配置

python
# cors_config.py
from flask_cors import CORS

# 基本CORS配置
cors = CORS(resources={
    r"/api/*": {
        "origins": ["https://yourdomain.com", "https://app.yourdomain.com"],
        "methods": ["GET", "POST", "OPTIONS"],
        "allow_headers": ["Content-Type", "Authorization"],
        "supports_credentials": True,
        "max_age": 3600
    }
})

# 应用到应用
cors.init_app(app)

5. 模型安全

提示注入防护

python
# prompt_security.py
import re

class PromptSecurity:
    def __init__(self):
        self.forbidden_patterns = [
            r'ignore\s+previous\s+instructions',
            r'disregard\s+all\s+prior',
            r'system\s+prompt',
            r'you\s+are\s+now',
            r'forget\s+everything',
            r'act\s+as\s+if',
        ]
    
    def check_prompt(self, prompt):
        """检查提示是否包含注入攻击"""
        prompt_lower = prompt.lower()
        
        for pattern in self.forbidden_patterns:
            if re.search(pattern, prompt_lower):
                return False, f"Detected potential injection: {pattern}"
        
        return True, None
    
    def sanitize_prompt(self, prompt):
        """清理提示"""
        # 移除控制字符
        prompt = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f]', '', prompt)
        
        # 限制长度
        max_length = 4000
        if len(prompt) > max_length:
            prompt = prompt[:max_length]
        
        return prompt

# 使用示例
security = PromptSecurity()

@app.route('/api/query', methods=['POST'])
def query():
    data = request.json
    question = data.get('question', '')
    
    # 安全检查
    is_safe, error = security.check_prompt(question)
    if not is_safe:
        return jsonify({'error': 'Invalid input', 'message': error}), 400
    
    # 清理提示
    clean_question = security.sanitize_prompt(question)
    
    # 处理查询
    result = rag_system.query(clean_question)
    return jsonify(result)

输出过滤

python
# output_filter.py
import re

class OutputFilter:
    def __init__(self):
        self.sensitive_patterns = [
            r'\b\d{16}\b',  # 信用卡号
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
            r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',  # 邮箱
            r'\b\d{11}\b',  # 手机号
        ]
    
    def filter_output(self, text):
        """过滤敏感信息"""
        filtered = text
        for pattern in self.sensitive_patterns:
            filtered = re.sub(pattern, '[REDACTED]', filtered)
        return filtered
    
    def check_harmful_content(self, text):
        """检查有害内容"""
        # 这里应该使用更复杂的内容审核API
        harmful_keywords = ['暴力', '色情', '歧视', '仇恨']
        
        for keyword in harmful_keywords:
            if keyword in text.lower():
                return False, f"Harmful content detected: {keyword}"
        
        return True, None

# 使用示例
output_filter = OutputFilter()

@app.route('/api/query', methods=['POST'])
def query():
    # ... 处理查询 ...
    result = rag_system.query(question)
    
    # 过滤输出
    filtered_answer = output_filter.filter_output(result['answer'])
    
    # 检查有害内容
    is_safe, error = output_filter.check_harmful_content(filtered_answer)
    if not is_safe:
        return jsonify({'error': 'Content filtered', 'message': error}), 400
    
    return jsonify({'answer': filtered_answer})

6. 日志与审计

python
# audit_logging.py
import logging
import json
from datetime import datetime
from functools import wraps

class AuditLogger:
    def __init__(self):
        self.logger = logging.getLogger('audit')
        handler = logging.FileHandler('logs/audit.log')
        handler.setFormatter(logging.Formatter(
            '%(asctime)s - %(message)s'
        ))
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
    
    def log_access(self, user_id, action, resource, result):
        """记录访问日志"""
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'user_id': user_id,
            'action': action,
            'resource': resource,
            'result': result,
            'ip_address': request.remote_addr,
            'user_agent': request.user_agent.string
        }
        self.logger.info(json.dumps(log_entry, ensure_ascii=False))
    
    def audit_trail(self, action):
        """审计追踪装饰器"""
        def decorator(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                user_id = getattr(request, 'user_id', 'anonymous')
                resource = request.path
                
                try:
                    result = f(*args, **kwargs)
                    self.log_access(user_id, action, resource, 'success')
                    return result
                except Exception as e:
                    self.log_access(user_id, action, resource, f'error: {str(e)}')
                    raise
            return wrapper
        return decorator

# 使用示例
audit_logger = AuditLogger()

@app.route('/api/sensitive-data')
@auth.login_required
@audit_logger.audit_trail('access_sensitive_data')
def sensitive_data():
    # 处理敏感数据
    pass

7. 安全测试

python
# security_tests.py
import unittest
from app import app

class SecurityTests(unittest.TestCase):
    def setUp(self):
        self.app = app.test_client()
    
    def test_sql_injection(self):
        """测试SQL注入防护"""
        malicious_input = "'; DROP TABLE users; --"
        response = self.app.post('/api/query', json={
            'question': malicious_input
        })
        # 应该正常处理或返回错误,而不是执行恶意代码
        self.assertIn(response.status_code, [200, 400])
    
    def test_xss_protection(self):
        """测试XSS防护"""
        xss_payload = "<script>alert('xss')</script>"
        response = self.app.post('/api/query', json={
            'question': xss_payload
        })
        # 响应中不应该包含未过滤的脚本
        self.assertNotIn('<script>', response.data.decode())
    
    def test_rate_limiting(self):
        """测试速率限制"""
        # 发送大量请求
        for _ in range(110):  # 超过限制
            response = self.app.get('/api/query')
        
        # 应该被限制
        self.assertEqual(response.status_code, 429)
    
    def test_authentication_required(self):
        """测试认证要求"""
        response = self.app.get('/api/protected')
        self.assertEqual(response.status_code, 401)
    
    def test_permission_control(self):
        """测试权限控制"""
        # 使用普通用户令牌访问管理员接口
        user_token = self.get_user_token()
        response = self.app.get(
            '/api/admin-only',
            headers={'Authorization': f'Bearer {user_token}'}
        )
        self.assertEqual(response.status_code, 403)

if __name__ == '__main__':
    unittest.main()