Skip to content

Commit

Permalink
Use the new vector-related enums in Quarkus and the DS api for creati…
Browse files Browse the repository at this point in the history
…ng the index
  • Loading branch information
jmartisk committed Nov 14, 2023
1 parent 36f6a58 commit db4ab37
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 90 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -29,13 +27,13 @@
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
import io.quarkus.redis.datasource.json.ReactiveJsonCommands;
import io.quarkus.redis.datasource.keys.KeyScanArgs;
import io.quarkus.redis.datasource.search.CreateArgs;
import io.quarkus.redis.datasource.search.Document;
import io.quarkus.redis.datasource.search.QueryArgs;
import io.quarkus.redis.datasource.search.SearchQueryResponse;
import io.smallrye.mutiny.Uni;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Request;
import io.vertx.mutiny.redis.client.Response;

public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {

Expand Down Expand Up @@ -67,19 +65,12 @@ private void createIndexIfDoesNotExist() {
}
}).await().indefinitely();
if (!indexes.contains(schema.getIndexName())) {
// TODO: rewrite to use the typesafe data source API
Request request = Request.cmd(Command.FT_CREATE)
.arg(schema.getIndexName())
.arg("ON")
.arg("JSON")
.arg("PREFIX")
.arg("1")
.arg(schema.getPrefix())
.arg("SCHEMA");
schema.defineFields(request);
LOG.debug(
"Creating index with command: " + request.toString().replaceAll("\r\n", " "));
ds.getRedis().send(request).await().indefinitely();
CreateArgs indexCreateArgs = new CreateArgs()
.onJson()
.prefixes(schema.getPrefix());
schema.defineFields(indexCreateArgs);
LOG.debug("Creating Redis index " + schema.getIndexName());
ds.search().ftCreate(schema.getIndexName(), indexCreateArgs).await().indefinitely();
} else {
LOG.debug("Index in Redis already exists: " + schema.getIndexName());
}
Expand Down Expand Up @@ -160,8 +151,6 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
.param("BLOB", referenceEmbedding.vector());
Uni<SearchQueryResponse> search = ds.search()
.ftSearch(schema.getIndexName(), query, args);
System.out.println("ARGS = " + args.toArgs());
System.out.println("query = " + query);
SearchQueryResponse response = search.await().indefinitely();
return response.documents().stream().map(this::extractEmbeddingMatch)
.filter(embeddingMatch -> embeddingMatch.score() >= minScore)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import java.util.List;
import java.util.Optional;

import io.quarkiverse.langchain4j.redis.MetricType;
import io.quarkiverse.langchain4j.redis.VectorAlgorithm;
import io.quarkus.redis.datasource.search.DistanceMetric;
import io.quarkus.redis.datasource.search.VectorAlgorithm;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;
import io.smallrye.config.WithDefault;
Expand Down Expand Up @@ -52,7 +52,7 @@ public interface RedisEmbeddingStoreConfig {
* Metric used to compute the distance between two vectors.
*/
@WithDefault("COSINE")
MetricType metricType();
DistanceMetric distanceMetric();

/**
* Name of the key that will be used to store the embedding vector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public RedisEmbeddingStore apply(SyntheticCreationalContext<RedisEmbeddingStore>
.metadataFields(config.metadataFields().orElse(Collections.emptyList()))
.vectorAlgorithm(config.vectorAlgorithm())
.dimension(config.dimension())
.metricType(config.metricType())
.metricType(config.distanceMetric())
.build();
builder.schema(schema);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import java.util.List;

import io.quarkiverse.langchain4j.redis.MetricType;
import io.quarkiverse.langchain4j.redis.VectorAlgorithm;
import io.vertx.mutiny.redis.client.Request;
import io.quarkus.redis.datasource.search.CreateArgs;
import io.quarkus.redis.datasource.search.DistanceMetric;
import io.quarkus.redis.datasource.search.FieldOptions;
import io.quarkus.redis.datasource.search.FieldType;
import io.quarkus.redis.datasource.search.VectorAlgorithm;
import io.quarkus.redis.datasource.search.VectorType;

public class RedisSchema {

Expand All @@ -15,7 +18,7 @@ public class RedisSchema {
private List<String> metadataFields;
private VectorAlgorithm vectorAlgorithm;
private Long dimension;
private MetricType metricType;
private DistanceMetric distanceMetric;
private static final String JSON_PATH_PREFIX = "$.";

public RedisSchema(String indexName,
Expand All @@ -25,15 +28,15 @@ public RedisSchema(String indexName,
List<String> metadataFields,
VectorAlgorithm vectorAlgorithm,
Long dimension,
MetricType metricType) {
DistanceMetric distanceMetric) {
this.indexName = indexName;
this.prefix = prefix;
this.vectorFieldName = vectorFieldName;
this.scalarFieldName = scalarFieldName;
this.metadataFields = metadataFields;
this.vectorAlgorithm = vectorAlgorithm;
this.dimension = dimension;
this.metricType = metricType;
this.distanceMetric = distanceMetric;
}

public String getIndexName() {
Expand Down Expand Up @@ -64,51 +67,34 @@ public Long getDimension() {
return dimension;
}

public MetricType getMetricType() {
return metricType;
public DistanceMetric getDistanceMetric() {
return distanceMetric;
}

public void defineFields(Request args) {
public void defineFields(CreateArgs args) {
defineTextField(args);
defineVectorField(args);
defineMetadataFields(args);
}

private void defineMetadataFields(Request args) {
private void defineMetadataFields(CreateArgs args) {
for (String metadataField : metadataFields) {
args.arg(JSON_PATH_PREFIX + metadataField);
args.arg("AS");
args.arg(metadataField);
args.arg("TEXT");
args.arg("WEIGHT");
args.arg("1.0");
args.indexedField(JSON_PATH_PREFIX + metadataField, metadataField, FieldType.TEXT, new FieldOptions().weight(1.0));
}
}

private void defineTextField(Request args) {
args.arg(JSON_PATH_PREFIX + scalarFieldName);
args.arg("AS");
args.arg(scalarFieldName);
args.arg("TEXT");
args.arg("WEIGHT");
args.arg("1.0");
private void defineTextField(CreateArgs args) {
args.indexedField(JSON_PATH_PREFIX + scalarFieldName, scalarFieldName, FieldType.TEXT, new FieldOptions().weight(1.0));
}

private void defineVectorField(Request args) {
args.arg(JSON_PATH_PREFIX + vectorFieldName);
args.arg("AS");
args.arg(vectorFieldName);
args.arg("VECTOR");
args.arg(vectorAlgorithm.name());
args.arg("8");
args.arg("DIM");
args.arg(dimension);
args.arg("DISTANCE_METRIC");
args.arg(metricType.name());
args.arg("TYPE");
args.arg("FLOAT32");
args.arg("INITIAL_CAP");
args.arg("5");
private void defineVectorField(CreateArgs args) {
args.indexedField(JSON_PATH_PREFIX + vectorFieldName,
vectorFieldName,
FieldType.VECTOR, new FieldOptions()
.vectorAlgorithm(vectorAlgorithm)
.vectorType(VectorType.FLOAT32)
.dimension(dimension.intValue())
.distanceMetric(distanceMetric));
}

public static class Builder {
Expand All @@ -119,7 +105,7 @@ public static class Builder {
private List<String> metadataFields;
private VectorAlgorithm vectorAlgorithm;
private Long dimension;
private MetricType metricType;
private DistanceMetric metricType;

public Builder indexName(String indexName) {
this.indexName = indexName;
Expand Down Expand Up @@ -156,7 +142,7 @@ public Builder dimension(Long dimension) {
return this;
}

public Builder metricType(MetricType metricType) {
public Builder metricType(DistanceMetric metricType) {
this.metricType = metricType;
return this;
}
Expand Down

0 comments on commit db4ab37

Please sign in to comment.