Skip to content

性能优化

性能优化概述

性能指标

指标说明优化目标
延迟 (Latency)单次查询响应时间< 100ms
吞吐量 (Throughput)每秒查询数 (QPS)最大化
召回率 (Recall)搜索结果准确性> 95%
资源利用率CPU、内存、磁盘使用合理分配

优化层次

性能优化
├── 数据层优化
│   ├── 数据建模
│   ├── 分区策略
│   └── 数据分布

├── 索引层优化
│   ├── 索引类型选择
│   ├── 索引参数调优
│   └── 索引维护

├── 查询层优化
│   ├── 搜索参数调优
│   ├── 批量查询
│   └── 过滤优化

└── 系统层优化
    ├── 资源配置
    ├── 并发控制
    └── 缓存策略

数据层优化

合理的 Schema 设计

python
from pymilvus import FieldSchema, CollectionSchema, DataType

# 优化前:字段过多,数据冗余
fields_unoptimized = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=768),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=4096),  # 过大
    FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),  # 过大
    FieldSchema(name="metadata", dtype=DataType.JSON),  # 存储大量数据
]

# 优化后:精简字段,合理长度
fields_optimized = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=768),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),  # 合理长度
    FieldSchema(name="summary", dtype=DataType.VARCHAR, max_length=1024),  # 摘要替代全文
    FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=64),
    FieldSchema(name="tags", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, 
                max_length=32, max_capacity=10),
]

向量维度优化

python
# 选择合适的向量维度
DIMENSION_RECOMMENDATIONS = {
    "text_embedding": {
        "openai": 1536,      # OpenAI Embedding
        "bge": 1024,         # BGE Embedding
        "m3e": 768,          # M3E Embedding
    },
    "image_embedding": {
        "resnet50": 2048,
        "clip": 512,
        "efficientnet": 1280,
    },
    "audio_embedding": {
        "wav2vec": 768,
        "whisper": 512,
    }
}

# 维度选择建议
# - 维度越高,表达能力越强,但计算成本越高
# - 维度越低,速度越快,但可能损失精度
# - 推荐使用 128-768 维度的向量

分区策略优化

python
def optimize_partition_strategy(collection, partition_key, num_partitions=16):
    """优化分区策略"""
    
    # 1. 分析数据分布
    stats = collection.query(
        expr="",
        output_fields=[partition_key],
        limit=collection.num_entities
    )
    
    # 2. 计算每个值的分布
    from collections import Counter
    distribution = Counter([s[partition_key] for s in stats])
    
    # 3. 根据分布创建分区
    # 对于时间分区:按时间范围
    # 对于类别分区:合并小类别
    # 对于哈希分区:均匀分布
    
    print(f"数据分布: {distribution}")
    
    # 4. 建议分区数
    # 每个分区建议 100万-1000万 条记录
    recommended_partitions = max(1, collection.num_entities // 5000000)
    print(f"建议分区数: {recommended_partitions}")

索引层优化

索引类型选择决策树

python
def select_index_type(data_size, recall_requirement, memory_budget):
    """
    根据场景选择索引类型
    
    Args:
        data_size: 数据量(百万)
        recall_requirement: 召回率要求 (0-1)
        memory_budget: 内存预算(GB)
    
    Returns:
        推荐的索引类型和参数
    """
    
    if data_size < 1:
        return {
            "index_type": "FLAT",
            "params": {},
            "reason": "数据量小,使用精确搜索"
        }
    
    if recall_requirement > 0.99:
        return {
            "index_type": "IVF_FLAT",
            "params": {"nlist": min(4 * int(data_size ** 0.5), 65536)},
            "reason": "高召回率要求,使用精确索引"
        }
    
    if memory_budget < data_size * 0.5:
        return {
            "index_type": "IVF_SQ8",
            "params": {"nlist": min(4 * int(data_size ** 0.5), 65536)},
            "reason": "内存受限,使用量化索引"
        }
    
    if data_size > 100:
        return {
            "index_type": "IVF_PQ",
            "params": {
                "nlist": min(4 * int(data_size ** 0.5), 65536),
                "m": 16,
                "nbits": 8
            },
            "reason": "超大规模数据,使用乘积量化"
        }
    
    # 默认推荐 HNSW
    return {
        "index_type": "HNSW",
        "params": {
            "M": 16,
            "efConstruction": 200
        },
        "reason": "平衡速度和精度"
    }

索引参数自动调优

python
def auto_tune_index_params(collection, target_recall=0.95, sample_size=1000):
    """自动调优索引参数"""
    
    from pymilvus import Collection
    import random
    import numpy as np
    
    collection = Collection(collection.name)
    collection.load()
    
    # 1. 获取样本数据
    sample_data = collection.query(
        expr="",
        output_fields=["vector"],
        limit=sample_size
    )
    
    sample_vectors = [d["vector"] for d in sample_data]
    
    # 2. 测试不同参数组合
    best_params = None
    best_score = 0
    
    # IVF 参数网格
    nlist_options = [64, 128, 256, 512]
    nprobe_options = [8, 16, 32, 64]
    
    for nlist in nlist_options:
        # 创建索引
        index_params = {
            "index_type": "IVF_FLAT",
            "metric_type": "L2",
            "params": {"nlist": nlist}
        }
        collection.create_index("vector", index_params)
        collection.load()
        
        for nprobe in nprobe_options:
            search_params = {
                "metric_type": "L2",
                "params": {"nprobe": nprobe}
            }
            
            # 测试召回率
            recall = test_recall(collection, sample_vectors, search_params)
            
            # 测试延迟
            latency = test_latency(collection, sample_vectors, search_params)
            
            # 综合评分
            score = recall * 0.7 + (1 / (latency + 1)) * 0.3
            
            if score > best_score and recall >= target_recall:
                best_score = score
                best_params = {
                    "nlist": nlist,
                    "nprobe": nprobe,
                    "recall": recall,
                    "latency": latency
                }
    
    return best_params

def test_recall(collection, query_vectors, search_params, top_k=10):
    """测试召回率"""
    # 使用 FLAT 索引作为基准
    # 比较搜索结果的重叠度
    pass

def test_latency(collection, query_vectors, search_params):
    """测试延迟"""
    import time
    
    start = time.time()
    for vec in query_vectors[:100]:
        collection.search(
            data=[vec],
            anns_field="vector",
            param=search_params,
            limit=10
        )
    end = time.time()
    
    return (end - start) / 100

查询层优化

批量查询优化

python
import time
from concurrent.futures import ThreadPoolExecutor

def batch_search_optimized(collection, query_vectors, batch_size=100, max_workers=4):
    """优化的批量搜索"""
    
    all_results = []
    
    # 方法 1: 单线程批量
    start = time.time()
    for i in range(0, len(query_vectors), batch_size):
        batch = query_vectors[i:i + batch_size]
        results = collection.search(
            data=batch,
            anns_field="vector",
            param={"metric_type": "L2", "params": {"nprobe": 16}},
            limit=10
        )
        all_results.extend(results)
    single_thread_time = time.time() - start
    
    # 方法 2: 多线程并行
    def search_batch(batch):
        return collection.search(
            data=batch,
            anns_field="vector",
            param={"metric_type": "L2", "params": {"nprobe": 16}},
            limit=10
        )
    
    start = time.time()
    batches = [query_vectors[i:i + batch_size] 
               for i in range(0, len(query_vectors), batch_size)]
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(search_batch, batches))
    
    multi_thread_time = time.time() - start
    
    print(f"单线程耗时: {single_thread_time:.2f}s")
    print(f"多线程耗时: {multi_thread_time:.2f}s")
    print(f"加速比: {single_thread_time / multi_thread_time:.2f}x")
    
    return results

搜索参数动态调整

python
class AdaptiveSearchParams:
    """自适应搜索参数调整"""
    
    def __init__(self, collection):
        self.collection = collection
        self.query_history = []
        self.latency_threshold = 100  # ms
    
    def get_optimal_params(self, query_vector, required_recall=None):
        """获取最优搜索参数"""
        
        base_params = {
            "metric_type": "L2",
            "params": {"nprobe": 16}
        }
        
        # 根据历史查询调整
        if self.query_history:
            avg_latency = sum(h["latency"] for h in self.query_history[-100:]) / 100
            
            if avg_latency > self.latency_threshold:
                # 延迟过高,降低精度
                base_params["params"]["nprobe"] = max(8, base_params["params"]["nprobe"] // 2)
            else:
                # 延迟较低,可以提高精度
                base_params["params"]["nprobe"] = min(128, base_params["params"]["nprobe"] * 2)
        
        return base_params
    
    def record_query(self, params, latency, recall):
        """记录查询历史"""
        self.query_history.append({
            "params": params,
            "latency": latency,
            "recall": recall,
            "timestamp": time.time()
        })

预过滤优化

python
def optimized_filtered_search(collection, query_vector, filter_expr, top_k=10):
    """优化的过滤搜索"""
    
    # 策略 1: 如果过滤条件能大幅减少数据量,先过滤再搜索
    # 策略 2: 如果过滤条件较宽松,先搜索再过滤
    
    # 估算过滤后的数据量
    estimated_filtered = collection.query(
        expr=filter_expr,
        output_fields=["id"],
        limit=100000  # 估算上限
    )
    
    filter_ratio = len(estimated_filtered) / collection.num_entities
    
    if filter_ratio < 0.1:
        # 过滤条件很强,使用索引过滤
        results = collection.search(
            data=[query_vector],
            anns_field="vector",
            param={"metric_type": "L2", "params": {"nprobe": 32}},
            limit=top_k,
            expr=filter_expr
        )
    else:
        # 过滤条件较弱,先搜索再过滤
        results = collection.search(
            data=[query_vector],
            anns_field="vector",
            param={"metric_type": "L2", "params": {"nprobe": 16}},
            limit=top_k * 3  # 获取更多候选
        )
        
        # 应用过滤条件
        filtered_results = []
        for hit in results[0]:
            # 检查是否满足过滤条件
            entity = collection.query(
                expr=f"id == {hit.id}",
                output_fields=["id"],
                limit=1
            )
            if entity:
                filtered_results.append(hit)
            
            if len(filtered_results) >= top_k:
                break
        
        results = [filtered_results]
    
    return results

系统层优化

资源配置优化

yaml
# milvus.yaml 配置优化
queryNode:
  cacheSize: 8192  # 查询节点缓存大小 (MB)
  
dataNode:
  segment:
    maxSize: 512  # 最大段大小 (MB)
    sealProportion: 0.25  # 段密封比例

indexNode:
  scheduler:
    buildParallel: 1  # 并行构建索引数

连接池优化

python
from pymilvus import connections
import threading

class ConnectionPool:
    """Milvus 连接池"""
    
    def __init__(self, host="localhost", port="19530", pool_size=10):
        self.host = host
        self.port = port
        self.pool_size = pool_size
        self.connections = []
        self.lock = threading.Lock()
        self.current = 0
        
        # 初始化连接
        for i in range(pool_size):
            alias = f"conn_{i}"
            connections.connect(
                alias=alias,
                host=host,
                port=port
            )
            self.connections.append(alias)
    
    def get_connection(self):
        """获取连接"""
        with self.lock:
            alias = self.connections[self.current]
            self.current = (self.current + 1) % self.pool_size
            return alias
    
    def close_all(self):
        """关闭所有连接"""
        for alias in self.connections:
            connections.disconnect(alias)

# 使用示例
pool = ConnectionPool(pool_size=10)
alias = pool.get_connection()
collection = Collection("my_collection", using=alias)

缓存策略

python
from functools import lru_cache
import hashlib

class SearchCache:
    """搜索结果缓存"""
    
    def __init__(self, maxsize=10000, ttl=3600):
        self.cache = {}
        self.maxsize = maxsize
        self.ttl = ttl
    
    def _make_key(self, query_vector, params):
        """生成缓存键"""
        vector_str = ",".join([f"{x:.6f}" for x in query_vector])
        param_str = str(sorted(params.items()))
        return hashlib.md5(f"{vector_str}:{param_str}".encode()).hexdigest()
    
    def get(self, query_vector, params):
        """获取缓存"""
        key = self._make_key(query_vector, params)
        if key in self.cache:
            result, timestamp = self.cache[key]
            if time.time() - timestamp < self.ttl:
                return result
            else:
                del self.cache[key]
        return None
    
    def set(self, query_vector, params, result):
        """设置缓存"""
        if len(self.cache) >= self.maxsize:
            # LRU 淘汰
            oldest = min(self.cache.items(), key=lambda x: x[1][1])
            del self.cache[oldest[0]]
        
        key = self._make_key(query_vector, params)
        self.cache[key] = (result, time.time())

性能监控

查询性能监控

python
import time
import statistics

class PerformanceMonitor:
    """性能监控器"""
    
    def __init__(self):
        self.query_times = []
        self.recall_rates = []
    
    def record_query(self, latency, recall=None):
        """记录查询性能"""
        self.query_times.append(latency)
        if recall:
            self.recall_rates.append(recall)
    
    def get_stats(self):
        """获取统计信息"""
        if not self.query_times:
            return {}
        
        return {
            "query_count": len(self.query_times),
            "avg_latency": statistics.mean(self.query_times),
            "p50_latency": statistics.median(self.query_times),
            "p95_latency": sorted(self.query_times)[int(len(self.query_times) * 0.95)],
            "p99_latency": sorted(self.query_times)[int(len(self.query_times) * 0.99)],
            "avg_recall": statistics.mean(self.recall_rates) if self.recall_rates else None
        }
    
    def print_report(self):
        """打印性能报告"""
        stats = self.get_stats()
        
        print("=== 性能报告 ===")
        print(f"查询次数: {stats['query_count']}")
        print(f"平均延迟: {stats['avg_latency']:.2f}ms")
        print(f"P50 延迟: {stats['p50_latency']:.2f}ms")
        print(f"P95 延迟: {stats['p95_latency']:.2f}ms")
        print(f"P99 延迟: {stats['p99_latency']:.2f}ms")
        if stats['avg_recall']:
            print(f"平均召回率: {stats['avg_recall']:.2%}")

# 使用示例
monitor = PerformanceMonitor()

# 执行查询并记录
start = time.time()
results = collection.search(...)
latency = (time.time() - start) * 1000
monitor.record_query(latency)

# 打印报告
monitor.print_report()

完整优化示例

python
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
import time
import random

def performance_optimization_demo():
    """性能优化完整示例"""
    
    # 连接 Milvus
    connections.connect(host="localhost", port="19530")
    
    # 清理旧数据
    if utility.has_collection("perf_demo"):
        utility.drop_collection("perf_demo")
    
    # 创建优化的集合
    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
        FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=64),
        FieldSchema(name="timestamp", dtype=DataType.INT64)
    ]
    
    schema = CollectionSchema(fields, "性能优化演示")
    collection = Collection("perf_demo", schema)
    
    # 创建分区
    categories = ["科技", "生活", "娱乐", "教育", "体育"]
    for category in categories:
        collection.create_partition(f"category_{category}")
    
    # 插入数据
    print("插入测试数据...")
    for category in categories:
        data = []
        for i in range(10000):
            data.append({
                "vector": [random.random() for _ in range(128)],
                "category": category,
                "timestamp": int(time.time()) - random.randint(0, 86400 * 30)
            })
        collection.insert(data, partition_name=f"category_{category}")
    
    print(f"总共插入 {collection.num_entities} 条数据")
    
    # 创建优化索引
    print("\n创建优化索引...")
    index_params = {
        "index_type": "HNSW",
        "metric_type": "L2",
        "params": {
            "M": 16,
            "efConstruction": 200
        }
    }
    collection.create_index("vector", index_params)
    utility.wait_for_index_building_complete("perf_demo")
    
    # 只加载需要的分区
    print("\n加载指定分区...")
    collection.load(partition_names=["category_科技", "category_生活"])
    
    # 性能测试
    print("\n=== 性能测试 ===")
    query_vectors = [[random.random() for _ in range(128)] for _ in range(100)]
    
    # 测试 1: 基础搜索
    start = time.time()
    for vec in query_vectors:
        collection.search(
            data=[vec],
            anns_field="vector",
            param={"metric_type": "L2", "params": {"ef": 64}},
            limit=10
        )
    base_time = time.time() - start
    print(f"基础搜索: {base_time:.2f}s ({len(query_vectors)/base_time:.1f} QPS)")
    
    # 测试 2: 批量搜索
    start = time.time()
    for i in range(0, len(query_vectors), 10):
        batch = query_vectors[i:i + 10]
        collection.search(
            data=batch,
            anns_field="vector",
            param={"metric_type": "L2", "params": {"ef": 64}},
            limit=10
        )
    batch_time = time.time() - start
    print(f"批量搜索: {batch_time:.2f}s ({len(query_vectors)/batch_time:.1f} QPS)")
    
    # 测试 3: 分区搜索
    start = time.time()
    for vec in query_vectors:
        collection.search(
            data=[vec],
            anns_field="vector",
            param={"metric_type": "L2", "params": {"ef": 64}},
            partition_names=["category_科技"],
            limit=10
        )
    partition_time = time.time() - start
    print(f"分区搜索: {partition_time:.2f}s ({len(query_vectors)/partition_time:.1f} QPS)")
    
    # 清理
    collection.release()
    utility.drop_collection("perf_demo")
    
    print("\n性能优化演示完成!")

if __name__ == "__main__":
    performance_optimization_demo()

优化检查清单

数据建模

  • [ ] 选择合适的向量维度
  • [ ] 精简 Schema,避免冗余字段
  • [ ] 合理设置字符串字段长度
  • [ ] 使用 ARRAY 替代 JSON 存储列表

索引优化

  • [ ] 根据数据规模选择索引类型
  • [ ] 调整 nlist/nprobe 参数
  • [ ] 定期重建索引
  • [ ] 监控索引构建时间

查询优化

  • [ ] 使用批量查询
  • [ ] 合理设置 limit
  • [ ] 使用分区减少搜索范围
  • [ ] 优化过滤表达式

系统优化

  • [ ] 配置合适的缓存大小
  • [ ] 使用连接池
  • [ ] 监控资源使用
  • [ ] 设置合理的超时时间

下一步

掌握性能优化后,你可以:

  1. 了解监控与运维
  2. 学习集群部署
  3. 探索实际应用案例