Source code for milvus.client.utils
from ..grpc_gen import status_pb2
from ..grpc_gen.milvus_pb2 import TopKQueryResult as Grpc_Result
from ..client.abstract import TopKQueryResult
from ..client.exceptions import ParamError
[docs]def merge_results(results_list, topk, *args, **kwargs):
"""
merge query results
"""
def _reduce(source_ids, ids, source_diss, diss, k, reverse):
"""
"""
if source_diss[k - 1] <= diss[0]:
return source_ids, source_diss
if diss[k - 1] <= source_diss[0]:
return ids, diss
source_diss.extend(diss)
diss_t = enumerate(source_diss)
diss_m_rst = sorted(diss_t, key=lambda x: x[1], reverse=reverse)[:k]
diss_m_out = [id_ for _, id_ in diss_m_rst]
source_ids.extend(ids)
id_m_out = [source_ids[i] for i, _ in diss_m_rst]
return id_m_out, diss_m_out
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
reason="Success")
reverse = kwargs.get('reverse', False)
raw = kwargs.get('raw', False)
if not results_list:
return status, [], []
merge_id_results = []
merge_dis_results = []
row_num = 0
for files_collection in results_list:
if not isinstance(files_collection, Grpc_Result) and \
not isinstance(files_collection, TopKQueryResult):
return ParamError("Result type is unknown.")
row_num = files_collection.row_num
if not row_num:
continue
ids = files_collection.ids
diss = files_collection.distances # distance collections
# Notice: batch_len is equal to topk, may need to compare with topk
batch_len = len(ids) // row_num
for row_index in range(row_num):
id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len]
dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len]
if len(merge_id_results) < row_index:
raise ValueError("merge error")
if len(merge_id_results) == row_index:
merge_id_results.append(id_batch)
merge_dis_results.append(dis_batch)
else:
merge_id_results[row_index], merge_dis_results[row_index] = \
_reduce(merge_id_results[row_index], id_batch,
merge_dis_results[row_index], dis_batch,
batch_len,
reverse)
id_mrege_list = []
dis_mrege_list = []
for id_results, dis_results in zip(merge_id_results, merge_dis_results):
id_mrege_list.extend(id_results)
dis_mrege_list.extend(dis_results)
raw_result = Grpc_Result(
status=status,
row_num=row_num,
ids=id_mrege_list,
distances=dis_mrege_list
)
return raw_result if raw else TopKQueryResult(raw_result)