Skip to content

Commit

Permalink
Google Java Format
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions committed Sep 26, 2023
1 parent d50d5ed commit e4a4b02
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ public String getSearchQuery() {
return searchQuery;
}


public PostgresLanguage getPostgresLanguage() {
return postgresLanguage;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ public EdgeChain<List<PostgresWordEmbeddings>> queryRRF(PostgresEndpoint postgre
try {
List<PostgresWordEmbeddings> wordEmbeddingsList = new ArrayList<>();
List<List<Float>> embeddings =
postgresEndpoint.getWordEmbeddingsList().stream()
.map(WordEmbeddings::getValues)
.toList();
postgresEndpoint.getWordEmbeddingsList().stream()
.map(WordEmbeddings::getValues)
.toList();

List<Map<String, Object>> rows =
this.repository.queryRRF(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,107 +241,110 @@ public List<Map<String, Object>> query(
}

public List<Map<String, Object>> queryRRF(
String tableName,
String namespace,
String metadataTableName,
List<List<Float>> values,
RRFWeight textWeight,
RRFWeight similarityWeight,
RRFWeight dateWeight,
String searchQuery,
PostgresLanguage language,
int probes,
PostgresDistanceMetric metric,
int topK,
OrderRRFBy orderRRFBy) {
String tableName,
String namespace,
String metadataTableName,
List<List<Float>> values,
RRFWeight textWeight,
RRFWeight similarityWeight,
RRFWeight dateWeight,
String searchQuery,
PostgresLanguage language,
int probes,
PostgresDistanceMetric metric,
int topK,
OrderRRFBy orderRRFBy) {

jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes));

StringBuilder query = new StringBuilder();

for(int i = 0; i < values.size(); i++) {
for (int i = 0; i < values.size(); i++) {
String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i)));

query .append("(")
.append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n")
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n",
textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n",
similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n",
dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight()))
.append("FROM ( ")
.append(
"SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp,"
+ " svtm.document_date, svtm.metadata, ")
.append(
String.format(
"ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ",
language.getValue(), searchQuery));
query
.append("(")
.append(
"SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n")
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n",
textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n",
similarityWeight.getBaseWeight().getValue(),
similarityWeight.getFineTuneWeight()))
.append(
String.format(
"%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n",
dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight()))
.append("FROM ( ")
.append(
"SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp,"
+ " svtm.document_date, svtm.metadata, ")
.append(
String.format(
"ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ",
language.getValue(), searchQuery));

switch (metric) {
case COSINE -> query.append(
String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings));
String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings));
case IP -> query.append(
String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings));
String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings));
case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings));
default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric);
}

query
.append("CASE ")
.append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling
.append(
"ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM"
+ " svtm.document_date) ")
.append("END AS date_rank ")
.append("FROM ")
.append(
String.format(
"(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE"
+ " namespace = '%s'",
tableName, namespace));
.append("CASE ")
.append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling
.append(
"ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM"
+ " svtm.document_date) ")
.append("END AS date_rank ")
.append("FROM ")
.append(
String.format(
"(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s"
+ " WHERE namespace = '%s'",
tableName, namespace));

switch (metric) {
case COSINE -> query
.append(" ORDER BY embedding <=> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
.append(" ORDER BY embedding <=> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
case IP -> query
.append(" ORDER BY embedding <#> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
.append(" ORDER BY embedding <#> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
case L2 -> query
.append(" ORDER BY embedding <-> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
.append(" ORDER BY embedding <-> ")
.append("'")
.append(embeddings)
.append("'")
.append(" LIMIT ")
.append(topK);
default -> throw new IllegalArgumentException("Invalid metric: " + metric);
}
query
.append(")")
.append(" sv ")
.append("JOIN ")
.append(tableName.concat("_join_").concat(metadataTableName))
.append(" jtm ON sv.id = jtm.id ")
.append("JOIN ")
.append(tableName.concat("_").concat(metadataTableName))
.append(" svtm ON jtm.metadata_id = svtm.metadata_id ")
.append(") subquery ");
.append(")")
.append(" sv ")
.append("JOIN ")
.append(tableName.concat("_join_").concat(metadataTableName))
.append(" jtm ON sv.id = jtm.id ")
.append("JOIN ")
.append(tableName.concat("_").concat(metadataTableName))
.append(" svtm ON jtm.metadata_id = svtm.metadata_id ")
.append(") subquery ");

switch (orderRRFBy) {
case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC");
Expand All @@ -355,12 +358,11 @@ public List<Map<String, Object>> queryRRF(
if (i < values.size() - 1) {
query.append(" UNION ALL ").append("\n");
}

}

if (values.size() > 1) {
return jdbcTemplate.queryForList(
String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query));
String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query));
} else {
return jdbcTemplate.queryForList(query.toString());
}
Expand Down

0 comments on commit e4a4b02

Please sign in to comment.