Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions paimon-python/pypaimon/globalindex/vector_search_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def or_(self, other: GlobalIndexResult) -> GlobalIndexResult:
other_score_getter = other.score_getter()

result_or = RoaringBitmap64.or_(this_row_ids, other_row_ids)

def combined_score_getter(row_id: int) -> Optional[float]:
if row_id in this_row_ids:
return this_score_getter(row_id)
return other_score_getter(row_id)

return SimpleScoredGlobalIndexResult(result_or, combined_score_getter)

merged_scores = {}
for row_id in other_row_ids:
merged_scores[row_id] = other_score_getter(row_id)
for row_id in this_row_ids:
merged_scores[row_id] = this_score_getter(row_id)

return SimpleScoredGlobalIndexResult(result_or, lambda row_id: merged_scores.get(row_id))

def top_k(self, k: int) -> 'ScoredGlobalIndexResult':
"""Return the top-k results by score."""
Expand Down
13 changes: 8 additions & 5 deletions paimon-python/pypaimon/table/source/full_text_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta
from pypaimon.globalindex.global_index_result import GlobalIndexResult
from pypaimon.globalindex.offset_global_index_reader import OffsetGlobalIndexReader
from pypaimon.globalindex.vector_search_result import ScoredGlobalIndexResult
from pypaimon.globalindex.vector_search_result import DictBasedScoredIndexResult
from pypaimon.table.source.full_text_search_split import FullTextSearchSplit
from pypaimon.table.source.full_text_scan import FullTextScanPlan

Expand Down Expand Up @@ -60,19 +60,22 @@ def read(self, splits: List[FullTextSearchSplit]) -> GlobalIndexResult:
if not splits:
return GlobalIndexResult.create_empty()

result = ScoredGlobalIndexResult.create_empty()
merged_scores = {}
for split in splits:
split_result = self._eval(
split.row_range_start, split.row_range_end,
split.full_text_index_files
)
if split_result is not None:
result = result.or_(split_result)
score_getter = split_result.score_getter()
for row_id in split_result.results():
if row_id not in merged_scores:
merged_scores[row_id] = score_getter(row_id)

return result.top_k(self._limit)
return DictBasedScoredIndexResult(merged_scores).top_k(self._limit)

def _eval(self, row_range_start, row_range_end, full_text_index_files
) -> Optional[ScoredGlobalIndexResult]:
) -> Optional[GlobalIndexResult]:
index_io_meta_list = []
for index_file in full_text_index_files:
meta = index_file.global_index_meta
Expand Down
11 changes: 7 additions & 4 deletions paimon-python/pypaimon/table/source/vector_search_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pypaimon.globalindex.global_index_result import GlobalIndexResult
from pypaimon.globalindex.offset_global_index_reader import OffsetGlobalIndexReader
from pypaimon.globalindex.vector_search import VectorSearch
from pypaimon.globalindex.vector_search_result import ScoredGlobalIndexResult
from pypaimon.globalindex.vector_search_result import DictBasedScoredIndexResult


class VectorSearchRead(ABC):
Expand Down Expand Up @@ -57,16 +57,19 @@ def read(self, splits):

pre_filter = self._pre_filter(splits)

result = ScoredGlobalIndexResult.create_empty()
merged_scores = {}
for split in splits:
split_result = self._eval(
split.row_range_start, split.row_range_end,
split.vector_index_files, pre_filter
)
if split_result is not None:
result = result.or_(split_result)
score_getter = split_result.score_getter()
for row_id in split_result.results():
if row_id not in merged_scores:
merged_scores[row_id] = score_getter(row_id)

return result.top_k(self._limit)
return DictBasedScoredIndexResult(merged_scores).top_k(self._limit)

def _pre_filter(self, splits):
# type: (list) -> Optional[RoaringBitmap64]
Expand Down
122 changes: 122 additions & 0 deletions paimon-python/pypaimon/tests/vector_search_filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,5 +541,127 @@ def test_with_partition_filter_rejects_non_partition_field(self):
self.assertIn("non-partition", str(ctx.exception))


class VectorSearchManySplitsTest(unittest.TestCase):

def test_vector_search_with_many_splits(self):
from pypaimon.globalindex.vector_search_result import (
DictBasedScoredIndexResult,
)
from pypaimon.table.source.vector_search_read import VectorSearchReadImpl
from pypaimon.table.source.vector_search_split import VectorSearchSplit

num_splits = 1200
embedding_field = _field(1, "embedding", "FLOAT")
entries = [
_entry(None, field_id=1, index_type="lumina-vector-ann",
file_name="vec-%d.index" % i,
row_range_start=i, row_range_end=i)
for i in range(num_splits)
]
table = _StubTable(fields=[embedding_field], entries=entries)
_patch_snapshot(self, entries)

def _fake_create(index_type, file_io, index_path,
index_io_meta_list, options=None):
row_id = index_io_meta_list[0].file_name
row_id = int(row_id.split("-")[1].split(".")[0])

class _FakeReader:
def visit_vector_search(self_inner, vs):
return DictBasedScoredIndexResult({row_id: float(row_id)})

def close(self_inner):
pass

def __enter__(self_inner):
return self_inner

def __exit__(self_inner, *a):
return False
return _FakeReader()

splits = [
VectorSearchSplit(
row_range_start=i, row_range_end=i,
vector_index_files=[entries[i].index_file])
for i in range(num_splits)
]

with mock.patch(
"pypaimon.table.source.vector_search_read._create_vector_reader",
side_effect=_fake_create):
reader = VectorSearchReadImpl(
table, limit=10, vector_column=embedding_field,
query_vector=[1.0], filter_=None)
result = reader.read(splits)

self.assertLessEqual(result.results().cardinality(), 10)
self.assertEqual(result.results().cardinality(), 10)
scores = sorted(result.score_getter()(rid) for rid in result.results())
self.assertEqual(scores, [float(i) for i in range(1190, 1200)])

def tearDown(self):
mock.patch.stopall()


class FullTextSearchManySplitsTest(unittest.TestCase):

def test_full_text_search_with_many_splits(self):
from pypaimon.globalindex.vector_search_result import (
DictBasedScoredIndexResult,
)
from pypaimon.table.source.full_text_read import FullTextReadImpl
from pypaimon.table.source.full_text_search_split import (
FullTextSearchSplit,
)

num_splits = 1200
text_field = _field(1, "content", "STRING")
entries = [
_entry(None, field_id=1, index_type="tantivy-fulltext",
file_name="ft-%d.index" % i,
row_range_start=i, row_range_end=i)
for i in range(num_splits)
]
table = _StubTable(fields=[text_field], entries=entries)
_patch_snapshot(self, entries)

def _fake_create(index_type, file_io, index_path,
index_io_meta_list):
row_id = index_io_meta_list[0].file_name
row_id = int(row_id.split("-")[1].split(".")[0])

class _FakeReader:
def visit_full_text_search(self_inner, fts):
return DictBasedScoredIndexResult({row_id: float(row_id)})

def close(self_inner):
pass
return _FakeReader()

splits = [
FullTextSearchSplit(
row_range_start=i, row_range_end=i,
full_text_index_files=[entries[i].index_file])
for i in range(num_splits)
]

with mock.patch(
"pypaimon.table.source.full_text_read._create_full_text_reader",
side_effect=_fake_create):
reader = FullTextReadImpl(
table, limit=10, text_column=text_field,
query_text="test")
result = reader.read(splits)

self.assertLessEqual(result.results().cardinality(), 10)
self.assertEqual(result.results().cardinality(), 10)
scores = sorted(result.score_getter()(rid) for rid in result.results())
self.assertEqual(scores, [float(i) for i in range(1190, 1200)])

def tearDown(self):
mock.patch.stopall()


if __name__ == "__main__":
unittest.main()
Loading