Skip to content

Commit

Permalink
Query RRF With 'n' Embeddings (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
EmadHanif01 authored Sep 26, 2023
1 parent 21bf6d4 commit d50d5ed
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ public class PostgresEndpoint extends Endpoint {
private String documentDate;

/** RRF * */
private int upperLimit;

private RRFWeight textWeight;

private RRFWeight similarityWeight;
Expand Down Expand Up @@ -216,9 +214,6 @@ public String getSearchQuery() {
return searchQuery;
}

public int getUpperLimit() {
return upperLimit;
}

public PostgresLanguage getPostgresLanguage() {
return postgresLanguage;
Expand Down Expand Up @@ -321,7 +316,7 @@ public Observable<List<PostgresWordEmbeddings>> query(

public Observable<List<PostgresWordEmbeddings>> queryRRF(
String metadataTable,
WordEmbeddings wordEmbedding,
List<WordEmbeddings> wordEmbeddingsList,
RRFWeight textWeight,
RRFWeight similarityWeight,
RRFWeight dateWeight,
Expand All @@ -330,10 +325,9 @@ public Observable<List<PostgresWordEmbeddings>> queryRRF(
PostgresLanguage postgresLanguage,
int probes,
PostgresDistanceMetric metric,
int upperLimit,
int topK) {
this.metadataTableNames = List.of(metadataTable);
this.wordEmbedding = wordEmbedding;
this.wordEmbeddingsList = wordEmbeddingsList;
this.textWeight = textWeight;
this.similarityWeight = similarityWeight;
this.dateWeight = dateWeight;
Expand All @@ -342,7 +336,6 @@ public Observable<List<PostgresWordEmbeddings>> queryRRF(
this.postgresLanguage = postgresLanguage;
this.probes = probes;
this.metric = metric;
this.upperLimit = upperLimit;
this.topK = topK;
return Observable.fromSingle(this.postgresService.queryRRF(this));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,20 +283,24 @@ public EdgeChain<List<PostgresWordEmbeddings>> queryRRF(PostgresEndpoint postgre
emitter -> {
try {
List<PostgresWordEmbeddings> wordEmbeddingsList = new ArrayList<>();
List<List<Float>> embeddings =
postgresEndpoint.getWordEmbeddingsList().stream()
.map(WordEmbeddings::getValues)
.toList();

List<Map<String, Object>> rows =
this.repository.queryRRF(
postgresEndpoint.getTableName(),
getNamespace(postgresEndpoint),
postgresEndpoint.getMetadataTableNames().get(0),
postgresEndpoint.getWordEmbedding().getValues(),
embeddings,
postgresEndpoint.getTextWeight(),
postgresEndpoint.getSimilarityWeight(),
postgresEndpoint.getDateWeight(),
postgresEndpoint.getSearchQuery(),
postgresEndpoint.getPostgresLanguage(),
postgresEndpoint.getProbes(),
postgresEndpoint.getMetric(),
postgresEndpoint.getUpperLimit(),
postgresEndpoint.getTopK(),
postgresEndpoint.getOrderRRFBy());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,117 +241,129 @@ public List<Map<String, Object>> query(
}

public List<Map<String, Object>> queryRRF(
String tableName,
String namespace,
String metadataTableName,
List<Float> values,
RRFWeight textWeight,
RRFWeight similarityWeight,
RRFWeight dateWeight,
String searchQuery,
PostgresLanguage language,
int probes,
PostgresDistanceMetric metric,
int upperLimit,
int topK,
OrderRRFBy orderRRFBy) {
String tableName,
String namespace,
String metadataTableName,
List<List<Float>> values,
RRFWeight textWeight,
RRFWeight similarityWeight,
RRFWeight dateWeight,
String searchQuery,
PostgresLanguage language,
int probes,
PostgresDistanceMetric metric,
int topK,
OrderRRFBy orderRRFBy) {

jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes));

String embeddings = Arrays.toString(FloatUtils.toFloatArray(values));

StringBuilder query = new StringBuilder();
query
.append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n")
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n",
textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n",
similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n",
dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight()))
.append("FROM ( ")
.append(
"SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp,"
+ " svtm.document_date, svtm.metadata, ")
.append(
String.format(
"ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ",
language.getValue(), searchQuery));

switch (metric) {
case COSINE -> query.append(
String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings));
case IP -> query.append(
String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings));
case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings));
default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric);
}

query
.append("CASE ")
.append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling
.append(
"ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM"
+ " svtm.document_date) ")
.append("END AS date_rank ")
.append("FROM ")
.append(
String.format(
"(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE"
+ " namespace = '%s'",
tableName, namespace));

switch (metric) {
case COSINE -> query
.append(" ORDER BY embedding <=> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(upperLimit);
case IP -> query
.append(" ORDER BY embedding <#> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(upperLimit);
case L2 -> query
.append(" ORDER BY embedding <-> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(upperLimit);
default -> throw new IllegalArgumentException("Invalid metric: " + metric);
}
query
.append(")")
.append(" sv ")
.append("JOIN ")
.append(tableName.concat("_join_").concat(metadataTableName))
.append(" jtm ON sv.id = jtm.id ")
.append("JOIN ")
.append(tableName.concat("_").concat(metadataTableName))
.append(" svtm ON jtm.metadata_id = svtm.metadata_id ")
.append(") subquery ");

switch (orderRRFBy) {
case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC");
case SIMILARITY -> query.append("ORDER BY similarity DESC, rrf_score DESC");
case DATE_RANK -> query.append("ORDER BY date_rank DESC, rrf_score DESC");
case DEFAULT -> query.append("ORDER BY rrf_score DESC");
default -> throw new IllegalArgumentException("Invalid orderRRFBy value");
for(int i = 0; i < values.size(); i++) {
String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i)));

query .append("(")
.append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n")
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n",
textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n",
similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n",
dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight()))
.append("FROM ( ")
.append(
"SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp,"
+ " svtm.document_date, svtm.metadata, ")
.append(
String.format(
"ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ",
language.getValue(), searchQuery));

switch (metric) {
case COSINE -> query.append(
String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings));
case IP -> query.append(
String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings));
case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings));
default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric);
}

query
.append("CASE ")
.append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling
.append(
"ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM"
+ " svtm.document_date) ")
.append("END AS date_rank ")
.append("FROM ")
.append(
String.format(
"(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE"
+ " namespace = '%s'",
tableName, namespace));

switch (metric) {
case COSINE -> query
.append(" ORDER BY embedding <=> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
case IP -> query
.append(" ORDER BY embedding <#> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
case L2 -> query
.append(" ORDER BY embedding <-> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
default -> throw new IllegalArgumentException("Invalid metric: " + metric);
}
query
.append(")")
.append(" sv ")
.append("JOIN ")
.append(tableName.concat("_join_").concat(metadataTableName))
.append(" jtm ON sv.id = jtm.id ")
.append("JOIN ")
.append(tableName.concat("_").concat(metadataTableName))
.append(" svtm ON jtm.metadata_id = svtm.metadata_id ")
.append(") subquery ");

switch (orderRRFBy) {
case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC");
case SIMILARITY -> query.append("ORDER BY similarity DESC, rrf_score DESC");
case DATE_RANK -> query.append("ORDER BY date_rank DESC, rrf_score DESC");
case DEFAULT -> query.append("ORDER BY rrf_score DESC");
default -> throw new IllegalArgumentException("Invalid orderRRFBy value");
}

query.append(" LIMIT ").append(topK).append(")");
if (i < values.size() - 1) {
query.append(" UNION ALL ").append("\n");
}

}

query.append(" LIMIT ").append(topK).append(";");
return jdbcTemplate.queryForList(query.toString());
if (values.size() > 1) {
return jdbcTemplate.queryForList(
String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query));
} else {
return jdbcTemplate.queryForList(query.toString());
}
}

@Transactional(readOnly = true)
Expand Down

0 comments on commit d50d5ed

Please sign in to comment.