Skip to content

Commit

Permalink
Use the DataSource API for index searching in Redis
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartisk committed Nov 14, 2023
1 parent 9c91db7 commit 36f6a58
Showing 1 changed file with 33 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.jboss.logging.Logger;

Expand All @@ -30,6 +29,9 @@
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
import io.quarkus.redis.datasource.json.ReactiveJsonCommands;
import io.quarkus.redis.datasource.keys.KeyScanArgs;
import io.quarkus.redis.datasource.search.Document;
import io.quarkus.redis.datasource.search.QueryArgs;
import io.quarkus.redis.datasource.search.SearchQueryResponse;
import io.smallrye.mutiny.Uni;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Request;
Expand Down Expand Up @@ -152,63 +154,30 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
double minScore) {
String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
String query = format(queryTemplate, maxResults, schema.getVectorFieldName(), SCORE_FIELD_NAME);
// TODO: rewrite to the data source api, but we need a new
// method QueryArgs.param(String, byte[]) to get it working

// QueryArgs args = new QueryArgs()
// .sortByAscending(SCORE_FIELD_NAME)
// .param("DIALECT", "2")
// .param("BLOB", toByteArray(referenceEmbedding.vector()));
// Uni<SearchQueryResponse> search = ds.search()
// .ftSearch(schema.getIndexName(), query, args);
// SearchQueryResponse response = search.await().indefinitely();
Request request = Request.cmd(Command.FT_SEARCH)
.arg(schema.getIndexName())
.arg(query)
.arg("PARAMS")
.arg("2")
.arg("BLOB")
.arg(toByteArray(referenceEmbedding.vector()))
.arg("DIALECT")
.arg("2");
Response response = ds.getRedis().send(request).await().indefinitely();
return StreamSupport.stream(response.get("results").spliterator(), false)
.map(this::toEmbeddingMatch)
QueryArgs args = new QueryArgs()
.sortByAscending(SCORE_FIELD_NAME)
.param("DIALECT", "2")
.param("BLOB", referenceEmbedding.vector());
Uni<SearchQueryResponse> search = ds.search()
.ftSearch(schema.getIndexName(), query, args);
System.out.println("ARGS = " + args.toArgs());
System.out.println("query = " + query);
SearchQueryResponse response = search.await().indefinitely();
return response.documents().stream().map(this::extractEmbeddingMatch)
.filter(embeddingMatch -> embeddingMatch.score() >= minScore)
.collect(toList());
}

/**
* Deletes all keys with the prefix that is used by this embedding store.
*/
public void deleteAll() {
KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*");
Set<String> keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
if (!keysToDelete.isEmpty()) {
Request command = Request.cmd(Command.DEL);
keysToDelete.forEach(command::arg);
ds.getRedis().send(command).await().indefinitely();
LOG.debug("Deleted " + keysToDelete.size() + " keys");
}
}

public static byte[] toByteArray(float[] input) {
byte[] bytes = new byte[Float.BYTES * input.length];
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input);
return bytes;
}

private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {
String document = response.get(EXTRA_ATTRIBUTES).get("$").toString();
private EmbeddingMatch<TextSegment> extractEmbeddingMatch(Document document) {
try {
JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readTree(document);
JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER
.readTree(document.property("$").asString());
JsonNode embedded = jsonNode.get(schema.getScalarFieldName());
Embedding embedding = new Embedding(
Json.fromJson(jsonNode.get(schema.getVectorFieldName()).toString(), float[].class));
double score = (2 - response.get(EXTRA_ATTRIBUTES).get(SCORE_FIELD_NAME).toDouble()) / 2;
String id = response.get(ID).toString().substring(schema.getPrefix().length());
List<String> metadataFields = schema.getMetadataFields();
Map<String, String> metadata = metadataFields.stream()
double score = (2 - document.property(SCORE_FIELD_NAME).asDouble()) / 2;
String id = document.key().substring(schema.getPrefix().length());
Map<String, String> metadata = schema.getMetadataFields().stream()
.filter(jsonNode::has)
.collect(Collectors.toMap(metadataFieldName -> metadataFieldName,
(name) -> jsonNode.get(name).asText()));
Expand All @@ -220,6 +189,20 @@ private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {

}

/**
* Deletes all keys with the prefix that is used by this embedding store.
*/
public void deleteAll() {
KeyScanArgs args = new KeyScanArgs().match(schema.getPrefix() + "*");
Set<String> keysToDelete = ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
if (!keysToDelete.isEmpty()) {
Request command = Request.cmd(Command.DEL);
keysToDelete.forEach(command::arg);
ds.getRedis().send(command).await().indefinitely();
LOG.debug("Deleted " + keysToDelete.size() + " keys");
}
}

public static class Builder {

private ReactiveRedisDataSource redisClient;
Expand Down

0 comments on commit 36f6a58

Please sign in to comment.