From e4a4b0205efe88829246c8c5a4668afb912cf9e0 Mon Sep 17 00:00:00 2001 From: github-actions <> Date: Tue, 26 Sep 2023 10:04:07 +0000 Subject: [PATCH] Google Java Format --- .../lib/endpoint/impl/PostgresEndpoint.java | 1 - .../lib/index/client/impl/PostgresClient.java | 6 +- .../PostgresClientRepository.java | 160 +++++++++--------- 3 files changed, 84 insertions(+), 83 deletions(-) diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java index 8393e2954..bd9510e8e 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java @@ -214,7 +214,6 @@ public String getSearchQuery() { return searchQuery; } - public PostgresLanguage getPostgresLanguage() { return postgresLanguage; } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java index fe7afc78b..810c99626 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java @@ -284,9 +284,9 @@ public EdgeChain> queryRRF(PostgresEndpoint postgre try { List wordEmbeddingsList = new ArrayList<>(); List> embeddings = - postgresEndpoint.getWordEmbeddingsList().stream() - .map(WordEmbeddings::getValues) - .toList(); + postgresEndpoint.getWordEmbeddingsList().stream() + .map(WordEmbeddings::getValues) + .toList(); List> rows = this.repository.queryRRF( diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java index 2217bd56e..7ff270b86 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java @@ -241,107 +241,110 @@ public List> query( } public List> queryRRF( - String tableName, - String namespace, - String metadataTableName, - List> values, - RRFWeight textWeight, - RRFWeight similarityWeight, - RRFWeight dateWeight, - String searchQuery, - PostgresLanguage language, - int probes, - PostgresDistanceMetric metric, - int topK, - OrderRRFBy orderRRFBy) { + String tableName, + String namespace, + String metadataTableName, + List> values, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + String searchQuery, + PostgresLanguage language, + int probes, + PostgresDistanceMetric metric, + int topK, + OrderRRFBy orderRRFBy) { jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); StringBuilder query = new StringBuilder(); - for(int i = 0; i < values.size(); i++) { + for (int i = 0; i < values.size(); i++) { String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i))); - query .append("(") - .append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", - textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", - similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight())) - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", - dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) - .append("FROM ( ") - .append( - "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," - + " svtm.document_date, svtm.metadata, ") - .append( - String.format( - "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", - language.getValue(), searchQuery)); + query + .append("(") + .append( + "SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", + textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", + similarityWeight.getBaseWeight().getValue(), + similarityWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", + dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) + .append("FROM ( ") + .append( + "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," + + " svtm.document_date, svtm.metadata, ") + .append( + String.format( + "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", + language.getValue(), searchQuery)); switch (metric) { case COSINE -> query.append( - String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); + String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); case IP -> query.append( - String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); + String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings)); default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); } query - .append("CASE ") - .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling - .append( - "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" - + " svtm.document_date) ") - .append("END AS date_rank ") - .append("FROM ") - .append( - String.format( - "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE" - + " namespace = '%s'", - tableName, namespace)); + .append("CASE ") + .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling + .append( + "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" + + " svtm.document_date) ") + .append("END AS date_rank ") + .append("FROM ") + .append( + String.format( + "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s" + + " WHERE namespace = '%s'", + tableName, namespace)); switch (metric) { case COSINE -> query - .append(" ORDER BY embedding <=> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); case IP -> query - .append(" ORDER BY embedding <#> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); case L2 -> query - .append(" ORDER BY embedding <-> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); default -> throw new IllegalArgumentException("Invalid metric: " + metric); } query - .append(")") - .append(" sv ") - .append("JOIN ") - .append(tableName.concat("_join_").concat(metadataTableName)) - .append(" jtm ON sv.id = jtm.id ") - .append("JOIN ") - .append(tableName.concat("_").concat(metadataTableName)) - .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") - .append(") subquery "); + .append(")") + .append(" sv ") + .append("JOIN ") + .append(tableName.concat("_join_").concat(metadataTableName)) + .append(" jtm ON sv.id = jtm.id ") + .append("JOIN ") + .append(tableName.concat("_").concat(metadataTableName)) + .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") + .append(") subquery "); switch (orderRRFBy) { case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC"); @@ -355,12 +358,11 @@ public List> queryRRF( if (i < values.size() - 1) { query.append(" UNION ALL ").append("\n"); } - } if (values.size() > 1) { return jdbcTemplate.queryForList( - String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query)); + String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query)); } else { return jdbcTemplate.queryForList(query.toString()); }