diff --git a/paimon-python/pypaimon/globalindex/vector_search_result.py b/paimon-python/pypaimon/globalindex/vector_search_result.py index a93247e69063..5e6d2fb46404 100644 --- a/paimon-python/pypaimon/globalindex/vector_search_result.py +++ b/paimon-python/pypaimon/globalindex/vector_search_result.py @@ -70,13 +70,14 @@ def or_(self, other: GlobalIndexResult) -> GlobalIndexResult: other_score_getter = other.score_getter() result_or = RoaringBitmap64.or_(this_row_ids, other_row_ids) - - def combined_score_getter(row_id: int) -> Optional[float]: - if row_id in this_row_ids: - return this_score_getter(row_id) - return other_score_getter(row_id) - - return SimpleScoredGlobalIndexResult(result_or, combined_score_getter) + + merged_scores = {} + for row_id in other_row_ids: + merged_scores[row_id] = other_score_getter(row_id) + for row_id in this_row_ids: + merged_scores[row_id] = this_score_getter(row_id) + + return SimpleScoredGlobalIndexResult(result_or, lambda row_id: merged_scores.get(row_id)) def top_k(self, k: int) -> 'ScoredGlobalIndexResult': """Return the top-k results by score.""" diff --git a/paimon-python/pypaimon/table/source/full_text_read.py b/paimon-python/pypaimon/table/source/full_text_read.py index 9be1c2e4caa4..f8fb1373fb65 100644 --- a/paimon-python/pypaimon/table/source/full_text_read.py +++ b/paimon-python/pypaimon/table/source/full_text_read.py @@ -25,7 +25,7 @@ from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.globalindex.offset_global_index_reader import OffsetGlobalIndexReader -from pypaimon.globalindex.vector_search_result import ScoredGlobalIndexResult +from pypaimon.globalindex.vector_search_result import DictBasedScoredIndexResult from pypaimon.table.source.full_text_search_split import FullTextSearchSplit from pypaimon.table.source.full_text_scan import FullTextScanPlan @@ -60,19 +60,22 @@ def read(self, splits: List[FullTextSearchSplit]) -> GlobalIndexResult: if not splits: return GlobalIndexResult.create_empty() - result = ScoredGlobalIndexResult.create_empty() + merged_scores = {} for split in splits: split_result = self._eval( split.row_range_start, split.row_range_end, split.full_text_index_files ) if split_result is not None: - result = result.or_(split_result) + score_getter = split_result.score_getter() + for row_id in split_result.results(): + if row_id not in merged_scores: + merged_scores[row_id] = score_getter(row_id) - return result.top_k(self._limit) + return DictBasedScoredIndexResult(merged_scores).top_k(self._limit) def _eval(self, row_range_start, row_range_end, full_text_index_files - ) -> Optional[ScoredGlobalIndexResult]: + ) -> Optional[GlobalIndexResult]: index_io_meta_list = [] for index_file in full_text_index_files: meta = index_file.global_index_meta diff --git a/paimon-python/pypaimon/table/source/vector_search_read.py b/paimon-python/pypaimon/table/source/vector_search_read.py index 24c399a19213..3c4b3f8e1dde 100644 --- a/paimon-python/pypaimon/table/source/vector_search_read.py +++ b/paimon-python/pypaimon/table/source/vector_search_read.py @@ -24,7 +24,7 @@ from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.globalindex.offset_global_index_reader import OffsetGlobalIndexReader from pypaimon.globalindex.vector_search import VectorSearch -from pypaimon.globalindex.vector_search_result import ScoredGlobalIndexResult +from pypaimon.globalindex.vector_search_result import DictBasedScoredIndexResult class VectorSearchRead(ABC): @@ -57,16 +57,19 @@ def read(self, splits): pre_filter = self._pre_filter(splits) - result = ScoredGlobalIndexResult.create_empty() + merged_scores = {} for split in splits: split_result = self._eval( split.row_range_start, split.row_range_end, split.vector_index_files, pre_filter ) if split_result is not None: - result = result.or_(split_result) + score_getter = split_result.score_getter() + for row_id in split_result.results(): + if row_id not in merged_scores: + merged_scores[row_id] = score_getter(row_id) - return result.top_k(self._limit) + return DictBasedScoredIndexResult(merged_scores).top_k(self._limit) def _pre_filter(self, splits): # type: (list) -> Optional[RoaringBitmap64] diff --git a/paimon-python/pypaimon/tests/vector_search_filter_test.py b/paimon-python/pypaimon/tests/vector_search_filter_test.py index 857c62e28677..6a136f7a0df7 100644 --- a/paimon-python/pypaimon/tests/vector_search_filter_test.py +++ b/paimon-python/pypaimon/tests/vector_search_filter_test.py @@ -541,5 +541,127 @@ def test_with_partition_filter_rejects_non_partition_field(self): self.assertIn("non-partition", str(ctx.exception)) +class VectorSearchManySplitsTest(unittest.TestCase): + + def test_vector_search_with_many_splits(self): + from pypaimon.globalindex.vector_search_result import ( + DictBasedScoredIndexResult, + ) + from pypaimon.table.source.vector_search_read import VectorSearchReadImpl + from pypaimon.table.source.vector_search_split import VectorSearchSplit + + num_splits = 1200 + embedding_field = _field(1, "embedding", "FLOAT") + entries = [ + _entry(None, field_id=1, index_type="lumina-vector-ann", + file_name="vec-%d.index" % i, + row_range_start=i, row_range_end=i) + for i in range(num_splits) + ] + table = _StubTable(fields=[embedding_field], entries=entries) + _patch_snapshot(self, entries) + + def _fake_create(index_type, file_io, index_path, + index_io_meta_list, options=None): + row_id = index_io_meta_list[0].file_name + row_id = int(row_id.split("-")[1].split(".")[0]) + + class _FakeReader: + def visit_vector_search(self_inner, vs): + return DictBasedScoredIndexResult({row_id: float(row_id)}) + + def close(self_inner): + pass + + def __enter__(self_inner): + return self_inner + + def __exit__(self_inner, *a): + return False + return _FakeReader() + + splits = [ + VectorSearchSplit( + row_range_start=i, row_range_end=i, + vector_index_files=[entries[i].index_file]) + for i in range(num_splits) + ] + + with mock.patch( + "pypaimon.table.source.vector_search_read._create_vector_reader", + side_effect=_fake_create): + reader = VectorSearchReadImpl( + table, limit=10, vector_column=embedding_field, + query_vector=[1.0], filter_=None) + result = reader.read(splits) + + self.assertLessEqual(result.results().cardinality(), 10) + self.assertEqual(result.results().cardinality(), 10) + scores = sorted(result.score_getter()(rid) for rid in result.results()) + self.assertEqual(scores, [float(i) for i in range(1190, 1200)]) + + def tearDown(self): + mock.patch.stopall() + + +class FullTextSearchManySplitsTest(unittest.TestCase): + + def test_full_text_search_with_many_splits(self): + from pypaimon.globalindex.vector_search_result import ( + DictBasedScoredIndexResult, + ) + from pypaimon.table.source.full_text_read import FullTextReadImpl + from pypaimon.table.source.full_text_search_split import ( + FullTextSearchSplit, + ) + + num_splits = 1200 + text_field = _field(1, "content", "STRING") + entries = [ + _entry(None, field_id=1, index_type="tantivy-fulltext", + file_name="ft-%d.index" % i, + row_range_start=i, row_range_end=i) + for i in range(num_splits) + ] + table = _StubTable(fields=[text_field], entries=entries) + _patch_snapshot(self, entries) + + def _fake_create(index_type, file_io, index_path, + index_io_meta_list): + row_id = index_io_meta_list[0].file_name + row_id = int(row_id.split("-")[1].split(".")[0]) + + class _FakeReader: + def visit_full_text_search(self_inner, fts): + return DictBasedScoredIndexResult({row_id: float(row_id)}) + + def close(self_inner): + pass + return _FakeReader() + + splits = [ + FullTextSearchSplit( + row_range_start=i, row_range_end=i, + full_text_index_files=[entries[i].index_file]) + for i in range(num_splits) + ] + + with mock.patch( + "pypaimon.table.source.full_text_read._create_full_text_reader", + side_effect=_fake_create): + reader = FullTextReadImpl( + table, limit=10, text_column=text_field, + query_text="test") + result = reader.read(splits) + + self.assertLessEqual(result.results().cardinality(), 10) + self.assertEqual(result.results().cardinality(), 10) + scores = sorted(result.score_getter()(rid) for rid in result.results()) + self.assertEqual(scores, [float(i) for i in range(1190, 1200)]) + + def tearDown(self): + mock.patch.stopall() + + if __name__ == "__main__": unittest.main()