Skip to content

高级搜索技巧

多路召回与融合

什么是多路召回

多路召回是指从多个维度或路径检索相关结果,然后融合得到最终结果。常用于提升搜索的召回率和准确性。

RRF(Reciprocal Rank Fusion)

python
def reciprocal_rank_fusion(results_list, k=60):
    """
    RRF 融合算法
    
    Args:
        results_list: 多路搜索结果列表
        k: RRF 常数,通常取 60
    
    Returns:
        融合后的排序结果
    """
    scores = {}
    
    for results in results_list:
        for rank, hit in enumerate(results[0]):
            doc_id = hit.id
            # RRF 公式: score = sum(1 / (k + rank))
            scores[doc_id] = scores.get(doc_id, 0) + 1 / (k + rank + 1)
    
    # 按分数降序排序
    sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return sorted_results

# 使用示例
vector_results = collection.search(
    data=[query_vector],
    anns_field="vector",
    param=search_params,
    limit=100
)

keyword_results = collection.search(
    data=[query_vector],
    anns_field="title_vector",
    param=search_params,
    limit=100
)

# 融合结果
fused_results = reciprocal_rank_fusion([vector_results, keyword_results])

加权融合

python
def weighted_fusion(results_list, weights):
    """加权融合多路结果"""
    scores = {}
    
    for results, weight in zip(results_list, weights):
        for rank, hit in enumerate(results[0]):
            doc_id = hit.id
            # 使用权重调整分数
            score = weight * (1 / (rank + 1))
            scores[doc_id] = scores.get(doc_id, 0) + score
    
    return sorted(scores.items(), key=lambda x: x[1], reverse=True)

# 使用示例
fused = weighted_fusion(
    [vector_results, keyword_results, semantic_results],
    weights=[0.5, 0.3, 0.2]
)

混合搜索

向量 + 关键词混合搜索

python
def hybrid_search(collection, query_vector, keyword, top_k=10):
    """向量搜索 + 关键词过滤"""
    
    # 先进行向量搜索
    vector_results = collection.search(
        data=[query_vector],
        anns_field="content_vector",
        param={"metric_type": "L2", "params": {"nprobe": 32}},
        limit=top_k * 3,  # 获取更多候选
        output_fields=["title", "content"]
    )
    
    # 关键词过滤
    filtered_results = []
    for hit in vector_results[0]:
        content = hit.entity.get('content', '')
        title = hit.entity.get('title', '')
        
        # 检查是否包含关键词
        if keyword.lower() in content.lower() or keyword.lower() in title.lower():
            filtered_results.append(hit)
        
        if len(filtered_results) >= top_k:
            break
    
    return filtered_results

多向量字段搜索

python
def multi_vector_search(collection, text_vector, image_vector, weights=[0.6, 0.4]):
    """多向量字段融合搜索"""
    
    # 文本向量搜索
    text_results = collection.search(
        data=[text_vector],
        anns_field="text_vector",
        param={"metric_type": "COSINE", "params": {"nprobe": 16}},
        limit=100
    )
    
    # 图像向量搜索
    image_results = collection.search(
        data=[image_vector],
        anns_field="image_vector",
        param={"metric_type": "L2", "params": {"nprobe": 16}},
        limit=100
    )
    
    # 归一化距离并融合
    scores = {}
    
    # 处理文本结果
    for hit in text_results[0]:
        # COSINE 距离转换为相似度
        similarity = (hit.distance + 1) / 2
        scores[hit.id] = weights[0] * similarity
    
    # 处理图像结果
    for hit in image_results[0]:
        # L2 距离转换为相似度(假设最大距离为 10)
        similarity = 1 - min(hit.distance / 10, 1)
        scores[hit.id] = scores.get(hit.id, 0) + weights[1] * similarity
    
    return sorted(scores.items(), key=lambda x: x[1], reverse=True)

搜索结果重排序

基于业务规则重排

python
def rerank_by_business_rules(hits, user_profile):
    """根据业务规则重排序"""
    scored_hits = []
    
    for hit in hits:
        base_score = 1 / (hit.distance + 0.001)  # 基础相似度分数
        
        # 业务规则加分
        bonus = 0
        
        # 用户偏好匹配
        category = hit.entity.get('category')
        if category in user_profile['preferred_categories']:
            bonus += 0.2
        
        # 热门内容加分
        view_count = hit.entity.get('view_count', 0)
        if view_count > 10000:
            bonus += 0.1
        
        # 新内容加分
        publish_time = hit.entity.get('publish_time', 0)
        if publish_time > time.time() - 7 * 24 * 3600:  # 一周内
            bonus += 0.15
        
        final_score = base_score + bonus
        scored_hits.append((hit, final_score))
    
    # 按最终分数排序
    scored_hits.sort(key=lambda x: x[1], reverse=True)
    return [hit for hit, _ in scored_hits]

基于机器学习模型重排

python
def rerank_by_ml_model(hits, query_features, model):
    """使用机器学习模型重排序"""
    scored_hits = []
    
    for hit in hits:
        # 提取特征
        features = extract_features(hit, query_features)
        
        # 模型预测相关性分数
        relevance_score = model.predict([features])[0]
        
        scored_hits.append((hit, relevance_score))
    
    # 按相关性分数排序
    scored_hits.sort(key=lambda x: x[1], reverse=True)
    return [hit for hit, _ in scored_hits]

def extract_features(hit, query_features):
    """提取重排序特征"""
    features = []
    
    # 向量相似度特征
    features.append(1 / (hit.distance + 0.001))
    
    # 内容质量特征
    features.append(hit.entity.get('quality_score', 0))
    features.append(hit.entity.get('view_count', 0) / 10000)
    
    # 时效性特征
    age_days = (time.time() - hit.entity.get('publish_time', 0)) / 86400
    features.append(max(0, 1 - age_days / 30))  # 30天内衰减
    
    # 与查询的特征交互
    features.extend(query_features)
    
    return features

分页与游标

基于偏移的分页

python
def search_with_pagination(collection, query_vector, page=1, page_size=10):
    """分页搜索"""
    offset = (page - 1) * page_size
    
    # 获取足够多的结果
    results = collection.search(
        data=[query_vector],
        anns_field="vector",
        param={"metric_type": "L2", "params": {"nprobe": 32}},
        limit=offset + page_size,
        output_fields=["title", "content"]
    )
    
    hits = results[0]
    if offset >= len(hits):
        return []
    
    return hits[offset:offset + page_size]

基于游标的分页

python
class CursorPagination:
    """基于游标的分页"""
    
    def __init__(self, collection, query_vector, page_size=10):
        self.collection = collection
        self.query_vector = query_vector
        self.page_size = page_size
        self.last_distance = None
        self.excluded_ids = set()
    
    def next_page(self):
        """获取下一页"""
        expr = None
        
        # 排除已返回的结果
        if self.excluded_ids:
            id_list = ",".join(map(str, self.excluded_ids))
            expr = f"id not in [{id_list}]"
        
        results = self.collection.search(
            data=[self.query_vector],
            anns_field="vector",
            param={"metric_type": "L2", "params": {"nprobe": 32}},
            limit=self.page_size,
            expr=expr,
            output_fields=["title"]
        )
        
        hits = results[0]
        
        # 记录已返回的 ID
        for hit in hits:
            self.excluded_ids.add(hit.id)
        
        return hits

# 使用示例
pager = CursorPagination(collection, query_vector, page_size=10)
page1 = pager.next_page()
page2 = pager.next_page()
page3 = pager.next_page()

搜索结果去重

基于内容相似度去重

python
def deduplicate_results(hits, similarity_threshold=0.95):
    """基于向量相似度去重"""
    unique_hits = []
    
    for hit in hits:
        is_duplicate = False
        
        for unique_hit in unique_hits:
            # 计算两个结果向量的相似度
            similarity = calculate_similarity(
                hit.entity.get('vector'),
                unique_hit.entity.get('vector')
            )
            
            if similarity > similarity_threshold:
                is_duplicate = True
                break
        
        if not is_duplicate:
            unique_hits.append(hit)
    
    return unique_hits

def calculate_similarity(vec1, vec2):
    """计算余弦相似度"""
    import numpy as np
    
    vec1 = np.array(vec1)
    vec2 = np.array(vec2)
    
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    
    return dot_product / (norm1 * norm2)

基于标题相似度去重

python
def deduplicate_by_title(hits, similarity_threshold=0.8):
    """基于标题文本相似度去重"""
    from difflib import SequenceMatcher
    
    unique_hits = []
    
    for hit in hits:
        title = hit.entity.get('title', '')
        is_duplicate = False
        
        for unique_hit in unique_hits:
            unique_title = unique_hit.entity.get('title', '')
            
            # 计算文本相似度
            similarity = SequenceMatcher(None, title, unique_title).ratio()
            
            if similarity > similarity_threshold:
                is_duplicate = True
                break
        
        if not is_duplicate:
            unique_hits.append(hit)
    
    return unique_hits

搜索结果多样性

MMR(Maximal Marginal Relevance)

python
def mmr_search(collection, query_vector, lambda_param=0.5, top_k=10):
    """
    MMR 算法实现搜索结果多样性
    
    Args:
        lambda_param: 相关性与多样性的平衡参数
                      0 = 只考虑多样性,1 = 只考虑相关性
    """
    # 获取候选结果
    candidates = collection.search(
        data=[query_vector],
        anns_field="vector",
        param={"metric_type": "L2", "params": {"nprobe": 32}},
        limit=top_k * 3,
        output_fields=["vector"]
    )[0]
    
    selected = []
    
    while len(selected) < top_k and candidates:
        best_score = -float('inf')
        best_candidate = None
        
        for candidate in candidates:
            # 相关性分数(与查询的相似度)
            relevance = 1 / (candidate.distance + 0.001)
            
            # 多样性分数(与已选结果的相似度)
            max_similarity = 0
            for selected_item in selected:
                sim = calculate_similarity(
                    candidate.entity.get('vector'),
                    selected_item.entity.get('vector')
                )
                max_similarity = max(max_similarity, sim)
            
            # MMR 分数
            mmr_score = lambda_param * relevance - (1 - lambda_param) * max_similarity
            
            if mmr_score > best_score:
                best_score = mmr_score
                best_candidate = candidate
        
        if best_candidate:
            selected.append(best_candidate)
            candidates.remove(best_candidate)
    
    return selected

实时搜索优化

缓存热门查询

python
from functools import lru_cache
import hashlib

class SearchCache:
    """搜索结果缓存"""
    
    def __init__(self, maxsize=1000):
        self.cache = {}
        self.maxsize = maxsize
    
    def _make_key(self, query_vector, params):
        """生成缓存键"""
        # 将向量转换为字符串作为键
        vector_str = ",".join([f"{x:.4f}" for x in query_vector])
        param_str = str(sorted(params.items()))
        key = hashlib.md5(f"{vector_str}:{param_str}".encode()).hexdigest()
        return key
    
    def get(self, query_vector, params):
        """获取缓存结果"""
        key = self._make_key(query_vector, params)
        return self.cache.get(key)
    
    def set(self, query_vector, params, results):
        """设置缓存结果"""
        if len(self.cache) >= self.maxsize:
            # LRU 淘汰
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        key = self._make_key(query_vector, params)
        self.cache[key] = results

# 使用示例
cache = SearchCache(maxsize=1000)

def cached_search(collection, query_vector, params):
    """带缓存的搜索"""
    # 尝试从缓存获取
    cached_results = cache.get(query_vector, params)
    if cached_results:
        return cached_results
    
    # 执行搜索
    results = collection.search(
        data=[query_vector],
        anns_field="vector",
        param=params,
        limit=10
    )
    
    # 缓存结果
    cache.set(query_vector, params, results)
    
    return results

完整示例

python
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
import random
import time

def advanced_search_demo():
    """高级搜索技巧完整示例"""
    
    # 连接 Milvus
    connections.connect(host="localhost", port="19530")
    
    # 清理旧数据
    if utility.has_collection("advanced_search_demo"):
        utility.drop_collection("advanced_search_demo")
    
    # 创建集合
    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
        FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=64),
        FieldSchema(name="view_count", dtype=DataType.INT64),
        FieldSchema(name="publish_time", dtype=DataType.INT64)
    ]
    
    schema = CollectionSchema(fields, "高级搜索演示")
    collection = Collection("advanced_search_demo", schema)
    
    # 插入测试数据
    print("插入测试数据...")
    categories = ["科技", "生活", "娱乐", "教育", "体育"]
    data = []
    for i in range(1000):
        data.append({
            "vector": [random.random() for _ in range(128)],
            "title": f"文章_{i}_{random.choice(categories)}",
            "category": random.choice(categories),
            "view_count": random.randint(100, 100000),
            "publish_time": int(time.time()) - random.randint(0, 30 * 86400)
        })
    
    collection.insert(data)
    
    # 创建索引
    index_params = {
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {"nlist": 128}
    }
    collection.create_index("vector", index_params)
    utility.wait_for_index_building_complete("advanced_search_demo")
    collection.load()
    
    print("\n=== 分页搜索 ===")
    query_vector = [random.random() for _ in range(128)]
    
    for page in range(1, 4):
        results = search_with_pagination(collection, query_vector, page=page, page_size=5)
        print(f"第 {page} 页结果:")
        for hit in results:
            print(f"  ID: {hit.id}, 距离: {hit.distance:.4f}")
    
    print("\n=== 业务规则重排序 ===")
    results = collection.search(
        data=[query_vector],
        anns_field="vector",
        param={"metric_type": "L2", "params": {"nprobe": 32}},
        limit=20,
        output_fields=["title", "category", "view_count", "publish_time"]
    )
    
    user_profile = {"preferred_categories": ["科技", "教育"]}
    reranked = rerank_by_business_rules(results[0], user_profile)
    
    print("重排序后前 5 个结果:")
    for hit in reranked[:5]:
        print(f"  {hit.entity.get('title')} - {hit.entity.get('category')}")
    
    # 清理
    utility.drop_collection("advanced_search_demo")
    print("\n演示完成!")

if __name__ == "__main__":
    advanced_search_demo()

下一步

掌握高级搜索技巧后,你可以:

  1. 了解数据备份与恢复
  2. 学习性能优化
  3. 探索实际应用案例