diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f9aa93a84fdb..61d855556eb8 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -265,6 +265,10 @@ Other API Changes --------------------- +* GITHUB#16053: Add an (optional) field-writer creation strategy to Lucene99FlatVectorsWriter. + This decouples the external interface and the internal representation of FlatFieldVectorsWriters created + by this FlatVectorsWriter, to allow different formats to change how vectors are stored in memory. (Lorenzo Dematte) + * GITHUB#15663: Allow subclasses of NumericComparator to implement their own CompetitiveDISIBuilder subtypes. (Alan Woodward) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 20565e0f17f7..4afe41507316 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -43,6 +43,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.IOFunction; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; @@ -58,14 +59,48 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter { private final SegmentWriteState segmentWriteState; private final IndexOutput meta, vectorData; + private final IOFunction> fieldWriterFactory; - private final List> fields = new ArrayList<>(); + private record FieldData(FlatFieldVectorsWriter fieldWriter, FieldInfo fieldInfo) {} + + private final List fields = new ArrayList<>(); private boolean finished; + /** + * Constructs a writer that uses the default factory to build per-field vector storage. This + * default factory creates instances of {@link FlatFieldVectorsWriter} that store vector data as a + * List of on-heap arrays, one per-vector (see {@code DefaultFieldWriter}). + * + * @param state the segment write state + * @param scorer the flat vectors scorer used to score vectors at index-build time + */ public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) throws IOException { + this(state, scorer, DefaultFieldWriter::create); + } + + /** + * Constructs a writer that uses the supplied {@code strategyFactory} to build per-field vector + * storage. The factory is consulted on every {@link #addField(FieldInfo)} call and returns a + * user-defined {@link FlatFieldVectorsWriter}. + * + *

Note: the strategy is only consulted during indexing (i.e. via {@link #addField(FieldInfo)} + * and the subsequent {@link #flush}). Merges write directly to the new segment via {@link + * #mergeOneFlatVectorField} and do not go through the strategy. + * + * @param state the segment write state + * @param scorer the flat vectors scorer used to score vectors at index-build time + * @param strategyFactory the per-field storage factory; receives the {@link FieldInfo} for the + * field being added and returns the {@link FlatFieldVectorsWriter} that will back it + */ + public Lucene99FlatVectorsWriter( + SegmentWriteState state, + FlatVectorsScorer scorer, + IOFunction> strategyFactory) + throws IOException { super(scorer); segmentWriteState = state; + fieldWriterFactory = strategyFactory; String metaFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION); @@ -100,20 +135,20 @@ public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scor @Override public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - FieldWriter newField = FieldWriter.create(fieldInfo); - fields.add(newField); - return newField; + var newFieldWriter = fieldWriterFactory.apply(fieldInfo); + fields.add(new FieldData(newFieldWriter, fieldInfo)); + return newFieldWriter; } @Override public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { - for (FieldWriter field : fields) { + for (var field : fields) { if (sortMap == null) { - writeField(field, maxDoc); + writeField(field.fieldWriter(), field.fieldInfo(), maxDoc); } else { - writeSortingField(field, maxDoc, sortMap); + writeSortingField(field.fieldWriter(), field.fieldInfo(), maxDoc, sortMap); } - field.finish(); + field.fieldWriter().finish(); } } @@ -136,8 +171,8 @@ public void finish() throws IOException { @Override public long ramBytesUsed() { long total = SHALLOW_RAM_BYTES_USED; - for (FieldWriter field : fields) { - total += field.ramBytesUsed(); + for (var field : fields) { + total += field.fieldWriter().ramBytesUsed(); } return total; } @@ -150,69 +185,76 @@ private static long alignOutput(IndexOutput output, VectorEncoding encoding) thr }); } - private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + private void writeField(FlatFieldVectorsWriter fieldWriter, FieldInfo fieldInfo, int maxDoc) + throws IOException { // write vector values - VectorEncoding encoding = fieldData.fieldInfo.getVectorEncoding(); + VectorEncoding encoding = fieldInfo.getVectorEncoding(); long vectorDataOffset = alignOutput(vectorData, encoding); switch (encoding) { - case BYTE -> writeByteVectors(fieldData); - case FLOAT32 -> writeFloat32Vectors(fieldData); + case BYTE -> writeByteVectors(fieldWriter); + case FLOAT32 -> writeFloat32Vectors(fieldWriter, fieldInfo); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta( - fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldWriter.getDocsWithFieldSet()); } - private void writeFloat32Vectors(FieldWriter fieldData) throws IOException { + private void writeFloat32Vectors(FlatFieldVectorsWriter fieldWriter, FieldInfo fieldInfo) + throws IOException { final ByteBuffer buffer = - ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (Object v : fieldData.vectors) { + ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldWriter.getVectors()) { buffer.asFloatBuffer().put((float[]) v); vectorData.writeBytes(buffer.array(), buffer.array().length); } } - private void writeByteVectors(FieldWriter fieldData) throws IOException { - for (Object v : fieldData.vectors) { + private void writeByteVectors(FlatFieldVectorsWriter fieldWriter) throws IOException { + for (Object v : fieldWriter.getVectors()) { byte[] vector = (byte[]) v; vectorData.writeBytes(vector, vector.length); } } - private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) + private void writeSortingField( + FlatFieldVectorsWriter fieldWriter, FieldInfo fieldInfo, int maxDoc, Sorter.DocMap sortMap) throws IOException { - final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + var docsWithFieldSet = fieldWriter.getDocsWithFieldSet(); + final int[] ordMap = new int[docsWithFieldSet.cardinality()]; // new ord to old ord DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); - mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + mapOldOrdToNewOrd(docsWithFieldSet, sortMap, null, ordMap, newDocsWithField); // write vector values - VectorEncoding encoding = fieldData.fieldInfo.getVectorEncoding(); + VectorEncoding encoding = fieldInfo.getVectorEncoding(); long vectorDataOffset = alignOutput(vectorData, encoding); switch (encoding) { - case BYTE -> writeSortedByteVectors(fieldData, ordMap); - case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap); + case BYTE -> writeSortedByteVectors(fieldWriter, ordMap); + case FLOAT32 -> writeSortedFloat32Vectors(fieldWriter, fieldInfo, ordMap); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; - writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + writeMeta(fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); } - private void writeSortedFloat32Vectors(FieldWriter fieldData, int[] ordMap) - throws IOException { + private void writeSortedFloat32Vectors( + FlatFieldVectorsWriter fieldWriter, FieldInfo fieldInfo, int[] ordMap) throws IOException { final ByteBuffer buffer = - ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); for (int ordinal : ordMap) { - float[] vector = (float[]) fieldData.vectors.get(ordinal); + float[] vector = (float[]) fieldWriter.getVectors().get(ordinal); buffer.asFloatBuffer().put(vector); vectorData.writeBytes(buffer.array(), buffer.array().length); } } - private void writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { + private void writeSortedByteVectors(FlatFieldVectorsWriter fieldWriter, int[] ordMap) + throws IOException { for (int ordinal : ordMap) { - byte[] vector = (byte[]) fieldData.vectors.get(ordinal); + byte[] vector = (byte[]) fieldWriter.getVectors().get(ordinal); vectorData.writeBytes(vector, vector.length); } } @@ -221,7 +263,7 @@ private void writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) thro public void mergeOneFlatVectorField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { // Since we know we will not be searching for additional indexing, we can just write the - // the vectors directly to the new segment. + // vectors directly to the new segment. VectorEncoding encoding = fieldInfo.getVectorEncoding(); long vectorDataOffset = alignOutput(vectorData, encoding); // No need to use temporary file as we don't have to re-open for reading @@ -310,29 +352,34 @@ public void close() throws IOException { IOUtils.close(meta, vectorData); } - private abstract static class FieldWriter extends FlatFieldVectorsWriter { + /** + * Default {@link FlatFieldVectorsWriter} implementation: stores vectors on-heap in an {@link + * ArrayList}, copying each value via {@link #copyValue} on {@link #addValue}. This is the + * implementation used when {@link Lucene99FlatVectorsWriter} is constructed without a strategy + * factory. + */ + private abstract static class DefaultFieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_RAM_BYTES_USED = - RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + RamUsageEstimator.shallowSizeOfInstance(DefaultFieldWriter.class); private final FieldInfo fieldInfo; - private final int dim; private final DocsWithFieldSet docsWithField; private final List vectors; private boolean finished; private int lastDocID = -1; - static FieldWriter create(FieldInfo fieldInfo) { + private static FlatFieldVectorsWriter create(FieldInfo fieldInfo) { int dim = fieldInfo.getVectorDimension(); return switch (fieldInfo.getVectorEncoding()) { case BYTE -> - new Lucene99FlatVectorsWriter.FieldWriter(fieldInfo) { + new DefaultFieldWriter(fieldInfo) { @Override public byte[] copyValue(byte[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); } }; case FLOAT32 -> - new Lucene99FlatVectorsWriter.FieldWriter(fieldInfo) { + new DefaultFieldWriter(fieldInfo) { @Override public float[] copyValue(float[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); @@ -341,10 +388,9 @@ public float[] copyValue(float[] value) { }; } - FieldWriter(FieldInfo fieldInfo) { + DefaultFieldWriter(FieldInfo fieldInfo) { super(); this.fieldInfo = fieldInfo; - this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestKnnVectorsFormatCustomWriter.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestKnnVectorsFormatCustomWriter.java new file mode 100644 index 000000000000..51c37bbc04ca --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestKnnVectorsFormatCustomWriter.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.RamUsageEstimator; + +/** + * Exercises the {@code strategyFactory} constructor of {@link Lucene99FlatVectorsWriter} against + * the full {@link BaseKnnVectorsFormatTestCase} suite, using a paged storage strategy as a concrete + * example of a non-default {@link FlatFieldVectorsWriter}. + * + *

The strategy stores vectors in a {@code List} of fixed-size pages instead of the + * default {@code ArrayList} / {@code ArrayList}, and exposes them back through + * {@link FlatFieldVectorsWriter#getVectors()} via an {@link AbstractList} adapter that materializes + * a heap array per access. The on-disk format produced is identical to the default configuration — + * only the in-memory accumulation differs. + */ +public class TestKnnVectorsFormatCustomWriter extends BaseKnnVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(new PagedHnswVectorsFormat()); + } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + + /** + * A {@link KnnVectorsFormat} that produces an HNSW writer whose underlying flat vector writer + * uses the paged storage strategy. Reads are delegated to a standard {@link + * Lucene99HnswVectorsFormat} — the on-disk format is unchanged. + */ + private static final class PagedHnswVectorsFormat extends KnnVectorsFormat { + + private static final FlatVectorsScorer SCORER = + FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); + private static final FlatVectorsFormat FLAT_FORMAT = new Lucene99FlatVectorsFormat(SCORER); + private static final Lucene99HnswVectorsFormat READ_FORMAT = new Lucene99HnswVectorsFormat(); + + PagedHnswVectorsFormat() { + // Reuse the registered name so segments written by this format can be reopened by the + // standard Lucene99HnswVectorsFormat via SPI — the on-disk bytes are identical. + super("Lucene99HnswVectorsFormat"); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + // Build a flat writer with our paged strategy, and wrap it in the standard HNSW writer. + var flatWriter = + new Lucene99FlatVectorsWriter(state, SCORER, PagedFieldVectorsWriter::create); + return new Lucene99HnswVectorsWriter( + state, + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + FLAT_FORMAT, + flatWriter, + Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER, + null); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return READ_FORMAT.fieldsReader(state); + } + + @Override + public int getMaxDimensions(String fieldName) { + return DEFAULT_MAX_DIMENSIONS; + } + } + + /** + * A {@link FlatFieldVectorsWriter} that accumulates vectors into a list of fixed-size {@link + * ByteBuffer} pages, each holding {@link #VECTORS_PER_PAGE} vectors. + */ + private abstract static class PagedFieldVectorsWriter extends FlatFieldVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(PagedFieldVectorsWriter.class); + + /** Fixed at 10 to keep the test deliberately simple. */ + private static final int VECTORS_PER_PAGE = 10; + + private final FieldInfo fieldInfo; + private final int bytesPerVector; + private final DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + private final List pages = new ArrayList<>(); + private int size; + private int lastDocID = -1; + private boolean finished; + + static PagedFieldVectorsWriter create(FieldInfo fi) { + return switch (fi.getVectorEncoding()) { + case FLOAT32 -> new FloatPaged(fi); + case BYTE -> new BytePaged(fi); + }; + } + + PagedFieldVectorsWriter(FieldInfo fi, int bytesPerVector) { + this.fieldInfo = fi; + this.bytesPerVector = bytesPerVector; + } + + /** Writes one vector's worth of bytes into {@code buf} at its current position. */ + abstract void serialize(ByteBuffer buf, T vector); + + /** Reads one vector's worth of bytes from {@code buf} at its current position. */ + abstract T deserialize(ByteBuffer buf); + + /** Returns a fresh, positioned view of the byte region holding vector {@code ord}. */ + private ByteBuffer bufferFor(int ord) { + ByteBuffer page = pages.get(ord / VECTORS_PER_PAGE); + int offset = (ord % VECTORS_PER_PAGE) * bytesPerVector; + return page.slice(offset, bytesPerVector).order(ByteOrder.LITTLE_ENDIAN); + } + + @Override + public final void addValue(int docID, T vectorValue) { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)"); + } + assert docID > lastDocID; + lastDocID = docID; + + if (size % VECTORS_PER_PAGE == 0) { + pages.add( + ByteBuffer.allocate(VECTORS_PER_PAGE * bytesPerVector).order(ByteOrder.LITTLE_ENDIAN)); + } + serialize(bufferFor(size), vectorValue); + docsWithField.add(docID); + size++; + } + + /** + * Unsupported: this writer owns its storage and copies vector bytes directly in {@link + * #addValue} via {@link #serialize(ByteBuffer, Object)}. {@code copyValue} is never invoked. + */ + @Override + public final T copyValue(T vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public final List getVectors() { + return new AbstractList<>() { + @Override + public T get(int index) { + return deserialize(bufferFor(index)); + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public final DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public final void finish() { + finished = true; + } + + @Override + public final boolean isFinished() { + return finished; + } + + @Override + public final long ramBytesUsed() { + return SHALLOW_RAM_BYTES_USED + + docsWithField.ramBytesUsed() + + (long) pages.size() + * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + VECTORS_PER_PAGE * bytesPerVector); + } + + private static final class FloatPaged extends PagedFieldVectorsWriter { + private final int dim; + + FloatPaged(FieldInfo fi) { + super(fi, fi.getVectorDimension() * Float.BYTES); + this.dim = fi.getVectorDimension(); + } + + @Override + void serialize(ByteBuffer buf, float[] vector) { + buf.asFloatBuffer().put(vector); + } + + @Override + float[] deserialize(ByteBuffer buf) { + float[] out = new float[dim]; + buf.asFloatBuffer().get(out); + return out; + } + } + + private static final class BytePaged extends PagedFieldVectorsWriter { + private final int dim; + + BytePaged(FieldInfo fi) { + super(fi, fi.getVectorDimension()); + this.dim = fi.getVectorDimension(); + } + + @Override + void serialize(ByteBuffer buf, byte[] vector) { + buf.put(vector); + } + + @Override + byte[] deserialize(ByteBuffer buf) { + byte[] out = new byte[dim]; + buf.get(out); + return out; + } + } + } +}