Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make use of new stuff in the Redis DataSource API #16

Merged
merged 3 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
<maven.compiler.release>17</maven.compiler.release>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<quarkus.version>3.5.3</quarkus.version>
<quarkus.version>3.6.0</quarkus.version>
<langchain4j.version>0.24.0</langchain4j.version>
<quarkus-poi.version>2.0.4</quarkus-poi.version>
<assertj.version>3.24.2</assertj.version>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.jboss.logging.Logger;

Expand All @@ -30,10 +27,13 @@
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
import io.quarkus.redis.datasource.json.ReactiveJsonCommands;
import io.quarkus.redis.datasource.keys.KeyScanArgs;
import io.quarkus.redis.datasource.search.CreateArgs;
import io.quarkus.redis.datasource.search.Document;
import io.quarkus.redis.datasource.search.QueryArgs;
import io.quarkus.redis.datasource.search.SearchQueryResponse;
import io.smallrye.mutiny.Uni;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Request;
import io.vertx.mutiny.redis.client.Response;

public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {

Expand Down Expand Up @@ -65,19 +65,12 @@ private void createIndexIfDoesNotExist() {
}
}).await().indefinitely();
if (!indexes.contains(schema.getIndexName())) {
// TODO: rewrite to use the typesafe data source API
Request request = Request.cmd(Command.FT_CREATE)
.arg(schema.getIndexName())
.arg("ON")
.arg("JSON")
.arg("PREFIX")
.arg("1")
.arg(schema.getPrefix())
.arg("SCHEMA");
schema.defineFields(request);
LOG.debug(
"Creating index with command: " + request.toString().replaceAll("\r\n", " "));
ds.getRedis().send(request).await().indefinitely();
CreateArgs indexCreateArgs = new CreateArgs()
.onJson()
.prefixes(schema.getPrefix());
schema.defineFields(indexCreateArgs);
LOG.debug("Creating Redis index " + schema.getIndexName());
ds.search().ftCreate(schema.getIndexName(), indexCreateArgs).await().indefinitely();
} else {
LOG.debug("Index in Redis already exists: " + schema.getIndexName());
}
Expand Down Expand Up @@ -152,63 +145,28 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
double minScore) {
String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
String query = format(queryTemplate, maxResults, schema.getVectorFieldName(), SCORE_FIELD_NAME);
// TODO: rewrite to the data source api, but we need a new
// method QueryArgs.param(String, byte[]) to get it working

// QueryArgs args = new QueryArgs()
// .sortByAscending(SCORE_FIELD_NAME)
// .param("DIALECT", "2")
// .param("BLOB", toByteArray(referenceEmbedding.vector()));
// Uni<SearchQueryResponse> search = ds.search()
// .ftSearch(schema.getIndexName(), query, args);
// SearchQueryResponse response = search.await().indefinitely();
Request request = Request.cmd(Command.FT_SEARCH)
.arg(schema.getIndexName())
.arg(query)
.arg("PARAMS")
.arg("2")
.arg("BLOB")
.arg(toByteArray(referenceEmbedding.vector()))
.arg("DIALECT")
.arg("2");
Response response = ds.getRedis().send(request).await().indefinitely();
return StreamSupport.stream(response.get("results").spliterator(), false)
.map(this::toEmbeddingMatch)
QueryArgs args = new QueryArgs()
.sortByAscending(SCORE_FIELD_NAME)
.param("DIALECT", "2")
.param("BLOB", referenceEmbedding.vector());
Uni<SearchQueryResponse> search = ds.search()
.ftSearch(schema.getIndexName(), query, args);
SearchQueryResponse response = search.await().indefinitely();
return response.documents().stream().map(this::extractEmbeddingMatch)
.filter(embeddingMatch -> embeddingMatch.score() >= minScore)
.collect(toList());
}

/**
* Deletes all keys with the prefix that is used by this embedding store.
*/
public void deleteAll() {
KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*");
Set<String> keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
if (!keysToDelete.isEmpty()) {
Request command = Request.cmd(Command.DEL);
keysToDelete.forEach(command::arg);
ds.getRedis().send(command).await().indefinitely();
LOG.debug("Deleted " + keysToDelete.size() + " keys");
}
}

public static byte[] toByteArray(float[] input) {
byte[] bytes = new byte[Float.BYTES * input.length];
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input);
return bytes;
}

private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {
String document = response.get(EXTRA_ATTRIBUTES).get("$").toString();
private EmbeddingMatch<TextSegment> extractEmbeddingMatch(Document document) {
try {
JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readTree(document);
JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER
.readTree(document.property("$").asString());
JsonNode embedded = jsonNode.get(schema.getScalarFieldName());
Embedding embedding = new Embedding(
Json.fromJson(jsonNode.get(schema.getVectorFieldName()).toString(), float[].class));
double score = (2 - response.get(EXTRA_ATTRIBUTES).get(SCORE_FIELD_NAME).toDouble()) / 2;
String id = response.get(ID).toString().substring(schema.getPrefix().length());
List<String> metadataFields = schema.getMetadataFields();
Map<String, String> metadata = metadataFields.stream()
double score = (2 - document.property(SCORE_FIELD_NAME).asDouble()) / 2;
String id = document.key().substring(schema.getPrefix().length());
Map<String, String> metadata = schema.getMetadataFields().stream()
.filter(jsonNode::has)
.collect(Collectors.toMap(metadataFieldName -> metadataFieldName,
(name) -> jsonNode.get(name).asText()));
Expand All @@ -220,6 +178,20 @@ private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {

}

/**
* Deletes all keys with the prefix that is used by this embedding store.
*/
public void deleteAll() {
KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*");
Set<String> keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
if (!keysToDelete.isEmpty()) {
Request command = Request.cmd(Command.DEL);
keysToDelete.forEach(command::arg);
ds.getRedis().send(command).await().indefinitely();
LOG.debug("Deleted " + keysToDelete.size() + " keys");
}
}

public static class Builder {

private ReactiveRedisDataSource redisClient;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import java.util.List;
import java.util.Optional;

import io.quarkiverse.langchain4j.redis.MetricType;
import io.quarkiverse.langchain4j.redis.VectorAlgorithm;
import io.quarkus.redis.datasource.search.DistanceMetric;
import io.quarkus.redis.datasource.search.VectorAlgorithm;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;
import io.smallrye.config.WithDefault;
Expand Down Expand Up @@ -51,7 +51,7 @@ public interface RedisEmbeddingStoreConfig {
* Metric used to compute the distance between two vectors.
*/
@WithDefault("COSINE")
MetricType metricType();
DistanceMetric distanceMetric();

/**
* Name of the key that will be used to store the embedding vector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public RedisEmbeddingStore apply(SyntheticCreationalContext<RedisEmbeddingStore>
.metadataFields(config.metadataFields().orElse(Collections.emptyList()))
.vectorAlgorithm(config.vectorAlgorithm())
.dimension(config.dimension())
.metricType(config.metricType())
.metricType(config.distanceMetric())
.build();
builder.schema(schema);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import java.util.List;

import io.quarkiverse.langchain4j.redis.MetricType;
import io.quarkiverse.langchain4j.redis.VectorAlgorithm;
import io.vertx.mutiny.redis.client.Request;
import io.quarkus.redis.datasource.search.CreateArgs;
import io.quarkus.redis.datasource.search.DistanceMetric;
import io.quarkus.redis.datasource.search.FieldOptions;
import io.quarkus.redis.datasource.search.FieldType;
import io.quarkus.redis.datasource.search.VectorAlgorithm;
import io.quarkus.redis.datasource.search.VectorType;

public class RedisSchema {

Expand All @@ -15,7 +18,7 @@ public class RedisSchema {
private List<String> metadataFields;
private VectorAlgorithm vectorAlgorithm;
private Long dimension;
private MetricType metricType;
private DistanceMetric distanceMetric;
private static final String JSON_PATH_PREFIX = "$.";

public RedisSchema(String indexName,
Expand All @@ -25,15 +28,15 @@ public RedisSchema(String indexName,
List<String> metadataFields,
VectorAlgorithm vectorAlgorithm,
Long dimension,
MetricType metricType) {
DistanceMetric distanceMetric) {
this.indexName = indexName;
this.prefix = prefix;
this.vectorFieldName = vectorFieldName;
this.scalarFieldName = scalarFieldName;
this.metadataFields = metadataFields;
this.vectorAlgorithm = vectorAlgorithm;
this.dimension = dimension;
this.metricType = metricType;
this.distanceMetric = distanceMetric;
}

public String getIndexName() {
Expand Down Expand Up @@ -64,51 +67,34 @@ public Long getDimension() {
return dimension;
}

public MetricType getMetricType() {
return metricType;
public DistanceMetric getDistanceMetric() {
return distanceMetric;
}

public void defineFields(Request args) {
public void defineFields(CreateArgs args) {
defineTextField(args);
defineVectorField(args);
defineMetadataFields(args);
}

private void defineMetadataFields(Request args) {
private void defineMetadataFields(CreateArgs args) {
for (String metadataField : metadataFields) {
args.arg(JSON_PATH_PREFIX + metadataField);
args.arg("AS");
args.arg(metadataField);
args.arg("TEXT");
args.arg("WEIGHT");
args.arg("1.0");
args.indexedField(JSON_PATH_PREFIX + metadataField, metadataField, FieldType.TEXT, new FieldOptions().weight(1.0));
}
}

private void defineTextField(Request args) {
args.arg(JSON_PATH_PREFIX + scalarFieldName);
args.arg("AS");
args.arg(scalarFieldName);
args.arg("TEXT");
args.arg("WEIGHT");
args.arg("1.0");
private void defineTextField(CreateArgs args) {
args.indexedField(JSON_PATH_PREFIX + scalarFieldName, scalarFieldName, FieldType.TEXT, new FieldOptions().weight(1.0));
}

private void defineVectorField(Request args) {
args.arg(JSON_PATH_PREFIX + vectorFieldName);
args.arg("AS");
args.arg(vectorFieldName);
args.arg("VECTOR");
args.arg(vectorAlgorithm.name());
args.arg("8");
args.arg("DIM");
args.arg(dimension);
args.arg("DISTANCE_METRIC");
args.arg(metricType.name());
args.arg("TYPE");
args.arg("FLOAT32");
args.arg("INITIAL_CAP");
args.arg("5");
private void defineVectorField(CreateArgs args) {
args.indexedField(JSON_PATH_PREFIX + vectorFieldName,
vectorFieldName,
FieldType.VECTOR, new FieldOptions()
.vectorAlgorithm(vectorAlgorithm)
.vectorType(VectorType.FLOAT32)
.dimension(dimension.intValue())
.distanceMetric(distanceMetric));
}

public static class Builder {
Expand All @@ -119,7 +105,7 @@ public static class Builder {
private List<String> metadataFields;
private VectorAlgorithm vectorAlgorithm;
private Long dimension;
private MetricType metricType;
private DistanceMetric metricType;

public Builder indexName(String indexName) {
this.indexName = indexName;
Expand Down Expand Up @@ -160,7 +146,7 @@ public Builder dimension(Long dimension) {
return this;
}

public Builder metricType(MetricType metricType) {
public Builder metricType(DistanceMetric metricType) {
this.metricType = metricType;
return this;
}
Expand Down
Loading