Skip to content

Commit

Permalink
Use constant for min_score and max_distance
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 24, 2024
1 parent 44bd037 commit 5931e82
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,5 @@ public class KNNConstants {

public static final Float DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO = 0.95f;
public static final String MIN_SCORE = "min_score";
public static final String MAX_DISTANCE = "max_distance";
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;

import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;
Expand All @@ -52,8 +54,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");
public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");
public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE);
public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE);
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand Down
6 changes: 4 additions & 2 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@
import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE;
import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
Expand Down Expand Up @@ -1733,9 +1735,9 @@ private List<List<KNNResult>> validateRadiusSearchResults(
queryBuilder.startObject(fieldName);
queryBuilder.field("vector", queryVector);
if (distanceThreshold != null) {
queryBuilder.field("max_distance", distanceThreshold);
queryBuilder.field(MAX_DISTANCE, distanceThreshold);
} else if (scoreThreshold != null) {
queryBuilder.field("min_score", scoreThreshold);
queryBuilder.field(MIN_SCORE, scoreThreshold);
} else {
throw new IllegalArgumentException("Invalid threshold");
}
Expand Down
6 changes: 4 additions & 2 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

Expand Down Expand Up @@ -629,9 +631,9 @@ private void validateRadiusSearchResults(
builder.startObject(FIELD_NAME);
builder.field("vector", searchVectors[i]);
if (distanceThreshold != null) {
builder.field("max_distance", distanceThreshold);
builder.field(MAX_DISTANCE, distanceThreshold);
} else if (scoreThreshold != null) {
builder.field("min_score", scoreThreshold);
builder.field(MIN_SCORE, scoreThreshold);
} else {
throw new IllegalArgumentException("Either distance or score must be provided");
}
Expand Down

0 comments on commit 5931e82

Please sign in to comment.