diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 80ba771596a9..54fbdcb0747f 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -107,6 +107,8 @@ New Features * GITHUB#14178: Add a Faiss-based vector format in the sandbox module. (Kaival Parikh) +* GITHUB#16067: Add byte vector support to FaissKnnVectorsFormat. (Prithvi S) + * GITHUB#15508: Use native vectorization in Lucene. (Ankur Goel, Shubham Chaudhary, Dawid Weiss) * GITHUB#15818: Add BM25 k3 query-term frequency saturation to BM25Similarity. (Sagar Upadhyaya) diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java index 02275e5b5963..591fd3194a22 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java @@ -119,7 +119,11 @@ public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVector throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta); } IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length); - indexMap.put(fieldMeta.name, FaissLibrary.INSTANCE.readIndex(indexInput)); + FieldInfo fi = state.fieldInfos.fieldInfo(fieldMeta.name); + indexMap.put( + fieldMeta.name, + FaissLibrary.INSTANCE.readIndex( + indexInput, fi.getVectorSimilarityFunction(), fi.getVectorEncoding())); } } catch (Throwable t) { IOUtils.closeWhileSuppressingExceptions(t, this); @@ -158,10 +162,8 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { } @Override - public ByteVectorValues getByteVectorValues(String field) { - // TODO: Support using SQ8 quantization, see: - // - https://github.com/opensearch-project/k-NN/pull/2425 - throw new UnsupportedOperationException("Byte vectors not supported"); + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); } @Override @@ -176,9 +178,11 @@ public void search( @Override public void search( String field, byte[] vector, KnnCollector knnCollector, AcceptDocs acceptDocs) { - // TODO: Support using SQ8 quantization, see: - // - https://github.com/opensearch-project/k-NN/pull/2425 - throw new UnsupportedOperationException("Byte vectors not supported"); + float[] floatVector = new float[vector.length]; + for (int i = 0; i < vector.length; i++) { + floatVector[i] = vector[i]; + } + search(field, floatVector, knnCollector, acceptDocs); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java index c269653c490d..09fd15940168 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java @@ -31,6 +31,7 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; @@ -100,10 +101,11 @@ public FaissKnnVectorsWriter( public IORunnable mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { rawVectorsWriter.mergeOneFlatVectorField(fieldInfo, mergeState); switch (fieldInfo.getVectorEncoding()) { - case BYTE -> - // TODO: Support using SQ8 quantization, see: - // - https://github.com/opensearch-project/k-NN/pull/2425 - throw new UnsupportedOperationException("Byte vectors not supported"); + case BYTE -> { + ByteVectorValues merged = + KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + writeFloatField(fieldInfo, new ByteToFloatVectorValues(merged), doc -> doc); + } case FLOAT32 -> { FloatVectorValues merged = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); @@ -126,10 +128,20 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { for (Map.Entry> entry : rawFields.entrySet()) { FieldInfo fieldInfo = entry.getKey(); switch (fieldInfo.getVectorEncoding()) { - case BYTE -> - // TODO: Support using SQ8 quantization, see: - // - https://github.com/opensearch-project/k-NN/pull/2425 - throw new UnsupportedOperationException("Byte vectors not supported"); + case BYTE -> { + @SuppressWarnings("unchecked") + FlatFieldVectorsWriter rawWriter = + (FlatFieldVectorsWriter) entry.getValue(); + + List vectors = rawWriter.getVectors(); + int dimension = fieldInfo.getVectorDimension(); + DocIdSet docIdSet = rawWriter.getDocsWithFieldSet(); + + writeFloatField( + fieldInfo, + new ByteToFloatVectorValues(ByteVectorValues.fromBytes(vectors, dimension), docIdSet), + (sortMap != null) ? sortMap::oldToNew : doc -> doc); + } case FLOAT32 -> { @SuppressWarnings("unchecked") @@ -197,6 +209,53 @@ public long ramBytesUsed() { return rawVectorsWriter.ramBytesUsed(); } + private static class ByteToFloatVectorValues extends FloatVectorValues { + private final ByteVectorValues bytes; + private final DocIdSet docIdSet; + + public ByteToFloatVectorValues(ByteVectorValues bytes) { + this(bytes, null); + } + + public ByteToFloatVectorValues(ByteVectorValues bytes, DocIdSet docIdSet) { + this.bytes = bytes; + this.docIdSet = docIdSet; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + byte[] b = bytes.vectorValue(ord); + float[] f = new float[b.length]; + for (int i = 0; i < b.length; i++) { + f[i] = b[i]; + } + return f; + } + + @Override + public int dimension() { + return bytes.dimension(); + } + + @Override + public int size() { + return bytes.size(); + } + + @Override + public FloatVectorValues copy() throws IOException { + return new ByteToFloatVectorValues(bytes.copy(), docIdSet); + } + + @Override + public DocIndexIterator iterator() { + if (docIdSet != null) { + return fromDISI(docIdSet.iterator()); + } + return bytes.iterator(); + } + } + private static class BufferedFloatVectorValues extends FloatVectorValues { private final List floats; private final int dimension; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java index ad0715bcb333..17049cef1d8d 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java @@ -18,6 +18,7 @@ import java.io.Closeable; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; @@ -54,5 +55,5 @@ Index createIndex( FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId); - Index readIndex(IndexInput input); + Index readIndex(IndexInput input, VectorSimilarityFunction function, VectorEncoding encoding); } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java index 08ec8264131c..cdde6ba1b62a 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java @@ -36,9 +36,9 @@ import java.lang.invoke.MethodType; import java.nio.ByteOrder; import java.util.Map; -import java.util.stream.Collectors; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; @@ -151,19 +151,6 @@ private static int functionToMetric(VectorSimilarityFunction function) { return metric; } - // Invert FUNCTION_TO_METRIC - private static final Map METRIC_TO_FUNCTION = - FUNCTION_TO_METRIC.entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); - - private static VectorSimilarityFunction metricToFunction(int metric) { - VectorSimilarityFunction function = METRIC_TO_FUNCTION.get(metric); - if (function == null) { - throw new UnsupportedOperationException("Metric not supported: " + metric); - } - return function; - } - @Override public FaissLibrary.Index createIndex( String description, @@ -229,7 +216,7 @@ public FaissLibrary.Index createIndex( // Add docs to index handleException(wrapper.faiss_Index_add_with_ids(indexPointer, size, docs, ids)); - return new Index(indexPointer); + return new Index(indexPointer, function, VectorEncoding.FLOAT32); } catch (IOException e) { throw new UncheckedIOException(e); @@ -241,7 +228,8 @@ public FaissLibrary.Index createIndex( private static final int FAISS_IO_FLAG_READ_ONLY = 2; @Override - public FaissLibrary.Index readIndex(IndexInput input) { + public FaissLibrary.Index readIndex( + IndexInput input, VectorSimilarityFunction function, VectorEncoding encoding) { try (Arena temp = Arena.ofConfined()) { MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input); MemorySegment readerStub = @@ -262,7 +250,7 @@ public FaissLibrary.Index readIndex(IndexInput input) { customIOReaderPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY, pointer)); MemorySegment indexPointer = pointer.get(ADDRESS, 0); - return new Index(indexPointer); + return new Index(indexPointer, function, encoding); } } @@ -277,29 +265,39 @@ private interface FloatToFloatFunction { private final FloatToFloatFunction scaler; private boolean closed; - private Index(MemorySegment indexPointer) { + private Index( + MemorySegment indexPointer, VectorSimilarityFunction function, VectorEncoding encoding) { this.arena = Arena.ofShared(); this.indexPointer = indexPointer // Ensure timely cleanup .reinterpret(arena, wrapper::faiss_Index_free); - // Get underlying function - int metricType = wrapper.faiss_Index_metric_type(indexPointer); - VectorSimilarityFunction function = metricToFunction(metricType); + int dimension = wrapper.faiss_Index_d(indexPointer); // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java this.scaler = switch (function) { - case DOT_PRODUCT -> - // distance in Faiss === dotProduct in Lucene - distance -> Math.max((1 + distance) / 2, 0); + case DOT_PRODUCT -> { + if (encoding == VectorEncoding.BYTE) { + float denom = (float) (dimension * (1 << 15)); + yield distance -> 0.5f + distance / denom; + } else { + yield distance -> Math.max((1 + distance) / 2, 0); + } + } case EUCLIDEAN -> // distance in Faiss === squareDistance in Lucene distance -> 1 / (1 + distance); - case COSINE, MAXIMUM_INNER_PRODUCT -> throw new AssertionError("Should not reach here"); + case COSINE -> + // For COSINE, vectors are normalized so inner product == cosine similarity + distance -> Math.max((1 + distance) / 2, 0); + + case MAXIMUM_INNER_PRODUCT -> + throw new UnsupportedOperationException( + "Similarity function not supported: " + function); }; this.closed = false; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java index 575ebf953224..423f685a3e1b 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java @@ -210,6 +210,19 @@ int faiss_Index_metric_type(MemorySegment indexPointer) { } } + private final MethodHandle faiss_Index_d$MH = + getHandle("faiss_Index_d", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_Index_d(MemorySegment indexPointer) { + try { + return (int) faiss_Index_d$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + private final MethodHandle faiss_Index_search$MH = getHandle( "faiss_Index_search", diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java index bd4bcf9c70fa..e027959ce158 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.sandbox.codecs.faiss; +import static org.apache.lucene.index.VectorEncoding.BYTE; import static org.apache.lucene.index.VectorEncoding.FLOAT32; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; @@ -44,7 +45,7 @@ public class TestFaissKnnVectorsFormat extends BaseKnnVectorsFormatTestCase { private static final String FAISS_RUN_TESTS = "tests.faiss.run"; - private static final VectorEncoding[] SUPPORTED_ENCODINGS = {FLOAT32}; + private static final VectorEncoding[] SUPPORTED_ENCODINGS = {FLOAT32, BYTE}; private static final VectorSimilarityFunction[] SUPPORTED_FUNCTIONS = {DOT_PRODUCT, EUCLIDEAN}; @BeforeClass @@ -92,30 +93,6 @@ public void testRecall() throws IOException { @Ignore // does not honour visitedLimit public void testSearchWithVisitedLimit() {} - @Override - @Ignore // does not support byte vectors - public void testByteVectorScorerIteration() {} - - @Override - @Ignore // does not support byte vectors - public void testMismatchedFields() {} - - @Override - @Ignore // does not support byte vectors - public void testSortedIndexBytes() {} - - @Override - @Ignore // does not support byte vectors - public void testRandomBytes() {} - - @Override - @Ignore // does not support byte vectors - public void testEmptyByteVectorData() {} - - @Override - @Ignore // does not support byte vectors - public void testMergingWithDifferentByteKnnFields() {} - @Monster("Uses large amount of heap and RAM") public void testLargeVectorData() throws IOException { KnnVectorsFormat format =