From 04ff73e8fe7f9af8a167e31db1990508c67e5329 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sat, 2 May 2026 21:51:59 -0700 Subject: [PATCH 1/8] [slop] data blind quantization Allow callers to disable centering at the format level, which also disables writing of float vectors since they are no longer needed. Includes a path to handle of mix of centered and uncentered segments as input. In this case the uncentered/no float vectors will be dequantized and requantized but this case should be relatively uncommon. Includes OSQ changes to allow a zero vector for COSINE if the vector is not a unit vector. Maybe fix this in upstream callers? --- ...Lucene104ScalarQuantizedVectorsFormat.java | 21 +- ...Lucene104ScalarQuantizedVectorsReader.java | 23 +- ...Lucene104ScalarQuantizedVectorsWriter.java | 386 ++++++++++++++++-- .../OptimizedScalarQuantizer.java | 8 +- ...Lucene104ScalarQuantizedVectorsFormat.java | 160 ++++++++ 5 files changed, 567 insertions(+), 31 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index fb8221deffb0..d312d528aa49 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -111,22 +111,35 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); private final ScalarEncoding encoding; + private final boolean enableCentering; - /** Creates a new instance with UNSIGNED_BYTE encoding. */ + /** Creates a new instance with UNSIGNED_BYTE encoding and centering enabled. */ public Lucene104ScalarQuantizedVectorsFormat() { this(ScalarEncoding.UNSIGNED_BYTE); } - /** Creates a new instance with the chosen quantization encoding. */ + /** Creates a new instance with the chosen quantization encoding and centering enabled. */ public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { + this(encoding, true); + } + + /** + * Creates a new instance with the chosen quantization encoding and centering setting. + * + *

When {@code enableCentering} is {@code false} (data-blind mode), no centroid is computed + * and no raw float vectors are written. This reduces storage at the cost of slightly lower + * quantization quality. + */ + public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding, boolean enableCentering) { super(NAME); this.encoding = encoding; + this.enableCentering = enableCentering; } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene104ScalarQuantizedVectorsWriter( - state, encoding, rawVectorFormat.fieldsWriter(state), scorer); + state, encoding, enableCentering, rawVectorFormat.fieldsWriter(state), scorer); } @Override @@ -146,6 +159,8 @@ public String toString() { + NAME + ", encoding=" + encoding + + ", enableCentering=" + + enableCentering + ", flatVectorScorer=" + scorer + ", rawVectorFormat=" diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index 15fe4950036f..a95e7ea5f528 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -237,9 +237,9 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { + VectorEncoding.FLOAT32); } - FloatVectorValues rawFloatVectorValues = rawVectorsReader.getFloatVectorValues(field); + FloatVectorValues rawFloatVectorValues = getRawFloatVectorValues(field); - if (rawFloatVectorValues.size() == 0) { + if (rawFloatVectorValues == null || rawFloatVectorValues.size() == 0) { return OffHeapScalarQuantizedFloatVectorValues.load( fi.ordToDocDISIReaderConfiguration, fi.dimension, @@ -358,6 +358,25 @@ public float[] getCentroid(String field) { return null; } + /** Returns {@code true} if raw float vectors are stored for {@code field}. */ + boolean hasRawFloatVectors(String field) throws IOException { + FloatVectorValues raw = getRawFloatVectorValues(field); + return raw != null && raw.size() > 0; + } + + /** + * Returns raw float vectors from the underlying flat reader, or {@code null} if the field was + * not written there (data-blind mode). Some flat readers throw {@link IllegalArgumentException} + * instead of returning null for missing fields, so we catch that here. + */ + private FloatVectorValues getRawFloatVectorValues(String field) throws IOException { + try { + return rawVectorsReader.getFloatVectorValues(field); + } catch (IllegalArgumentException e) { + return null; // field not present in raw reader (written in data-blind mode) + } + } + private static IndexInput openDataInput( SegmentReadState state, int versionMeta, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 5373ff85be3d..6c29f0eddf69 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -34,6 +34,7 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -49,6 +50,7 @@ import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; @@ -67,6 +69,7 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { private final List fields = new ArrayList<>(); private final IndexOutput meta, vectorData; private final ScalarEncoding encoding; + private final boolean enableCentering; private final FlatVectorsWriter rawVectorDelegate; private boolean finished; @@ -74,11 +77,13 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { public Lucene104ScalarQuantizedVectorsWriter( SegmentWriteState state, ScalarEncoding encoding, + boolean enableCentering, FlatVectorsWriter rawVectorDelegate, Lucene104ScalarQuantizedVectorScorer vectorsScorer) throws IOException { super(vectorsScorer); this.encoding = encoding; + this.enableCentering = enableCentering; this.segmentWriteState = state; String metaFileName = IndexFileNames.segmentFileName( @@ -116,15 +121,17 @@ public Lucene104ScalarQuantizedVectorsWriter( @Override public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { @SuppressWarnings("unchecked") - FieldWriter fieldWriter = - new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + FlatFieldVectorsWriter storage = + enableCentering + ? (FlatFieldVectorsWriter) this.rawVectorDelegate.addField(fieldInfo) + : new InMemoryFloatFieldWriter(); + FieldWriter fieldWriter = new FieldWriter(fieldInfo, storage, enableCentering); fields.add(fieldWriter); return fieldWriter; } - return rawVectorDelegate; + return this.rawVectorDelegate.addField(fieldInfo); } @Override @@ -137,13 +144,17 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } final float[] clusterCenter; int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); - clusterCenter = new float[field.dimensionSums.length]; - if (vectorCount > 0) { - for (int i = 0; i < field.dimensionSums.length; i++) { - clusterCenter[i] = field.dimensionSums[i] / vectorCount; - } - if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { - VectorUtil.l2normalize(clusterCenter); + if (!enableCentering) { + clusterCenter = new float[field.fieldInfo.getVectorDimension()]; + } else { + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } } } if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { @@ -321,15 +332,22 @@ public void finish() throws IOException { @Override public void mergeOneFlatVectorField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - // Don't need access to the random vectors, we can just use the merged - rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); return; } - final float[] centroid; + if (enableCentering) { + mergeOneFlatVectorFieldCentered(fieldInfo, mergeState); + } else { + mergeOneFlatVectorFieldDataBlind(fieldInfo, mergeState); + } + } + + private void mergeOneFlatVectorFieldCentered(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); - centroid = mergedCentroid; if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { segmentWriteState.infoStream.message( QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); @@ -344,22 +362,84 @@ public void mergeOneFlatVectorField(FieldInfo fieldInfo, MergeState mergeState) floatVectorValues, new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), encoding, - centroid); + mergedCentroid); long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; float centroidDp = - docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(mergedCentroid, mergedCentroid) : 0; writeMeta( fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, - centroid, + mergedCentroid, centroidDp, docsWithField); } + private void mergeOneFlatVectorFieldDataBlind(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + float[] zeroCentroid = new float[fieldInfo.getVectorDimension()]; + // Classify each contributing segment as quantized-only or re-quantizable (has raw floats). + // Quantized-only segments must have a matching encoding; otherwise re-quantization from raw + // floats would be required, which is not possible when raw floats were never written. + boolean anyHasRawFloats = false; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader reader = mergeState.knnVectorsReaders[i]; + if (reader == null || reader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + if (hasRawFloatVectors(reader, fieldInfo.name)) { + anyHasRawFloats = true; + } else { + QuantizedByteVectorValues qvv = getQuantizedVectorValues(reader, fieldInfo.name); + if (qvv != null && qvv.getScalarEncoding() != encoding) { + throw new IllegalStateException( + "Cannot merge field \"" + + fieldInfo.name + + "\" from data-blind segment with encoding " + + qvv.getScalarEncoding() + + " into data-blind format with encoding " + + encoding + + ": re-quantization requires raw float vectors"); + } + } + } + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField; + if (anyHasRawFloats) { + // At least one segment has raw floats; use the float path with zero centroid. + // Quantized-only segments serve reconstructed floats via getFloatVectorValues(). + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + QuantizedFloatVectorValues quantizedVectorValues = + new QuantizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + encoding, + zeroCentroid); + docsWithField = writeVectorData(vectorData, quantizedVectorValues); + } else { + // All segments are quantized-only with matching encoding: copy bytes directly. + MergedQuantizedByteVectorValues mergedQBVV = + MergedQuantizedByteVectorValues.merge(fieldInfo, mergeState, zeroCentroid, encoding); + docsWithField = writeVectorData(vectorData, mergedQBVV); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + zeroCentroid, + 0f, + docsWithField); + } + static DocsWithFieldSet writeVectorData( IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); @@ -427,6 +507,24 @@ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { return null; } + static boolean hasRawFloatVectors(KnnVectorsReader vectorsReader, String fieldName) + throws IOException { + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); + if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { + return reader.hasRawFloatVectors(fieldName); + } + return true; // non-Lucene104 format; assume raw floats are available + } + + static QuantizedByteVectorValues getQuantizedVectorValues( + KnnVectorsReader vectorsReader, String fieldName) throws IOException { + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); + if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { + return reader.getQuantizedVectorValues(fieldName); + } + return null; + } + static int mergeAndRecalculateCentroids( MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { boolean recalculate = false; @@ -518,15 +616,20 @@ public long ramBytesUsed() { static class FieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); private final FieldInfo fieldInfo; + private final boolean enableCentering; private boolean finished; private final FlatFieldVectorsWriter flatFieldVectorsWriter; - private final float[] dimensionSums; + final float[] dimensionSums; private final FloatArrayList magnitudes = new FloatArrayList(); - FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + FieldWriter( + FieldInfo fieldInfo, + FlatFieldVectorsWriter flatFieldVectorsWriter, + boolean enableCentering) { this.fieldInfo = fieldInfo; this.flatFieldVectorsWriter = flatFieldVectorsWriter; - this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + this.enableCentering = enableCentering; + this.dimensionSums = enableCentering ? new float[fieldInfo.getVectorDimension()] : null; } @Override @@ -554,6 +657,10 @@ public void finish() throws IOException { if (finished) { return; } + if (!flatFieldVectorsWriter.isFinished()) { + // InMemoryFloatFieldWriter is not flushed through the raw delegate, so finish it here. + flatFieldVectorsWriter.finish(); + } assert flatFieldVectorsWriter.isFinished(); finished = true; } @@ -570,10 +677,12 @@ public void addValue(int docID, float[] vectorValue) throws IOException { float dp = VectorUtil.dotProduct(vectorValue, vectorValue); float divisor = (float) Math.sqrt(dp); magnitudes.add(divisor); - for (int i = 0; i < vectorValue.length; i++) { - dimensionSums[i] += (vectorValue[i] / divisor); + if (enableCentering) { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } } - } else { + } else if (enableCentering) { for (int i = 0; i < vectorValue.length; i++) { dimensionSums[i] += vectorValue[i]; } @@ -594,6 +703,61 @@ public long ramBytesUsed() { } } + /** + * In-memory storage for float vectors when centering is disabled. Holds vectors buffered during + * indexing so they can be quantized at flush time, without writing any raw float data to disk. + * + *

TODO: replace with immediate per-vector quantization in addValue() to avoid buffering raw + * floats in memory altogether. + */ + private static class InMemoryFloatFieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = + shallowSizeOfInstance(InMemoryFloatFieldWriter.class); + private final List vectors = new ArrayList<>(); + private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + private boolean finished; + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + vectors.add(Arrays.copyOf(vectorValue, vectorValue.length)); + docsWithField.add(docID); + } + + @Override + public float[] copyValue(float[] vectorValue) { + return Arrays.copyOf(vectorValue, vectorValue.length); + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() { + finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + for (float[] v : vectors) { + size += RamUsageEstimator.sizeOf(v); + } + return size; + } + } + static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private OptimizedScalarQuantizer.QuantizationResult corrections; private final byte[] quantized; @@ -709,6 +873,180 @@ public int ordToDoc(int ord) { } } + private static final class QuantizedByteVectorValuesSub extends DocIDMerger.Sub { + final QuantizedByteVectorValues values; + final KnnVectorValues.DocIndexIterator iterator; + + QuantizedByteVectorValuesSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { + super(docMap); + this.values = values; + this.iterator = values.iterator(); + assert iterator.docID() == -1; + } + + @Override + public int nextDoc() throws IOException { + return iterator.nextDoc(); + } + } + + /** Merged view of {@link QuantizedByteVectorValues} from multiple segments. */ + static final class MergedQuantizedByteVectorValues extends QuantizedByteVectorValues { + private final List subs; + private final DocIDMerger docIdMerger; + private final int size; + private final float[] centroid; + private final float centroidDP; + private final ScalarEncoding scalarEncoding; + private int docId = -1; + private int lastOrd = -1; + private QuantizedByteVectorValuesSub current; + + private MergedQuantizedByteVectorValues( + List subs, + MergeState mergeState, + float[] centroid, + ScalarEncoding scalarEncoding) + throws IOException { + this.subs = subs; + this.docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); + int totalSize = 0; + for (QuantizedByteVectorValuesSub sub : subs) { + totalSize += sub.values.size(); + } + this.size = totalSize; + this.centroid = centroid; + this.centroidDP = VectorUtil.dotProduct(centroid, centroid); + this.scalarEncoding = scalarEncoding; + } + + static MergedQuantizedByteVectorValues merge( + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + ScalarEncoding encoding) + throws IOException { + List subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader reader = mergeState.knnVectorsReaders[i]; + if (reader == null) { + continue; + } + QuantizedByteVectorValues qbvv = getQuantizedVectorValues(reader, fieldInfo.name); + if (qbvv == null || qbvv.size() == 0) { + continue; + } + subs.add(new QuantizedByteVectorValuesSub(mergeState.docMaps[i], qbvv)); + } + return new MergedQuantizedByteVectorValues(subs, mergeState, centroid, encoding); + } + + @Override + public DocIndexIterator iterator() { + return new DocIndexIterator() { + private int index = -1; + + @Override + public int docID() { + return docId; + } + + @Override + public int index() { + return index; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + index = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++lastOrd; + ++index; + } + return docId; + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size; + } + }; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); + } + return current.values.vectorValue(current.iterator.index()); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) + throws IOException { + if (ord != lastOrd) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); + } + return current.values.getCorrectiveTerms(current.iterator.index()); + } + + @Override + public int dimension() { + return subs.isEmpty() ? 0 : subs.get(0).values.dimension(); + } + + @Override + public int size() { + return size; + } + + @Override + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + + @Override + public ScalarEncoding getScalarEncoding() { + return scalarEncoding; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public float getCentroidDP() { + return centroidDP; + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } + + @Override + public QuantizedByteVectorValues copy() { + throw new UnsupportedOperationException(); + } + } + static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java index 179799ff83af..8a96162855b8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java @@ -119,7 +119,9 @@ public record QuantizationResult( public QuantizationResult[] multiScalarQuantize( float[] vector, byte[][] destinations, byte[] bits, float[] centroid) { assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); - assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert similarityFunction != COSINE + || VectorUtil.isUnitVector(centroid) + || VectorUtil.dotProduct(centroid, centroid) == 0; assert bits.length == destinations.length; float[] intervalScratch = new float[2]; double vecMean = 0; @@ -187,7 +189,9 @@ public QuantizationResult[] multiScalarQuantize( public QuantizationResult scalarQuantize( float[] vector, byte[] destination, byte bits, float[] centroid) { assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); - assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert similarityFunction != COSINE + || VectorUtil.isUnitVector(centroid) + || VectorUtil.dotProduct(centroid, centroid) == 0; assert vector.length <= destination.length; assert bits > 0 && bits <= 8; float[] intervalScratch = new float[2]; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 7825d92706e5..51e8d7ad06da 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.oneOf; import java.io.IOException; +import java.util.Arrays; import java.util.Locale; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; @@ -38,6 +39,7 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -117,6 +119,7 @@ public KnnVectorsFormat knnVectorsFormat() { "Lucene104ScalarQuantizedVectorsFormat(" + "name=Lucene104ScalarQuantizedVectorsFormat, " + "encoding=UNSIGNED_BYTE, " + + "enableCentering=true, " + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; var defaultScorer = @@ -255,6 +258,163 @@ private void replaceWithEmptyVectorFile(Directory dir, String fileName) throws E } } + // ---- data-blind (enableCentering=false) tests ---- + + private Codec datablindCodec(ScalarEncoding enc) { + return TestUtil.alwaysKnnVectorsFormat( + new Lucene104ScalarQuantizedVectorsFormat(enc, false)); + } + + public void testDataBlindSearchCorrectness() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(50, 200); + int dims = random().nextInt(4, 65); + VectorSimilarityFunction sim = randomSimilarity(); + ScalarEncoding enc = ScalarEncoding.UNSIGNED_BYTE; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(datablindCodec(enc)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + int k = Math.min(10, numVectors); + TopDocs hits = searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); + assertEquals(k, hits.totalHits.value()); + } + } + } + } + + public void testDataBlindNoRawFloatVectors() throws Exception { + String fieldName = "field"; + int dims = 8; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < 20; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader leaf = getOnlyLeafReader(reader); + FloatVectorValues fvv = leaf.getFloatVectorValues(fieldName); + assertNotNull(fvv); + assertEquals(20, fvv.size()); + // In data-blind mode the float values are reconstructed from quantized data, not backed + // by raw stored floats — so the returned type must NOT be a ScalarQuantizedVectorValues. + assertFalse( + "data-blind mode must not store raw float vectors", + fvv instanceof Lucene104ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues); + } + } + } + } + + public void testDataBlindMultiSegmentMerge() throws Exception { + String fieldName = "field"; + int dims = 16; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + int numPerSegment = 30; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + // write two separate segments + for (int s = 0; s < 2; s++) { + for (int i = 0; i < numPerSegment; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + w.commit(); + } + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + assertEquals(1, reader.leaves().size()); + LeafReader leaf = reader.leaves().get(0).reader(); + FloatVectorValues fvv = leaf.getFloatVectorValues(fieldName); + assertEquals(numPerSegment * 2, fvv.size()); + + // search should still return results + IndexSearcher searcher = new IndexSearcher(reader); + int k = 10; + TopDocs hits = searcher.search( + new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); + assertEquals(k, hits.totalHits.value()); + } + } + } + } + + public void testDataBlindIncompatibleEncodingMerge() throws Exception { + String fieldName = "field"; + int dims = 16; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + + try (Directory dir1 = newDirectory(); Directory dir2 = newDirectory()) { + // segment 1: UNSIGNED_BYTE, data-blind + IndexWriterConfig iwc1 = newIndexWriterConfig(); + iwc1.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); + try (IndexWriter w = new IndexWriter(dir1, iwc1)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + } + + // segment 2: PACKED_NIBBLE, data-blind + IndexWriterConfig iwc2 = newIndexWriterConfig(); + iwc2.setCodec(datablindCodec(ScalarEncoding.PACKED_NIBBLE)); + try (IndexWriter w = new IndexWriter(dir2, iwc2)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + } + + // merge both into a PACKED_NIBBLE data-blind index — UNSIGNED_BYTE segment has no raw + // floats, so re-quantization to PACKED_NIBBLE is impossible: expect an error. + // SerialMergeScheduler makes merges synchronous so the exception propagates directly. + try (Directory dirMerge = newDirectory()) { + IndexWriterConfig iwcMerge = newIndexWriterConfig(); + iwcMerge.setCodec(datablindCodec(ScalarEncoding.PACKED_NIBBLE)); + iwcMerge.setMergeScheduler(new SerialMergeScheduler()); + try (IndexWriter w = new IndexWriter(dirMerge, iwcMerge)) { + w.addIndexes(dir1, dir2); + expectThrows(Exception.class, () -> w.forceMerge(1)); + } + } + } + } + /** Updates vector metadata file to indicate zero vector length. */ private void updateVectorMetadataFile(Directory dir, String fileName) throws Exception { // Read original metadata From af0070500d4871764ca285c191ac470f2eb05d93 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sat, 2 May 2026 22:05:16 -0700 Subject: [PATCH 2/8] audit of format + reader --- ...Lucene104ScalarQuantizedVectorsFormat.java | 6 ++-- ...Lucene104ScalarQuantizedVectorsReader.java | 31 +++++++++---------- ...Lucene104ScalarQuantizedVectorsWriter.java | 8 ++--- ...Lucene104ScalarQuantizedVectorsFormat.java | 14 ++++----- 4 files changed, 27 insertions(+), 32 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index d312d528aa49..8a54792ffeba 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -126,9 +126,9 @@ public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { /** * Creates a new instance with the chosen quantization encoding and centering setting. * - *

When {@code enableCentering} is {@code false} (data-blind mode), no centroid is computed - * and no raw float vectors are written. This reduces storage at the cost of slightly lower - * quantization quality. + *

When {@code enableCentering} is {@code false} (data-blind mode), no centroid is computed and + * no raw float vectors are written. This reduces storage costs by 4x or more but reduces + * quantization accuracy, particularly at lower bit rates. */ public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding, boolean enableCentering) { super(NAME); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index a95e7ea5f528..c1311fe933a2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -237,7 +237,8 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { + VectorEncoding.FLOAT32); } - FloatVectorValues rawFloatVectorValues = getRawFloatVectorValues(field); + FloatVectorValues rawFloatVectorValues = + fi.isDataBlind() ? null : rawVectorsReader.getFloatVectorValues(field); if (rawFloatVectorValues == null || rawFloatVectorValues.size() == 0) { return OffHeapScalarQuantizedFloatVectorValues.load( @@ -358,23 +359,13 @@ public float[] getCentroid(String field) { return null; } - /** Returns {@code true} if raw float vectors are stored for {@code field}. */ boolean hasRawFloatVectors(String field) throws IOException { - FloatVectorValues raw = getRawFloatVectorValues(field); - return raw != null && raw.size() > 0; - } - - /** - * Returns raw float vectors from the underlying flat reader, or {@code null} if the field was - * not written there (data-blind mode). Some flat readers throw {@link IllegalArgumentException} - * instead of returning null for missing fields, so we catch that here. - */ - private FloatVectorValues getRawFloatVectorValues(String field) throws IOException { - try { - return rawVectorsReader.getFloatVectorValues(field); - } catch (IllegalArgumentException e) { - return null; // field not present in raw reader (written in data-blind mode) + FieldEntry fi = fields.get(field); + if (fi == null || fi.isDataBlind()) { + return false; } + FloatVectorValues raw = rawVectorsReader.getFloatVectorValues(field); + return raw != null && raw.size() > 0; } private static IndexInput openDataInput( @@ -583,6 +574,14 @@ private record FieldEntry( float centroidDP, OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { + boolean isDataBlind() { + if (centroid == null) return false; + for (float v : centroid) { + if (v != 0f) return false; + } + return true; + } + static FieldEntry create( IndexInput input, VectorEncoding vectorEncoding, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 6c29f0eddf69..0f1f8a715bae 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -711,8 +711,7 @@ public long ramBytesUsed() { * floats in memory altogether. */ private static class InMemoryFloatFieldWriter extends FlatFieldVectorsWriter { - private static final long SHALLOW_SIZE = - shallowSizeOfInstance(InMemoryFloatFieldWriter.class); + private static final long SHALLOW_SIZE = shallowSizeOfInstance(InMemoryFloatFieldWriter.class); private final List vectors = new ArrayList<>(); private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); private boolean finished; @@ -921,10 +920,7 @@ private MergedQuantizedByteVectorValues( } static MergedQuantizedByteVectorValues merge( - FieldInfo fieldInfo, - MergeState mergeState, - float[] centroid, - ScalarEncoding encoding) + FieldInfo fieldInfo, MergeState mergeState, float[] centroid, ScalarEncoding encoding) throws IOException { List subs = new ArrayList<>(); for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 51e8d7ad06da..d0268cf11382 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -23,7 +23,6 @@ import static org.hamcrest.Matchers.oneOf; import java.io.IOException; -import java.util.Arrays; import java.util.Locale; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; @@ -261,8 +260,7 @@ private void replaceWithEmptyVectorFile(Directory dir, String fileName) throws E // ---- data-blind (enableCentering=false) tests ---- private Codec datablindCodec(ScalarEncoding enc) { - return TestUtil.alwaysKnnVectorsFormat( - new Lucene104ScalarQuantizedVectorsFormat(enc, false)); + return TestUtil.alwaysKnnVectorsFormat(new Lucene104ScalarQuantizedVectorsFormat(enc, false)); } public void testDataBlindSearchCorrectness() throws Exception { @@ -288,7 +286,8 @@ public void testDataBlindSearchCorrectness() throws Exception { try (IndexReader reader = DirectoryReader.open(w)) { IndexSearcher searcher = new IndexSearcher(reader); int k = Math.min(10, numVectors); - TopDocs hits = searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); + TopDocs hits = + searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); assertEquals(k, hits.totalHits.value()); } } @@ -360,8 +359,8 @@ public void testDataBlindMultiSegmentMerge() throws Exception { // search should still return results IndexSearcher searcher = new IndexSearcher(reader); int k = 10; - TopDocs hits = searcher.search( - new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); + TopDocs hits = + searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); assertEquals(k, hits.totalHits.value()); } } @@ -373,7 +372,8 @@ public void testDataBlindIncompatibleEncodingMerge() throws Exception { int dims = 16; VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; - try (Directory dir1 = newDirectory(); Directory dir2 = newDirectory()) { + try (Directory dir1 = newDirectory(); + Directory dir2 = newDirectory()) { // segment 1: UNSIGNED_BYTE, data-blind IndexWriterConfig iwc1 = newIndexWriterConfig(); iwc1.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); From fc84a8183ad7ade94d8d7d2b90d9b6a892f35946 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sun, 3 May 2026 11:20:10 -0700 Subject: [PATCH 3/8] fix check --- .../Lucene104ScalarQuantizedVectorsWriter.java | 12 +++++++++--- ...estLucene104HnswScalarQuantizedVectorsFormat.java | 1 + 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 0f1f8a715bae..3463095a41cf 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -49,6 +49,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; @@ -126,7 +127,7 @@ public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExceptio FlatFieldVectorsWriter storage = enableCentering ? (FlatFieldVectorsWriter) this.rawVectorDelegate.addField(fieldInfo) - : new InMemoryFloatFieldWriter(); + : new InMemoryFloatFieldWriter(fieldInfo); FieldWriter fieldWriter = new FieldWriter(fieldInfo, storage, enableCentering); fields.add(fieldWriter); return fieldWriter; @@ -712,19 +713,24 @@ public long ramBytesUsed() { */ private static class InMemoryFloatFieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_SIZE = shallowSizeOfInstance(InMemoryFloatFieldWriter.class); + private final int dim; private final List vectors = new ArrayList<>(); private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); private boolean finished; + public InMemoryFloatFieldWriter(FieldInfo fieldInfo) { + dim = fieldInfo.getVectorDimension(); + } + @Override public void addValue(int docID, float[] vectorValue) throws IOException { - vectors.add(Arrays.copyOf(vectorValue, vectorValue.length)); + vectors.add(copyValue(vectorValue)); docsWithField.add(docID); } @Override public float[] copyValue(float[] vectorValue) { - return Arrays.copyOf(vectorValue, vectorValue.length); + return ArrayUtil.copyOfSubArray(vectorValue, 0, dim); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java index 0ec1689c36e4..68edf0bc87ea 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -91,6 +91,7 @@ public KnnVectorsFormat knnVectorsFormat() { + " maxConn=10, beamWidth=20, tinySegmentsThreshold=100," + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat," + " encoding=UNSIGNED_BYTE," + + " enableCentering=true," + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; From f37bdabdd61480ae20565930ab08a9986b1138d8 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sun, 3 May 2026 11:39:05 -0700 Subject: [PATCH 4/8] fix field writer to better match lucene99 raw writer --- ...Lucene104ScalarQuantizedVectorsWriter.java | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 3463095a41cf..88ecf344dbba 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -704,33 +704,38 @@ public long ramBytesUsed() { } } - /** - * In-memory storage for float vectors when centering is disabled. Holds vectors buffered during - * indexing so they can be quantized at flush time, without writing any raw float data to disk. - * - *

TODO: replace with immediate per-vector quantization in addValue() to avoid buffering raw - * floats in memory altogether. - */ private static class InMemoryFloatFieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_SIZE = shallowSizeOfInstance(InMemoryFloatFieldWriter.class); - private final int dim; + private final FieldInfo fieldInfo; private final List vectors = new ArrayList<>(); private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); private boolean finished; + private int lastDocID = -1; public InMemoryFloatFieldWriter(FieldInfo fieldInfo) { - dim = fieldInfo.getVectorDimension(); + this.fieldInfo = fieldInfo; } @Override public void addValue(int docID, float[] vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)"); + } + assert docID > lastDocID; vectors.add(copyValue(vectorValue)); docsWithField.add(docID); + lastDocID = docID; } @Override public float[] copyValue(float[] vectorValue) { - return ArrayUtil.copyOfSubArray(vectorValue, 0, dim); + return ArrayUtil.copyOfSubArray(vectorValue, 0, fieldInfo.getVectorDimension()); } @Override @@ -756,10 +761,14 @@ public boolean isFinished() { @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; - for (float[] v : vectors) { - size += RamUsageEstimator.sizeOf(v); + if (vectors.isEmpty()) { + return size; } - return size; + return size + + docsWithField.ramBytesUsed() + + (long) vectors.size() + * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + + (long) vectors.size() * fieldInfo.getVectorDimension() * Float.BYTES; } } From 818fe8523763e43db19a8ee1c2f89b12b6002e55 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sun, 3 May 2026 12:11:58 -0700 Subject: [PATCH 5/8] move to lucene105 --- lucene/core/src/java/module-info.java | 5 +- ...Lucene104ScalarQuantizedVectorsFormat.java | 21 +- ...Lucene104ScalarQuantizedVectorsReader.java | 22 +- ...Lucene104ScalarQuantizedVectorsWriter.java | 397 +----- ...ne105HnswScalarQuantizedVectorsFormat.java | 225 ++++ .../Lucene105ScalarQuantizedVectorScorer.java | 294 +++++ ...Lucene105ScalarQuantizedVectorsFormat.java | 170 +++ ...Lucene105ScalarQuantizedVectorsReader.java | 687 ++++++++++ ...Lucene105ScalarQuantizedVectorsWriter.java | 1102 +++++++++++++++++ ...fHeapScalarQuantizedFloatVectorValues.java | 402 ++++++ .../OffHeapScalarQuantizedVectorValues.java | 490 ++++++++ .../lucene/codecs/lucene105/package-info.java | 22 + .../org.apache.lucene.codecs.KnnVectorsFormat | 2 + ...ne104HnswScalarQuantizedVectorsFormat.java | 1 - ...Lucene104ScalarQuantizedVectorsFormat.java | 160 --- ...ne105HnswScalarQuantizedVectorsFormat.java | 215 ++++ ...Lucene105ScalarQuantizedVectorsFormat.java | 459 +++++++ 17 files changed, 4101 insertions(+), 573 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorScorer.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedFloatVectorValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedVectorValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene105/package-info.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105HnswScalarQuantizedVectorsFormat.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105ScalarQuantizedVectorsFormat.java diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 1358b9fe068d..c369c8302248 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -32,6 +32,7 @@ exports org.apache.lucene.codecs.lucene99; exports org.apache.lucene.codecs.lucene103.blocktree; exports org.apache.lucene.codecs.lucene104; + exports org.apache.lucene.codecs.lucene105; exports org.apache.lucene.codecs.perfield; exports org.apache.lucene.codecs; exports org.apache.lucene.document; @@ -86,7 +87,9 @@ provides org.apache.lucene.codecs.KnnVectorsFormat with org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat, org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat, - org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat; + org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene105.Lucene105ScalarQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene105.Lucene105HnswScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index 8a54792ffeba..fb8221deffb0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -111,35 +111,22 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); private final ScalarEncoding encoding; - private final boolean enableCentering; - /** Creates a new instance with UNSIGNED_BYTE encoding and centering enabled. */ + /** Creates a new instance with UNSIGNED_BYTE encoding. */ public Lucene104ScalarQuantizedVectorsFormat() { this(ScalarEncoding.UNSIGNED_BYTE); } - /** Creates a new instance with the chosen quantization encoding and centering enabled. */ + /** Creates a new instance with the chosen quantization encoding. */ public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { - this(encoding, true); - } - - /** - * Creates a new instance with the chosen quantization encoding and centering setting. - * - *

When {@code enableCentering} is {@code false} (data-blind mode), no centroid is computed and - * no raw float vectors are written. This reduces storage costs by 4x or more but reduces - * quantization accuracy, particularly at lower bit rates. - */ - public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding, boolean enableCentering) { super(NAME); this.encoding = encoding; - this.enableCentering = enableCentering; } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene104ScalarQuantizedVectorsWriter( - state, encoding, enableCentering, rawVectorFormat.fieldsWriter(state), scorer); + state, encoding, rawVectorFormat.fieldsWriter(state), scorer); } @Override @@ -159,8 +146,6 @@ public String toString() { + NAME + ", encoding=" + encoding - + ", enableCentering=" - + enableCentering + ", flatVectorScorer=" + scorer + ", rawVectorFormat=" diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index c1311fe933a2..15fe4950036f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -237,10 +237,9 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { + VectorEncoding.FLOAT32); } - FloatVectorValues rawFloatVectorValues = - fi.isDataBlind() ? null : rawVectorsReader.getFloatVectorValues(field); + FloatVectorValues rawFloatVectorValues = rawVectorsReader.getFloatVectorValues(field); - if (rawFloatVectorValues == null || rawFloatVectorValues.size() == 0) { + if (rawFloatVectorValues.size() == 0) { return OffHeapScalarQuantizedFloatVectorValues.load( fi.ordToDocDISIReaderConfiguration, fi.dimension, @@ -359,15 +358,6 @@ public float[] getCentroid(String field) { return null; } - boolean hasRawFloatVectors(String field) throws IOException { - FieldEntry fi = fields.get(field); - if (fi == null || fi.isDataBlind()) { - return false; - } - FloatVectorValues raw = rawVectorsReader.getFloatVectorValues(field); - return raw != null && raw.size() > 0; - } - private static IndexInput openDataInput( SegmentReadState state, int versionMeta, @@ -574,14 +564,6 @@ private record FieldEntry( float centroidDP, OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { - boolean isDataBlind() { - if (centroid == null) return false; - for (float v : centroid) { - if (v != 0f) return false; - } - return true; - } - static FieldEntry create( IndexInput input, VectorEncoding vectorEncoding, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 88ecf344dbba..5373ff85be3d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -34,7 +34,6 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -49,9 +48,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; @@ -70,7 +67,6 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { private final List fields = new ArrayList<>(); private final IndexOutput meta, vectorData; private final ScalarEncoding encoding; - private final boolean enableCentering; private final FlatVectorsWriter rawVectorDelegate; private boolean finished; @@ -78,13 +74,11 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { public Lucene104ScalarQuantizedVectorsWriter( SegmentWriteState state, ScalarEncoding encoding, - boolean enableCentering, FlatVectorsWriter rawVectorDelegate, Lucene104ScalarQuantizedVectorScorer vectorsScorer) throws IOException { super(vectorsScorer); this.encoding = encoding; - this.enableCentering = enableCentering; this.segmentWriteState = state; String metaFileName = IndexFileNames.segmentFileName( @@ -122,17 +116,15 @@ public Lucene104ScalarQuantizedVectorsWriter( @Override public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { @SuppressWarnings("unchecked") - FlatFieldVectorsWriter storage = - enableCentering - ? (FlatFieldVectorsWriter) this.rawVectorDelegate.addField(fieldInfo) - : new InMemoryFloatFieldWriter(fieldInfo); - FieldWriter fieldWriter = new FieldWriter(fieldInfo, storage, enableCentering); + FieldWriter fieldWriter = + new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); fields.add(fieldWriter); return fieldWriter; } - return this.rawVectorDelegate.addField(fieldInfo); + return rawVectorDelegate; } @Override @@ -145,17 +137,13 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } final float[] clusterCenter; int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); - if (!enableCentering) { - clusterCenter = new float[field.fieldInfo.getVectorDimension()]; - } else { - clusterCenter = new float[field.dimensionSums.length]; - if (vectorCount > 0) { - for (int i = 0; i < field.dimensionSums.length; i++) { - clusterCenter[i] = field.dimensionSums[i] / vectorCount; - } - if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { - VectorUtil.l2normalize(clusterCenter); - } + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); } } if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { @@ -333,22 +321,15 @@ public void finish() throws IOException { @Override public void mergeOneFlatVectorField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); return; } - if (enableCentering) { - mergeOneFlatVectorFieldCentered(fieldInfo, mergeState); - } else { - mergeOneFlatVectorFieldDataBlind(fieldInfo, mergeState); - } - } - - private void mergeOneFlatVectorFieldCentered(FieldInfo fieldInfo, MergeState mergeState) - throws IOException { - rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); + final float[] centroid; final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + centroid = mergedCentroid; if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { segmentWriteState.infoStream.message( QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); @@ -363,84 +344,22 @@ private void mergeOneFlatVectorFieldCentered(FieldInfo fieldInfo, MergeState mer floatVectorValues, new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), encoding, - mergedCentroid); + centroid); long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; float centroidDp = - docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(mergedCentroid, mergedCentroid) : 0; + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; writeMeta( fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, - mergedCentroid, + centroid, centroidDp, docsWithField); } - private void mergeOneFlatVectorFieldDataBlind(FieldInfo fieldInfo, MergeState mergeState) - throws IOException { - float[] zeroCentroid = new float[fieldInfo.getVectorDimension()]; - // Classify each contributing segment as quantized-only or re-quantizable (has raw floats). - // Quantized-only segments must have a matching encoding; otherwise re-quantization from raw - // floats would be required, which is not possible when raw floats were never written. - boolean anyHasRawFloats = false; - for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { - KnnVectorsReader reader = mergeState.knnVectorsReaders[i]; - if (reader == null || reader.getFloatVectorValues(fieldInfo.name) == null) { - continue; - } - if (hasRawFloatVectors(reader, fieldInfo.name)) { - anyHasRawFloats = true; - } else { - QuantizedByteVectorValues qvv = getQuantizedVectorValues(reader, fieldInfo.name); - if (qvv != null && qvv.getScalarEncoding() != encoding) { - throw new IllegalStateException( - "Cannot merge field \"" - + fieldInfo.name - + "\" from data-blind segment with encoding " - + qvv.getScalarEncoding() - + " into data-blind format with encoding " - + encoding - + ": re-quantization requires raw float vectors"); - } - } - } - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); - DocsWithFieldSet docsWithField; - if (anyHasRawFloats) { - // At least one segment has raw floats; use the float path with zero centroid. - // Quantized-only segments serve reconstructed floats via getFloatVectorValues(). - FloatVectorValues floatVectorValues = - MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); - } - QuantizedFloatVectorValues quantizedVectorValues = - new QuantizedFloatVectorValues( - floatVectorValues, - new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), - encoding, - zeroCentroid); - docsWithField = writeVectorData(vectorData, quantizedVectorValues); - } else { - // All segments are quantized-only with matching encoding: copy bytes directly. - MergedQuantizedByteVectorValues mergedQBVV = - MergedQuantizedByteVectorValues.merge(fieldInfo, mergeState, zeroCentroid, encoding); - docsWithField = writeVectorData(vectorData, mergedQBVV); - } - long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; - writeMeta( - fieldInfo, - segmentWriteState.segmentInfo.maxDoc(), - vectorDataOffset, - vectorDataLength, - zeroCentroid, - 0f, - docsWithField); - } - static DocsWithFieldSet writeVectorData( IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); @@ -508,24 +427,6 @@ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { return null; } - static boolean hasRawFloatVectors(KnnVectorsReader vectorsReader, String fieldName) - throws IOException { - vectorsReader = vectorsReader.unwrapReaderForField(fieldName); - if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { - return reader.hasRawFloatVectors(fieldName); - } - return true; // non-Lucene104 format; assume raw floats are available - } - - static QuantizedByteVectorValues getQuantizedVectorValues( - KnnVectorsReader vectorsReader, String fieldName) throws IOException { - vectorsReader = vectorsReader.unwrapReaderForField(fieldName); - if (vectorsReader instanceof Lucene104ScalarQuantizedVectorsReader reader) { - return reader.getQuantizedVectorValues(fieldName); - } - return null; - } - static int mergeAndRecalculateCentroids( MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { boolean recalculate = false; @@ -617,20 +518,15 @@ public long ramBytesUsed() { static class FieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); private final FieldInfo fieldInfo; - private final boolean enableCentering; private boolean finished; private final FlatFieldVectorsWriter flatFieldVectorsWriter; - final float[] dimensionSums; + private final float[] dimensionSums; private final FloatArrayList magnitudes = new FloatArrayList(); - FieldWriter( - FieldInfo fieldInfo, - FlatFieldVectorsWriter flatFieldVectorsWriter, - boolean enableCentering) { + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { this.fieldInfo = fieldInfo; this.flatFieldVectorsWriter = flatFieldVectorsWriter; - this.enableCentering = enableCentering; - this.dimensionSums = enableCentering ? new float[fieldInfo.getVectorDimension()] : null; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; } @Override @@ -658,10 +554,6 @@ public void finish() throws IOException { if (finished) { return; } - if (!flatFieldVectorsWriter.isFinished()) { - // InMemoryFloatFieldWriter is not flushed through the raw delegate, so finish it here. - flatFieldVectorsWriter.finish(); - } assert flatFieldVectorsWriter.isFinished(); finished = true; } @@ -678,12 +570,10 @@ public void addValue(int docID, float[] vectorValue) throws IOException { float dp = VectorUtil.dotProduct(vectorValue, vectorValue); float divisor = (float) Math.sqrt(dp); magnitudes.add(divisor); - if (enableCentering) { - for (int i = 0; i < vectorValue.length; i++) { - dimensionSums[i] += (vectorValue[i] / divisor); - } + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); } - } else if (enableCentering) { + } else { for (int i = 0; i < vectorValue.length; i++) { dimensionSums[i] += vectorValue[i]; } @@ -704,74 +594,6 @@ public long ramBytesUsed() { } } - private static class InMemoryFloatFieldWriter extends FlatFieldVectorsWriter { - private static final long SHALLOW_SIZE = shallowSizeOfInstance(InMemoryFloatFieldWriter.class); - private final FieldInfo fieldInfo; - private final List vectors = new ArrayList<>(); - private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - private boolean finished; - private int lastDocID = -1; - - public InMemoryFloatFieldWriter(FieldInfo fieldInfo) { - this.fieldInfo = fieldInfo; - } - - @Override - public void addValue(int docID, float[] vectorValue) throws IOException { - if (finished) { - throw new IllegalStateException("already finished, cannot add more values"); - } - if (docID == lastDocID) { - throw new IllegalArgumentException( - "VectorValuesField \"" - + fieldInfo.name - + "\" appears more than once in this document (only one value is allowed per field)"); - } - assert docID > lastDocID; - vectors.add(copyValue(vectorValue)); - docsWithField.add(docID); - lastDocID = docID; - } - - @Override - public float[] copyValue(float[] vectorValue) { - return ArrayUtil.copyOfSubArray(vectorValue, 0, fieldInfo.getVectorDimension()); - } - - @Override - public List getVectors() { - return vectors; - } - - @Override - public DocsWithFieldSet getDocsWithFieldSet() { - return docsWithField; - } - - @Override - public void finish() { - finished = true; - } - - @Override - public boolean isFinished() { - return finished; - } - - @Override - public long ramBytesUsed() { - long size = SHALLOW_SIZE; - if (vectors.isEmpty()) { - return size; - } - return size - + docsWithField.ramBytesUsed() - + (long) vectors.size() - * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) - + (long) vectors.size() * fieldInfo.getVectorDimension() * Float.BYTES; - } - } - static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private OptimizedScalarQuantizer.QuantizationResult corrections; private final byte[] quantized; @@ -887,177 +709,6 @@ public int ordToDoc(int ord) { } } - private static final class QuantizedByteVectorValuesSub extends DocIDMerger.Sub { - final QuantizedByteVectorValues values; - final KnnVectorValues.DocIndexIterator iterator; - - QuantizedByteVectorValuesSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { - super(docMap); - this.values = values; - this.iterator = values.iterator(); - assert iterator.docID() == -1; - } - - @Override - public int nextDoc() throws IOException { - return iterator.nextDoc(); - } - } - - /** Merged view of {@link QuantizedByteVectorValues} from multiple segments. */ - static final class MergedQuantizedByteVectorValues extends QuantizedByteVectorValues { - private final List subs; - private final DocIDMerger docIdMerger; - private final int size; - private final float[] centroid; - private final float centroidDP; - private final ScalarEncoding scalarEncoding; - private int docId = -1; - private int lastOrd = -1; - private QuantizedByteVectorValuesSub current; - - private MergedQuantizedByteVectorValues( - List subs, - MergeState mergeState, - float[] centroid, - ScalarEncoding scalarEncoding) - throws IOException { - this.subs = subs; - this.docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); - int totalSize = 0; - for (QuantizedByteVectorValuesSub sub : subs) { - totalSize += sub.values.size(); - } - this.size = totalSize; - this.centroid = centroid; - this.centroidDP = VectorUtil.dotProduct(centroid, centroid); - this.scalarEncoding = scalarEncoding; - } - - static MergedQuantizedByteVectorValues merge( - FieldInfo fieldInfo, MergeState mergeState, float[] centroid, ScalarEncoding encoding) - throws IOException { - List subs = new ArrayList<>(); - for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { - KnnVectorsReader reader = mergeState.knnVectorsReaders[i]; - if (reader == null) { - continue; - } - QuantizedByteVectorValues qbvv = getQuantizedVectorValues(reader, fieldInfo.name); - if (qbvv == null || qbvv.size() == 0) { - continue; - } - subs.add(new QuantizedByteVectorValuesSub(mergeState.docMaps[i], qbvv)); - } - return new MergedQuantizedByteVectorValues(subs, mergeState, centroid, encoding); - } - - @Override - public DocIndexIterator iterator() { - return new DocIndexIterator() { - private int index = -1; - - @Override - public int docID() { - return docId; - } - - @Override - public int index() { - return index; - } - - @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; - index = NO_MORE_DOCS; - } else { - docId = current.mappedDocID; - ++lastOrd; - ++index; - } - return docId; - } - - @Override - public int advance(int target) { - throw new UnsupportedOperationException(); - } - - @Override - public long cost() { - return size; - } - }; - } - - @Override - public byte[] vectorValue(int ord) throws IOException { - if (ord != lastOrd) { - throw new IllegalStateException( - "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); - } - return current.values.vectorValue(current.iterator.index()); - } - - @Override - public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) - throws IOException { - if (ord != lastOrd) { - throw new IllegalStateException( - "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); - } - return current.values.getCorrectiveTerms(current.iterator.index()); - } - - @Override - public int dimension() { - return subs.isEmpty() ? 0 : subs.get(0).values.dimension(); - } - - @Override - public int size() { - return size; - } - - @Override - public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - } - - @Override - public ScalarEncoding getScalarEncoding() { - return scalarEncoding; - } - - @Override - public float[] getCentroid() { - return centroid; - } - - @Override - public float getCentroidDP() { - return centroidDP; - } - - @Override - public OptimizedScalarQuantizer getQuantizer() { - throw new UnsupportedOperationException(); - } - - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); - } - - @Override - public QuantizedByteVectorValues copy() { - throw new UnsupportedOperationException(); - } - } - static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..101174cf0b7d --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.HNSW_GRAPH_THRESHOLD; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; + +/** + * A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary + * quantized using {@link Lucene105ScalarQuantizedVectorsFormat} before being stored in the graph. + * + * @lucene.experimental + */ +public class Lucene105HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "Lucene105HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private final Lucene105ScalarQuantizedVectorsFormat flatVectorsFormat; + + /** + * The threshold to use to bypass HNSW graph building for tiny segments in terms of k for a graph + * i.e. number of docs to match the query (default is {@link + * Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD}). + * + *

+ */ + private final int tinySegmentsThreshold; + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public Lucene105HnswScalarQuantizedVectorsFormat() { + this( + ScalarEncoding.UNSIGNED_BYTE, + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + DEFAULT_NUM_MERGE_WORKER, + null, + HNSW_GRAPH_THRESHOLD); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene105HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { + this( + ScalarEncoding.UNSIGNED_BYTE, + maxConn, + beamWidth, + DEFAULT_NUM_MERGE_WORKER, + null, + HNSW_GRAPH_THRESHOLD); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param encoding the quantization encoding used to encode the vectors + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene105HnswScalarQuantizedVectorsFormat( + ScalarEncoding encoding, int maxConn, int beamWidth) { + this(encoding, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param encoding the quantization encoding used to encode the vectors + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene105HnswScalarQuantizedVectorsFormat( + ScalarEncoding encoding, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec) { + this(encoding, maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene105HnswScalarQuantizedVectorsFormat( + ScalarEncoding encoding, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec, + int tinySegmentsThreshold) { + super(NAME); + flatVectorsFormat = new Lucene105ScalarQuantizedVectorsFormat(encoding); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + this.tinySegmentsThreshold = tinySegmentsThreshold; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + tinySegmentsThreshold); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene105HnswScalarQuantizedVectorsFormat(name=Lucene105HnswScalarQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", tinySegmentsThreshold=" + + tinySegmentsThreshold + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..9dc5a9aadb5b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorScorer.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; + +/** + * Vector scorer over OptimizedScalarQuantized vectors + * + * @lucene.experimental + */ +public class Lucene105ScalarQuantizedVectorScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene105ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues qv) { + return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues qv) { + FlatVectorsScorer.checkDimensions(target.length, qv.dimension()); + OptimizedScalarQuantizer quantizer = qv.getQuantizer(); + ScalarEncoding scalarEncoding = qv.getScalarEncoding(); + byte[] scratch = new byte[scalarEncoding.getDiscreteDimensions(qv.dimension())]; + final byte[] targetQuantized; + if (scalarEncoding.isAsymmetric() == false) { + targetQuantized = scratch; + } else { + // This is asymmetric quantization, we will pack the vector + targetQuantized = new byte[scalarEncoding.getQueryPackedLength(scratch.length)]; + } + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + var targetCorrectiveTerms = + quantizer.scalarQuantize( + target, scratch, scalarEncoding.getQueryBits(), qv.getCentroid()); + // for asymmetric encodings with 4-bit query, we need to transpose the nibbles for fast + // scoring comparisons + if (scalarEncoding == ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE + || scalarEncoding == ScalarEncoding.DIBIT_QUERY_NIBBLE) { + OptimizedScalarQuantizer.transposeHalfByte(scratch, targetQuantized); + } + return new RandomVectorScorer.AbstractRandomVectorScorer(qv) { + @Override + public float score(int node) throws IOException { + return quantizedScore( + targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction); + } + }; + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + FlatVectorsScorer.checkDimensions(target.length, vectorValues.dimension()); + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues scoringVectors, + QuantizedByteVectorValues targetVectors) { + return new AsymmetricQuantizedRandomVectorScorerSupplier( + scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "Lucene105ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + + nonQuantizedDelegate + + ")"; + } + + static class AsymmetricQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final QuantizedByteVectorValues queryVectors; + private final QuantizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + AsymmetricQuantizedRandomVectorScorerSupplier( + QuantizedByteVectorValues queryVectors, + QuantizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction) { + assert targetVectors.getScalarEncoding().isAsymmetric(); + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + final QuantizedByteVectorValues targetVectors = this.targetVectors.copy(); + final QuantizedByteVectorValues queryVectors = this.queryVectors.copy(); + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetVectors) { + private OptimizedScalarQuantizer.QuantizationResult queryCorrections = null; + private byte[] vector = null; + + @Override + public void setScoringOrdinal(int node) throws IOException { + vector = queryVectors.vectorValue(node); + queryCorrections = queryVectors.getCorrectiveTerms(node); + } + + @Override + public float score(int node) throws IOException { + if (vector == null || queryCorrections == null) { + throw new IllegalStateException("setScoringOrdinal was not called"); + } + + return quantizedScore(vector, queryCorrections, targetVectors, node, similarityFunction); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new AsymmetricQuantizedRandomVectorScorerSupplier( + queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + private static final class ScalarQuantizedVectorScorerSupplier + implements RandomVectorScorerSupplier { + private final QuantizedByteVectorValues targetValues; + private final QuantizedByteVectorValues values; + private final VectorSimilarityFunction similarity; + + public ScalarQuantizedVectorScorerSupplier( + QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException { + assert values.getScalarEncoding().isAsymmetric() == false; + this.targetValues = values.copy(); + this.values = values; + this.similarity = similarity; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private byte[] targetVector; + private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms; + + @Override + public float score(int node) throws IOException { + return quantizedScore(targetVector, targetCorrectiveTerms, values, node, similarity); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + var rawTargetVector = targetValues.vectorValue(node); + switch (values.getScalarEncoding()) { + case UNSIGNED_BYTE, SEVEN_BIT -> targetVector = rawTargetVector; + case PACKED_NIBBLE -> { + if (targetVector == null) { + targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)]; + } + OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector); + } + case SINGLE_BIT_QUERY_NIBBLE, DIBIT_QUERY_NIBBLE -> { + throw new IllegalStateException( + values.getScalarEncoding().name() + + " encoding is not supported for symmetric quantization"); + } + } + targetCorrectiveTerms = targetValues.getCorrectiveTerms(node); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedVectorScorerSupplier(values.copy(), similarity); + } + } + + private static final float[] SCALE_LUT = + new float[] { + 1f, + 1f / ((1 << 2) - 1), + 1f / ((1 << 3) - 1), + 1f / ((1 << 4) - 1), + 1f / ((1 << 5) - 1), + 1f / ((1 << 6) - 1), + 1f / ((1 << 7) - 1), + 1f / ((1 << 8) - 1), + }; + + private static float quantizedScore( + byte[] quantizedQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + QuantizedByteVectorValues targetVectors, + int targetOrd, + VectorSimilarityFunction similarityFunction) + throws IOException { + var scalarEncoding = targetVectors.getScalarEncoding(); + byte[] quantizedDoc = targetVectors.vectorValue(targetOrd); + float qcDist = + switch (scalarEncoding) { + case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc); + case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc); + case PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc); + case SINGLE_BIT_QUERY_NIBBLE -> + VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc); + case DIBIT_QUERY_NIBBLE -> VectorUtil.int4DibitDotProduct(quantizedQuery, quantizedDoc); + }; + OptimizedScalarQuantizer.QuantizationResult indexCorrections = + targetVectors.getCorrectiveTerms(targetOrd); + float queryScale = SCALE_LUT[scalarEncoding.getQueryBits() - 1]; + float scale = SCALE_LUT[scalarEncoding.getBits() - 1]; + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we must scale according to the bits + float lx = (indexCorrections.upperInterval() - ax) * scale; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * queryScale; + float y1 = queryCorrections.quantizedComponentSum(); + float score = + ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - 2 * score; + // Ensure that 'score' (the squared euclidean distance) is non-negative. The computed value + // may be negative as a result of quantization loss. + return 1 / (1f + Math.max(score, 0f)); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + // Ensure that 'score' (a normalized dot product) is in [-1,1]. The computed value may be out + // of bounds as a result of quantization loss. + score = Math.clamp(score, -1, 1); + return (1f + score) / 2f; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..45c688e11e24 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsFormat.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; + +/** + * The quantization format used here is a per-vector optimized scalar quantization. These ideas are + * evolutions of LVQ proposed in Similarity search in the + * blink of an eye with compressed indices by Cecilia Aguerrebere et al., the previous work on + * globally optimized scalar quantization in Apache Lucene, and Accelerating Large-Scale Inference with Anisotropic + * Vector Quantization by Ruiqi Guo et. al. Also see {@link + * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are: + * + * + * + * A previous work related to improvements over regular LVQ is Practical and Asymptotically Optimal Quantization of + * High-Dimensional Vectors in Euclidean Space for Approximate Nearest Neighbor Search by + * Jianyang Gao, et. al. + * + *

The format is stored within two files: + * + *

.veq (vector data) file

+ * + *

Stores the quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

+ * + *

.vemq (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

+ * + * @lucene.experimental + */ +public class Lucene105ScalarQuantizedVectorsFormat extends FlatVectorsFormat { + public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; + public static final String NAME = "Lucene105ScalarQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "Lucene105ScalarQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene105ScalarQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemq"; + static final String VECTOR_DATA_EXTENSION = "veq"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private static final Lucene105ScalarQuantizedVectorScorer scorer = + new Lucene105ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private final ScalarEncoding encoding; + private final boolean enableCentering; + + /** Creates a new instance with UNSIGNED_BYTE encoding and centering enabled. */ + public Lucene105ScalarQuantizedVectorsFormat() { + this(ScalarEncoding.UNSIGNED_BYTE); + } + + /** Creates a new instance with the chosen quantization encoding and centering enabled. */ + public Lucene105ScalarQuantizedVectorsFormat(ScalarEncoding encoding) { + this(encoding, true); + } + + /** + * Creates a new instance with the chosen quantization encoding and centering setting. + * + *

When {@code enableCentering} is {@code false} (data-blind mode), no centroid is computed and + * no raw float vectors are written. This reduces storage costs by 4x or more but reduces + * quantization accuracy, particularly at lower bit rates. + */ + public Lucene105ScalarQuantizedVectorsFormat(ScalarEncoding encoding, boolean enableCentering) { + super(NAME); + this.encoding = encoding; + this.enableCentering = enableCentering; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene105ScalarQuantizedVectorsWriter( + state, encoding, enableCentering, rawVectorFormat.fieldsWriter(state), scorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene105ScalarQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene105ScalarQuantizedVectorsFormat(name=" + + NAME + + ", encoding=" + + encoding + + ", enableCentering=" + + enableCentering + + ", flatVectorScorer=" + + scorer + + ", rawVectorFormat=" + + rawVectorFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java new file mode 100644 index 000000000000..59c198e75b1a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java @@ -0,0 +1,687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import static org.apache.lucene.codecs.lucene105.Lucene105ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Stream; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; +import org.apache.lucene.util.quantization.QuantizedVectorsReader; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +/** + * Reader for scalar quantized vectors in the Lucene 10.5 format. + * + * @lucene.experimental + */ +public class Lucene105ScalarQuantizedVectorsReader extends FlatVectorsReader + implements QuantizedVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene105ScalarQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final Lucene105ScalarQuantizedVectorScorer vectorScorer; + public static final int EXHAUSTIVE_BULK_SCORE_ORDS = 64; + + public Lucene105ScalarQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + Lucene105ScalarQuantizedVectorScorer vectorsScorer) + throws IOException { + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + this(state, rawVectorsReader, vectorsScorer, DataAccessHint.RANDOM); + } + + public Lucene105ScalarQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + Lucene105ScalarQuantizedVectorScorer vectorsScorer, + DataAccessHint accessHint) + throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene105ScalarQuantizedVectorsFormat.META_EXTENSION); + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene105ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene105ScalarQuantizedVectorsFormat.VERSION_START, + Lucene105ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + + final IOContext.FileOpenHint[] hints = + Stream.of(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, accessHint) + .filter(Objects::nonNull) + .toArray(IOContext.FileOpenHint[]::new); + quantizedVectorData = + openDataInput( + state, + versionMeta, + VECTOR_DATA_EXTENSION, + Lucene105ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + state.context.withHints(hints)); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + dimension + + " != " + + fieldEntry.dimension); + } + + long numQuantizedVectorBytes = + Math.multiplyExact( + (fieldEntry.scalarEncoding.getDocPackedLength(dimension) + + (Float.BYTES * 3) + + Integer.BYTES), + (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (dims=" + + dimension + + " + 16" + + ") = " + + numQuantizedVectorBytes); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData), + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fi.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + + FloatVectorValues rawFloatVectorValues = + fi.isDataBlind() ? null : rawVectorsReader.getFloatVectorValues(field); + + if (rawFloatVectorValues == null || rawFloatVectorValues.size() == 0) { + return OffHeapScalarQuantizedFloatVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + fi.scalarEncoding, + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + } + + OffHeapScalarQuantizedVectorValues sqvv = + OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + return new ScalarQuantizedVectorValues(rawFloatVectorValues, sqvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits()); + // if k is larger than the number of vectors we expect to visit in an HNSW search, + // we can just iterate over all vectors and collect them. + int[] ords = new int[EXHAUSTIVE_BULK_SCORE_ORDS]; + float[] scores = new float[EXHAUSTIVE_BULK_SCORE_ORDS]; + int numOrds = 0; + int numVectors = scorer.maxOrd(); + for (int i = 0; i < numVectors; i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + if (knnCollector.earlyTerminated()) { + break; + } + ords[numOrds++] = i; + if (numOrds == ords.length) { + knnCollector.incVisitedCount(numOrds); + if (scorer.bulkScore(ords, scores, numOrds) > knnCollector.minCompetitiveSimilarity()) { + for (int j = 0; j < numOrds; j++) { + knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]); + } + } + numOrds = 0; + } + } + } + + if (numOrds > 0) { + knnCollector.incVisitedCount(numOrds); + if (scorer.bulkScore(ords, scores, numOrds) > knnCollector.minCompetitiveSimilarity()) { + for (int j = 0; j < numOrds; j++) { + knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]); + } + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + Objects.requireNonNull(fieldInfo); + var raw = rawVectorsReader.getOffHeapByteSize(fieldInfo); + var fieldEntry = fields.get(fieldInfo.name); + if (fieldEntry == null) { + assert fieldInfo.getVectorEncoding() == VectorEncoding.BYTE; + return raw; + } + var quant = Map.of(VECTOR_DATA_EXTENSION, fieldEntry.vectorDataLength()); + return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant); + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + boolean hasRawFloatVectors(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null || fi.isDataBlind()) { + return false; + } + FloatVectorValues raw = rawVectorsReader.getFloatVectorValues(field); + return raw != null && raw.size() > 0; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene105ScalarQuantizedVectorsFormat.VERSION_START, + Lucene105ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + return in; + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, in); + throw t; + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + @Override + public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fi.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + return OffHeapScalarQuantizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.scalarEncoding, + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + } + + @Override + public ScalarQuantizer getQuantizationState(String fieldName) { + return null; + } + + @Override + public CloseableRandomVectorScorerSupplier getRandomVectorScorerSupplierForMerge( + FieldInfo fieldInfo, SegmentWriteState segmentWriteState) throws IOException { + FieldEntry fi = fields.get(fieldInfo.name); + if (fi == null) { + return null; + } + QuantizedByteVectorValues vectorValues = getQuantizedVectorValues(fieldInfo.name); + if (fi.scalarEncoding.isAsymmetric() == false) { + RandomVectorScorerSupplier supplier = + vectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), vectorValues); + return CloseableRandomVectorScorerSupplier.create(supplier, vectorValues.size(), () -> {}); + } + FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo.name); + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + String tempScoreQuantizedVectorName = null; + DocsWithFieldSet docsWithField; + try (IndexOutput tempScoreQuantizedVector = + segmentWriteState.directory.createTempOutput( + segmentWriteState.segmentInfo.name, "queries", segmentWriteState.context)) { + tempScoreQuantizedVectorName = tempScoreQuantizedVector.getName(); + docsWithField = + writeBinarizedQueryData( + vectorValues, + fi.scalarEncoding, + tempScoreQuantizedVector, + floatVectorValues, + quantizer); + CodecUtil.writeFooter(tempScoreQuantizedVector); + } catch (Throwable t) { + if (tempScoreQuantizedVectorName != null) { + IOUtils.deleteFilesSuppressingExceptions( + t, segmentWriteState.directory, tempScoreQuantizedVectorName); + } + throw t; + } + IndexInput quantizedScoreDataInput = + segmentWriteState.directory.openInput( + tempScoreQuantizedVectorName, segmentWriteState.context); + try { + OffHeapScalarQuantizedVectorValues scoreVectorValues = + new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues( + true, + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + vectorValues.getCentroid(), + vectorValues.getCentroidDP(), + quantizer, + fi.scalarEncoding, + fieldInfo.getVectorSimilarityFunction(), + vectorScorer, + quantizedScoreDataInput); + RandomVectorScorerSupplier scorerSupplier = + vectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), scoreVectorValues, vectorValues); + final String finalTempScoreQuantizedVectorName = tempScoreQuantizedVectorName; + return CloseableRandomVectorScorerSupplier.create( + scorerSupplier, + vectorValues.size(), + () -> { + IOUtils.close(quantizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, finalTempScoreQuantizedVectorName); + }); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, quantizedScoreDataInput); + throw t; + } + } + + static DocsWithFieldSet writeBinarizedQueryData( + QuantizedByteVectorValues quantizedByteVectorValues, + ScalarEncoding encoding, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + OptimizedScalarQuantizer binaryQuantizer) + throws IOException { + if (encoding.isAsymmetric() == false) { + throw new IllegalArgumentException("encoding and queryEncoding must be different"); + } + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + int discretizedDims = encoding.getDiscreteDimensions(floatVectorValues.dimension()); + byte[] quantizationScratch = new byte[discretizedDims]; + byte[] toQuery = new byte[encoding.getQueryPackedLength(discretizedDims)]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult r = + binaryQuantizer.scalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + encoding.getQueryBits(), + quantizedByteVectorValues.getCentroid()); + docsWithField.add(docV); + // pack and store the 4bit query vector + transposeHalfByte(quantizationScratch, toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r.lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r.upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r.additionalCorrection())); + binarizedQueryData.writeInt(r.quantizedComponentSum()); + } + return docsWithField; + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + long vectorDataOffset, + long vectorDataLength, + int size, + ScalarEncoding scalarEncoding, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { + + boolean isDataBlind() { + if (centroid == null) return false; + for (float v : centroid) { + if (v != 0f) return false; + } + return true; + } + + static FieldEntry create( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + ScalarEncoding scalarEncoding = ScalarEncoding.UNSIGNED_BYTE; + if (size > 0) { + int wireNumber = input.readVInt(); + scalarEncoding = + ScalarEncoding.fromWireNumber(wireNumber) + .orElseThrow( + () -> + new IllegalStateException( + "Could not get ScalarEncoding from wire number: " + wireNumber)); + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = + OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + vectorDataOffset, + vectorDataLength, + size, + scalarEncoding, + centroid, + centroidDP, + conf); + } + } + + /** Vector values holding row and quantized vector values */ + protected static final class ScalarQuantizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final QuantizedByteVectorValues quantizedVectorValues; + + ScalarQuantizedVectorValues( + FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public ScalarQuantizedVectorValues copy() throws IOException { + return new ScalarQuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + @Override + public VectorScorer rescorer(float[] target) throws IOException { + return rawVectorValues.rescorer(target); + } + + QuantizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java new file mode 100644 index 000000000000..c659ddc71e97 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java @@ -0,0 +1,1102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import static org.apache.lucene.codecs.lucene105.Lucene105ScalarQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.codecs.lucene105.Lucene105ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; + +/** + * Writes quantized vector values and metadata to index segments in the format for Lucene 10.5. + * + * @lucene.experimental + */ +public class Lucene105ScalarQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene105ScalarQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, vectorData; + private final ScalarEncoding encoding; + private final boolean enableCentering; + private final FlatVectorsWriter rawVectorDelegate; + private boolean finished; + + /** Sole constructor */ + public Lucene105ScalarQuantizedVectorsWriter( + SegmentWriteState state, + ScalarEncoding encoding, + boolean enableCentering, + FlatVectorsWriter rawVectorDelegate, + Lucene105ScalarQuantizedVectorScorer vectorsScorer) + throws IOException { + super(vectorsScorer); + this.encoding = encoding; + this.enableCentering = enableCentering; + this.segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene105ScalarQuantizedVectorsFormat.META_EXTENSION); + + String vectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene105ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION); + this.rawVectorDelegate = rawVectorDelegate; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene105ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene105ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + vectorData, + Lucene105ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene105ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FlatFieldVectorsWriter storage = + enableCentering + ? (FlatFieldVectorsWriter) this.rawVectorDelegate.addField(fieldInfo) + : new InMemoryFloatFieldWriter(fieldInfo); + FieldWriter fieldWriter = new FieldWriter(fieldInfo, storage, enableCentering); + fields.add(fieldWriter); + return fieldWriter; + } + return this.rawVectorDelegate.addField(fieldInfo); + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + if (!enableCentering) { + clusterCenter = new float[field.fieldInfo.getVectorDimension()]; + } else { + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + } + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField( + FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + writeVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + !fieldData.getVectors().isEmpty() ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet()); + } + + private void writeVectors( + FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + byte[] scratch = + new byte[encoding.getDiscreteDimensions(fieldData.fieldInfo.getVectorDimension())]; + byte[] vector = + switch (encoding) { + case UNSIGNED_BYTE, SEVEN_BIT -> scratch; + case PACKED_NIBBLE, SINGLE_BIT_QUERY_NIBBLE, DIBIT_QUERY_NIBBLE -> + new byte[encoding.getDocPackedLength(scratch.length)]; + }; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); + switch (encoding) { + case PACKED_NIBBLE -> OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector); + case SINGLE_BIT_QUERY_NIBBLE -> OptimizedScalarQuantizer.packAsBinary(scratch, vector); + case DIBIT_QUERY_NIBBLE -> OptimizedScalarQuantizer.transposeDibit(scratch, vector); + case UNSIGNED_BYTE, SEVEN_BIT -> {} + } + vectorData.writeBytes(vector, vector.length); + vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + vectorData.writeInt(corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + writeSortedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = vectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + clusterCenter, + centroidDp, + newDocsWithField); + } + + private void writeSortedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + byte[] scratch = + new byte[encoding.getDiscreteDimensions(fieldData.fieldInfo.getVectorDimension())]; + byte[] vector = + switch (encoding) { + case UNSIGNED_BYTE, SEVEN_BIT -> scratch; + case PACKED_NIBBLE, SINGLE_BIT_QUERY_NIBBLE, DIBIT_QUERY_NIBBLE -> + new byte[encoding.getDocPackedLength(scratch.length)]; + }; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(), clusterCenter); + switch (encoding) { + case PACKED_NIBBLE -> OffHeapScalarQuantizedVectorValues.packNibbles(scratch, vector); + case SINGLE_BIT_QUERY_NIBBLE -> OptimizedScalarQuantizer.packAsBinary(scratch, vector); + case DIBIT_QUERY_NIBBLE -> OptimizedScalarQuantizer.transposeDibit(scratch, vector); + case UNSIGNED_BYTE, SEVEN_BIT -> {} + } + vectorData.writeBytes(vector, vector.length); + vectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + vectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + vectorData.writeInt(corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + meta.writeVInt(encoding.getWireNumber()); + final ByteBuffer buffer = + ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public void mergeOneFlatVectorField(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); + return; + } + if (enableCentering) { + mergeOneFlatVectorFieldCentered(fieldInfo, mergeState); + } else { + mergeOneFlatVectorFieldDataBlind(fieldInfo, mergeState); + } + } + + private void mergeOneFlatVectorFieldCentered(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + rawVectorDelegate.mergeOneFlatVectorField(fieldInfo, mergeState); + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + QuantizedFloatVectorValues quantizedVectorValues = + new QuantizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + encoding, + mergedCentroid); + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeVectorData(vectorData, quantizedVectorValues); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(mergedCentroid, mergedCentroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + mergedCentroid, + centroidDp, + docsWithField); + } + + private void mergeOneFlatVectorFieldDataBlind(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + float[] zeroCentroid = new float[fieldInfo.getVectorDimension()]; + // Classify each contributing segment as quantized-only or re-quantizable (has raw floats). + // Quantized-only segments must have a matching encoding; otherwise re-quantization from raw + // floats would be required, which is not possible when raw floats were never written. + boolean anyHasRawFloats = false; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader reader = mergeState.knnVectorsReaders[i]; + if (reader == null || reader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + if (hasRawFloatVectors(reader, fieldInfo.name)) { + anyHasRawFloats = true; + } else { + QuantizedByteVectorValues qvv = getQuantizedVectorValues(reader, fieldInfo.name); + if (qvv != null && qvv.getScalarEncoding() != encoding) { + throw new IllegalStateException( + "Cannot merge field \"" + + fieldInfo.name + + "\" from data-blind segment with encoding " + + qvv.getScalarEncoding() + + " into data-blind format with encoding " + + encoding + + ": re-quantization requires raw float vectors"); + } + } + } + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField; + if (anyHasRawFloats) { + // At least one segment has raw floats; use the float path with zero centroid. + // Quantized-only segments serve reconstructed floats via getFloatVectorValues(). + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + QuantizedFloatVectorValues quantizedVectorValues = + new QuantizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + encoding, + zeroCentroid); + docsWithField = writeVectorData(vectorData, quantizedVectorValues); + } else { + // All segments are quantized-only with matching encoding: copy bytes directly. + MergedQuantizedByteVectorValues mergedQBVV = + MergedQuantizedByteVectorValues.merge(fieldInfo, mergeState, zeroCentroid, encoding); + docsWithField = writeVectorData(vectorData, mergedQBVV); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + zeroCentroid, + 0f, + docsWithField); + } + + static DocsWithFieldSet writeVectorData( + IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = quantizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = quantizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + output.writeInt(corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedQueryData( + QuantizedByteVectorValues quantizedByteVectorValues, + ScalarEncoding encoding, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + OptimizedScalarQuantizer binaryQuantizer) + throws IOException { + if (encoding.isAsymmetric() == false) { + throw new IllegalArgumentException("encoding and queryEncoding must be different"); + } + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + int discretizedDims = encoding.getDiscreteDimensions(floatVectorValues.dimension()); + byte[] quantizationScratch = new byte[discretizedDims]; + byte[] toQuery = new byte[encoding.getQueryPackedLength(discretizedDims)]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult r = + binaryQuantizer.scalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + encoding.getQueryBits(), + quantizedByteVectorValues.getCentroid()); + docsWithField.add(docV); + // pack and store the 4bit query vector + transposeHalfByte(quantizationScratch, toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r.lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r.upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r.additionalCorrection())); + binarizedQueryData.writeInt(r.quantizedComponentSum()); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); + if (vectorsReader instanceof Lucene105ScalarQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static boolean hasRawFloatVectors(KnnVectorsReader vectorsReader, String fieldName) + throws IOException { + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); + if (vectorsReader instanceof Lucene105ScalarQuantizedVectorsReader reader) { + return reader.hasRawFloatVectors(fieldName); + } + return true; // non-Lucene105 format; assume raw floats are available + } + + static QuantizedByteVectorValues getQuantizedVectorValues( + KnnVectorsReader vectorsReader, String fieldName) throws IOException { + vectorsReader = vectorsReader.unwrapReaderForField(fieldName); + if (vectorsReader instanceof Lucene105ScalarQuantizedVectorsReader reader) { + return reader.getQuantizedVectorValues(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids( + MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null + || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (totalVectorCount == 0) { + return 0; + } else if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) + throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = + mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final boolean enableCentering; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter( + FieldInfo fieldInfo, + FlatFieldVectorsWriter flatFieldVectorsWriter, + boolean enableCentering) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.enableCentering = enableCentering; + this.dimensionSums = enableCentering ? new float[fieldInfo.getVectorDimension()] : null; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + if (!flatFieldVectorsWriter.isFinished()) { + // InMemoryFloatFieldWriter is not flushed through the raw delegate, so finish it here. + flatFieldVectorsWriter.finish(); + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + if (enableCentering) { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } + } else if (enableCentering) { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + private static class InMemoryFloatFieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(InMemoryFloatFieldWriter.class); + private final FieldInfo fieldInfo; + private final List vectors = new ArrayList<>(); + private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + private boolean finished; + private int lastDocID = -1; + + public InMemoryFloatFieldWriter(FieldInfo fieldInfo) { + this.fieldInfo = fieldInfo; + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)"); + } + assert docID > lastDocID; + vectors.add(copyValue(vectorValue)); + docsWithField.add(docID); + lastDocID = docID; + } + + @Override + public float[] copyValue(float[] vectorValue) { + return ArrayUtil.copyOfSubArray(vectorValue, 0, fieldInfo.getVectorDimension()); + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() { + finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + if (vectors.isEmpty()) { + return size; + } + return size + + docsWithField.ramBytesUsed() + + (long) vectors.size() + * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + + (long) vectors.size() * fieldInfo.getVectorDimension() * Float.BYTES; + } + } + + static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] quantized; + private final byte[] packed; + private final float[] centroid; + private final float centroidDP; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + private final ScalarEncoding encoding; + + private int lastOrd = -1; + + QuantizedFloatVectorValues( + FloatVectorValues delegate, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.encoding = encoding; + this.quantized = new byte[encoding.getDiscreteDimensions(delegate.dimension())]; + this.packed = + switch (encoding) { + case UNSIGNED_BYTE, SEVEN_BIT -> this.quantized; + case PACKED_NIBBLE, SINGLE_BIT_QUERY_NIBBLE, DIBIT_QUERY_NIBBLE -> + new byte[encoding.getDocPackedLength(quantized.length)]; + }; + this.centroid = centroid; + this.centroidDP = VectorUtil.dotProduct(centroid, centroid); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + quantize(ord); + lastOrd = ord; + } + return packed; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public ScalarEncoding getScalarEncoding() { + return encoding; + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public float getCentroidDP() { + return centroidDP; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public QuantizedByteVectorValues copy() throws IOException { + return new QuantizedFloatVectorValues(values.copy(), quantizer, encoding, centroid); + } + + private void quantize(int ord) throws IOException { + corrections = + quantizer.scalarQuantize( + values.vectorValue(ord), quantized, encoding.getBits(), centroid); + switch (encoding) { + case PACKED_NIBBLE -> OffHeapScalarQuantizedVectorValues.packNibbles(quantized, packed); + case SINGLE_BIT_QUERY_NIBBLE -> OptimizedScalarQuantizer.packAsBinary(quantized, packed); + case DIBIT_QUERY_NIBBLE -> OptimizedScalarQuantizer.transposeDibit(quantized, packed); + case UNSIGNED_BYTE, SEVEN_BIT -> {} + } + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + private static final class QuantizedByteVectorValuesSub extends DocIDMerger.Sub { + final QuantizedByteVectorValues values; + final KnnVectorValues.DocIndexIterator iterator; + + QuantizedByteVectorValuesSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { + super(docMap); + this.values = values; + this.iterator = values.iterator(); + assert iterator.docID() == -1; + } + + @Override + public int nextDoc() throws IOException { + return iterator.nextDoc(); + } + } + + /** Merged view of {@link QuantizedByteVectorValues} from multiple segments. */ + static final class MergedQuantizedByteVectorValues extends QuantizedByteVectorValues { + private final List subs; + private final DocIDMerger docIdMerger; + private final int size; + private final float[] centroid; + private final float centroidDP; + private final ScalarEncoding scalarEncoding; + private int docId = -1; + private int lastOrd = -1; + private QuantizedByteVectorValuesSub current; + + private MergedQuantizedByteVectorValues( + List subs, + MergeState mergeState, + float[] centroid, + ScalarEncoding scalarEncoding) + throws IOException { + this.subs = subs; + this.docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); + int totalSize = 0; + for (QuantizedByteVectorValuesSub sub : subs) { + totalSize += sub.values.size(); + } + this.size = totalSize; + this.centroid = centroid; + this.centroidDP = VectorUtil.dotProduct(centroid, centroid); + this.scalarEncoding = scalarEncoding; + } + + static MergedQuantizedByteVectorValues merge( + FieldInfo fieldInfo, MergeState mergeState, float[] centroid, ScalarEncoding encoding) + throws IOException { + List subs = new ArrayList<>(); + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader reader = mergeState.knnVectorsReaders[i]; + if (reader == null) { + continue; + } + QuantizedByteVectorValues qbvv = getQuantizedVectorValues(reader, fieldInfo.name); + if (qbvv == null || qbvv.size() == 0) { + continue; + } + subs.add(new QuantizedByteVectorValuesSub(mergeState.docMaps[i], qbvv)); + } + return new MergedQuantizedByteVectorValues(subs, mergeState, centroid, encoding); + } + + @Override + public DocIndexIterator iterator() { + return new DocIndexIterator() { + private int index = -1; + + @Override + public int docID() { + return docId; + } + + @Override + public int index() { + return index; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + index = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++lastOrd; + ++index; + } + return docId; + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size; + } + }; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); + } + return current.values.vectorValue(current.iterator.index()); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) + throws IOException { + if (ord != lastOrd) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); + } + return current.values.getCorrectiveTerms(current.iterator.index()); + } + + @Override + public int dimension() { + return subs.isEmpty() ? 0 : subs.get(0).values.dimension(); + } + + @Override + public int size() { + return size; + } + + @Override + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + + @Override + public ScalarEncoding getScalarEncoding() { + return scalarEncoding; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public float getCentroidDP() { + return centroidDP; + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); + } + + @Override + public QuantizedByteVectorValues copy() { + throw new UnsupportedOperationException(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedFloatVectorValues.java new file mode 100644 index 000000000000..927daceb9620 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedFloatVectorValues.java @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene105; + +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.deQuantize; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; + +/** + * Reads quantized vector values from the index input and returns float vector values after + * dequantizing them. + * + *

This class provides functionality to read quantized vectors which are stored in the index, and + * then dequantize them back to float vectors with some precision loss. The implementation is based + * on {@code OffHeapScalarQuantizedVectorValues} with modifications to the {@code vectorValue()} + * method to return float vectors after dequantizing the vectors. + * + *

Usage: This class is used for read-only indexes where full-precision float vectors have been + * dropped from the index to save storage space. Full-precision vectors can be removed from an index + * using a method as implemented in {@code + * TestLucene105ScalarQuantizedVectorsFormat.simulateEmptyRawVectors()}. + * + * @lucene.internal + */ +abstract class OffHeapScalarQuantizedFloatVectorValues extends FloatVectorValues + implements HasIndexSlice { + + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final float[] vectorValue; + final byte[] byteValue; + final ByteBuffer byteBuffer; + final byte[] unpackedByteVectorValue; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final ScalarEncoding encoding; + final float[] centroid; + + OffHeapScalarQuantizedFloatVectorValues( + int dimension, + int size, + float[] centroid, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.correctiveValues = new float[3]; + this.encoding = encoding; + int docPackedLength = encoding.getDocPackedLength(dimension); + this.byteSize = docPackedLength + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(docPackedLength); + this.vectorValue = new float[dimension]; + this.byteValue = byteBuffer.array(); + this.unpackedByteVectorValue = new byte[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return vectorValue; + } + + // read quantized byte vector, correctiveValues and quantizedComponentSum + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + + // unpack bytes + switch (encoding) { + case PACKED_NIBBLE -> + OffHeapScalarQuantizedVectorValues.unpackNibbles(byteValue, unpackedByteVectorValue); + case SINGLE_BIT_QUERY_NIBBLE -> + OptimizedScalarQuantizer.unpackBinary(byteValue, unpackedByteVectorValue); + case DIBIT_QUERY_NIBBLE -> + OptimizedScalarQuantizer.untransposeDibit(byteValue, unpackedByteVectorValue); + case UNSIGNED_BYTE, SEVEN_BIT -> { + deQuantize( + byteValue, + vectorValue, + encoding.getBits(), + correctiveValues[0], + correctiveValues[1], + centroid); + lastOrd = targetOrd; + return vectorValue; + } + } + + // dequantize + deQuantize( + unpackedByteVectorValue, + vectorValue, + encoding.getBits(), + correctiveValues[0], + correctiveValues[1], + centroid); + + lastOrd = targetOrd; + return vectorValue; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) + throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + slice.seek(((long) targetOrd * byteSize) + byteValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + + @Override + public int getVectorByteLength() { + return dimension; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + static OffHeapScalarQuantizedFloatVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData) + throws IOException { + if (configuration.isEmpty()) { + return new OffHeapScalarQuantizedFloatVectorValues.EmptyOffHeapVectorValues( + dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = + vectorData.slice( + "scalar-quantized-float-vector-data", + quantizedVectorDataOffset, + quantizedVectorDataLength); + if (configuration.isDense()) { + return new OffHeapScalarQuantizedFloatVectorValues.DenseOffHeapVectorValues( + dimension, size, centroid, encoding, similarityFunction, vectorsScorer, bytesSlice); + } else { + return new OffHeapScalarQuantizedFloatVectorValues.SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + encoding, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); + } + } + + /** Dense off-heap scalar quantized vector values */ + private static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedFloatVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super(dimension, size, centroid, encoding, similarityFunction, vectorsScorer, slice); + } + + @Override + public OffHeapScalarQuantizedFloatVectorValues.DenseOffHeapVectorValues copy() + throws IOException { + return new OffHeapScalarQuantizedFloatVectorValues.DenseOffHeapVectorValues( + dimension, size, centroid, encoding, similarityFunction, vectorsScorer, slice.clone()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + OffHeapScalarQuantizedFloatVectorValues.DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) { + return Bulk.fromRandomScorerDense(scorer, iterator, matchingDocs); + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap scalar quantized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedFloatVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + ScalarEncoding encoding, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) + throws IOException { + super(dimension, size, centroid, encoding, similarityFunction, vectorsScorer, slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public OffHeapScalarQuantizedFloatVectorValues.SparseOffHeapVectorValues copy() + throws IOException { + return new OffHeapScalarQuantizedFloatVectorValues.SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + encoding, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + OffHeapScalarQuantizedFloatVectorValues.SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) { + return Bulk.fromRandomScorerSparse(scorer, iterator, matchingDocs); + } + }; + } + } + + /** Empty vector values */ + private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedFloatVectorValues { + EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super( + dimension, + 0, + null, + ScalarEncoding.UNSIGNED_BYTE, + similarityFunction, + vectorsScorer, + null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public OffHeapScalarQuantizedFloatVectorValues.DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedVectorValues.java new file mode 100644 index 000000000000..927d94754135 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/OffHeapScalarQuantizedVectorValues.java @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; + +/** + * Scalar quantized vector values loaded from off-heap + * + * @lucene.internal + */ +public abstract class OffHeapScalarQuantizedVectorValues extends QuantizedByteVectorValues { + final int dimension; + final int size; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] vectorValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer quantizer; + final ScalarEncoding encoding; + final float[] centroid; + final float centroidDp; + final boolean isQuerySide; + + OffHeapScalarQuantizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + this( + false, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice); + } + + OffHeapScalarQuantizedVectorValues( + boolean isQuerySide, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + assert isQuerySide == false || encoding.isAsymmetric(); + this.isQuerySide = isQuerySide; + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.correctiveValues = new float[3]; + this.encoding = encoding; + int docPackedLength = + isQuerySide + ? encoding.getQueryPackedLength(dimension) + : encoding.getDocPackedLength(dimension); + this.byteSize = docPackedLength + (Float.BYTES * 3) + Integer.BYTES; + this.byteBuffer = ByteBuffer.allocate(docPackedLength); + this.vectorValue = byteBuffer.array(); + this.quantizer = quantizer; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return vectorValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + lastOrd = targetOrd; + return vectorValue; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) + throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + slice.seek(((long) targetOrd * byteSize) + vectorValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = slice.readInt(); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return quantizer; + } + + @Override + public ScalarEncoding getScalarEncoding() { + return encoding; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return vectorValue.length; + } + + static void packNibbles(byte[] unpacked, byte[] packed) { + assert unpacked.length == packed.length * 2; + for (int i = 0; i < packed.length; i++) { + int x = unpacked[i] << 4 | unpacked[packed.length + i]; + packed[i] = (byte) x; + } + } + + static void unpackNibbles(byte[] packed, byte[] unpacked) { + assert unpacked.length == packed.length * 2; + for (int i = 0; i < packed.length; i++) { + unpacked[i] = (byte) ((packed[i] >> 4) & 0x0F); + unpacked[packed.length + i] = (byte) (packed[i] & 0x0F); + } + } + + static OffHeapScalarQuantizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData) + throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = + vectorData.slice( + "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + bytesSlice); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); + } + } + + /** Dense off-heap scalar quantized vector values */ + static class DenseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice); + } + + DenseOffHeapVectorValues( + boolean isQuerySide, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super( + isQuerySide, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice); + } + + @Override + public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() throws IOException { + return new OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues( + isQuerySide, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + assert isQuerySide == false; + OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) { + return Bulk.fromRandomScorerDense(scorer, iterator, matchingDocs); + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap scalar quantized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + ScalarEncoding encoding, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) + throws IOException { + super( + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + similarityFunction, + vectorsScorer, + slice); + assert isQuerySide == false; + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + assert isQuerySide == false; + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + quantizer, + encoding, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + assert isQuerySide == false; + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public VectorScorer.Bulk bulk(DocIdSetIterator matchingDocs) { + return Bulk.fromRandomScorerSparse(scorer, iterator, matchingDocs); + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapScalarQuantizedVectorValues { + EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super( + dimension, + 0, + null, + Float.NaN, + null, + ScalarEncoding.UNSIGNED_BYTE, + similarityFunction, + vectorsScorer, + null); + assert isQuerySide == false; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public OffHeapScalarQuantizedVectorValues.DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/package-info.java new file mode 100644 index 000000000000..acb2c18cdb6e --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Lucene 10.5 scalar quantized vector format, extending 10.4 with data-blind mode ({@code + * enableCentering=false}) which omits raw float storage and enables lossless quantized-only merges. + */ +package org.apache.lucene.codecs.lucene105; diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 3ac106d11c84..8237ae45ea92 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -16,3 +16,5 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat +org.apache.lucene.codecs.lucene105.Lucene105ScalarQuantizedVectorsFormat +org.apache.lucene.codecs.lucene105.Lucene105HnswScalarQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java index 68edf0bc87ea..0ec1689c36e4 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -91,7 +91,6 @@ public KnnVectorsFormat knnVectorsFormat() { + " maxConn=10, beamWidth=20, tinySegmentsThreshold=100," + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat," + " encoding=UNSIGNED_BYTE," - + " enableCentering=true," + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index d0268cf11382..7825d92706e5 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -38,7 +38,6 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -118,7 +117,6 @@ public KnnVectorsFormat knnVectorsFormat() { "Lucene104ScalarQuantizedVectorsFormat(" + "name=Lucene104ScalarQuantizedVectorsFormat, " + "encoding=UNSIGNED_BYTE, " - + "enableCentering=true, " + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; var defaultScorer = @@ -257,164 +255,6 @@ private void replaceWithEmptyVectorFile(Directory dir, String fileName) throws E } } - // ---- data-blind (enableCentering=false) tests ---- - - private Codec datablindCodec(ScalarEncoding enc) { - return TestUtil.alwaysKnnVectorsFormat(new Lucene104ScalarQuantizedVectorsFormat(enc, false)); - } - - public void testDataBlindSearchCorrectness() throws Exception { - String fieldName = "field"; - int numVectors = random().nextInt(50, 200); - int dims = random().nextInt(4, 65); - VectorSimilarityFunction sim = randomSimilarity(); - ScalarEncoding enc = ScalarEncoding.UNSIGNED_BYTE; - - try (Directory dir = newDirectory()) { - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setCodec(datablindCodec(enc)); - try (IndexWriter w = new IndexWriter(dir, iwc)) { - KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); - for (int i = 0; i < numVectors; i++) { - Document doc = new Document(); - field.setVectorValue(randomVector(dims)); - doc.add(field); - w.addDocument(doc); - } - w.commit(); - - try (IndexReader reader = DirectoryReader.open(w)) { - IndexSearcher searcher = new IndexSearcher(reader); - int k = Math.min(10, numVectors); - TopDocs hits = - searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); - assertEquals(k, hits.totalHits.value()); - } - } - } - } - - public void testDataBlindNoRawFloatVectors() throws Exception { - String fieldName = "field"; - int dims = 8; - VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; - - try (Directory dir = newDirectory()) { - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); - try (IndexWriter w = new IndexWriter(dir, iwc)) { - KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); - for (int i = 0; i < 20; i++) { - Document doc = new Document(); - field.setVectorValue(randomVector(dims)); - doc.add(field); - w.addDocument(doc); - } - w.commit(); - - try (IndexReader reader = DirectoryReader.open(w)) { - LeafReader leaf = getOnlyLeafReader(reader); - FloatVectorValues fvv = leaf.getFloatVectorValues(fieldName); - assertNotNull(fvv); - assertEquals(20, fvv.size()); - // In data-blind mode the float values are reconstructed from quantized data, not backed - // by raw stored floats — so the returned type must NOT be a ScalarQuantizedVectorValues. - assertFalse( - "data-blind mode must not store raw float vectors", - fvv instanceof Lucene104ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues); - } - } - } - } - - public void testDataBlindMultiSegmentMerge() throws Exception { - String fieldName = "field"; - int dims = 16; - VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; - int numPerSegment = 30; - - try (Directory dir = newDirectory()) { - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); - try (IndexWriter w = new IndexWriter(dir, iwc)) { - KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); - // write two separate segments - for (int s = 0; s < 2; s++) { - for (int i = 0; i < numPerSegment; i++) { - Document doc = new Document(); - field.setVectorValue(randomVector(dims)); - doc.add(field); - w.addDocument(doc); - } - w.commit(); - } - w.forceMerge(1); - - try (IndexReader reader = DirectoryReader.open(w)) { - assertEquals(1, reader.leaves().size()); - LeafReader leaf = reader.leaves().get(0).reader(); - FloatVectorValues fvv = leaf.getFloatVectorValues(fieldName); - assertEquals(numPerSegment * 2, fvv.size()); - - // search should still return results - IndexSearcher searcher = new IndexSearcher(reader); - int k = 10; - TopDocs hits = - searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); - assertEquals(k, hits.totalHits.value()); - } - } - } - } - - public void testDataBlindIncompatibleEncodingMerge() throws Exception { - String fieldName = "field"; - int dims = 16; - VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; - - try (Directory dir1 = newDirectory(); - Directory dir2 = newDirectory()) { - // segment 1: UNSIGNED_BYTE, data-blind - IndexWriterConfig iwc1 = newIndexWriterConfig(); - iwc1.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); - try (IndexWriter w = new IndexWriter(dir1, iwc1)) { - KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); - for (int i = 0; i < 10; i++) { - Document doc = new Document(); - field.setVectorValue(randomVector(dims)); - doc.add(field); - w.addDocument(doc); - } - } - - // segment 2: PACKED_NIBBLE, data-blind - IndexWriterConfig iwc2 = newIndexWriterConfig(); - iwc2.setCodec(datablindCodec(ScalarEncoding.PACKED_NIBBLE)); - try (IndexWriter w = new IndexWriter(dir2, iwc2)) { - KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); - for (int i = 0; i < 10; i++) { - Document doc = new Document(); - field.setVectorValue(randomVector(dims)); - doc.add(field); - w.addDocument(doc); - } - } - - // merge both into a PACKED_NIBBLE data-blind index — UNSIGNED_BYTE segment has no raw - // floats, so re-quantization to PACKED_NIBBLE is impossible: expect an error. - // SerialMergeScheduler makes merges synchronous so the exception propagates directly. - try (Directory dirMerge = newDirectory()) { - IndexWriterConfig iwcMerge = newIndexWriterConfig(); - iwcMerge.setCodec(datablindCodec(ScalarEncoding.PACKED_NIBBLE)); - iwcMerge.setMergeScheduler(new SerialMergeScheduler()); - try (IndexWriter w = new IndexWriter(dirMerge, iwcMerge)) { - w.addIndexes(dir1, dir2); - expectThrows(Exception.class, () -> w.forceMerge(1)); - } - } - } - } - /** Updates vector metadata file to indicate zero vector length. */ private void updateVectorMetadataFile(Directory dir, String fileName) throws Exception { // Read original metadata diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..f16d0e9c0f0e --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.SameThreadExecutorService; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; +import org.junit.Before; + +public class TestLucene105HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private KnnVectorsFormat format; + private ScalarEncoding encoding; + + @Before + @Override + public void setUp() throws Exception { + var encodingValues = ScalarEncoding.values(); + encoding = encodingValues[random().nextInt(encodingValues.length)]; + format = + new Lucene105HnswScalarQuantizedVectorsFormat( + encoding, + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + 1, + null); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene105HnswScalarQuantizedVectorsFormat( + ScalarEncoding.UNSIGNED_BYTE, 10, 20, 1, null); + } + }; + String expectedPattern = + "Lucene105HnswScalarQuantizedVectorsFormat(name=Lucene105HnswScalarQuantizedVectorsFormat," + + " maxConn=10, beamWidth=20, tinySegmentsThreshold=100," + + " flatVectorFormat=Lucene105ScalarQuantizedVectorsFormat(name=Lucene105ScalarQuantizedVectorsFormat," + + " encoding=UNSIGNED_BYTE," + + " enableCentering=true," + + " flatVectorScorer=Lucene105ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; + + var defaultScorer = + format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + var memSegScorer = + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + float[] docVector = + similarityFunction == VectorSimilarityFunction.DOT_PRODUCT + ? VectorUtil.l2normalize(ArrayUtil.copyArray(vector)) + : vector; + doc.add(new KnnFloatVectorField("f", docVector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals( + docVector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = + similarityFunction == VectorSimilarityFunction.DOT_PRODUCT + ? randomNormalizedVector(vector.length) + : randomVector(vector.length); + float trueScore = similarityFunction.compare(docVector, randomVector); + TopDocs td = + r.searchNearestVectors( + "f", + randomVector, + 1, + AcceptDocs.fromLiveDocs(null, r.maxDoc()), + Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true + // score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + public void testLimits() { + expectThrows( + IllegalArgumentException.class, + () -> new Lucene105HnswScalarQuantizedVectorsFormat(-1, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene105HnswScalarQuantizedVectorsFormat(0, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene105HnswScalarQuantizedVectorsFormat(20, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene105HnswScalarQuantizedVectorsFormat(20, -1)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene105HnswScalarQuantizedVectorsFormat(512 + 1, 20)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene105HnswScalarQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> + new Lucene105HnswScalarQuantizedVectorsFormat( + ScalarEncoding.UNSIGNED_BYTE, 20, 100, 1, new SameThreadExecutorService())); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + knnVectorsReader = knnVectorsReader.unwrapReaderForField("f"); + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + assertNotNull(offHeap.get("vex")); + long corrections = Float.BYTES + Float.BYTES + Float.BYTES + Integer.BYTES; + long expected = encoding.getDocPackedLength(fieldInfo.getVectorDimension()) + corrections; + assertEquals(expected, (long) offHeap.get("veq")); + assertEquals(3, offHeap.size()); + } + } + } + } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105ScalarQuantizedVectorsFormat.java new file mode 100644 index 000000000000..09f427ae8d6f --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene105/TestLucene105ScalarQuantizedVectorsFormat.java @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene105; + +import static java.lang.String.format; +import static org.apache.lucene.codecs.lucene105.Lucene105ScalarQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.SerialMergeScheduler; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues.ScalarEncoding; +import org.junit.Before; + +public class TestLucene105ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private ScalarEncoding encoding; + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + var encodingValues = ScalarEncoding.values(); + encoding = encodingValues[random().nextInt(encodingValues.length)]; + format = new Lucene105ScalarQuantizedVectorsFormat(encoding); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene105ScalarQuantizedVectorsFormat(); + } + }; + String expectedPattern = + "Lucene105ScalarQuantizedVectorsFormat(" + + "name=Lucene105ScalarQuantizedVectorsFormat, " + + "encoding=UNSIGNED_BYTE, " + + "enableCentering=true, " + + "flatVectorScorer=Lucene105ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; + var defaultScorer = + format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + var memSegScorer = + format( + Locale.ROOT, + expectedPattern, + "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + QuantizedByteVectorValues qvectorValues = + ((Lucene105ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] scratch = new byte[encoding.getDiscreteDimensions(dims)]; + byte[] expectedVector = new byte[encoding.getDocPackedLength(scratch.length)]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = + new Lucene105ScalarQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + scratch, + encoding.getBits(), + centroid); + switch (encoding) { + case UNSIGNED_BYTE, SEVEN_BIT -> + System.arraycopy(scratch, 0, expectedVector, 0, dims); + case PACKED_NIBBLE -> + OffHeapScalarQuantizedVectorValues.packNibbles(scratch, expectedVector); + case SINGLE_BIT_QUERY_NIBBLE -> + OptimizedScalarQuantizer.packAsBinary(scratch, expectedVector); + case DIBIT_QUERY_NIBBLE -> + OptimizedScalarQuantizer.transposeDibit(scratch, expectedVector); + } + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + var actualCorrections = qvectorValues.getCorrectiveTerms(docIndexIterator.index()); + assertEquals(corrections.lowerInterval(), actualCorrections.lowerInterval(), 0.00001f); + assertEquals(corrections.upperInterval(), actualCorrections.upperInterval(), 0.00001f); + assertEquals( + corrections.additionalCorrection(), + actualCorrections.additionalCorrection(), + 0.00001f); + assertEquals( + corrections.quantizedComponentSum(), actualCorrections.quantizedComponentSum()); + } + } + } + } + } + + @Override + protected boolean supportsFloatVectorFallback() { + return true; + } + + @Override + protected int getQuantizationBits() { + return encoding.getBits(); + } + + /** Simulates empty raw vectors by modifying index files. */ + @Override + protected void simulateEmptyRawVectors(Directory dir) throws Exception { + final String[] indexFiles = dir.listAll(); + final String RAW_VECTOR_EXTENSION = "vec"; + final String VECTOR_META_EXTENSION = "vemf"; + + for (String file : indexFiles) { + if (file.endsWith("." + RAW_VECTOR_EXTENSION)) { + replaceWithEmptyVectorFile(dir, file); + } else if (file.endsWith("." + VECTOR_META_EXTENSION)) { + updateVectorMetadataFile(dir, file); + } + } + } + + /** Replaces a raw vector file with an empty one that has valid header/footer. */ + private void replaceWithEmptyVectorFile(Directory dir, String fileName) throws Exception { + byte[] indexHeader; + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + indexHeader = CodecUtil.readIndexHeader(in); + } + dir.deleteFile(fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + // Write header + out.writeBytes(indexHeader, 0, indexHeader.length); + // Write footer (no content in between) + CodecUtil.writeFooter(out); + } + } + + // ---- data-blind (enableCentering=false) tests ---- + + private Codec datablindCodec(ScalarEncoding enc) { + return TestUtil.alwaysKnnVectorsFormat(new Lucene105ScalarQuantizedVectorsFormat(enc, false)); + } + + public void testDataBlindSearchCorrectness() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(50, 200); + int dims = random().nextInt(4, 65); + VectorSimilarityFunction sim = randomSimilarity(); + ScalarEncoding enc = ScalarEncoding.UNSIGNED_BYTE; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(datablindCodec(enc)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + int k = Math.min(10, numVectors); + TopDocs hits = + searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); + assertEquals(k, hits.totalHits.value()); + } + } + } + } + + public void testDataBlindNoRawFloatVectors() throws Exception { + String fieldName = "field"; + int dims = 8; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < 20; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader leaf = getOnlyLeafReader(reader); + FloatVectorValues fvv = leaf.getFloatVectorValues(fieldName); + assertNotNull(fvv); + assertEquals(20, fvv.size()); + // In data-blind mode the float values are reconstructed from quantized data, not backed + // by raw stored floats — so the returned type must NOT be a ScalarQuantizedVectorValues. + assertFalse( + "data-blind mode must not store raw float vectors", + fvv instanceof Lucene105ScalarQuantizedVectorsReader.ScalarQuantizedVectorValues); + } + } + } + } + + public void testDataBlindMultiSegmentMerge() throws Exception { + String fieldName = "field"; + int dims = 16; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + int numPerSegment = 30; + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + // write two separate segments + for (int s = 0; s < 2; s++) { + for (int i = 0; i < numPerSegment; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + w.commit(); + } + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + assertEquals(1, reader.leaves().size()); + LeafReader leaf = reader.leaves().get(0).reader(); + FloatVectorValues fvv = leaf.getFloatVectorValues(fieldName); + assertEquals(numPerSegment * 2, fvv.size()); + + // search should still return results + IndexSearcher searcher = new IndexSearcher(reader); + int k = 10; + TopDocs hits = + searcher.search(new KnnFloatVectorQuery(fieldName, randomVector(dims), k), k); + assertEquals(k, hits.totalHits.value()); + } + } + } + } + + public void testDataBlindIncompatibleEncodingMerge() throws Exception { + String fieldName = "field"; + int dims = 16; + VectorSimilarityFunction sim = VectorSimilarityFunction.EUCLIDEAN; + + try (Directory dir1 = newDirectory(); + Directory dir2 = newDirectory()) { + // segment 1: UNSIGNED_BYTE, data-blind + IndexWriterConfig iwc1 = newIndexWriterConfig(); + iwc1.setCodec(datablindCodec(ScalarEncoding.UNSIGNED_BYTE)); + try (IndexWriter w = new IndexWriter(dir1, iwc1)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + } + + // segment 2: PACKED_NIBBLE, data-blind + IndexWriterConfig iwc2 = newIndexWriterConfig(); + iwc2.setCodec(datablindCodec(ScalarEncoding.PACKED_NIBBLE)); + try (IndexWriter w = new IndexWriter(dir2, iwc2)) { + KnnFloatVectorField field = new KnnFloatVectorField(fieldName, randomVector(dims), sim); + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + field.setVectorValue(randomVector(dims)); + doc.add(field); + w.addDocument(doc); + } + } + + // merge both into a PACKED_NIBBLE data-blind index — UNSIGNED_BYTE segment has no raw + // floats, so re-quantization to PACKED_NIBBLE is impossible: expect an error. + // SerialMergeScheduler makes merges synchronous so the exception propagates directly. + try (Directory dirMerge = newDirectory()) { + IndexWriterConfig iwcMerge = newIndexWriterConfig(); + iwcMerge.setCodec(datablindCodec(ScalarEncoding.PACKED_NIBBLE)); + iwcMerge.setMergeScheduler(new SerialMergeScheduler()); + try (IndexWriter w = new IndexWriter(dirMerge, iwcMerge)) { + w.addIndexes(dir1, dir2); + expectThrows(Exception.class, () -> w.forceMerge(1)); + } + } + } + } + + /** Updates vector metadata file to indicate zero vector length. */ + private void updateVectorMetadataFile(Directory dir, String fileName) throws Exception { + // Read original metadata + byte[] indexHeader; + int fieldNumber, vectorEncoding, vectorSimilarityFunction, dimension; + long vectorStartPos; + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + indexHeader = CodecUtil.readIndexHeader(in); + fieldNumber = in.readInt(); + vectorEncoding = in.readInt(); + vectorSimilarityFunction = in.readInt(); + vectorStartPos = in.readVLong(); + in.readVLong(); // Skip original vector length + dimension = in.readVInt(); + } + + // Create updated metadata file + dir.deleteFile(fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + // Write header + out.writeBytes(indexHeader, 0, indexHeader.length); + + // Write metadata with zero vector length + out.writeInt(fieldNumber); + out.writeInt(vectorEncoding); + out.writeInt(vectorSimilarityFunction); + out.writeVLong(vectorStartPos); + out.writeVLong(0); // Set vector length to 0 + out.writeVInt(dimension); + out.writeInt(0); + + // Write configuration + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, out, null, 0, 0, null); + + // Mark end of fields and write footer + out.writeInt(-1); + CodecUtil.writeFooter(out); + } + } +} From 153c28c008206e48ab9005e416006b3292b95a76 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sun, 3 May 2026 12:18:08 -0700 Subject: [PATCH 6/8] change wire format to omit center on uncentered data sets --- .../Lucene105ScalarQuantizedVectorsReader.java | 6 ++++-- .../Lucene105ScalarQuantizedVectorsWriter.java | 12 +++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java index 59c198e75b1a..ae75fcec33f8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsReader.java @@ -602,9 +602,11 @@ static FieldEntry create( () -> new IllegalStateException( "Could not get ScalarEncoding from wire number: " + wireNumber)); - centroid = new float[dimension]; - input.readFloats(centroid, 0, dimension); centroidDP = Float.intBitsToFloat(input.readInt()); + centroid = new float[dimension]; + if (centroidDP != 0f) { + input.readFloats(centroid, 0, dimension); + } } else { centroid = null; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java index c659ddc71e97..19c3c957cb97 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105ScalarQuantizedVectorsWriter.java @@ -302,12 +302,14 @@ private void writeMeta( meta.writeVInt(count); if (count > 0) { meta.writeVInt(encoding.getWireNumber()); - final ByteBuffer buffer = - ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) - .order(ByteOrder.LITTLE_ENDIAN); - buffer.asFloatBuffer().put(clusterCenter); - meta.writeBytes(buffer.array(), buffer.array().length); meta.writeInt(Float.floatToIntBits(centroidDp)); + if (centroidDp != 0f) { + final ByteBuffer buffer = + ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + } } OrdToDocDISIReaderConfiguration.writeStoredMeta( DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); From 27c58614be7e70b2ef7d29b2ba1abb2346cebf3e Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sun, 3 May 2026 21:03:56 -0700 Subject: [PATCH 7/8] propagate parameter to hnsw codec --- ...ne105HnswScalarQuantizedVectorsFormat.java | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java index 101174cf0b7d..0cb0814fdae5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene105/Lucene105HnswScalarQuantizedVectorsFormat.java @@ -156,8 +156,33 @@ public Lucene105HnswScalarQuantizedVectorsFormat( int numMergeWorkers, ExecutorService mergeExec, int tinySegmentsThreshold) { + this(encoding, true, maxConn, beamWidth, numMergeWorkers, mergeExec, tinySegmentsThreshold); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param encoding the quantization encoding used to encode the vectors + * @param enableCentering if {@code false}, no centroid is computed and raw float vectors are not + * written to disk (data-blind mode) + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + * @param tinySegmentsThreshold the threshold below which HNSW graph building is skipped + */ + public Lucene105HnswScalarQuantizedVectorsFormat( + ScalarEncoding encoding, + boolean enableCentering, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec, + int tinySegmentsThreshold) { super(NAME); - flatVectorsFormat = new Lucene105ScalarQuantizedVectorsFormat(encoding); + flatVectorsFormat = new Lucene105ScalarQuantizedVectorsFormat(encoding, enableCentering); if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { throw new IllegalArgumentException( "maxConn must be positive and less than or equal to " From b9c9cc6d516b9c8e3a12026f2d77cc4658d4ce56 Mon Sep 17 00:00:00 2001 From: Trevor McCulloch Date: Sun, 3 May 2026 22:06:03 -0700 Subject: [PATCH 8/8] ch-ch-ch-changes --- lucene/CHANGES.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8ef68343e5fb..53926fc8b414 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -289,6 +289,10 @@ New Features or mixed) via log-odds fusion with softplus gating and sqrt(n) confidence scaling. (Jaepil Jeong) +# GITHUB#16029: Scalar quantization option to disable centering and writing of float vectors. This + reduces vector storage costs by 4x or more but also reduces quantization accuracy. + (Trevor McCulloch) + Improvements --------------------- * GITHUB#15823: Implement method to add all stream elements into a PriorityQueue.