Appearance
高级搜索技巧
多路召回与融合
什么是多路召回
多路召回是指从多个维度或路径检索相关结果,然后融合得到最终结果。常用于提升搜索的召回率和准确性。
RRF(Reciprocal Rank Fusion)
python
def reciprocal_rank_fusion(results_list, k=60):
"""
RRF 融合算法
Args:
results_list: 多路搜索结果列表
k: RRF 常数,通常取 60
Returns:
融合后的排序结果
"""
scores = {}
for results in results_list:
for rank, hit in enumerate(results[0]):
doc_id = hit.id
# RRF 公式: score = sum(1 / (k + rank))
scores[doc_id] = scores.get(doc_id, 0) + 1 / (k + rank + 1)
# 按分数降序排序
sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return sorted_results
# 使用示例
vector_results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=100
)
keyword_results = collection.search(
data=[query_vector],
anns_field="title_vector",
param=search_params,
limit=100
)
# 融合结果
fused_results = reciprocal_rank_fusion([vector_results, keyword_results])加权融合
python
def weighted_fusion(results_list, weights):
"""加权融合多路结果"""
scores = {}
for results, weight in zip(results_list, weights):
for rank, hit in enumerate(results[0]):
doc_id = hit.id
# 使用权重调整分数
score = weight * (1 / (rank + 1))
scores[doc_id] = scores.get(doc_id, 0) + score
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
# 使用示例
fused = weighted_fusion(
[vector_results, keyword_results, semantic_results],
weights=[0.5, 0.3, 0.2]
)混合搜索
向量 + 关键词混合搜索
python
def hybrid_search(collection, query_vector, keyword, top_k=10):
"""向量搜索 + 关键词过滤"""
# 先进行向量搜索
vector_results = collection.search(
data=[query_vector],
anns_field="content_vector",
param={"metric_type": "L2", "params": {"nprobe": 32}},
limit=top_k * 3, # 获取更多候选
output_fields=["title", "content"]
)
# 关键词过滤
filtered_results = []
for hit in vector_results[0]:
content = hit.entity.get('content', '')
title = hit.entity.get('title', '')
# 检查是否包含关键词
if keyword.lower() in content.lower() or keyword.lower() in title.lower():
filtered_results.append(hit)
if len(filtered_results) >= top_k:
break
return filtered_results多向量字段搜索
python
def multi_vector_search(collection, text_vector, image_vector, weights=[0.6, 0.4]):
"""多向量字段融合搜索"""
# 文本向量搜索
text_results = collection.search(
data=[text_vector],
anns_field="text_vector",
param={"metric_type": "COSINE", "params": {"nprobe": 16}},
limit=100
)
# 图像向量搜索
image_results = collection.search(
data=[image_vector],
anns_field="image_vector",
param={"metric_type": "L2", "params": {"nprobe": 16}},
limit=100
)
# 归一化距离并融合
scores = {}
# 处理文本结果
for hit in text_results[0]:
# COSINE 距离转换为相似度
similarity = (hit.distance + 1) / 2
scores[hit.id] = weights[0] * similarity
# 处理图像结果
for hit in image_results[0]:
# L2 距离转换为相似度(假设最大距离为 10)
similarity = 1 - min(hit.distance / 10, 1)
scores[hit.id] = scores.get(hit.id, 0) + weights[1] * similarity
return sorted(scores.items(), key=lambda x: x[1], reverse=True)搜索结果重排序
基于业务规则重排
python
def rerank_by_business_rules(hits, user_profile):
"""根据业务规则重排序"""
scored_hits = []
for hit in hits:
base_score = 1 / (hit.distance + 0.001) # 基础相似度分数
# 业务规则加分
bonus = 0
# 用户偏好匹配
category = hit.entity.get('category')
if category in user_profile['preferred_categories']:
bonus += 0.2
# 热门内容加分
view_count = hit.entity.get('view_count', 0)
if view_count > 10000:
bonus += 0.1
# 新内容加分
publish_time = hit.entity.get('publish_time', 0)
if publish_time > time.time() - 7 * 24 * 3600: # 一周内
bonus += 0.15
final_score = base_score + bonus
scored_hits.append((hit, final_score))
# 按最终分数排序
scored_hits.sort(key=lambda x: x[1], reverse=True)
return [hit for hit, _ in scored_hits]基于机器学习模型重排
python
def rerank_by_ml_model(hits, query_features, model):
"""使用机器学习模型重排序"""
scored_hits = []
for hit in hits:
# 提取特征
features = extract_features(hit, query_features)
# 模型预测相关性分数
relevance_score = model.predict([features])[0]
scored_hits.append((hit, relevance_score))
# 按相关性分数排序
scored_hits.sort(key=lambda x: x[1], reverse=True)
return [hit for hit, _ in scored_hits]
def extract_features(hit, query_features):
"""提取重排序特征"""
features = []
# 向量相似度特征
features.append(1 / (hit.distance + 0.001))
# 内容质量特征
features.append(hit.entity.get('quality_score', 0))
features.append(hit.entity.get('view_count', 0) / 10000)
# 时效性特征
age_days = (time.time() - hit.entity.get('publish_time', 0)) / 86400
features.append(max(0, 1 - age_days / 30)) # 30天内衰减
# 与查询的特征交互
features.extend(query_features)
return features分页与游标
基于偏移的分页
python
def search_with_pagination(collection, query_vector, page=1, page_size=10):
"""分页搜索"""
offset = (page - 1) * page_size
# 获取足够多的结果
results = collection.search(
data=[query_vector],
anns_field="vector",
param={"metric_type": "L2", "params": {"nprobe": 32}},
limit=offset + page_size,
output_fields=["title", "content"]
)
hits = results[0]
if offset >= len(hits):
return []
return hits[offset:offset + page_size]基于游标的分页
python
class CursorPagination:
"""基于游标的分页"""
def __init__(self, collection, query_vector, page_size=10):
self.collection = collection
self.query_vector = query_vector
self.page_size = page_size
self.last_distance = None
self.excluded_ids = set()
def next_page(self):
"""获取下一页"""
expr = None
# 排除已返回的结果
if self.excluded_ids:
id_list = ",".join(map(str, self.excluded_ids))
expr = f"id not in [{id_list}]"
results = self.collection.search(
data=[self.query_vector],
anns_field="vector",
param={"metric_type": "L2", "params": {"nprobe": 32}},
limit=self.page_size,
expr=expr,
output_fields=["title"]
)
hits = results[0]
# 记录已返回的 ID
for hit in hits:
self.excluded_ids.add(hit.id)
return hits
# 使用示例
pager = CursorPagination(collection, query_vector, page_size=10)
page1 = pager.next_page()
page2 = pager.next_page()
page3 = pager.next_page()搜索结果去重
基于内容相似度去重
python
def deduplicate_results(hits, similarity_threshold=0.95):
"""基于向量相似度去重"""
unique_hits = []
for hit in hits:
is_duplicate = False
for unique_hit in unique_hits:
# 计算两个结果向量的相似度
similarity = calculate_similarity(
hit.entity.get('vector'),
unique_hit.entity.get('vector')
)
if similarity > similarity_threshold:
is_duplicate = True
break
if not is_duplicate:
unique_hits.append(hit)
return unique_hits
def calculate_similarity(vec1, vec2):
"""计算余弦相似度"""
import numpy as np
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot_product / (norm1 * norm2)基于标题相似度去重
python
def deduplicate_by_title(hits, similarity_threshold=0.8):
"""基于标题文本相似度去重"""
from difflib import SequenceMatcher
unique_hits = []
for hit in hits:
title = hit.entity.get('title', '')
is_duplicate = False
for unique_hit in unique_hits:
unique_title = unique_hit.entity.get('title', '')
# 计算文本相似度
similarity = SequenceMatcher(None, title, unique_title).ratio()
if similarity > similarity_threshold:
is_duplicate = True
break
if not is_duplicate:
unique_hits.append(hit)
return unique_hits搜索结果多样性
MMR(Maximal Marginal Relevance)
python
def mmr_search(collection, query_vector, lambda_param=0.5, top_k=10):
"""
MMR 算法实现搜索结果多样性
Args:
lambda_param: 相关性与多样性的平衡参数
0 = 只考虑多样性,1 = 只考虑相关性
"""
# 获取候选结果
candidates = collection.search(
data=[query_vector],
anns_field="vector",
param={"metric_type": "L2", "params": {"nprobe": 32}},
limit=top_k * 3,
output_fields=["vector"]
)[0]
selected = []
while len(selected) < top_k and candidates:
best_score = -float('inf')
best_candidate = None
for candidate in candidates:
# 相关性分数(与查询的相似度)
relevance = 1 / (candidate.distance + 0.001)
# 多样性分数(与已选结果的相似度)
max_similarity = 0
for selected_item in selected:
sim = calculate_similarity(
candidate.entity.get('vector'),
selected_item.entity.get('vector')
)
max_similarity = max(max_similarity, sim)
# MMR 分数
mmr_score = lambda_param * relevance - (1 - lambda_param) * max_similarity
if mmr_score > best_score:
best_score = mmr_score
best_candidate = candidate
if best_candidate:
selected.append(best_candidate)
candidates.remove(best_candidate)
return selected实时搜索优化
缓存热门查询
python
from functools import lru_cache
import hashlib
class SearchCache:
"""搜索结果缓存"""
def __init__(self, maxsize=1000):
self.cache = {}
self.maxsize = maxsize
def _make_key(self, query_vector, params):
"""生成缓存键"""
# 将向量转换为字符串作为键
vector_str = ",".join([f"{x:.4f}" for x in query_vector])
param_str = str(sorted(params.items()))
key = hashlib.md5(f"{vector_str}:{param_str}".encode()).hexdigest()
return key
def get(self, query_vector, params):
"""获取缓存结果"""
key = self._make_key(query_vector, params)
return self.cache.get(key)
def set(self, query_vector, params, results):
"""设置缓存结果"""
if len(self.cache) >= self.maxsize:
# LRU 淘汰
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
key = self._make_key(query_vector, params)
self.cache[key] = results
# 使用示例
cache = SearchCache(maxsize=1000)
def cached_search(collection, query_vector, params):
"""带缓存的搜索"""
# 尝试从缓存获取
cached_results = cache.get(query_vector, params)
if cached_results:
return cached_results
# 执行搜索
results = collection.search(
data=[query_vector],
anns_field="vector",
param=params,
limit=10
)
# 缓存结果
cache.set(query_vector, params, results)
return results完整示例
python
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
import random
import time
def advanced_search_demo():
"""高级搜索技巧完整示例"""
# 连接 Milvus
connections.connect(host="localhost", port="19530")
# 清理旧数据
if utility.has_collection("advanced_search_demo"):
utility.drop_collection("advanced_search_demo")
# 创建集合
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128),
FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="view_count", dtype=DataType.INT64),
FieldSchema(name="publish_time", dtype=DataType.INT64)
]
schema = CollectionSchema(fields, "高级搜索演示")
collection = Collection("advanced_search_demo", schema)
# 插入测试数据
print("插入测试数据...")
categories = ["科技", "生活", "娱乐", "教育", "体育"]
data = []
for i in range(1000):
data.append({
"vector": [random.random() for _ in range(128)],
"title": f"文章_{i}_{random.choice(categories)}",
"category": random.choice(categories),
"view_count": random.randint(100, 100000),
"publish_time": int(time.time()) - random.randint(0, 30 * 86400)
})
collection.insert(data)
# 创建索引
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128}
}
collection.create_index("vector", index_params)
utility.wait_for_index_building_complete("advanced_search_demo")
collection.load()
print("\n=== 分页搜索 ===")
query_vector = [random.random() for _ in range(128)]
for page in range(1, 4):
results = search_with_pagination(collection, query_vector, page=page, page_size=5)
print(f"第 {page} 页结果:")
for hit in results:
print(f" ID: {hit.id}, 距离: {hit.distance:.4f}")
print("\n=== 业务规则重排序 ===")
results = collection.search(
data=[query_vector],
anns_field="vector",
param={"metric_type": "L2", "params": {"nprobe": 32}},
limit=20,
output_fields=["title", "category", "view_count", "publish_time"]
)
user_profile = {"preferred_categories": ["科技", "教育"]}
reranked = rerank_by_business_rules(results[0], user_profile)
print("重排序后前 5 个结果:")
for hit in reranked[:5]:
print(f" {hit.entity.get('title')} - {hit.entity.get('category')}")
# 清理
utility.drop_collection("advanced_search_demo")
print("\n演示完成!")
if __name__ == "__main__":
advanced_search_demo()下一步
掌握高级搜索技巧后,你可以: