diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index eeae8f559f4b..afe5bc0c341c 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -149,6 +149,8 @@ Optimizations * GITHUB#15597, GITHUB#15777: Reduce memory usage of NeighborArray (Viliam Durina) +* GITHUB#16034: Sibling expansion as an optimization for KNN vector search over parent-child document relationships. (Anna Ruggero, Alessandro Benedetti) + Bug Fixes --------------------- * GITHUB#14049: Randomize KNN codec params in RandomCodec. Fixes scalar quantization div-by-zero diff --git a/lucene/benchmark-jmh/build.gradle b/lucene/benchmark-jmh/build.gradle index 6f874e410b9b..6d64bb7a4f72 100644 --- a/lucene/benchmark-jmh/build.gradle +++ b/lucene/benchmark-jmh/build.gradle @@ -20,6 +20,7 @@ description = 'Lucene JMH micro-benchmarking module' dependencies { moduleImplementation project(':lucene:core') moduleImplementation project(':lucene:expressions') + moduleImplementation project(':lucene:join') moduleImplementation project(':lucene:sandbox') moduleTestImplementation project(':lucene:test-framework') diff --git a/lucene/benchmark-jmh/src/java/module-info.java b/lucene/benchmark-jmh/src/java/module-info.java index bb6b9d516bb0..8090c7554739 100644 --- a/lucene/benchmark-jmh/src/java/module-info.java +++ b/lucene/benchmark-jmh/src/java/module-info.java @@ -24,6 +24,7 @@ requires jdk.unsupported; requires org.apache.lucene.core; requires org.apache.lucene.expressions; + requires org.apache.lucene.join; requires org.apache.lucene.sandbox; requires commons.math3; diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/DiversifyingChildrenKnnQueryBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/DiversifyingChildrenKnnQueryBenchmark.java new file mode 100644 index 000000000000..f6f0f10358c8 --- /dev/null +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/DiversifyingChildrenKnnQueryBenchmark.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.benchmark.jmh; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.IOUtils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +/** + * Benchmarks end-to-end latency of {@link DiversifyingChildrenFloatKnnVectorQuery} with sibling + * expansion enabled across three sibling-correlation scenarios: + * + * + * + * Run with: + * + *
+ *   ./gradlew -p lucene/benchmark-jmh assemble
+ *   java -jar lucene/benchmark-jmh/build/benchmarks/lucene-benchmark-jmh-*.jar DiversifyingChildrenKnnQueryBenchmark
+ * 
+ */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +// 4 iterations 1 second each - results discarded +@Warmup(iterations = 4, time = 1) +// 5 iterations 1 second each - results recorded (how many calls we can do in 1 sec) +@Measurement(iterations = 5, time = 1) +// 3 separate JVM processes +@Fork( + value = 3, + jvmArgsAppend = {"-Xmx4g", "-Xms4g", "-XX:+AlwaysPreTouch"}) +public class DiversifyingChildrenKnnQueryBenchmark { + + private static final String FIELD = "vec"; + private static final String PARENT_FIELD = "docType"; + private static final String PARENT_VALUE = "_parent"; + private static final int NUM_QUERY_VECTORS = 256; + + /** + * Sibling correlation scenario: + * + * + */ + @Param({"best", "standard", "worst"}) + public String siblingCorrelation; + + @Param({"5000"}) + public int numParents; + + @Param({"4", "8", "16"}) + public int childrenPerParent; + + @Param({"10", "100"}) + public int k; + + @Param({"128"}) + public int dim; + + private Path tmpDir; + private Directory dir; + private IndexReader reader; + private IndexSearcher searcher; + private QueryBitSetProducer parentFilter; + private float[][] queryVectors; + private int queryIdx; + + @Setup(Level.Trial) + public void setup() throws IOException { + tmpDir = Files.createTempDirectory("DiversifyingChildrenKnnQueryBenchmark"); + dir = MMapDirectory.open(tmpDir); + + // How much siblings are near to each other + float noiseLevel = + switch (siblingCorrelation) { + case "best" -> 0.05f; // nearly identical + case "standard" -> 0.30f; // moderately correlated + default -> Float.NaN; // worst: fully random, no centroid + }; + + Random rnd = new Random(42); + // index creation + try (IndexWriter w = new IndexWriter(dir, new IndexWriterConfig())) { + // 5000 parents + for (int p = 0; p < numParents; p++) { + // vector of 128-dim + float[] centroid = Float.isNaN(noiseLevel) ? null : randomUnitVector(dim, rnd); + List block = new ArrayList<>(); + // 4 - 8 - 16 children per parent + for (int c = 0; c < childrenPerParent; c++) { + float[] vec = + centroid == null + ? randomUnitVector(dim, rnd) + : perturbedUnitVector(centroid, noiseLevel, rnd); + // create child doc + Document child = new Document(); + child.add(new KnnFloatVectorField(FIELD, vec, VectorSimilarityFunction.DOT_PRODUCT)); + // add to the index block + block.add(child); + } + // create parent document + Document parent = new Document(); + // docType = _parent + parent.add(new StringField(PARENT_FIELD, PARENT_VALUE, Field.Store.NO)); + // add to the index block + block.add(parent); + // add to the index writer + w.addDocuments(block); + } + // compress to one segment + w.forceMerge(1); + } + + reader = DirectoryReader.open(dir); + searcher = new IndexSearcher(reader); + // parent filter docType = _parent + parentFilter = new QueryBitSetProducer(new TermQuery(new Term(PARENT_FIELD, PARENT_VALUE))); + + Random qrnd = new Random(123); + queryVectors = new float[NUM_QUERY_VECTORS][]; + for (int i = 0; i < NUM_QUERY_VECTORS; i++) { + // random query vectors + queryVectors[i] = randomUnitVector(dim, qrnd); + } + } + + @TearDown(Level.Trial) + public void teardown() throws IOException { + IOUtils.close(reader, dir); + IOUtils.rm(tmpDir); + } + + @Benchmark + public TopDocs search() throws IOException { + // benchmarked part - search + // iterates on all the queries in a round-robin + float[] query = queryVectors[queryIdx++ & (NUM_QUERY_VECTORS - 1)]; + Query knnQuery = + new DiversifyingChildrenFloatKnnVectorQuery(FIELD, query, null, k, parentFilter); + return searcher.search(knnQuery, k); + } + + private static float[] randomUnitVector(int dim, Random rnd) { + float[] v = new float[dim]; + for (int i = 0; i < dim; i++) v[i] = rnd.nextFloat() * 2 - 1; + return normalise(v); + } + + /** Returns a unit vector near {@code centroid} with per-dimension noise scaled by noiseLevel. */ + private static float[] perturbedUnitVector(float[] centroid, float noiseLevel, Random rnd) { + float[] v = new float[centroid.length]; + for (int i = 0; i < centroid.length; i++) { + v[i] = centroid[i] + noiseLevel * (rnd.nextFloat() * 2 - 1); + } + return normalise(v); + } + + // Since we use DOT PRODUCT + private static float[] normalise(float[] v) { + float norm = 0; + for (float x : v) norm += x * x; + norm = (float) Math.sqrt(norm); + for (int i = 0; i < v.length; i++) v[i] /= norm; + return v; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java index 2f668d966d41..db9ef79b2163 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java @@ -96,14 +96,73 @@ protected static void scoreEntryPoints( assert scores != null && scores.length >= eps.length; scorer.bulkScore(eps, scores, eps.length); results.incVisitedCount(eps.length); + float[] siblingScores = null; + int[] siblingsOrd = new int[0]; for (int i = 0; i < eps.length; i++) { float score = scores[i]; int ep = eps[i]; visited.set(ep); candidates.add(ep, score); if (acceptOrds == null || acceptOrds.get(ep)) { + // Fetch siblingsOrd BEFORE collect() so the parent is not yet in the heap + // The instanceof check is needed: this method is also called with a + // GraphBuilderKnnCollector + if (results instanceof OrdinalTranslatedKnnCollector collector) { + if (collector.isSiblingExpansionCollector()) { + siblingsOrd = collector.getSiblingOrdinals(ep, visited, siblingsOrd); + for (int ord : siblingsOrd) visited.set(ord); + } + } + // Collect the ep node here so after we have a correctly updated minCompetitiveSimilarity results.collect(ep, score); + if (siblingsOrd.length > 0) { + siblingScores = + scoreHnswNodes( + results, + scorer, + candidates, + acceptOrds, + siblingsOrd, + siblingScores); + } } } } + + /** + * Scores and collects siblings, adding competitive ones to the candidate queue. Reuses and + * returns the siblingScores buffer, reallocating only if too small. + */ + protected static float[] scoreHnswNodes( + KnnCollector results, + RandomVectorScorer scorer, + NeighborQueue candidates, + Bits acceptOrds, + int[] hnswNodesOrd, + float[] scores) + throws IOException { + int numNodes = hnswNodesOrd.length; + // If scores not defined yet or too small to collect scores a new one is created + // Otherwise we reuse the old one that will be overridden in bulkScore with new scores + if (scores == null || scores.length < numNodes) { + scores = new float[numNodes]; + } + float maxScore = scorer.bulkScore(hnswNodesOrd, scores, numNodes); + results.incVisitedCount(numNodes); + if (maxScore > results.minCompetitiveSimilarity()) { + float minSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + for (int j = 0; j < numNodes; j++) { + float sibScore = scores[j]; + // We avoid adding to candidates a sibling with a bad score + if (sibScore >= minSimilarity) { + candidates.add(hnswNodesOrd[j], sibScore); + if (acceptOrds == null || acceptOrds.get(hnswNodesOrd[j])) { + results.collect(hnswNodesOrd[j], sibScore); + minSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + } + } + } + } + return scores; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/DocSiblingExpansion.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/DocSiblingExpansion.java new file mode 100644 index 000000000000..d0200f58b06c --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/DocSiblingExpansion.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** + * Implemented by collectors that understand parent-child document relationships and can enumerate + * sibling document ids for a given child document id, as well as translate document ids back to + * vector ordinals. + * + *

This interface is used internally by {@link OrdinalTranslatedKnnCollector} to bridge between + * the ordinal space of the HNSW graph and the document-id space of the collector. + * + * @lucene.experimental + */ +// The interface cannot be removed. It exists for a module-boundary reason. +// DocSiblingExpansion is in lucene/core, while DiversifyingNearestChildrenKnnCollector is in +// lucene/join. +// The dependency is one-way: join depends on core, never the reverse. So +// OrdinalTranslatedKnnCollector (in core) +// has no way to reference DiversifyingNearestChildrenKnnCollector directly. +// +// The interface is the bridge — it lets core call findSiblingDocIds and docIdToOrdinal on the +// collector without +// creating a circular dependency. Removing it would require either moving +// OrdinalTranslatedKnnCollector into +// join (bigger refactor) or adding a core → join dependency (illegal in this architecture). +public interface DocSiblingExpansion { + + /** + * Returns the doc ids of all siblings of {@code childDocId}, or an empty array if there are no + * other siblings. + * + * @param childDocId the document id of the child that is about to be collected + * @return sibling doc ids, or an empty array + */ + int[] findSiblingDocIds(int childDocId); + + /** + * Translates a document id to its vector ordinal, or returns {@code -1} if the document has no + * vector in this field. + * + * @param docId the document id + * @return the vector ordinal, or {@code -1} + */ + int docIdToOrdinal(int docId); +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index d739915ca078..786fc8ee53da 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -44,6 +44,8 @@ public class HnswGraphSearcher extends AbstractHnswGraphSearcher { protected int[] bulkNodes = null; protected float[] bulkScores = null; + protected float[] siblingScores = null; + protected int[] siblingsOrd = new int[0]; /** * HNSW search is roughly logarithmic. This doesn't take maxConn into account, but it is a pretty @@ -347,6 +349,15 @@ void searchLevel( if (score >= minAcceptedSimilarity) { candidates.add(node, score); if (acceptOrds == null || acceptOrds.get(node)) { + // Fetch siblingsOrd BEFORE collect() so the parent is not yet in the heap + // The instanceof check is needed: this method is also called with a + // GraphBuilderKnnCollector + if (results instanceof OrdinalTranslatedKnnCollector collector) { + if (collector.isSiblingExpansionCollector()) { + siblingsOrd = collector.getSiblingOrdinals(node, visited, siblingsOrd); + for (int ord : siblingsOrd) visited.set(ord); + } + } if (results.collect(node, score)) { float oldMinAcceptedSimilarity = minAcceptedSimilarity; minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); @@ -356,6 +367,22 @@ void searchLevel( shouldExploreMinSim = true; } } + // Score and collect all siblingsOrd of the newly-discovered parent + if (siblingsOrd.length > 0) { + float prevMinSim = results.minCompetitiveSimilarity(); + siblingScores = + scoreHnswNodes( + results, + scorer, + candidates, + acceptOrds, + siblingsOrd, + siblingScores); + if (results.minCompetitiveSimilarity() > prevMinSim) { + minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity()); + shouldExploreMinSim = true; + } + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index 5225fe700ab9..d0afb7e0da93 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -20,9 +20,13 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BitSet; /** - * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId + * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId. + * Sibling expansion is active only when the wrapped collector also implements {@link + * DocSiblingExpansion}. */ public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator { @@ -50,4 +54,45 @@ public TopDocs topDocs() { : TotalHits.Relation.EQUAL_TO), td.scoreDocs); } + + // Needed since we could have a TopKnnCollector at this point + public boolean isSiblingExpansionCollector() { + return collector instanceof DocSiblingExpansion; + } + + public int[] getSiblingOrdinals(int hnswNode, BitSet visitedHnswNodes, int[] siblingOrdinals) { + DocSiblingExpansion docExpanderCollector = (DocSiblingExpansion) collector; + int docId = vectorOrdinalToDocId.apply(hnswNode); + // We do not check if parent is in heap since if we already seed A + // - A was found and scored (parent added), but we reach the budget limit and were not able to + // score B, + // we then found B through graph traversal. We want to visit it even if we already visited A. + // We do not visit siblings in score order. + int[] siblingDocIds = docExpanderCollector.findSiblingDocIds(docId); + if (siblingOrdinals.length < siblingDocIds.length) { + siblingOrdinals = new int[siblingDocIds.length]; + } + // siblingOrdinals is pre-allocated to siblingDocIds.length and Java initializes int arrays to 0 + // so this variable is necessary. + // due to visited result we could have a partial array to return + int count = 0; + for (int sibDocId : siblingDocIds) { + int sibOrd = docExpanderCollector.docIdToOrdinal(sibDocId); + // sibOrd = -1: sibling has no vector for this field → no HNSW node, cannot be scored. + // + // !visitedHnswNodes: sibling was already reached via normal graph traversal: + // - B was scored via traversal but with score < minAcceptedSimilarity, so collect() + // was never called and the parent is not in the heap. Expansion from a different + // child A finds B already visited. + if (sibOrd >= 0 && !visitedHnswNodes.get(sibOrd)) { + siblingOrdinals[count++] = sibOrd; + } + } + if (count == 0) { + return new int[0]; + } + return count < siblingOrdinals.length + ? ArrayUtil.copyOfSubArray(siblingOrdinals, 0, count) + : siblingOrdinals; + } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index 877d004d5c78..77dde715426e 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -156,7 +156,7 @@ protected TopDocs exactSearch( @Override protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { - return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher); + return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher, field); } @Override diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index e2d6179ec426..c5f2f95ddabf 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -155,7 +155,7 @@ protected TopDocs exactSearch( @Override protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { - return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher); + return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher, field); } @Override diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java index 59cb58d22888..404f8cd156c9 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollector.java @@ -25,40 +25,42 @@ import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.hnsw.DocSiblingExpansion; /** * This collects the nearest children vectors. Diversifying the results over the provided parent * filter. This means the nearest children vectors are returned, but only one per parent */ -class DiversifyingNearestChildrenKnnCollector extends AbstractKnnCollector { +class DiversifyingNearestChildrenKnnCollector extends AbstractKnnCollector + implements DocSiblingExpansion { private final BitSet parentBitSet; private final NodeIdCachingHeap heap; + // docId → vector ordinal mapping; -1 means the doc has no vector. Null when not needed. + private final int[] docToOrd; /** - * Create a new object for joining nearest child kNN documents with a parent bitset - * - * @param k The number of joined parent documents to collect - * @param visitLimit how many child vectors can be visited - * @param parentBitSet The leaf parent bitset - */ - public DiversifyingNearestChildrenKnnCollector(int k, int visitLimit, BitSet parentBitSet) { - this(k, visitLimit, null, parentBitSet); - } - - /** - * Create a new object for joining nearest child kNN documents with a parent bitset + * Create a new object for joining nearest child kNN documents with a parent bitset and a + * precomputed docId-to-ordinal mapping that enables dynamic sibling expansion during HNSW graph + * search. * * @param k The number of joined parent documents to collect * @param visitLimit how many child vectors can be visited * @param searchStrategy The search strategy to use * @param parentBitSet The leaf parent bitset + * @param docToOrd precomputed array mapping docId to vector ordinal; {@code -1} entries mean the + * document has no vector. An empty array disables sibling expansion. */ public DiversifyingNearestChildrenKnnCollector( - int k, int visitLimit, KnnSearchStrategy searchStrategy, BitSet parentBitSet) { + int k, + int visitLimit, + KnnSearchStrategy searchStrategy, + BitSet parentBitSet, + int[] docToOrd) { super(k, visitLimit, searchStrategy); this.parentBitSet = parentBitSet; this.heap = new NodeIdCachingHeap(k); + this.docToOrd = docToOrd; } /** @@ -93,9 +95,6 @@ public String toString() { @Override public TopDocs topDocs() { assert heap.size() <= k() : "Tried to collect more results than the maximum number allowed"; - while (heap.size() > k()) { - heap.popToDrain(); - } ScoreDoc[] scoreDocs = new ScoreDoc[heap.size()]; for (int i = 1; i <= scoreDocs.length; i++) { scoreDocs[scoreDocs.length - i] = new ScoreDoc(heap.topNode(), heap.topScore()); @@ -114,6 +113,48 @@ public int numCollected() { return heap.size(); } + @Override + public int[] findSiblingDocIds(int childDocId) { + int parent = parentBitSet.nextSetBit(childDocId); + // Find siblings range(prevParent, parent). If parent is 0 there are not prevParents so -1 + int prevParent = parent > 0 ? parentBitSet.prevSetBit(parent - 1) : -1; + int from = prevParent + 1; + // Children of parent are all docIds in [from, parent); exclude childDocId itself + // The childDoc itself is scored in collect() + int siblingsSize = parent - from - 1; + // One child case + if (siblingsSize == 0) { + return new int[0]; + } + int[] siblings = new int[siblingsSize]; + int idx = 0; + for (int docId = from; docId < parent; docId++) { + if (docId != childDocId) { + siblings[idx++] = docId; + } + } + return idx > 0 ? siblings : new int[0]; + } + + @Override + public int docIdToOrdinal(int docId) { + // Conditions explanation: + // docToOrd == null — buildDocToOrd returns null when the segment has no vector values at all + // for the field (line 113). Without this guard you'd get a NullPointerException. + // docId >= docToOrd.length — In production, docToOrd is sized exactly maxDoc for that segment + // (line 116), so every valid docId in that segment fits. This check is therefore a pure + // defensive guard — it can't be triggered by + // correct production code, but it protects against: + // - bugs in how the array is built or passed + // - future refactoring that changes the array size + // - misuse when constructing the collector manually (as in tests, e.g. the + // docIdToOrdinal(9999) test case) + if (docId >= docToOrd.length) { + return -1; + } + return docToOrd[docId]; + } + /** * This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead * of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this @@ -181,15 +222,9 @@ private void pushIn(int nodeId, int parentId, float score) { private void updateElement(int heapIndex, int nodeId, int parentId, float score) { assert parentNodes[heapIndex] == parentId : "attempted to update heap element value but with a different parent id"; - float oldScore = scores[heapIndex]; childNodes[heapIndex] = nodeId; scores[heapIndex] = score; - // Since we are a min heap, if the new value is less, we need to make sure to bubble it up - if (score < oldScore) { - upHeap(heapIndex); - } else { - downHeap(heapIndex); - } + downHeap(heapIndex); } /** @@ -284,7 +319,7 @@ private void upHeap(int origPos) { scores[i] = savedScore; } - private int downHeap(int i) { + private void downHeap(int i) { int savedChild = childNodes[i]; int savedParent = parentNodes[i]; float savedScore = scores[i]; @@ -309,7 +344,6 @@ private int downHeap(int i) { childNodes[i] = savedChild; parentNodes[i] = savedParent; scores[i] = savedScore; - return i; } // Used only during popToDrain: the index map is never read again after closed=true, diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java index 7ebacdd8194d..845a0267c5e2 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java @@ -18,7 +18,12 @@ package org.apache.lucene.search.join; import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.knn.KnnCollectorManager; @@ -31,21 +36,30 @@ */ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollectorManager { + // Cache keyed by (segment core cache key → (field name → docToOrd array)). + // Entries are evicted automatically when the segment is closed via addClosedListener. + private static final ConcurrentHashMap> + DOC_TO_ORD_CACHE = new ConcurrentHashMap<>(); + // the number of docs to collect private final int k; // filter identifying the parent documents. private final BitSetProducer parentsFilter; + // vector field name; used to build the docId-to-ordinal mapping for sibling expansion + private final String field; /** * Constructor * * @param k - the number of top k vectors to collect * @param parentsFilter Filter identifying the parent documents. + * @param field the vector field name */ public DiversifyingNearestChildrenKnnCollectorManager( - int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher) { + int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher, String field) { this.k = k; this.parentsFilter = parentsFilter; + this.field = field; } /** @@ -62,8 +76,9 @@ public KnnCollector newCollector( if (parentBitSet == null) { return null; } + int[] docToOrd = getCachedDocToOrd(context); return new DiversifyingNearestChildrenKnnCollector( - k, visitedLimit, searchStrategy, parentBitSet); + k, visitedLimit, searchStrategy, parentBitSet, docToOrd); } @Override @@ -74,12 +89,115 @@ public KnnCollector newOptimisticCollector( if (parentBitSet == null) { return null; } + int[] docToOrd = getCachedDocToOrd(context); return new DiversifyingNearestChildrenKnnCollector( - k, visitedLimit, searchStrategy, parentBitSet); + k, visitedLimit, searchStrategy, parentBitSet, docToOrd); } @Override public boolean isOptimistic() { return true; } + + /** + * Returns the docId-to-ordinal array for the given leaf, building and caching it on first access. + * The cached array is evicted automatically when the segment closes. + */ + private int[] getCachedDocToOrd(LeafReaderContext context) throws IOException { + IndexReader.CacheHelper cacheHelper = context.reader().getCoreCacheHelper(); + if (cacheHelper == null) { + return buildDocToOrd(context); + } + IndexReader.CacheKey cacheKey = cacheHelper.getKey(); + ConcurrentHashMap fieldMap = new ConcurrentHashMap<>(); + ConcurrentHashMap existing = DOC_TO_ORD_CACHE.putIfAbsent(cacheKey, fieldMap); + if (existing == null) { + // We inserted the new entry — register cleanup when the segment closes + cacheHelper.addClosedListener(DOC_TO_ORD_CACHE::remove); + } else { + fieldMap = existing; + } + int[] cached = fieldMap.get(field); + if (cached != null) { + return cached; + } + int[] built = buildDocToOrd(context); + int[] race = fieldMap.putIfAbsent(field, built); + return race != null ? race : built; + } + + /** + * Builds a docId-to-ordinal array for the given leaf, mapping each docId to its vector ordinal. + * + *

Returns an empty array if the field has no vector values in this segment at all — sibling + * expansion will be disabled for this leaf. + * + *

Otherwise returns an array of size {@code maxDoc} where each entry is the vector ordinal for + * that docId, or {@code -1} if that specific document has no vector (sparse indexing). + */ + // Step 1 — Index time: ordinals are assigned by insertion order + // In Lucene99FlatVectorsWriter.addValue(), each vector is appended to an ArrayList + // (vectors.add(copy)) and its docId + // is recorded in docsWithField. Documents are always added in ascending docId order (enforced by + // assert docID > + // lastDocID). So ordinal 0 = first doc with a vector, ordinal 1 = second, etc. + // + // Step 2 — Index time: ordToDoc mapping is written in the same order + // In OrdToDocDISIReaderConfiguration.writeStoredMeta(), docsWithField.iterator() is iterated in + // ascending docId + // order, and each docId is written to DirectMonotonicWriter sequentially. The i-th value written + // becomes ordinal + // i — so the ordToDoc array stored on disk is exactly: ordToDoc[0] = first docId, ordToDoc[1] = + // second docId, ... + // + // Step 3 — Query time: buildDocToOrd inverts the same ordering + // getFloatVectorValues(field).iterator() also yields docIds in ascending order (same set, same + // order as + // docsWithField at index time). The loop: + // while (iter.nextDoc() != NO_MORE_DOCS) { + // docToOrd[iter.docID()] = ord++; + // } + // assigns ord = 0 to the first docId, ord = 1 to the second — exactly inverting the ordToDoc + // array written at step 2. + // + // Step 4 — The HNSW graph uses these same ordinals as node IDs + // HNSW nodes are identified by their ordinal (the position in the flat vector store). So when the + // searcher returns + // ordinal k as a graph node, docToOrd[docId] = k being correct means docIdToOrdinal will find the + // right HNSW node + // for any sibling docId. + private int[] buildDocToOrd(LeafReaderContext context) throws IOException { + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + // fi = null if the field doesn't exist in this segment at all. + // fi.getVectorDimension() = 0 if the field exist in the segment but was not indexed as a vector + // field. + // + // 1. approximateSearch calls newCollector before searchNearestVectors + // 2. newCollector calls buildDocToOrd + // 3. Only then searchNearestVectors is called + // So buildDocToOrd is called for every segment before Lucene gets a chance to short-circuit on + // no vectors. + // The guard in buildDocToOrd is genuinely needed — without it, calling getFloatVectorValues on + // a segment with + // no vectors would return null (as we saw in CodecReader) and .iterator() would NPE. + // The alternative would be to guard in newCollector itself before calling getCachedDocToOrd, + // but the current + // placement is fine. + if (fi == null || fi.getVectorDimension() == 0) { + return new int[0]; + } + DocIdSetIterator iter = + switch (fi.getVectorEncoding()) { + case FLOAT32 -> context.reader().getFloatVectorValues(field).iterator(); + case BYTE -> context.reader().getByteVectorValues(field).iterator(); + }; + int maxDoc = context.reader().maxDoc(); + int[] docToOrd = new int[maxDoc]; + Arrays.fill(docToOrd, -1); + int ord = 0; + while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + docToOrd[iter.docID()] = ord++; + } + return docToOrd; + } } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/DiversifyingChildrenKnnCollectorTestCase.java b/lucene/join/src/test/org/apache/lucene/search/join/DiversifyingChildrenKnnCollectorTestCase.java new file mode 100644 index 000000000000..82ce76196a73 --- /dev/null +++ b/lucene/join/src/test/org/apache/lucene/search/join/DiversifyingChildrenKnnCollectorTestCase.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BitSet; + +/** + * Base class for {@link DiversifyingNearestChildrenKnnCollector} tests. Provides shared helpers for + * building the block-join parent {@link BitSet} used across test subclasses. + */ +abstract class DiversifyingChildrenKnnCollectorTestCase extends LuceneTestCase { + + /** + * Builds a {@link BitSet} whose set bits are the parent doc ids in a contiguous block-join + * layout: {@code [child_0 … child_{C-1}, parent_C]}, repeated {@code numParents} times. + * + *

Example with {@code childrenPerParent=3}: parent doc ids are 3, 7, 11, … + */ + static BitSet parentBitSet(int numParents, int childrenPerParent) throws IOException { + int[] parentDocIds = new int[numParents]; + for (int p = 1; p <= numParents; p++) { + parentDocIds[p - 1] = p * (childrenPerParent + 1) - 1; + } + int totalDocs = numParents * (childrenPerParent + 1); + return BitSet.of( + new TestToParentJoinKnnResults.IntArrayDocIdSetIterator(parentDocIds, numParents), + totalDocs + 1); + } + + /** + * Builds a docId-to-ordinal array for the block-join layout. Parent docs get -1 (no vector); + * child docs get consecutive ordinals starting from 0. + */ + static int[] buildDocToOrd(int numParents, int childrenPerParent) { + int totalDocs = numParents * (childrenPerParent + 1); + int[] docToOrd = new int[totalDocs]; + Arrays.fill(docToOrd, -1); + int ord = 0; + for (int d = 0; d < totalDocs; d++) { + if ((d + 1) % (childrenPerParent + 1) != 0) { + docToOrd[d] = ord++; + } + } + return docToOrd; + } + + static DiversifyingNearestChildrenKnnCollector makeCollector( + int k, BitSet parents, int[] docToOrd) { + return new DiversifyingNearestChildrenKnnCollector( + k, Integer.MAX_VALUE, (KnnSearchStrategy) null, parents, docToOrd); + } +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingChildrenKnnSiblingExpansion.java b/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingChildrenKnnSiblingExpansion.java new file mode 100644 index 000000000000..2202384bb114 --- /dev/null +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingChildrenKnnSiblingExpansion.java @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; + +/** + * Tests for sibling expansion in {@link DiversifyingNearestChildrenKnnCollector} ({@link + * org.apache.lucene.util.hnsw.DocSiblingExpansion} contract) and end-to-end correctness via {@link + * DiversifyingChildrenFloatKnnVectorQuery}. + */ +public class TestDiversifyingChildrenKnnSiblingExpansion + extends DiversifyingChildrenKnnCollectorTestCase { + + // --------------------------------------------------------------------------- + // Unit tests: DiversifyingNearestChildrenKnnCollector — findSiblingDocIds + // --------------------------------------------------------------------------- + + public void testFindSiblingDocIds_returnsAllSiblings() throws IOException { + // 2 parents, 2 children each → blocks: [0,1|2], [3,4|5] + int numParents = 2, childrenPerParent = 2; + BitSet parents = parentBitSet(numParents, childrenPerParent); + int[] docToOrd = buildDocToOrd(numParents, childrenPerParent); + DiversifyingNearestChildrenKnnCollector c = makeCollector(10, parents, docToOrd); + + // child 0: sibling is 1 + int[] s0 = c.findSiblingDocIds(0); + assertNotNull(s0); + assertArrayEquals(new int[] {1}, s0); + + // child 1: sibling is 0 + int[] s1 = c.findSiblingDocIds(1); + assertNotNull(s1); + assertArrayEquals(new int[] {0}, s1); + + // child 3 (second block): sibling is 4 + int[] s3 = c.findSiblingDocIds(3); + assertNotNull(s3); + assertArrayEquals(new int[] {4}, s3); + + // child 4 (second block): sibling is 3 + int[] s4 = c.findSiblingDocIds(4); + assertNotNull(s4); + assertArrayEquals(new int[] {3}, s4); + } + + public void testFindSiblingDocIds_singleChildParent_returnsNull() throws IOException { + // 1 parents, 1 child each → blocks: [0|1] + int numParents = 1, childrenPerParent = 1; + BitSet parents = parentBitSet(numParents, childrenPerParent); + int[] docToOrd = buildDocToOrd(numParents, childrenPerParent); + DiversifyingNearestChildrenKnnCollector c = makeCollector(10, parents, docToOrd); + + assertArrayEquals("sole child has no siblings", new int[0], c.findSiblingDocIds(0)); + } + + /** + * The trigger child must not appear in its own sibling list. Only the *other* children of the + * same parent are returned. + */ + public void testFindSiblingDocIds_excludesTriggerChild() throws IOException { + // 1 parent, 3 children: [C0, C1, C2 | P0=3] + int numParents = 1, childrenPerParent = 3; + BitSet parents = parentBitSet(numParents, childrenPerParent); + int[] docToOrd = buildDocToOrd(numParents, childrenPerParent); + DiversifyingNearestChildrenKnnCollector c = makeCollector(10, parents, docToOrd); + + // Trigger is C1 (docId=1); expected siblings: C0 and C2 only + int[] siblings = c.findSiblingDocIds(1); + assertNotNull(siblings); + assertEquals(2, siblings.length); + for (int s : siblings) { + assertNotEquals("trigger child must not appear in sibling list", 1, s); + } + } + + // --------------------------------------------------------------------------- + // Unit tests: DiversifyingNearestChildrenKnnCollector — docIdToOrdinal + // --------------------------------------------------------------------------- + + public void testDocIdToOrdinal_correctMapping() throws IOException { + // 2 parents, 2 children: docToOrd = [0, 1, -1, 2, 3, -1] + int numParents = 2, childrenPerParent = 2; + BitSet parents = parentBitSet(numParents, childrenPerParent); + int[] docToOrd = buildDocToOrd(numParents, childrenPerParent); + DiversifyingNearestChildrenKnnCollector c = makeCollector(10, parents, docToOrd); + + assertEquals(0, c.docIdToOrdinal(0)); + assertEquals(1, c.docIdToOrdinal(1)); + assertEquals(-1, c.docIdToOrdinal(2)); // parent → no vector + assertEquals(2, c.docIdToOrdinal(3)); + assertEquals(3, c.docIdToOrdinal(4)); + assertEquals(-1, c.docIdToOrdinal(5)); // parent → no vector + assertEquals(-1, c.docIdToOrdinal(9999)); // beyond array bounds + } + + public void testDocIdToOrdinal_emptyMapping_alwaysMinusOne() throws IOException { + // Collector created with an empty docToOrd array (sibling expansion disabled) + BitSet parents = parentBitSet(2, 2); + DiversifyingNearestChildrenKnnCollector c = + new DiversifyingNearestChildrenKnnCollector( + 5, Integer.MAX_VALUE, null, parents, new int[0]); + + assertEquals(-1, c.docIdToOrdinal(0)); + assertEquals(-1, c.docIdToOrdinal(1)); + } + + // --------------------------------------------------------------------------- + // Unit tests: heap replacement behaviour + // --------------------------------------------------------------------------- + + /** + * C3 is the entry point (score 0.85) but its siblings C4 (0.95) and C5 (0.90) are expanded + * immediately. The parent must be represented by C4, not C3. + */ + public void testBestSiblingReplacesFirstFoundChild() throws IOException { + // 1 parent, 3 children: [C0=0, C1=1, C2=2 | P0=3] + int numParents = 1, childrenPerParent = 3; + BitSet parents = parentBitSet(numParents, childrenPerParent); + int[] docToOrd = buildDocToOrd(numParents, childrenPerParent); + DiversifyingNearestChildrenKnnCollector c = makeCollector(1, parents, docToOrd); + + // C0 found first as entry point (like C3=0.85 in the example) + c.collect(0, 0.85f); + // Sibling expansion scores C1 (like C4=0.95) and C2 (like C5=0.90) + c.collect(1, 0.95f); + c.collect(2, 0.90f); + + TopDocs td = c.topDocs(); + assertEquals(1, td.scoreDocs.length); + assertEquals( + "best sibling must win over first-found child", 0.95f, td.scoreDocs[0].score, 1e-5f); + assertEquals("best child doc id must be C1", 1, td.scoreDocs[0].doc); + } + + /** + * When two parents are found and the heap is full (k=2), a sibling of the second parent that + * scores better than the trigger child replaces it. + */ + public void testBestSiblingReplacesWorseChildWhenHeapFull() throws IOException { + // 2 parents, 2 children each: [C0, C1 | P0=2], [C2, C3 | P1=5] + int numParents = 2, childrenPerParent = 2; + BitSet parents = parentBitSet(numParents, childrenPerParent); + int[] docToOrd = buildDocToOrd(numParents, childrenPerParent); + DiversifyingNearestChildrenKnnCollector c = makeCollector(2, parents, docToOrd); + + // P0: C0 found first (0.85), C1 is better sibling (0.95) → P0 represented by C1 + c.collect(0, 0.85f); + c.collect(1, 0.95f); + + // P1: C2 found first (0.60), C3 is better sibling (0.65) + c.collect(3, 0.60f); // trigger child → P1 enters heap, now FULL (size=2=k) + c.collect(4, 0.65f); // better sibling → heap full but P1 already present, updates entry + + TopDocs td = c.topDocs(); + assertEquals(2, td.scoreDocs.length); + assertEquals("P0 best child score", 0.95f, td.scoreDocs[0].score, 1e-5f); + assertEquals("P1 best child score", 0.65f, td.scoreDocs[1].score, 1e-5f); + assertEquals("P0 best child doc", 1, td.scoreDocs[0].doc); + assertEquals("P1 best child doc", 4, td.scoreDocs[1].doc); + } + + // --------------------------------------------------------------------------- + // Unit tests: OrdinalTranslatedKnnCollector — getSiblingOrdinals + // --------------------------------------------------------------------------- + + /** + * Even when the parent is already in the heap (from a prior expansion), getSiblingOrdinals must + * still return unvisited siblings. This allows a budget-truncated first expansion to be continued + * when additional children of the same parent are later discovered via normal graph traversal. + */ + public void testGetSiblingOrdinals_parentAlreadyInHeap_returnsUnvisitedSiblings() + throws IOException { + // 1 parent, 2 children: [C0=doc0, C1=doc1 | P0=doc2]; ordToDoc=[0,1] + BitSet parents = parentBitSet(1, 2); + int[] docToOrd = buildDocToOrd(1, 2); + int[] ordToDoc = {0, 1}; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector( + makeCollector(10, parents, docToOrd), ord -> ordToDoc[ord]); + + // Collect C0 → parent 2 enters the heap + collector.collect(0, 0.8f); + + // C1 (ordinal 1) is not yet visited; must still be returned as an expansion candidate + FixedBitSet visited = new FixedBitSet(3); + int[] result = collector.getSiblingOrdinals(0, visited, new int[0]); + assertNotEquals( + "unvisited sibling must be returned even when parent is already in heap", 0, result.length); + assertArrayEquals(new int[] {1}, result); + } + + /** + * Siblings whose ordinal is already in visitedOrds must be filtered out to prevent double-scoring + * when the sibling was independently discovered via normal graph traversal. + */ + public void testPendingSiblingOrdinals_filtersAlreadyVisited() throws IOException { + // 1 parent, 3 children: [C0=doc0, C1=doc1, C2=doc2 | P0=doc3] + // docToOrd=[0,1,2,-1]; ordToDoc=[0,1,2] + BitSet parents = parentBitSet(1, 3); + int[] docToOrd = buildDocToOrd(1, 3); + int[] ordToDoc = {0, 1, 2}; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector( + makeCollector(10, parents, docToOrd), ord -> ordToDoc[ord]); + + // Mark C1 (ordinal 1) as already visited by normal graph traversal + FixedBitSet visited = new FixedBitSet(4); + visited.set(1); + + // Trigger on C0 (ordinal 0 → docId 0); siblings are C1 and C2 + int[] result = collector.getSiblingOrdinals(0, visited, new int[0]); + assertNotEquals(0, result.length); + assertArrayEquals("C1 must be filtered (visited); only C2 remains", new int[] {2}, result); + } + + /** + * Siblings with no vector in this field (docIdToOrdinal returns -1) must be skipped because they + * have no node in the HNSW graph and cannot be scored. + */ + public void testPendingSiblingOrdinals_filtersSparseSiblings() throws IOException { + // 1 parent, 3 children but C1 has no vector: + // [C0=doc0, C1(sparse)=doc1, C2=doc2 | P0=doc3] + // docToOrd=[0,-1,1,-1]; ordToDoc=[0,2] + BitSet parents = parentBitSet(1, 3); + int[] docToOrd = {0, -1, 1, -1}; + int[] ordToDoc = {0, 2}; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector( + makeCollector(10, parents, docToOrd), ord -> ordToDoc[ord]); + + FixedBitSet visited = new FixedBitSet(4); + + // Trigger on C0 (ordinal 0 → docId 0); siblings are C1 (sparse) and C2 + int[] result = collector.getSiblingOrdinals(0, visited, new int[0]); + assertNotEquals(0, result.length); + assertArrayEquals("C1 (no vector) must be filtered; only C2 remains", new int[] {1}, result); + } + + // --------------------------------------------------------------------------- + // Index fixtures (used by integration tests only) + // --------------------------------------------------------------------------- + + private static final String FIELD = "vec"; + private static final int DIM = 4; + + /** + * Builds a block-join float-vector index. Parent p has numChildren children; child c of parent p + * gets vector: v[i] = (p * numChildren + c + i) * 0.1, then normalised for COSINE/DOT_PRODUCT. + */ + private Directory buildIndex(int numParents, int numChildren, VectorSimilarityFunction sim) + throws IOException { + Directory dir = newDirectory(); + try (IndexWriter w = + new IndexWriter( + dir, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) { + for (int p = 0; p < numParents; p++) { + List block = new ArrayList<>(); + for (int c = 0; c < numChildren; c++) { + float[] vec = childVector(p, c, numChildren); + if (sim == VectorSimilarityFunction.COSINE + || sim == VectorSimilarityFunction.DOT_PRODUCT) { + normalise(vec); + } + Document child = new Document(); + child.add(new KnnFloatVectorField(FIELD, vec, sim)); + child.add(new StoredField("parent", p)); + block.add(child); + } + Document parent = new Document(); + parent.add(new StringField("docType", "_parent", Field.Store.NO)); + parent.add(new StoredField("parent", p)); + block.add(parent); + w.addDocuments(block); + } + } + return dir; + } + + private static float[] childVector(int parent, int child, int numChildren) { + float[] vec = new float[DIM]; + for (int i = 0; i < DIM; i++) { + vec[i] = (parent * numChildren + child + i + 1) * 0.1f; + } + return vec; + } + + private static void normalise(float[] vec) { + float norm = 0; + for (float v : vec) norm += v * v; + norm = (float) Math.sqrt(norm); + if (norm > 0) { + for (int i = 0; i < DIM; i++) vec[i] /= norm; + } + } + + /** Computes brute-force top-k scores: max similarity per parent, sorted descending. */ + private static float[] bruteForceTopK( + float[] query, int numParents, int numChildren, VectorSimilarityFunction sim, int k) { + float[] parentBest = new float[numParents]; + Arrays.fill(parentBest, Float.NEGATIVE_INFINITY); + for (int p = 0; p < numParents; p++) { + for (int c = 0; c < numChildren; c++) { + float[] vec = childVector(p, c, numChildren); + if (sim == VectorSimilarityFunction.COSINE || sim == VectorSimilarityFunction.DOT_PRODUCT) { + normalise(vec); + } + float score = sim.compare(query, vec); + if (score > parentBest[p]) parentBest[p] = score; + } + } + float[] sorted = parentBest.clone(); + Arrays.sort(sorted); + int n = Math.min(k, numParents); + float[] top = new float[n]; + for (int i = 0; i < n; i++) top[i] = sorted[numParents - 1 - i]; + return top; + } + + // --------------------------------------------------------------------------- + // Integration tests: end-to-end correctness via DiversifyingChildrenFloatKnnVectorQuery + // --------------------------------------------------------------------------- + + /** + * Each result doc must belong to a distinct parent. Sibling expansion must not cause the same + * parent to appear multiple times. + */ + public void testSiblingExpansionNoDuplicateParents() throws Exception { + int numParents = 15, numChildren = 4, k = 8; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + float[] query = {1f, 0f, 0f, 0f}; + + try (Directory dir = buildIndex(numParents, numChildren, sim); + IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + BitSetProducer parentFilter = + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + + Query knnQuery = + new DiversifyingChildrenFloatKnnVectorQuery(FIELD, query, null, k, parentFilter); + TopDocs results = searcher.search(knnQuery, k); + + Set seenParents = new HashSet<>(); + for (ScoreDoc sd : results.scoreDocs) { + int parentId = + reader.storedFields().document(sd.doc).getField("parent").numericValue().intValue(); + assertTrue("parent " + parentId + " appeared more than once", seenParents.add(parentId)); + } + } + } + + /** + * When every parent has exactly one child there are no siblings to expand. Results must still be + * correct. + */ + public void testSiblingExpansionSingleChildParents() throws Exception { + int numParents = 12, numChildren = 1, k = 5; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + float[] query = {0.2f, 0.3f, 0.4f, 0.5f}; + + try (Directory dir = buildIndex(numParents, numChildren, sim); + IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + BitSetProducer parentFilter = + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + + Query knnQuery = + new DiversifyingChildrenFloatKnnVectorQuery(FIELD, query, null, k, parentFilter); + TopDocs results = searcher.search(knnQuery, k); + + float[] expected = bruteForceTopK(query, numParents, numChildren, sim, k); + assertEquals(Math.min(k, numParents), results.scoreDocs.length); + for (int i = 0; i < results.scoreDocs.length; i++) { + assertEquals("score at rank " + i, expected[i], results.scoreDocs[i].score, 1e-4f); + } + } + } + + /** + * The query is close to one parent's children; sibling expansion must find the best child of each + * discovered parent rather than whichever child the graph traversal happens to reach first. + */ + public void testSiblingExpansion_bestChildPerParentFound() throws Exception { + int numParents = 3, numChildren = 3, k = 2; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + float[] query = {0.9f, 0.9f, 0.9f, 0.9f}; + + try (Directory dir = buildIndex(numParents, numChildren, sim); + IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + BitSetProducer parentFilter = + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(reader, parentFilter); + + Query knnQuery = + new DiversifyingChildrenFloatKnnVectorQuery(FIELD, query, null, k, parentFilter); + TopDocs results = searcher.search(knnQuery, k); + + assertEquals(k, results.scoreDocs.length); + for (ScoreDoc sd : results.scoreDocs) { + int parentIdx = + reader.storedFields().document(sd.doc).getField("parent").numericValue().intValue(); + // verify no other child of the same parent scores higher than the returned one + for (int c = 0; c < numChildren; c++) { + float[] vec = childVector(parentIdx, c, numChildren); + float cScore = sim.compare(query, vec); + assertTrue( + "parent " + + parentIdx + + " has a better child (score " + + cScore + + ") than returned doc " + + sd.doc + + " (score " + + sd.score + + ")", + cScore <= sd.score + 1e-4f); + } + } + } + } + + /** + * With a single parent and many children, sibling expansion must score all children so the best + * one is returned. Without expansion the graph might stop early and miss the best child. + */ + public void testSiblingExpansion_singleParentManyChildren() throws Exception { + int numParents = 1, numChildren = 8, k = 1; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + float[] query = {0.9f, 0.8f, 0.7f, 0.6f}; + + try (Directory dir = buildIndex(numParents, numChildren, sim); + IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + BitSetProducer parentFilter = + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(reader, parentFilter); + + Query knnQuery = + new DiversifyingChildrenFloatKnnVectorQuery(FIELD, query, null, k, parentFilter); + TopDocs results = searcher.search(knnQuery, k); + + float[] expected = bruteForceTopK(query, numParents, numChildren, sim, k); + assertEquals(1, results.scoreDocs.length); + assertEquals("best child of single parent", expected[0], results.scoreDocs[0].score, 1e-4f); + } + } +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingNearestChildrenKnnCollectorPerformance.java b/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingNearestChildrenKnnCollectorPerformance.java index 99a1384144b1..14addde94549 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingNearestChildrenKnnCollectorPerformance.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestDiversifyingNearestChildrenKnnCollectorPerformance.java @@ -19,7 +19,6 @@ import java.io.IOException; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.BitSet; /** @@ -28,27 +27,13 @@ *

Correctness tests verify behaviour of the collector in various scenarios, including edge * cases. */ -public class TestDiversifyingNearestChildrenKnnCollectorPerformance extends LuceneTestCase { - - /** Builds a BitSet whose set bits are the parent doc ids in a contiguous block-join layout. */ - private static BitSet parentBitSet(int numParents, int childrenPerParent) throws IOException { - int[] parentDocIds = new int[numParents]; - for (int p = 1; p <= numParents; p++) { - // layout: [child_0 … child_{C-1}, parent_C], repeated - // e.g. with 3 children per parent: [0,1,2,3, 4,5,6,7, 8,9,10,11, ...] → parent doc ids are - // 3,7,11,... - parentDocIds[p - 1] = p * (childrenPerParent + 1) - 1; - } - int totalDocs = numParents * (childrenPerParent + 1); // children + 1 parent per block - return BitSet.of( - new TestToParentJoinKnnResults.IntArrayDocIdSetIterator(parentDocIds, numParents), - totalDocs + 1); - } +public class TestDiversifyingNearestChildrenKnnCollectorPerformance + extends DiversifyingChildrenKnnCollectorTestCase { /** Collects all children in order and returns topDocs. */ private static TopDocs collectAll(int k, BitSet parents, int[] childIds, float[] scores) { DiversifyingNearestChildrenKnnCollector collector = - new DiversifyingNearestChildrenKnnCollector(k, Integer.MAX_VALUE, parents); + new DiversifyingNearestChildrenKnnCollector(k, Integer.MAX_VALUE, null, parents, null); for (int i = 0; i < childIds.length; i++) { collector.collect(childIds[i], scores[i]); } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java b/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java index 175e1f264400..fe1b008b5f4d 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java @@ -33,7 +33,7 @@ public void testNeighborsProduct() throws IOException { // make sure we have the sign correct BitSet parentBitSet = BitSet.of(new IntArrayDocIdSetIterator(new int[] {1, 3, 5}, 3), 6); DiversifyingNearestChildrenKnnCollector nn = - new DiversifyingNearestChildrenKnnCollector(2, Integer.MAX_VALUE, parentBitSet); + new DiversifyingNearestChildrenKnnCollector(2, Integer.MAX_VALUE, null, parentBitSet, null); assertTrue(nn.collect(2, 0.5f)); assertTrue(nn.collect(0, 0.2f)); assertTrue(nn.collect(4, 1f)); @@ -48,7 +48,7 @@ public void testInsertions() throws IOException { float[] scores = new float[] {1f, 0.5f, 0.6f, 2f, 2f, 1.2f, 4f}; BitSet parentBitSet = BitSet.of(new IntArrayDocIdSetIterator(new int[] {3, 6, 9, 12}, 4), 13); DiversifyingNearestChildrenKnnCollector results = - new DiversifyingNearestChildrenKnnCollector(7, Integer.MAX_VALUE, parentBitSet); + new DiversifyingNearestChildrenKnnCollector(7, Integer.MAX_VALUE, null, parentBitSet, null); for (int i = 0; i < nodes.length; i++) { results.collect(nodes[i], scores[i]); } @@ -69,7 +69,7 @@ public void testInsertionWithOverflow() throws IOException { BitSet parentBitSet = BitSet.of(new IntArrayDocIdSetIterator(new int[] {3, 6, 9, 11, 13, 15}, 6), 16); DiversifyingNearestChildrenKnnCollector results = - new DiversifyingNearestChildrenKnnCollector(5, Integer.MAX_VALUE, parentBitSet); + new DiversifyingNearestChildrenKnnCollector(5, Integer.MAX_VALUE, null, parentBitSet, null); for (int i = 0; i < nodes.length - 1; i++) { results.collect(nodes[i], scores[i]); } @@ -104,7 +104,8 @@ public void testRandomInsertionsWithOverflow() throws IOException { BitSet parentBitSet = BitSet.of(new IntArrayDocIdSetIterator(parents, parents.length), nextParent + 1); DiversifyingNearestChildrenKnnCollector results = - new DiversifyingNearestChildrenKnnCollector(20, Integer.MAX_VALUE, parentBitSet); + new DiversifyingNearestChildrenKnnCollector( + 20, Integer.MAX_VALUE, null, parentBitSet, null); for (int i = 0; i < children.size(); i++) { results.collect(children.get(i), childrenScores.get(i)); }