diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/MetricType.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/MetricType.java deleted file mode 100644 index fdbdc88bf..000000000 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/MetricType.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.quarkiverse.langchain4j.redis; - -/** - * Similarity metric used by Redis - */ -public enum MetricType { - - /** - * cosine similarity - */ - COSINE, - - /** - * inner product - */ - IP, - - /** - * euclidean distance - */ - L2 -} diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java index ab26a65d5..77614ec85 100644 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java +++ b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java @@ -5,8 +5,6 @@ import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -29,13 +27,13 @@ import io.quarkus.redis.datasource.ReactiveRedisDataSource; import io.quarkus.redis.datasource.json.ReactiveJsonCommands; import io.quarkus.redis.datasource.keys.KeyScanArgs; +import io.quarkus.redis.datasource.search.CreateArgs; import io.quarkus.redis.datasource.search.Document; import io.quarkus.redis.datasource.search.QueryArgs; import io.quarkus.redis.datasource.search.SearchQueryResponse; import io.smallrye.mutiny.Uni; import io.vertx.mutiny.redis.client.Command; import io.vertx.mutiny.redis.client.Request; -import io.vertx.mutiny.redis.client.Response; public class RedisEmbeddingStore implements EmbeddingStore { @@ -67,19 +65,12 @@ private void createIndexIfDoesNotExist() { } }).await().indefinitely(); if (!indexes.contains(schema.getIndexName())) { - // TODO: rewrite to use the typesafe data source API - Request request = Request.cmd(Command.FT_CREATE) - .arg(schema.getIndexName()) - .arg("ON") - .arg("JSON") - .arg("PREFIX") - .arg("1") - .arg(schema.getPrefix()) - .arg("SCHEMA"); - schema.defineFields(request); - LOG.debug( - "Creating index with command: " + request.toString().replaceAll("\r\n", " ")); - ds.getRedis().send(request).await().indefinitely(); + CreateArgs indexCreateArgs = new CreateArgs() + .onJson() + .prefixes(schema.getPrefix()); + schema.defineFields(indexCreateArgs); + LOG.debug("Creating Redis index " + schema.getIndexName()); + ds.search().ftCreate(schema.getIndexName(), indexCreateArgs).await().indefinitely(); } else { LOG.debug("Index in Redis already exists: " + schema.getIndexName()); } @@ -160,8 +151,6 @@ public List> findRelevant(Embedding referenceEmbeddi .param("BLOB", referenceEmbedding.vector()); Uni search = ds.search() .ftSearch(schema.getIndexName(), query, args); - System.out.println("ARGS = " + args.toArgs()); - System.out.println("query = " + query); SearchQueryResponse response = search.await().indefinitely(); return response.documents().stream().map(this::extractEmbeddingMatch) .filter(embeddingMatch -> embeddingMatch.score() >= minScore) diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/VectorAlgorithm.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/VectorAlgorithm.java deleted file mode 100644 index b238e5ba4..000000000 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/VectorAlgorithm.java +++ /dev/null @@ -1,6 +0,0 @@ -package io.quarkiverse.langchain4j.redis; - -public enum VectorAlgorithm { - FLAT, - HNSW -} diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreConfig.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreConfig.java index e31c44b00..b4bfa0cc4 100644 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreConfig.java +++ b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreConfig.java @@ -5,8 +5,8 @@ import java.util.List; import java.util.Optional; -import io.quarkiverse.langchain4j.redis.MetricType; -import io.quarkiverse.langchain4j.redis.VectorAlgorithm; +import io.quarkus.redis.datasource.search.DistanceMetric; +import io.quarkus.redis.datasource.search.VectorAlgorithm; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; @@ -52,7 +52,7 @@ public interface RedisEmbeddingStoreConfig { * Metric used to compute the distance between two vectors. */ @WithDefault("COSINE") - MetricType metricType(); + DistanceMetric distanceMetric(); /** * Name of the key that will be used to store the embedding vector. diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreRecorder.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreRecorder.java index cc7287a8f..27d3d6300 100644 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreRecorder.java +++ b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisEmbeddingStoreRecorder.java @@ -40,7 +40,7 @@ public RedisEmbeddingStore apply(SyntheticCreationalContext .metadataFields(config.metadataFields().orElse(Collections.emptyList())) .vectorAlgorithm(config.vectorAlgorithm()) .dimension(config.dimension()) - .metricType(config.metricType()) + .metricType(config.distanceMetric()) .build(); builder.schema(schema); diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisSchema.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisSchema.java index 98f61bf6b..9154981fd 100644 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisSchema.java +++ b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/runtime/RedisSchema.java @@ -2,9 +2,12 @@ import java.util.List; -import io.quarkiverse.langchain4j.redis.MetricType; -import io.quarkiverse.langchain4j.redis.VectorAlgorithm; -import io.vertx.mutiny.redis.client.Request; +import io.quarkus.redis.datasource.search.CreateArgs; +import io.quarkus.redis.datasource.search.DistanceMetric; +import io.quarkus.redis.datasource.search.FieldOptions; +import io.quarkus.redis.datasource.search.FieldType; +import io.quarkus.redis.datasource.search.VectorAlgorithm; +import io.quarkus.redis.datasource.search.VectorType; public class RedisSchema { @@ -15,7 +18,7 @@ public class RedisSchema { private List metadataFields; private VectorAlgorithm vectorAlgorithm; private Long dimension; - private MetricType metricType; + private DistanceMetric distanceMetric; private static final String JSON_PATH_PREFIX = "$."; public RedisSchema(String indexName, @@ -25,7 +28,7 @@ public RedisSchema(String indexName, List metadataFields, VectorAlgorithm vectorAlgorithm, Long dimension, - MetricType metricType) { + DistanceMetric distanceMetric) { this.indexName = indexName; this.prefix = prefix; this.vectorFieldName = vectorFieldName; @@ -33,7 +36,7 @@ public RedisSchema(String indexName, this.metadataFields = metadataFields; this.vectorAlgorithm = vectorAlgorithm; this.dimension = dimension; - this.metricType = metricType; + this.distanceMetric = distanceMetric; } public String getIndexName() { @@ -64,51 +67,34 @@ public Long getDimension() { return dimension; } - public MetricType getMetricType() { - return metricType; + public DistanceMetric getDistanceMetric() { + return distanceMetric; } - public void defineFields(Request args) { + public void defineFields(CreateArgs args) { defineTextField(args); defineVectorField(args); defineMetadataFields(args); } - private void defineMetadataFields(Request args) { + private void defineMetadataFields(CreateArgs args) { for (String metadataField : metadataFields) { - args.arg(JSON_PATH_PREFIX + metadataField); - args.arg("AS"); - args.arg(metadataField); - args.arg("TEXT"); - args.arg("WEIGHT"); - args.arg("1.0"); + args.indexedField(JSON_PATH_PREFIX + metadataField, metadataField, FieldType.TEXT, new FieldOptions().weight(1.0)); } } - private void defineTextField(Request args) { - args.arg(JSON_PATH_PREFIX + scalarFieldName); - args.arg("AS"); - args.arg(scalarFieldName); - args.arg("TEXT"); - args.arg("WEIGHT"); - args.arg("1.0"); + private void defineTextField(CreateArgs args) { + args.indexedField(JSON_PATH_PREFIX + scalarFieldName, scalarFieldName, FieldType.TEXT, new FieldOptions().weight(1.0)); } - private void defineVectorField(Request args) { - args.arg(JSON_PATH_PREFIX + vectorFieldName); - args.arg("AS"); - args.arg(vectorFieldName); - args.arg("VECTOR"); - args.arg(vectorAlgorithm.name()); - args.arg("8"); - args.arg("DIM"); - args.arg(dimension); - args.arg("DISTANCE_METRIC"); - args.arg(metricType.name()); - args.arg("TYPE"); - args.arg("FLOAT32"); - args.arg("INITIAL_CAP"); - args.arg("5"); + private void defineVectorField(CreateArgs args) { + args.indexedField(JSON_PATH_PREFIX + vectorFieldName, + vectorFieldName, + FieldType.VECTOR, new FieldOptions() + .vectorAlgorithm(vectorAlgorithm) + .vectorType(VectorType.FLOAT32) + .dimension(dimension.intValue()) + .distanceMetric(distanceMetric)); } public static class Builder { @@ -119,7 +105,7 @@ public static class Builder { private List metadataFields; private VectorAlgorithm vectorAlgorithm; private Long dimension; - private MetricType metricType; + private DistanceMetric metricType; public Builder indexName(String indexName) { this.indexName = indexName; @@ -156,7 +142,7 @@ public Builder dimension(Long dimension) { return this; } - public Builder metricType(MetricType metricType) { + public Builder metricType(DistanceMetric metricType) { this.metricType = metricType; return this; }