Skip to content

Fix off-heap byte vector scoring at query time #14874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ Optimizations
* GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina)
* GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla)

* GITHUB#14874: Improve off-heap KNN byte vector query performance in cases where indexing and search are performed by the same process. (Kaival Parikh)

Bug Fixes
---------------------
* GITHUB#14049: Randomize KNN codec params in RandomCodec. Fixes scalar quantization div-by-zero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer

final int vectorByteSize;
final MemorySegmentAccessInput input;
final MemorySegment query;
final byte[] query;
byte[] scratch;

/**
Expand Down Expand Up @@ -61,7 +61,7 @@ public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
super(values);
this.input = input;
this.vectorByteSize = values.getVectorByteLength();
this.query = MemorySegment.ofArray(queryVector);
this.query = queryVector;
}

final MemorySegment getSegment(int ord) throws IOException {
Expand Down Expand Up @@ -113,7 +113,7 @@ public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
return 0.5f + raw / (float) (query.length * (1 << 15));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,45 +309,99 @@ private float squareDistanceBody(float[] a, float[] b, int limit) {
// We also support 128 bit vectors, going 32 bits at a time.
// This is slower but still faster than not vectorizing at all.

private interface ByteVectorLoader {
int length();

ByteVector load(VectorSpecies<Byte> species, int index);

byte tail(int index);
}

private record ArrayLoader(byte[] arr) implements ByteVectorLoader {
@Override
public int length() {
return arr.length;
}

@Override
public ByteVector load(VectorSpecies<Byte> species, int index) {
assert index + species.length() <= length();
return ByteVector.fromArray(species, arr, index);
}

@Override
public byte tail(int index) {
assert index <= length();
return arr[index];
}
}

private record MemorySegmentLoader(MemorySegment segment) implements ByteVectorLoader {
@Override
public int length() {
return Math.toIntExact(segment.byteSize());
}

@Override
public ByteVector load(VectorSpecies<Byte> species, int index) {
assert index + species.length() <= length();
return ByteVector.fromMemorySegment(species, segment, index, LITTLE_ENDIAN);
}

@Override
public byte tail(int index) {
assert index <= length();
return segment.get(JAVA_BYTE, index);
}
}

@Override
public int dotProduct(byte[] a, byte[] b) {
return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
return dotProductBody(new ArrayLoader(a), new ArrayLoader(b));
}

public static int dotProduct(byte[] a, MemorySegment b) {
return dotProductBody(new ArrayLoader(a), new MemorySegmentLoader(b));
}

public static int dotProduct(MemorySegment a, MemorySegment b) {
assert a.byteSize() == b.byteSize();
return dotProductBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
}

private static int dotProductBody(ByteVectorLoader a, ByteVectorLoader b) {
assert a.length() == b.length();
int i = 0;
int res = 0;

// only vectorize if we'll at least enter the loop a single time
if (a.byteSize() >= 16) {
if (a.length() >= 16) {
// compute vectorized dot product consistent with VPDPBUSD instruction
if (VECTOR_BITSIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.byteSize());
i += BYTE_SPECIES.loopBound(a.length());
res += dotProductBody512(a, b, i);
} else if (VECTOR_BITSIZE == 256) {
i += BYTE_SPECIES.loopBound(a.byteSize());
i += BYTE_SPECIES.loopBound(a.length());
res += dotProductBody256(a, b, i);
} else {
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length());
res += dotProductBody128(a, b, i);
}
}

// scalar tail
for (; i < a.byteSize(); i++) {
res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i);
for (; i < a.length(); i++) {
res += a.tail(i) * b.tail(i);
}
return res;
}

/** vectorized dot product body (512 bit vectors) */
private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) {
private static int dotProductBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector acc = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(BYTE_SPECIES, i);
ByteVector vb8 = b.load(BYTE_SPECIES, i);

// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Expand All @@ -363,11 +417,11 @@ private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit
}

/** vectorized dot product body (256 bit vectors) */
private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) {
private static int dotProductBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);

// 32-bit multiply and add into accumulator
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
Expand All @@ -379,13 +433,13 @@ private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit
}

/** vectorized dot product body (128 bit vectors) */
private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) {
private static int dotProductBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
// 4 bytes at a time (re-loading half the vector each time!)
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
// load 8 bytes
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);

// process first "half" only: 16-bit multiply
Vector<Short> va16 = va8.convert(B2S, 0);
Expand Down Expand Up @@ -577,27 +631,35 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) {

@Override
public float cosine(byte[] a, byte[] b) {
return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
return cosineBody(new ArrayLoader(a), new ArrayLoader(b));
}

public static float cosine(MemorySegment a, MemorySegment b) {
return cosineBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
}

public static float cosine(byte[] a, MemorySegment b) {
return cosineBody(new ArrayLoader(a), new MemorySegmentLoader(b));
}

private static float cosineBody(ByteVectorLoader a, ByteVectorLoader b) {
int i = 0;
int sum = 0;
int norm1 = 0;
int norm2 = 0;

// only vectorize if we'll at least enter the loop a single time
if (a.byteSize() >= 16) {
if (a.length() >= 16) {
final float[] ret;
if (VECTOR_BITSIZE >= 512) {
i += BYTE_SPECIES.loopBound((int) a.byteSize());
i += BYTE_SPECIES.loopBound(a.length());
ret = cosineBody512(a, b, i);
} else if (VECTOR_BITSIZE == 256) {
i += BYTE_SPECIES.loopBound((int) a.byteSize());
i += BYTE_SPECIES.loopBound(a.length());
ret = cosineBody256(a, b, i);
} else {
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
i += ByteVector.SPECIES_64.loopBound(a.length() - ByteVector.SPECIES_64.length());
ret = cosineBody128(a, b, i);
}
sum += ret[0];
Expand All @@ -606,9 +668,9 @@ public static float cosine(MemorySegment a, MemorySegment b) {
}

// scalar tail
for (; i < a.byteSize(); i++) {
byte elem1 = a.get(JAVA_BYTE, i);
byte elem2 = b.get(JAVA_BYTE, i);
for (; i < a.length(); i++) {
byte elem1 = a.tail(i);
byte elem2 = b.tail(i);
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
Expand All @@ -617,13 +679,13 @@ public static float cosine(MemorySegment a, MemorySegment b) {
}

/** vectorized cosine body (512 bit vectors) */
private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) {
private static float[] cosineBody512(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector accSum = IntVector.zero(INT_SPECIES);
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(BYTE_SPECIES, i);
ByteVector vb8 = b.load(BYTE_SPECIES, i);

// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Expand All @@ -647,13 +709,13 @@ private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit
}

/** vectorized cosine body (256 bit vectors) */
private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) {
private static float[] cosineBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);

// 16-bit multiply, and add into accumulators
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
Expand All @@ -672,13 +734,13 @@ private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit
}

/** vectorized cosine body (128 bit vectors) */
private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) {
private static float[] cosineBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);

// process first half only: 16-bit multiply
Vector<Short> va16 = va8.convert(B2S, 0);
Expand All @@ -700,39 +762,47 @@ private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit

@Override
public int squareDistance(byte[] a, byte[] b) {
return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
return squareDistanceBody(new ArrayLoader(a), new ArrayLoader(b));
}

public static int squareDistance(MemorySegment a, MemorySegment b) {
assert a.byteSize() == b.byteSize();
return squareDistanceBody(new MemorySegmentLoader(a), new MemorySegmentLoader(b));
}

public static int squareDistance(byte[] a, MemorySegment b) {
return squareDistanceBody(new ArrayLoader(a), new MemorySegmentLoader(b));
}

private static int squareDistanceBody(ByteVectorLoader a, ByteVectorLoader b) {
assert a.length() == b.length();
int i = 0;
int res = 0;

// only vectorize if we'll at least enter the loop a single time
if (a.byteSize() >= 16) {
if (a.length() >= 16) {
if (VECTOR_BITSIZE >= 256) {
i += BYTE_SPECIES.loopBound((int) a.byteSize());
i += BYTE_SPECIES.loopBound(a.length());
res += squareDistanceBody256(a, b, i);
} else {
i += ByteVector.SPECIES_64.loopBound((int) a.byteSize());
i += ByteVector.SPECIES_64.loopBound(a.length());
res += squareDistanceBody128(a, b, i);
}
}

// scalar tail
for (; i < a.byteSize(); i++) {
int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
for (; i < a.length(); i++) {
int diff = a.tail(i) - b.tail(i);
res += diff * diff;
}
return res;
}

/** vectorized square distance body (256+ bit vectors) */
private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) {
private static int squareDistanceBody256(ByteVectorLoader a, ByteVectorLoader b, int limit) {
IntVector acc = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(BYTE_SPECIES, i);
ByteVector vb8 = b.load(BYTE_SPECIES, i);

// 32-bit sub, multiply, and add into accumulators
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
Expand All @@ -746,14 +816,14 @@ private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int l
}

/** vectorized square distance body (128 bit vectors) */
private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) {
private static int squareDistanceBody128(ByteVectorLoader a, ByteVectorLoader b, int limit) {
// 128-bit implementation, which must "split up" vectors due to widening conversions
// it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
ByteVector va8 = a.load(ByteVector.SPECIES_64, i);
ByteVector vb8 = b.load(ByteVector.SPECIES_64, i);

// 16-bit sub
Vector<Short> va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0);
Expand Down
Loading