Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Use the Lucene Distance Calculation Function in Script Scoring for doing exact search #1287

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.11...2.x)
### Features
* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search [#1287](https://github.com/opensearch-project/k-NN/pull/1287)
* Add parent join support for lucene knn [#1182](https://github.com/opensearch-project/k-NN/pull/1182)
### Enhancements
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Objects;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import org.apache.lucene.util.VectorUtil;

public class KNNScoringUtil {
private static Logger logger = LogManager.getLogger(KNNScoringUtil.class);
Expand Down Expand Up @@ -48,13 +49,7 @@
* @return L2 score
*/
public static float l2Squared(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float squaredDistance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
squaredDistance += diff * diff;
}
return squaredDistance;
return VectorUtil.squareDistance(queryVector, inputVector);
TrungBui59 marked this conversation as resolved.
Show resolved Hide resolved
}

private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDataType) {
Expand Down Expand Up @@ -101,19 +96,16 @@
* @return cosine score
*/
public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) {
requireEqualDimension(queryVector, inputVector);
float dotProduct = 0.0f;
float normInputVector = 0.0f;
for (int i = 0; i < queryVector.length; i++) {
dotProduct += queryVector[i] * inputVector[i];
normInputVector += inputVector[i] * inputVector[i];
}
float normalizedProduct = normQueryVector * normInputVector;
if (normalizedProduct == 0) {
logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
return 0.0f;
}
return (float) (dotProduct / (Math.sqrt(normalizedProduct)));
return (float) (VectorUtil.dotProduct(queryVector, inputVector) / (Math.sqrt(normalizedProduct)));
}

/**
Expand Down Expand Up @@ -150,20 +142,28 @@
*/
public static float cosinesimil(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float dotProduct = 0.0f;
float normQueryVector = 0.0f;
float normInputVector = 0.0f;
for (int i = 0; i < queryVector.length; i++) {
dotProduct += queryVector[i] * inputVector[i];
normQueryVector += queryVector[i] * queryVector[i];
normInputVector += inputVector[i] * inputVector[i];
int numZeroInInput = 0;
int numZeroInQuery = 0;
float cosine = 0.0f;
for (int i = 0; i < inputVector.length; i++) {
if (inputVector[i] == 0) {
numZeroInInput++;
}

if (queryVector[i] == 0) {
numZeroInQuery++;
}
}
float normalizedProduct = normQueryVector * normInputVector;
if (normalizedProduct == 0) {
if (numZeroInInput == inputVector.length || numZeroInQuery == queryVector.length) {
return cosine;
}
try {
cosine = VectorUtil.cosine(queryVector, inputVector);
} catch (Exception e) {

Check warning on line 162 in src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java#L162

Added line #L162 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid catching all exceptions here or catch a more specific exception?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 Sure I will change it

logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
return 0.0f;
}
return (float) (dotProduct / (Math.sqrt(normalizedProduct)));
return cosine;
}

/**
Expand Down Expand Up @@ -217,7 +217,6 @@
* @return L1 score
*/
public static float l1Norm(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
Expand Down Expand Up @@ -255,7 +254,6 @@
* @return L-inf score
*/
public static float lInfNorm(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
Expand Down Expand Up @@ -293,12 +291,7 @@
* @return dot product score
*/
public static float innerProduct(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
distance += queryVector[i] * inputVector[i];
}
return distance;
return VectorUtil.dotProduct(queryVector, inputVector);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.opensearch.knn.plugin.script;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;

import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
Expand All @@ -19,11 +22,12 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.store.Directory;

import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.plugin.script.KNNScoringUtilTests.TestKNNScriptDocValues;

public class KNNScoringUtilTests extends KNNTestCase {

Expand Down
Loading