Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ New Features

* GITHUB#14178: Add a Faiss-based vector format in the sandbox module. (Kaival Parikh)

* GITHUB#16067: Add byte vector support to FaissKnnVectorsFormat. (Prithvi S)

* GITHUB#15508: Use native vectorization in Lucene. (Ankur Goel, Shubham Chaudhary, Dawid Weiss)

* GITHUB#15818: Add BM25 k3 query-term frequency saturation to BM25Similarity. (Sagar Upadhyaya)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVector
throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta);
}
IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length);
indexMap.put(fieldMeta.name, FaissLibrary.INSTANCE.readIndex(indexInput));
FieldInfo fi = state.fieldInfos.fieldInfo(fieldMeta.name);
indexMap.put(
fieldMeta.name,
FaissLibrary.INSTANCE.readIndex(
indexInput, fi.getVectorSimilarityFunction(), fi.getVectorEncoding()));
}
} catch (Throwable t) {
IOUtils.closeWhileSuppressingExceptions(t, this);
Expand Down Expand Up @@ -158,10 +162,8 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
}

@Override
public ByteVectorValues getByteVectorValues(String field) {
// TODO: Support using SQ8 quantization, see:
// - https://github.com/opensearch-project/k-NN/pull/2425
throw new UnsupportedOperationException("Byte vectors not supported");
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return rawVectorsReader.getByteVectorValues(field);
}

@Override
Expand All @@ -176,9 +178,11 @@ public void search(
@Override
public void search(
String field, byte[] vector, KnnCollector knnCollector, AcceptDocs acceptDocs) {
// TODO: Support using SQ8 quantization, see:
// - https://github.com/opensearch-project/k-NN/pull/2425
throw new UnsupportedOperationException("Byte vectors not supported");
float[] floatVector = new float[vector.length];
for (int i = 0; i < vector.length; i++) {
floatVector[i] = vector[i];
}
search(field, floatVector, knnCollector, acceptDocs);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
Expand Down Expand Up @@ -100,10 +101,11 @@ public FaissKnnVectorsWriter(
public IORunnable mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
rawVectorsWriter.mergeOneFlatVectorField(fieldInfo, mergeState);
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
// TODO: Support using SQ8 quantization, see:
// - https://github.com/opensearch-project/k-NN/pull/2425
throw new UnsupportedOperationException("Byte vectors not supported");
case BYTE -> {
ByteVectorValues merged =
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
writeFloatField(fieldInfo, new ByteToFloatVectorValues(merged), doc -> doc);
}
case FLOAT32 -> {
FloatVectorValues merged =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
Expand All @@ -126,10 +128,20 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (Map.Entry<FieldInfo, FlatFieldVectorsWriter<?>> entry : rawFields.entrySet()) {
FieldInfo fieldInfo = entry.getKey();
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
// TODO: Support using SQ8 quantization, see:
// - https://github.com/opensearch-project/k-NN/pull/2425
throw new UnsupportedOperationException("Byte vectors not supported");
case BYTE -> {
@SuppressWarnings("unchecked")
FlatFieldVectorsWriter<byte[]> rawWriter =
(FlatFieldVectorsWriter<byte[]>) entry.getValue();

List<byte[]> vectors = rawWriter.getVectors();
int dimension = fieldInfo.getVectorDimension();
DocIdSet docIdSet = rawWriter.getDocsWithFieldSet();

writeFloatField(
fieldInfo,
new ByteToFloatVectorValues(ByteVectorValues.fromBytes(vectors, dimension), docIdSet),
(sortMap != null) ? sortMap::oldToNew : doc -> doc);
}

case FLOAT32 -> {
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -197,6 +209,53 @@ public long ramBytesUsed() {
return rawVectorsWriter.ramBytesUsed();
}

private static class ByteToFloatVectorValues extends FloatVectorValues {
private final ByteVectorValues bytes;
private final DocIdSet docIdSet;

public ByteToFloatVectorValues(ByteVectorValues bytes) {
this(bytes, null);
}

public ByteToFloatVectorValues(ByteVectorValues bytes, DocIdSet docIdSet) {
this.bytes = bytes;
this.docIdSet = docIdSet;
}

@Override
public float[] vectorValue(int ord) throws IOException {
byte[] b = bytes.vectorValue(ord);
float[] f = new float[b.length];
for (int i = 0; i < b.length; i++) {
f[i] = b[i];
}
return f;
}

@Override
public int dimension() {
return bytes.dimension();
}

@Override
public int size() {
return bytes.size();
}

@Override
public FloatVectorValues copy() throws IOException {
return new ByteToFloatVectorValues(bytes.copy(), docIdSet);
}

@Override
public DocIndexIterator iterator() {
if (docIdSet != null) {
return fromDISI(docIdSet.iterator());
}
return bytes.iterator();
}
}

private static class BufferedFloatVectorValues extends FloatVectorValues {
private final List<float[]> floats;
private final int dimension;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.Closeable;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
Expand Down Expand Up @@ -54,5 +55,5 @@ Index createIndex(
FloatVectorValues floatVectorValues,
IntToIntFunction oldToNewDocId);

Index readIndex(IndexInput input);
Index readIndex(IndexInput input, VectorSimilarityFunction function, VectorEncoding encoding);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
import java.lang.invoke.MethodType;
import java.nio.ByteOrder;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
Expand Down Expand Up @@ -151,19 +151,6 @@ private static int functionToMetric(VectorSimilarityFunction function) {
return metric;
}

// Invert FUNCTION_TO_METRIC
private static final Map<Integer, VectorSimilarityFunction> METRIC_TO_FUNCTION =
FUNCTION_TO_METRIC.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));

private static VectorSimilarityFunction metricToFunction(int metric) {
VectorSimilarityFunction function = METRIC_TO_FUNCTION.get(metric);
if (function == null) {
throw new UnsupportedOperationException("Metric not supported: " + metric);
}
return function;
}

@Override
public FaissLibrary.Index createIndex(
String description,
Expand Down Expand Up @@ -229,7 +216,7 @@ public FaissLibrary.Index createIndex(
// Add docs to index
handleException(wrapper.faiss_Index_add_with_ids(indexPointer, size, docs, ids));

return new Index(indexPointer);
return new Index(indexPointer, function, VectorEncoding.FLOAT32);

} catch (IOException e) {
throw new UncheckedIOException(e);
Expand All @@ -241,7 +228,8 @@ public FaissLibrary.Index createIndex(
private static final int FAISS_IO_FLAG_READ_ONLY = 2;

@Override
public FaissLibrary.Index readIndex(IndexInput input) {
public FaissLibrary.Index readIndex(
IndexInput input, VectorSimilarityFunction function, VectorEncoding encoding) {
try (Arena temp = Arena.ofConfined()) {
MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input);
MemorySegment readerStub =
Expand All @@ -262,7 +250,7 @@ public FaissLibrary.Index readIndex(IndexInput input) {
customIOReaderPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY, pointer));
MemorySegment indexPointer = pointer.get(ADDRESS, 0);

return new Index(indexPointer);
return new Index(indexPointer, function, encoding);
}
}

Expand All @@ -277,29 +265,39 @@ private interface FloatToFloatFunction {
private final FloatToFloatFunction scaler;
private boolean closed;

private Index(MemorySegment indexPointer) {
private Index(
MemorySegment indexPointer, VectorSimilarityFunction function, VectorEncoding encoding) {
this.arena = Arena.ofShared();
this.indexPointer =
indexPointer
// Ensure timely cleanup
.reinterpret(arena, wrapper::faiss_Index_free);

// Get underlying function
int metricType = wrapper.faiss_Index_metric_type(indexPointer);
VectorSimilarityFunction function = metricToFunction(metricType);
int dimension = wrapper.faiss_Index_d(indexPointer);

// Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java
this.scaler =
switch (function) {
case DOT_PRODUCT ->
// distance in Faiss === dotProduct in Lucene
distance -> Math.max((1 + distance) / 2, 0);
case DOT_PRODUCT -> {
if (encoding == VectorEncoding.BYTE) {
float denom = (float) (dimension * (1 << 15));
yield distance -> 0.5f + distance / denom;
} else {
yield distance -> Math.max((1 + distance) / 2, 0);
}
}

case EUCLIDEAN ->
// distance in Faiss === squareDistance in Lucene
distance -> 1 / (1 + distance);

case COSINE, MAXIMUM_INNER_PRODUCT -> throw new AssertionError("Should not reach here");
case COSINE ->
// For COSINE, vectors are normalized so inner product == cosine similarity
distance -> Math.max((1 + distance) / 2, 0);

case MAXIMUM_INNER_PRODUCT ->
throw new UnsupportedOperationException(
"Similarity function not supported: " + function);
};

this.closed = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ int faiss_Index_metric_type(MemorySegment indexPointer) {
}
}

private final MethodHandle faiss_Index_d$MH =
getHandle("faiss_Index_d", FunctionDescriptor.of(JAVA_INT, ADDRESS));

int faiss_Index_d(MemorySegment indexPointer) {
try {
return (int) faiss_Index_d$MH.invokeExact(indexPointer);
} catch (RuntimeException | Error e) {
throw e;
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private final MethodHandle faiss_Index_search$MH =
getHandle(
"faiss_Index_search",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.lucene.sandbox.codecs.faiss;

import static org.apache.lucene.index.VectorEncoding.BYTE;
import static org.apache.lucene.index.VectorEncoding.FLOAT32;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
Expand Down Expand Up @@ -44,7 +45,7 @@
public class TestFaissKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
private static final String FAISS_RUN_TESTS = "tests.faiss.run";

private static final VectorEncoding[] SUPPORTED_ENCODINGS = {FLOAT32};
private static final VectorEncoding[] SUPPORTED_ENCODINGS = {FLOAT32, BYTE};
private static final VectorSimilarityFunction[] SUPPORTED_FUNCTIONS = {DOT_PRODUCT, EUCLIDEAN};

@BeforeClass
Expand Down Expand Up @@ -92,30 +93,6 @@ public void testRecall() throws IOException {
@Ignore // does not honour visitedLimit
public void testSearchWithVisitedLimit() {}

@Override
@Ignore // does not support byte vectors
public void testByteVectorScorerIteration() {}

@Override
@Ignore // does not support byte vectors
public void testMismatchedFields() {}

@Override
@Ignore // does not support byte vectors
public void testSortedIndexBytes() {}

@Override
@Ignore // does not support byte vectors
public void testRandomBytes() {}

@Override
@Ignore // does not support byte vectors
public void testEmptyByteVectorData() {}

@Override
@Ignore // does not support byte vectors
public void testMergingWithDifferentByteKnnFields() {}

@Monster("Uses large amount of heap and RAM")
public void testLargeVectorData() throws IOException {
KnnVectorsFormat format =
Expand Down
Loading