From 36f6a58cdfe473a281ce7f1e2fece4b981b9cd39 Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Tue, 7 Nov 2023 12:37:37 +0100 Subject: [PATCH] Use the DataSource API for index searching in Redis --- .../redis/RedisEmbeddingStore.java | 83 ++++++++----------- 1 file changed, 33 insertions(+), 50 deletions(-) diff --git a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java index 375acc6ce..ab26a65d5 100644 --- a/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java +++ b/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java @@ -12,7 +12,6 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import org.jboss.logging.Logger; @@ -30,6 +29,9 @@ import io.quarkus.redis.datasource.ReactiveRedisDataSource; import io.quarkus.redis.datasource.json.ReactiveJsonCommands; import io.quarkus.redis.datasource.keys.KeyScanArgs; +import io.quarkus.redis.datasource.search.Document; +import io.quarkus.redis.datasource.search.QueryArgs; +import io.quarkus.redis.datasource.search.SearchQueryResponse; import io.smallrye.mutiny.Uni; import io.vertx.mutiny.redis.client.Command; import io.vertx.mutiny.redis.client.Request; @@ -152,63 +154,30 @@ public List> findRelevant(Embedding referenceEmbeddi double minScore) { String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]"; String query = format(queryTemplate, maxResults, schema.getVectorFieldName(), SCORE_FIELD_NAME); - // TODO: rewrite to the data source api, but we need a new - // method QueryArgs.param(String, byte[]) to get it working - - // QueryArgs args = new QueryArgs() - // .sortByAscending(SCORE_FIELD_NAME) - // .param("DIALECT", "2") - // .param("BLOB", toByteArray(referenceEmbedding.vector())); - // Uni search = ds.search() - // .ftSearch(schema.getIndexName(), query, args); - // SearchQueryResponse response = search.await().indefinitely(); - Request request = Request.cmd(Command.FT_SEARCH) - .arg(schema.getIndexName()) - .arg(query) - .arg("PARAMS") - .arg("2") - .arg("BLOB") - .arg(toByteArray(referenceEmbedding.vector())) - .arg("DIALECT") - .arg("2"); - Response response = ds.getRedis().send(request).await().indefinitely(); - return StreamSupport.stream(response.get("results").spliterator(), false) - .map(this::toEmbeddingMatch) + QueryArgs args = new QueryArgs() + .sortByAscending(SCORE_FIELD_NAME) + .param("DIALECT", "2") + .param("BLOB", referenceEmbedding.vector()); + Uni search = ds.search() + .ftSearch(schema.getIndexName(), query, args); + System.out.println("ARGS = " + args.toArgs()); + System.out.println("query = " + query); + SearchQueryResponse response = search.await().indefinitely(); + return response.documents().stream().map(this::extractEmbeddingMatch) .filter(embeddingMatch -> embeddingMatch.score() >= minScore) .collect(toList()); } - /** - * Deletes all keys with the prefix that is used by this embedding store. - */ - public void deleteAll() { - KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*"); - Set keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely(); - if (!keysToDelete.isEmpty()) { - Request command = Request.cmd(Command.DEL); - keysToDelete.forEach(command::arg); - ds.getRedis().send(command).await().indefinitely(); - LOG.debug("Deleted " + keysToDelete.size() + " keys"); - } - } - - public static byte[] toByteArray(float[] input) { - byte[] bytes = new byte[Float.BYTES * input.length]; - ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input); - return bytes; - } - - private EmbeddingMatch toEmbeddingMatch(Response response) { - String document = response.get(EXTRA_ATTRIBUTES).get("$").toString(); + private EmbeddingMatch extractEmbeddingMatch(Document document) { try { - JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readTree(document); + JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER + .readTree(document.property("$").asString()); JsonNode embedded = jsonNode.get(schema.getScalarFieldName()); Embedding embedding = new Embedding( Json.fromJson(jsonNode.get(schema.getVectorFieldName()).toString(), float[].class)); - double score = (2 - response.get(EXTRA_ATTRIBUTES).get(SCORE_FIELD_NAME).toDouble()) / 2; - String id = response.get(ID).toString().substring(schema.getPrefix().length()); - List metadataFields = schema.getMetadataFields(); - Map metadata = metadataFields.stream() + double score = (2 - document.property(SCORE_FIELD_NAME).asDouble()) / 2; + String id = document.key().substring(schema.getPrefix().length()); + Map metadata = schema.getMetadataFields().stream() .filter(jsonNode::has) .collect(Collectors.toMap(metadataFieldName -> metadataFieldName, (name) -> jsonNode.get(name).asText())); @@ -220,6 +189,20 @@ private EmbeddingMatch toEmbeddingMatch(Response response) { } + /** + * Deletes all keys with the prefix that is used by this embedding store. + */ + public void deleteAll() { + KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*"); + Set keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely(); + if (!keysToDelete.isEmpty()) { + Request command = Request.cmd(Command.DEL); + keysToDelete.forEach(command::arg); + ds.getRedis().send(command).await().indefinitely(); + LOG.debug("Deleted " + keysToDelete.size() + " keys"); + } + } + public static class Builder { private ReactiveRedisDataSource redisClient;