diff --git a/Examples/pinecone/PineconeExample.java b/Examples/pinecone/PineconeExample.java index 3563e84f9..bdc92396b 100644 --- a/Examples/pinecone/PineconeExample.java +++ b/Examples/pinecone/PineconeExample.java @@ -34,6 +34,7 @@ public class PineconeExample { private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID + private static final String PINECONE_AUTH_KEY = ""; private static final String PINECONE_QUERY_API = ""; private static final String PINECONE_UPSERT_API = ""; @@ -41,6 +42,7 @@ public class PineconeExample { private static OpenAiEndpoint ada002Embedding; private static OpenAiEndpoint gpt3Endpoint; + private static OpenAiEndpoint gpt3StreamEndpoint; private static PineconeEndpoint upsertPineconeEndpoint; private static PineconeEndpoint queryPineconeEndpoint; @@ -64,7 +66,7 @@ public static void main(String[] args) { // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port", ""); + properties.setProperty("redis.port", "12285"); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); @@ -96,19 +98,40 @@ public static void main(String[] args) { 0.85, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + gpt3StreamEndpoint = + new OpenAiEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + true, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + upsertPineconeEndpoint = new PineconeEndpoint( PINECONE_UPSERT_API, PINECONE_AUTH_KEY, + "machine-learning", // Passing namespace; read more on Pinecone documentation. You can + // pass empty string new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); queryPineconeEndpoint = new PineconeEndpoint( - PINECONE_QUERY_API, PINECONE_AUTH_KEY, new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + PINECONE_QUERY_API, + PINECONE_AUTH_KEY, + "machine-learning", // Passing namespace; read more on Pinecone documentation. You can + // pass empty string + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); deletePineconeEndpoint = new PineconeEndpoint( - PINECONE_DELETE, PINECONE_AUTH_KEY, new FixedDelay(4, 5, TimeUnit.SECONDS)); + PINECONE_DELETE, + PINECONE_AUTH_KEY, + "machine-learning", // Passing namespace; read more on Pinecone documentation. You can + // pass empty string + new FixedDelay(4, 5, TimeUnit.SECONDS)); contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); @@ -158,13 +181,8 @@ public class PineconeController { // Namespace is optional (if not provided, it will be using Empty String "") @PostMapping("/pinecone/upsert") // /v1/examples/openai/upsert?namespace=machine-learning public void upsertPinecone(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - // Configure Pinecone - upsertPineconeEndpoint.setNamespace(namespace); - String[] arr = pdfReader.readByChunkSize(file, 512); /** @@ -181,13 +199,9 @@ public void upsertPinecone(ArkRequest arkRequest) throws IOException { @PostMapping(value = "/pinecone/query") public ArkResponse query(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - // Configure Pinecone - queryPineconeEndpoint.setNamespace(namespace); - // Step 1: Chain ==> Get Embeddings From Input & Then Query To Pinecone EdgeChain embeddingsChain = new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); @@ -216,15 +230,8 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure Pinecone - queryPineconeEndpoint.setNamespace(namespace); - - // Configure GPT3endpoint - gpt3Endpoint.setStream(stream); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -262,16 +269,16 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "PineconeChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -288,8 +295,13 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response - and once GPT chain indicates that it is finished, we will save the following into Postgres + and once GPT chain indicates that it is finished, we will save the following into Redis Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory */ @@ -318,8 +330,6 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { // Namespace is optional (if not provided, it will be using Empty String "") @DeleteMapping("/pinecone/deleteAll") public ArkResponse deletePinecone(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); - deletePineconeEndpoint.setNamespace(namespace); return new EdgeChain<>(deletePineconeEndpoint.deleteAll()).getArkResponse(); } diff --git a/Examples/postgresql/PostgreSQLExample.java b/Examples/postgresql/PostgreSQLExample.java index a966c1eef..0759abd07 100644 --- a/Examples/postgresql/PostgreSQLExample.java +++ b/Examples/postgresql/PostgreSQLExample.java @@ -9,6 +9,7 @@ import com.edgechain.lib.endpoint.impl.*; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -39,6 +40,7 @@ public class PostgreSQLExample { private static OpenAiEndpoint ada002Embedding; private static OpenAiEndpoint gpt3Endpoint; + private static OpenAiEndpoint gpt3StreamEndpoint; private static PostgresEndpoint postgresEndpoint; private static PostgreSQLHistoryContextEndpoint contextEndpoint; @@ -85,8 +87,21 @@ public static void main(String[] args) { 0.85, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + gpt3StreamEndpoint = + new OpenAiEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + true, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + + // Defining tablename and namespace... postgresEndpoint = - new PostgresEndpoint("spring_vectors", new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); + new PostgresEndpoint( + "pg_vectors", "machine-learning", new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -137,19 +152,22 @@ public class PostgreSQLController { */ @PostMapping("/postgres/upsert") public void upsert(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); String filename = arkRequest.getMultiPart("file").getSubmittedFileName(); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - postgresEndpoint.setNamespace(namespace); - String[] arr = pdfReader.readByChunkSize(file, 512); PostgresRetrieval retrieval = - new PostgresRetrieval(arr, ada002Embedding, postgresEndpoint, 1536, filename, arkRequest); + new PostgresRetrieval( + arr, + ada002Embedding, + postgresEndpoint, + 1536, + filename, + PostgresLanguage.ENGLISH, + arkRequest); - // retrieval.setBatchSize(100); // Modifying batchSize....(Default is 50) + // retrieval.setBatchSize(50); // Modifying batchSize....(Default is 30) // Getting ids from upsertion... Internally, it automatically parallelizes the operation... List ids = retrieval.upsert(); @@ -162,12 +180,9 @@ public void upsert(ArkRequest arkRequest) throws IOException { @PostMapping(value = "/postgres/query") public ArkResponse query(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - postgresEndpoint.setNamespace(namespace); - // Chain 1==> Get Embeddings From Input & Then Query To PostgreSQL EdgeChain embeddingsChain = new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); @@ -195,15 +210,9 @@ public ArkResponse chat(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure PostgresEndpoint - postgresEndpoint.setNamespace(namespace); - - gpt3Endpoint.setStream(stream); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -242,16 +251,16 @@ public ArkResponse chat(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -268,6 +277,12 @@ public ArkResponse chat(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion( + promptChain.get(), "PostgresChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response and once GPT chain indicates that it is finished, we will save the following into Postgres Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory diff --git a/Examples/redis/RedisExample.java b/Examples/redis/RedisExample.java index f3cb1bf9e..a9bfdce0c 100644 --- a/Examples/redis/RedisExample.java +++ b/Examples/redis/RedisExample.java @@ -37,6 +37,9 @@ public class RedisExample { private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID private static OpenAiEndpoint ada002Embedding; private static OpenAiEndpoint gpt3Endpoint; + + private static OpenAiEndpoint gpt3StreamEndpoint; + private static RedisEndpoint redisEndpoint; private static RedisHistoryContextEndpoint contextEndpoint; @@ -62,7 +65,7 @@ public static void main(String[] args) { // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port", "12885"); + properties.setProperty("redis.port", "12285"); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); @@ -88,8 +91,20 @@ public static void main(String[] args) { 0.85, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + gpt3StreamEndpoint = + new OpenAiEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + true, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + redisEndpoint = - new RedisEndpoint("vector_index", new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + new RedisEndpoint( + "vector_index", "machine-learning", new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); } @@ -124,20 +139,15 @@ public class RedisController { /********************** REDIS WITH OPENAI ****************************/ // Namespace is optional (if not provided, it will be using namespace will be "knowledge") + /** + * Both IndexName & namespace are integral for upsert & performing similarity search; If you are + * creating different namespace; recommended to use different index_name because filtering is + * done by index_name * + */ @PostMapping("/redis/upsert") // /v1/examples/openai/upsert?namespace=machine-learning public void upsert(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - /** - * Both IndexName & namespace are integral for upsert & performing similarity search; If you - * are creating different namespace; recommended to use different index_name because filtering - * is done by index_name * - */ - // Configure RedisEndpoint - redisEndpoint.setNamespace(namespace); - /** * We have two implementation for Read By Sentence: a) readBySentence(LangType, Your File) * EdgeChains sdk has predefined support to chunk by sentences w.r.t to 5 languages (english, @@ -166,12 +176,9 @@ public void upsert(ArkRequest arkRequest) throws IOException { @PostMapping(value = "/redis/similarity-search") public ArkResponse similaritySearch(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - redisEndpoint.setNamespace(namespace); - // Chain 1 ==> Generate Embeddings Using Ada002 EdgeChain ada002Chain = new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); @@ -186,12 +193,9 @@ public ArkResponse similaritySearch(ArkRequest arkRequest) { @PostMapping(value = "/redis/query") public ArkResponse queryRedis(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - redisEndpoint.setNamespace(namespace); - // Chain 1==> Get Embeddings From Input & Then Query To Redis EdgeChain embeddingsChain = new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); @@ -213,14 +217,8 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // configure GPT3Endpoint - gpt3Endpoint.setStream(stream); - - redisEndpoint.setNamespace(namespace); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -258,16 +256,16 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -284,6 +282,11 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response and once GPT chain indicates that it is finished, we will save the following into Redis Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory diff --git a/Examples/supabase-miniLM/SupabaseMiniLMExample.java b/Examples/supabase-miniLM/SupabaseMiniLMExample.java index 0bc2d8830..de0135ffa 100644 --- a/Examples/supabase-miniLM/SupabaseMiniLMExample.java +++ b/Examples/supabase-miniLM/SupabaseMiniLMExample.java @@ -7,6 +7,7 @@ import com.edgechain.lib.endpoint.impl.*; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -39,7 +40,9 @@ public class SupabaseMiniLMExample { private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID + private static OpenAiEndpoint gpt3Endpoint; + private static OpenAiEndpoint gpt3StreamEndpoint; private static PostgresEndpoint postgresEndpoint; private static PostgreSQLHistoryContextEndpoint contextEndpoint; @@ -59,6 +62,9 @@ public static void main(String[] args) { properties.setProperty("supabase.url", ""); properties.setProperty("supabase.annon.key", ""); + // For JWT decode + properties.setProperty("jwt.secret", ""); + // Adding Cors ==> You can configure multiple cors w.r.t your urls.; properties.setProperty("cors.origins", "http://localhost:4200"); @@ -68,12 +74,9 @@ public static void main(String[] args) { // For DB config properties.setProperty("postgres.db.host", ""); - properties.setProperty("postgres.db.username", "postgres"); + properties.setProperty("postgres.db.username", ""); properties.setProperty("postgres.db.password", ""); - // For JWT decode - properties.setProperty("jwt.secret", ""); - new SpringApplicationBuilder(SupabaseMiniLMExample.class).properties(properties).run(args); gpt3Endpoint = @@ -86,6 +89,17 @@ public static void main(String[] args) { 0.85, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + gpt3StreamEndpoint = + new OpenAiEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + true, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + // Creating MiniLM Endpoint // When endpoint.embeddings() is called; it will look for the model; if not available, it will // download on fly. @@ -96,7 +110,8 @@ public static void main(String[] args) { // Creating PostgresEndpoint ==> We create a new table because miniLM supports 384 dimensional // vectors; postgresEndpoint = - new PostgresEndpoint("minilm_vectors", new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); + new PostgresEndpoint( + "minilm_vectors", "minilm-ns", new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -149,19 +164,22 @@ public class SupabaseController { @PostMapping("/miniLM/upsert") @PreAuthorize("hasAnyAuthority('authenticated')") public void upsert(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); String filename = arkRequest.getMultiPart("file").getSubmittedFileName(); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - postgresEndpoint.setNamespace(namespace); - String[] arr = pdfReader.readByChunkSize(file, 512); PostgresRetrieval retrieval = - new PostgresRetrieval(arr, miniLMEndpoint, postgresEndpoint, 1536, filename, arkRequest); + new PostgresRetrieval( + arr, + miniLMEndpoint, + postgresEndpoint, + 384, + filename, + PostgresLanguage.ENGLISH, + arkRequest); - // retrieval.setBatchSize(100); // Modifying batchSize.... + // retrieval.setBatchSize(50); // Modifying batchSize.... // Getting ids from upsertion... Internally, it automatically parallelizes the operation... List ids = retrieval.upsert(); @@ -175,12 +193,9 @@ public void upsert(ArkRequest arkRequest) throws IOException { @PreAuthorize("hasAnyAuthority('authenticated')") public ArkResponse queryPostgres(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - postgresEndpoint.setNamespace(namespace); - // Chain 1==> Get Embeddings From Input using MiniLM & Then Query To PostgreSQL EdgeChain embeddingsChain = new EdgeChain<>(miniLMEndpoint.embeddings(query, arkRequest)); @@ -205,15 +220,9 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure PostgresEndpoint - postgresEndpoint.setNamespace(namespace); - - gpt3Endpoint.setStream(stream); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -252,17 +261,17 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion( - promptChain.get(), "MiniLMPostgresChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion( + promptChain.get(), "MiniLMPostgresChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -279,6 +288,12 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion( + promptChain.get(), "MiniLMPostgresChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response and once GPT chain indicates that it is finished, we will save the following into Postgres Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory diff --git a/Examples/wiki/WikiExample.java b/Examples/wiki/WikiExample.java index b5f169bc9..3bebb8dbd 100644 --- a/Examples/wiki/WikiExample.java +++ b/Examples/wiki/WikiExample.java @@ -27,11 +27,14 @@ @SpringBootApplication public class WikiExample { - private static final String OPENAI_AUTH_KEY = ""; + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID /* Step 3: Create OpenAiEndpoint to communicate with OpenAiServices; */ private static OpenAiEndpoint gpt3Endpoint; + + private static OpenAiEndpoint gpt3StreamEndpoint; + private static WikiEndpoint wikiEndpoint; // There is a 70% chance that file1 is executed; 30% chance file2 is executed.... @@ -66,7 +69,17 @@ public static void main(String[] args) { "gpt-3.5-turbo", "user", 0.7, - false, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + + gpt3StreamEndpoint = + new OpenAiEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.7, + true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); } @@ -85,20 +98,11 @@ public ArkResponse wikiSummary(ArkRequest arkRequest) { String query = arkRequest.getQueryParam("query"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure GPT4Endpoint - gpt3Endpoint.setStream(stream); - // Chain 1 ==> WikiChain EdgeChain wikiChain = new EdgeChain<>(wikiEndpoint.getPageContent(query)); // Chain 2 ===> Creating Prompt Chain & Return ChatCompletion EdgeChain promptChain = wikiChain.transform(this::fn); - - // Chain 3 ==> Pass Prompt to ChatCompletion API & Return ArkResponseObservable - EdgeChain openAiChain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "WikiChain", loader, arkRequest)); - /** * The best part is flexibility with just one method EdgeChainsSDK will return response either * in json or stream; The real magic happens here. Streaming happens only if your logic allows @@ -107,8 +111,25 @@ public ArkResponse wikiSummary(ArkRequest arkRequest) { // Note: When you call getArkResponse() or getArkStreamResponse() ==> Only then your streams // are executed... - if (stream) return openAiChain.getArkStreamResponse(); - else return openAiChain.getArkResponse(); + if (stream) { + + // Chain 3 ==> Pass Prompt to ChatCompletion API & Return ArkResponseObservable + EdgeChain openAiChain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion( + promptChain.get(), "WikiChain", loader, arkRequest)); + + return openAiChain.getArkStreamResponse(); + + } else { + + // Chain 3 ==> Pass Prompt to ChatCompletion API & Return ArkResponseObservable + EdgeChain openAiChain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "WikiChain", loader, arkRequest)); + + return openAiChain.getArkResponse(); + } } private String fn(WikiResponse wiki) { diff --git a/FlySpring/edgechain-app/.gitignore b/FlySpring/edgechain-app/.gitignore index 848775e32..3f969a6bf 100644 --- a/FlySpring/edgechain-app/.gitignore +++ b/FlySpring/edgechain-app/.gitignore @@ -45,3 +45,4 @@ build/ /src/main/java/com/edgechain/HydeExample.java /model/ +/src/main/java/com/edgechain/SupabaseMiniLMExample.java diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java index 5d9401c8d..8f89a7717 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java @@ -7,6 +7,7 @@ import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; import com.edgechain.lib.endpoint.impl.PostgresEndpoint; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Completable; @@ -30,6 +31,8 @@ public class PostgresRetrieval { private final String filename; + private final PostgresLanguage postgresLanguage; + private final ArkRequest arkRequest; private final PostgresEndpoint postgresEndpoint; @@ -47,12 +50,14 @@ public PostgresRetrieval( PostgresDistanceMetric metric, int lists, String filename, + PostgresLanguage postgresLanguage, ArkRequest arkRequest) { this.arr = arr; this.filename = filename; - this.arkRequest = arkRequest; this.postgresEndpoint = postgresEndpoint; this.embeddingEndpoint = embeddingEndpoint; + this.postgresLanguage = postgresLanguage; + this.arkRequest = arkRequest; this.dimensions = dimensions; this.metric = metric; @@ -72,10 +77,12 @@ public PostgresRetrieval( PostgresEndpoint postgresEndpoint, int dimensions, String filename, + PostgresLanguage postgresLanguage, ArkRequest arkRequest) { this.arr = arr; this.filename = filename; this.arkRequest = arkRequest; + this.postgresLanguage = postgresLanguage; this.postgresEndpoint = postgresEndpoint; this.embeddingEndpoint = embeddingEndpoint; @@ -129,15 +136,15 @@ private void upsertAndCollectIds( } private List executeBatchUpsert(List wordEmbeddingsList) { - return this.postgresEndpoint.upsert(wordEmbeddingsList, filename).stream() + return this.postgresEndpoint.upsert(wordEmbeddingsList, filename, postgresLanguage).stream() .map(StringResponse::getResponse) .collect(Collectors.toList()); } - public List insertMetadata() { + public List insertMetadata(String metadataTableName) { // Create Table... - this.postgresEndpoint.createMetadataTable(); + this.postgresEndpoint.createMetadataTable(metadataTableName); ConcurrentLinkedQueue uuidQueue = new ConcurrentLinkedQueue<>(); @@ -154,10 +161,11 @@ public List insertMetadata() { return new ArrayList<>(uuidQueue); } - public StringResponse insertOneMetadata(String metadata, String documentDate) { + public StringResponse insertOneMetadata( + String metadataTableName, String metadata, String documentDate) { // Create Table... - this.postgresEndpoint.createMetadataTable(); - return this.postgresEndpoint.insertMetadata(metadata, documentDate); + this.postgresEndpoint.createMetadataTable(metadataTableName); + return this.postgresEndpoint.insertMetadata(metadataTableName, metadata, documentDate); } private void insertMetadataAndCollectIds( diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java index 166bc7b6f..bae87f46f 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java @@ -27,13 +27,8 @@ private PostgresEndpoint getInstance() { @DeleteMapping("/deleteAll") public StringResponse deletePostgres(ArkRequest arkRequest) { - String table = arkRequest.getQueryParam("table"); String namespace = arkRequest.getQueryParam("namespace"); - - getInstance().setTableName(table); - getInstance().setNamespace(namespace); - - return getInstance().deleteAll(); + return getInstance().deleteAll(table, namespace); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java index bac2cb873..a26b5034e 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java @@ -81,16 +81,13 @@ public BgeSmallEndpoint(RetryPolicy retryPolicy, String modelUrl, String tokeniz @Override public Observable embeddings(String input, ArkRequest arkRequest) { - - final String str = input.replaceAll("'", ""); - - setRawText(str); + setRawText(input); if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); else this.callIdentifier = "URI wasn't provided"; return Observable.fromSingle( - bgeSmallService.embeddings(this).map(m -> new WordEmbeddings(str, m.getEmbedding()))); + bgeSmallService.embeddings(this).map(m -> new WordEmbeddings(input, m.getEmbedding()))); } private void downloadFile(String urlStr, String path) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java index 6c6648b95..8876cbfff 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java @@ -46,10 +46,7 @@ public MiniLMEndpoint(RetryPolicy retryPolicy, MiniLMModel miniLMModel) { @Override public Observable embeddings(String input, ArkRequest arkRequest) { - - final String str = input.replaceAll("'", ""); - - setRawText(str); + setRawText(input); if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); else this.callIdentifier = "URI wasn't provided"; @@ -59,6 +56,6 @@ public Observable embeddings(String input, ArkRequest arkRequest } return Observable.fromSingle( - miniLMService.embeddings(this).map(m -> new WordEmbeddings(str, m.getEmbedding()))); + miniLMService.embeddings(this).map(m -> new WordEmbeddings(input, m.getEmbedding()))); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java index 82e7ed15d..831a95636 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java @@ -300,11 +300,7 @@ public Observable chatCompletion( @Override public Observable embeddings(String input, ArkRequest arkRequest) { - // ?this.input = input; // set Input - - final String str = input.replaceAll("'", ""); - - setRawText(str); + setRawText(input); if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); else this.callIdentifier = "URI wasn't provided"; @@ -314,7 +310,7 @@ public Observable embeddings(String input, ArkRequest arkRequest .embeddings(this) .map( embeddingResponse -> - new WordEmbeddings(str, embeddingResponse.getData().get(0).getEmbedding()))); + new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); } private Observable chatCompletion(ArkRequest arkRequest) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java index e5c7a4bd8..1f573d236 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java @@ -27,10 +27,6 @@ public class PineconeEndpoint extends Endpoint { public PineconeEndpoint() {} - public PineconeEndpoint(String namespace) { - this.namespace = namespace; - } - public PineconeEndpoint(String url, String apiKey) { super(url, apiKey); } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java index 5d64f978e..6b1e3c960 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java @@ -3,7 +3,10 @@ import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; +import com.edgechain.lib.index.domain.RRFWeight; +import com.edgechain.lib.index.enums.OrderRRFBy; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.retrofit.PostgresService; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; @@ -44,6 +47,19 @@ public class PostgresEndpoint extends Endpoint { private List metadataList; private String documentDate; + /** RRF * */ + private int upperLimit; + + private RRFWeight textWeight; + + private RRFWeight similarityWeight; + private RRFWeight dateWeight; + + private OrderRRFBy orderRRFBy; + private String searchQuery; + + private PostgresLanguage postgresLanguage; + public PostgresEndpoint() {} public PostgresEndpoint(RetryPolicy retryPolicy) { @@ -54,6 +70,11 @@ public PostgresEndpoint(String tableName) { this.tableName = tableName; } + public PostgresEndpoint(String tableName, String namespace) { + this.tableName = tableName; + this.namespace = namespace; + } + public PostgresEndpoint(String tableName, List metadataTableNames) { this.tableName = tableName; this.metadataTableNames = metadataTableNames; @@ -64,6 +85,12 @@ public PostgresEndpoint(String tableName, RetryPolicy retryPolicy) { this.tableName = tableName; } + public PostgresEndpoint(String tableName, String namespace, RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.namespace = namespace; + } + public String getTableName() { return tableName; } @@ -166,6 +193,34 @@ public String getDocumentDate() { return documentDate; } + public RRFWeight getTextWeight() { + return textWeight; + } + + public RRFWeight getSimilarityWeight() { + return similarityWeight; + } + + public RRFWeight getDateWeight() { + return dateWeight; + } + + public OrderRRFBy getOrderRRFBy() { + return orderRRFBy; + } + + public String getSearchQuery() { + return searchQuery; + } + + public int getUpperLimit() { + return upperLimit; + } + + public PostgresLanguage getPostgresLanguage() { + return postgresLanguage; + } + public StringResponse upsert( WordEmbeddings wordEmbeddings, String filename, @@ -185,19 +240,24 @@ public StringResponse createTable(int dimensions, PostgresDistanceMetric metric, return this.postgresService.createTable(this).blockingGet(); } - public StringResponse createMetadataTable() { + public StringResponse createMetadataTable(String metadataTableName) { + this.metadataTableNames = List.of(metadataTableName); return this.postgresService.createMetadataTable(this).blockingGet(); } - public List upsert(List wordEmbeddingsList, String filename) { + public List upsert( + List wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) { this.wordEmbeddingsList = wordEmbeddingsList; this.filename = filename; + this.postgresLanguage = postgresLanguage; return this.postgresService.batchUpsert(this).blockingGet(); } - public StringResponse insertMetadata(String metadata, String documentDate) { + public StringResponse insertMetadata( + String metadataTableName, String metadata, String documentDate) { this.metadata = metadata; this.documentDate = documentDate; + this.metadataTableNames = List.of(metadataTableName); return this.postgresService.insertMetadata(this).blockingGet(); } @@ -206,15 +266,17 @@ public List batchInsertMetadata(List metadataList) { return this.postgresService.batchInsertMetadata(this).blockingGet(); } - public StringResponse insertIntoJoinTable(String id, String metadataId) { + public StringResponse insertIntoJoinTable( + String metadataTableName, String id, String metadataId) { this.id = id; this.metadataId = metadataId; + this.metadataTableNames = List.of(metadataTableName); return this.postgresService.insertIntoJoinTable(this).blockingGet(); } public Observable> query( WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) { - this.wordEmbedding = wordEmbeddings; + this.wordEmbeddingsList = List.of(wordEmbeddings); this.topK = topK; this.metric = metric; this.probes = 1; @@ -223,15 +285,59 @@ public Observable> query( public Observable> query( WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK, int probes) { - this.wordEmbedding = wordEmbeddings; + this.wordEmbeddingsList = List.of(wordEmbeddings); this.topK = topK; this.metric = metric; this.probes = probes; return Observable.fromSingle(this.postgresService.query(this)); } + public Observable> query( + List wordEmbeddingsList, + PostgresDistanceMetric metric, + int topK, + int probes) { + this.wordEmbeddingsList = wordEmbeddingsList; + this.metric = metric; + this.probes = probes; + this.topK = topK; + return Observable.fromSingle(this.postgresService.query(this)); + } + + public Observable> queryRRF( + String metadataTable, + WordEmbeddings wordEmbedding, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + OrderRRFBy orderRRFBy, + String searchQuery, + PostgresLanguage postgresLanguage, + int probes, + PostgresDistanceMetric metric, + int upperLimit, + int topK) { + this.metadataTableNames = List.of(metadataTable); + this.wordEmbedding = wordEmbedding; + this.textWeight = textWeight; + this.similarityWeight = similarityWeight; + this.dateWeight = dateWeight; + this.orderRRFBy = orderRRFBy; + this.searchQuery = searchQuery; + this.postgresLanguage = postgresLanguage; + this.probes = probes; + this.metric = metric; + this.upperLimit = upperLimit; + this.topK = topK; + return Observable.fromSingle(this.postgresService.queryRRF(this)); + } + public Observable> queryWithMetadata( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) { + List metadataTableNames, + WordEmbeddings wordEmbeddings, + PostgresDistanceMetric metric, + int topK) { + this.metadataTableNames = metadataTableNames; this.wordEmbedding = wordEmbeddings; this.topK = topK; this.metric = metric; @@ -240,7 +346,12 @@ public Observable> queryWithMetadata( } public Observable> queryWithMetadata( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK, int probes) { + List metadataTableNames, + WordEmbeddings wordEmbeddings, + PostgresDistanceMetric metric, + int topK, + int probes) { + this.metadataTableNames = metadataTableNames; this.wordEmbedding = wordEmbeddings; this.topK = topK; this.metric = metric; @@ -259,7 +370,9 @@ public Observable> getAllChunks(String tableName, S return Observable.fromSingle(this.postgresService.getAllChunks(this)); } - public StringResponse deleteAll() { + public StringResponse deleteAll(String tableName, String namespace) { + this.tableName = tableName; + this.namespace = namespace; return this.postgresService.deleteAll(this).blockingGet(); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java index 87b09e7db..6d714a38d 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java @@ -1,6 +1,7 @@ package com.edgechain.lib.index.client.impl; import com.edgechain.lib.configuration.context.ApplicationContextHolder; +import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.impl.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.repositories.PostgresClientMetadataRepository; @@ -10,6 +11,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.reactivex.rxjava3.core.Observable; + +import java.math.BigDecimal; +import java.sql.Date; import java.sql.Timestamp; import java.util.ArrayList; import java.util.HashMap; @@ -18,6 +22,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; + import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,7 +85,8 @@ public EdgeChain> batchUpsert(PostgresEndpoint postgresEndp postgresEndpoint.getTableName(), postgresEndpoint.getWordEmbeddingsList(), postgresEndpoint.getFilename(), - getNamespace(postgresEndpoint)); + getNamespace(postgresEndpoint), + postgresEndpoint.getPostgresLanguage()); List stringResponseList = strings.stream().map(StringResponse::new).toList(); @@ -108,7 +114,8 @@ public EdgeChain upsert(PostgresEndpoint postgresEndpoint) { postgresEndpoint.getTableName(), postgresEndpoint.getWordEmbedding(), postgresEndpoint.getFilename(), - getNamespace(postgresEndpoint)); + getNamespace(postgresEndpoint), + postgresEndpoint.getPostgresLanguage()); emitter.onNext(new StringResponse(embeddingId)); emitter.onComplete(); @@ -131,6 +138,7 @@ public EdgeChain insertMetadata(PostgresEndpoint postgresEndpoin String metadataId = this.metadataRepository.insertMetadata( + postgresEndpoint.getTableName(), postgresEndpoint.getMetadataTableNames().get(0), input, postgresEndpoint.getDocumentDate()); @@ -155,6 +163,7 @@ public EdgeChain> batchInsertMetadata(PostgresEndpoint post // Insert metadata List strings = this.metadataRepository.batchInsertMetadata( + postgresEndpoint.getTableName(), postgresEndpoint.getMetadataTableNames().get(0), postgresEndpoint.getMetadataList()); @@ -196,24 +205,114 @@ public EdgeChain> query(PostgresEndpoint postgresEn emitter -> { try { List wordEmbeddingsList = new ArrayList<>(); + + List> embeddings = + postgresEndpoint.getWordEmbeddingsList().stream() + .map(WordEmbeddings::getValues) + .toList(); + List> rows = this.repository.query( postgresEndpoint.getTableName(), getNamespace(postgresEndpoint), postgresEndpoint.getProbes(), postgresEndpoint.getMetric(), - postgresEndpoint.getWordEmbedding().getValues(), + embeddings, postgresEndpoint.getTopK()); for (Map row : rows) { PostgresWordEmbeddings val = new PostgresWordEmbeddings(); - val.setId(row.get("id").toString()); - val.setRawText((String) row.get("raw_text")); - val.setFilename((String) row.get("filename")); - val.setTimestamp(((Timestamp) row.get("timestamp")).toLocalDateTime()); - val.setNamespace((String) row.get("namespace")); - val.setScore((Double) row.get("score")); + val.setId(Objects.nonNull(row.get("id")) ? row.get("id").toString() : null); + val.setRawText( + Objects.nonNull(row.get("raw_text")) ? (String) row.get("raw_text") : null); + val.setFilename( + Objects.nonNull(row.get("filename")) ? (String) row.get("filename") : null); + val.setTimestamp( + Objects.nonNull(row.get("timestamp")) + ? ((Timestamp) row.get("timestamp")).toLocalDateTime() + : null); + val.setNamespace( + Objects.nonNull(row.get("namespace")) ? (String) row.get("namespace") : null); + + val.setScore( + Objects.nonNull(row.get("score")) ? (Double) row.get("score") : null); + + PGobject pgObject = (PGobject) row.get("embedding"); + String jsonString = pgObject.getValue(); + List values = objectMapper.readerFor(FLOAT_TYPE_REF).readValue(jsonString); + val.setValues(values); + + wordEmbeddingsList.add(val); + } + emitter.onNext(wordEmbeddingsList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain> queryRRF(PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + List wordEmbeddingsList = new ArrayList<>(); + List> rows = + this.repository.queryRRF( + postgresEndpoint.getTableName(), + getNamespace(postgresEndpoint), + postgresEndpoint.getMetadataTableNames().get(0), + postgresEndpoint.getWordEmbedding().getValues(), + postgresEndpoint.getTextWeight(), + postgresEndpoint.getSimilarityWeight(), + postgresEndpoint.getDateWeight(), + postgresEndpoint.getSearchQuery(), + postgresEndpoint.getPostgresLanguage(), + postgresEndpoint.getProbes(), + postgresEndpoint.getMetric(), + postgresEndpoint.getUpperLimit(), + postgresEndpoint.getTopK(), + postgresEndpoint.getOrderRRFBy()); + + for (Map row : rows) { + + PostgresWordEmbeddings val = new PostgresWordEmbeddings(); + val.setId(Objects.nonNull(row.get("id")) ? row.get("id").toString() : null); + val.setRawText( + Objects.nonNull(row.get("raw_text")) ? (String) row.get("raw_text") : null); + + val.setFilename( + Objects.nonNull(row.get("filename")) ? (String) row.get("filename") : null); + val.setTimestamp( + Objects.nonNull(row.get("timestamp")) + ? ((Timestamp) row.get("timestamp")).toLocalDateTime() + : null); + val.setNamespace( + Objects.nonNull(row.get("namespace")) ? (String) row.get("namespace") : null); + + BigDecimal bigDecimal = + Objects.nonNull(row.get("rrf_score")) + ? (BigDecimal) row.get("rrf_score") + : null; + val.setScore(bigDecimal.doubleValue()); + + if (postgresEndpoint.getMetadataTableNames().get(0).contains("title")) { + val.setTitleMetadata( + Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null); + } else { + val.setMetadata( + Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null); + } + Date documentDate = + Objects.nonNull(row.get("document_date")) + ? (Date) row.get("document_date") + : null; + val.setDocumentDate(documentDate.toString()); wordEmbeddingsList.add(val); } @@ -264,20 +363,34 @@ public EdgeChain> queryWithMetadata( Set contextChunkIds = new HashSet<>(); for (Map row : rows) { String metadataId = row.get("metadata_id").toString(); - if (!metadataTableName.contains("_title_metadata") + if (!metadataTableName.contains("title_metadata") && contextChunkIds.contains(metadataId)) continue; PostgresWordEmbeddings val = new PostgresWordEmbeddings(); - final String idStr = row.get("id").toString(); + final String idStr = + Objects.nonNull(row.get("id")) ? row.get("id").toString() : null; val.setId(idStr); - val.setRawText((String) row.get("raw_text")); - val.setFilename((String) row.get("filename")); - val.setTimestamp(((Timestamp) row.get("timestamp")).toLocalDateTime()); - val.setNamespace((String) row.get("namespace")); - val.setScore((Double) row.get("score")); + val.setRawText( + Objects.nonNull(row.get("raw_text")) + ? (String) row.get("raw_text") + : null); + val.setFilename( + Objects.nonNull(row.get("filename")) + ? (String) row.get("filename") + : null); + val.setTimestamp( + Objects.nonNull(row.get("timestamp")) + ? ((Timestamp) row.get("timestamp")).toLocalDateTime() + : null); + val.setNamespace( + Objects.nonNull(row.get("namespace")) + ? (String) row.get("namespace") + : null); + val.setScore( + Objects.nonNull(row.get("score")) ? (Double) row.get("score") : null); // Add metadata fields in response - if (metadataTableName.contains("_title_metadata")) { + if (metadataTableName.contains("title_metadata")) { titleMetadataMap.put(idStr, (String) row.get("metadata")); dateMetadataMap.put(idStr, (String) row.get("document_date")); @@ -355,6 +468,7 @@ public EdgeChain> getSimilarMetadataChunk( List wordEmbeddingsList = new ArrayList<>(); List> rows = this.metadataRepository.getSimilarMetadataChunk( + postgresEndpoint.getTableName(), postgresEndpoint.getMetadataTableNames().get(0), postgresEndpoint.getEmbeddingChunk()); for (Map row : rows) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/RRFWeight.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/RRFWeight.java new file mode 100644 index 000000000..19b6f6974 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/RRFWeight.java @@ -0,0 +1,45 @@ +package com.edgechain.lib.index.domain; + +import com.edgechain.lib.index.enums.BaseWeight; + +import java.util.StringJoiner; + +public class RRFWeight { + + private BaseWeight baseWeight = BaseWeight.W1_0; + private double fineTuneWeight = 0.5; + + public RRFWeight() {} + + public RRFWeight(BaseWeight baseWeight, double fineTuneWeight) { + this.baseWeight = baseWeight; + this.fineTuneWeight = fineTuneWeight; + + if (fineTuneWeight < 0 || fineTuneWeight > 1.0) + throw new IllegalArgumentException("Weights must be between 0 and 1."); + } + + public void setBaseWeight(BaseWeight baseWeight) { + this.baseWeight = baseWeight; + } + + public void setFineTuneWeight(double fineTuneWeight) { + this.fineTuneWeight = fineTuneWeight; + } + + public BaseWeight getBaseWeight() { + return baseWeight; + } + + public double getFineTuneWeight() { + return fineTuneWeight; + } + + @Override + public String toString() { + return new StringJoiner(", ", RRFWeight.class.getSimpleName() + "[", "]") + .add("baseWeight=" + baseWeight) + .add("fineTuneWeight=" + fineTuneWeight) + .toString(); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/BaseWeight.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/BaseWeight.java new file mode 100644 index 000000000..e998f8a09 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/BaseWeight.java @@ -0,0 +1,32 @@ +package com.edgechain.lib.index.enums; + +public enum BaseWeight { + W1_0(1.0), + W1_25(1.25), + W1_5(1.5), + W1_75(1.75), + W2_0(2.0), + W2_25(2.25), + W2_5(2.5), + W2_75(2.75), + W3_0(3.0); + + private final double value; + + BaseWeight(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + public static BaseWeight fromDouble(double value) { + for (BaseWeight baseWeight : BaseWeight.values()) { + if (baseWeight.getValue() == value) { + return baseWeight; + } + } + throw new IllegalArgumentException("Invalid BaseWeight value: " + value); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/OrderRRFBy.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/OrderRRFBy.java new file mode 100644 index 000000000..db857d0e2 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/OrderRRFBy.java @@ -0,0 +1,23 @@ +package com.edgechain.lib.index.enums; + +public enum OrderRRFBy { + DEFAULT, // Preferred Way; ordered by rrf_score; (relevance over freshness) + TEXT_RANK, // First Ordered By Text_Rank; then ordered by rrf_score (text_rank preferred, then + // relevance) + SIMILARITY, // First Ordered by Similarity; then ordered by rrf_score; (similarity preferred, then + // relevance) + DATE_RANK; // First Ordered by date_rank; then ordered by rrf_score; (freshness preferred, then + + // relevance) + + public static OrderRRFBy fromString(String value) { + if (value != null) { + for (OrderRRFBy orderRRFBy : OrderRRFBy.values()) { + if (orderRRFBy.name().equalsIgnoreCase(value)) { + return orderRRFBy; + } + } + } + throw new IllegalArgumentException("Invalid OrderRRFBy value: " + value); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/PostgresLanguage.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/PostgresLanguage.java new file mode 100644 index 000000000..01f7e82df --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/PostgresLanguage.java @@ -0,0 +1,43 @@ +package com.edgechain.lib.index.enums; + +public enum PostgresLanguage { + SIMPLE("simple"), + ARABIC("arabic"), + ARMENIAN("armenian"), + BASQUE("basque"), + CATALAN("catalan"), + DANISH("danish"), + DUTCH("dutch"), + ENGLISH("english"), + FINNISH("finnish"), + FRENCH("french"), + GERMAN("german"), + GREEK("greek"), + HINDI("hindi"), + HUNGARIAN("hungarian"), + INDONESIAN("indonesian"), + IRISH("irish"), + ITALIAN("italian"), + LITHUANIAN("lithuanian"), + NEPALI("nepali"), + NORWEGIAN("norwegian"), + PORTUGUESE("portuguese"), + ROMANIAN("romanian"), + RUSSIAN("russian"), + SERBIAN("serbian"), + SPANISH("spanish"), + SWEDISH("swedish"), + TAMIL("tamil"), + TURKISH("turkish"), + YIDDISH("yiddish"); + + private final String value; + + PostgresLanguage(String value) { + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java index 7b677467a..215715923 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java @@ -24,33 +24,43 @@ public void createTable(PostgresEndpoint postgresEndpoint) { String.format( "CREATE TABLE IF NOT EXISTS %s (metadata_id UUID PRIMARY KEY, metadata TEXT NOT NULL," + " document_date DATE);", - metadataTable)); + postgresEndpoint.getTableName() + "_" + metadataTable)); // Create a JOIN table jdbcTemplate.execute( String.format( - "CREATE TABLE IF NOT EXISTS %s (id UUID, metadata_id UUID, " - + "FOREIGN KEY (id) REFERENCES %s(id), " - + "FOREIGN KEY (metadata_id) REFERENCES %s(metadata_id), " + "CREATE TABLE IF NOT EXISTS %s (id UUID UNIQUE NOT NULL, metadata_id UUID NOT NULL, " + + "FOREIGN KEY (id) REFERENCES %s(id) ON DELETE CASCADE, " + + "FOREIGN KEY (metadata_id) REFERENCES %s(metadata_id) ON DELETE CASCADE, " + "PRIMARY KEY (id, metadata_id));", postgresEndpoint.getTableName() + "_join_" + metadataTable, postgresEndpoint.getTableName(), - metadataTable)); + postgresEndpoint.getTableName() + "_" + metadataTable)); + + jdbcTemplate.execute( + String.format( + "CREATE INDEX IF NOT EXISTS idx_%s ON %s (metadata_id);", + postgresEndpoint.getTableName() + "_join_" + metadataTable, + postgresEndpoint.getTableName() + "_join_" + metadataTable)); } @Transactional - public List batchInsertMetadata(String metadataTableName, List metadataList) { + public List batchInsertMetadata( + String table, String metadataTableName, List metadataList) { Set uuidSet = new HashSet<>(); for (int i = 0; i < metadataList.size(); i++) { + + String metadata = metadataList.get(i).replace("'", ""); + UUID metadataId = jdbcTemplate.queryForObject( String.format( - "INSERT INTO %s (metadata_id, metadata) VALUES ('%s', '%s') RETURNING" - + " metadata_id;", - metadataTableName, UuidCreator.getTimeOrderedEpoch(), metadataList.get(i)), - UUID.class); + "INSERT INTO %s (metadata_id, metadata) VALUES ('%s', ?) RETURNING metadata_id;", + table.concat("_").concat(metadataTableName), UuidCreator.getTimeOrderedEpoch()), + UUID.class, + metadata); if (metadataId != null) { uuidSet.add(metadataId.toString()); @@ -61,15 +71,22 @@ public List batchInsertMetadata(String metadataTableName, List m } @Transactional - public String insertMetadata(String metadataTableName, String metadata, String documentDate) { - - UUID uuid = UuidCreator.getTimeOrderedEpoch(); - jdbcTemplate.update( - String.format( - "INSERT INTO %s (metadata_id, metadata, document_date) VALUES ('%s', '%s'," - + " TO_DATE(NULLIF('%s', ''), 'Month DD, YYYY'));", - metadataTableName, uuid, metadata, documentDate)); - return uuid.toString(); + public String insertMetadata( + String table, String metadataTableName, String metadata, String documentDate) { + + metadata = metadata.replace("'", ""); + + UUID metadataId = + jdbcTemplate.queryForObject( + String.format( + "INSERT INTO %s (metadata_id, metadata, document_date) VALUES ('%s', ?," + + " TO_DATE(NULLIF(?, ''), 'Month DD, YYYY')) RETURNING metadata_id;", + table.concat("_").concat(metadataTableName), UuidCreator.getTimeOrderedEpoch()), + UUID.class, + metadata, + documentDate); + + return Objects.requireNonNull(metadataId).toString(); } @Transactional @@ -80,7 +97,8 @@ public void insertIntoJoinTable(PostgresEndpoint postgresEndpoint) { + postgresEndpoint.getMetadataTableNames().get(0); jdbcTemplate.execute( String.format( - "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s');", + "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET" + + " metadata_id = EXCLUDED.metadata_id;", joinTableName, UUID.fromString(postgresEndpoint.getId()), UUID.fromString(postgresEndpoint.getMetadataId()))); @@ -112,7 +130,7 @@ public List> queryWithMetadata( embeddings, tableName, joinTable, - metadataTableName, + tableName.concat("_").concat(metadataTableName), namespace, PostgresDistanceMetric.getDistanceMetric(metric), embeddings, @@ -129,7 +147,7 @@ public List> queryWithMetadata( embeddings, tableName, joinTable, - metadataTableName, + tableName.concat("_").concat(metadataTableName), namespace, PostgresDistanceMetric.getDistanceMetric(metric), embeddings, @@ -145,7 +163,7 @@ public List> queryWithMetadata( embeddings, tableName, joinTable, - metadataTableName, + tableName.concat("_").concat(metadataTableName), namespace, PostgresDistanceMetric.getDistanceMetric(metric), embeddings, @@ -156,11 +174,13 @@ public List> queryWithMetadata( // Full-text search @Transactional(readOnly = true, propagation = Propagation.REQUIRED) public List> getSimilarMetadataChunk( - String metadataTableName, String embeddingChunk) { + String table, String metadataTableName, String embeddingChunk) { // Remove special characters and replace with a space String cleanEmbeddingChunk = embeddingChunk.replaceAll("[^a-zA-Z0-9\\s]", " ").replaceAll("\\s+", " ").trim(); + String tableName = table.concat("_").concat(metadataTableName); + // Split the embeddingChunk into words and join them with the '|' (OR) operator String tsquery = String.join(" | ", cleanEmbeddingChunk.split("\\s+")); return jdbcTemplate.queryForList( @@ -168,6 +188,6 @@ public List> getSimilarMetadataChunk( "SELECT *, ts_rank(to_tsvector(%s.metadata), query) as rank_metadata " + "FROM %s, to_tsvector(%s.metadata) document, to_tsquery('%s') query " + "WHERE query @@ document ORDER BY rank_metadata DESC", - metadataTableName, metadataTableName, metadataTableName, tsquery)); + tableName, tableName, tableName, tsquery)); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java index 247c0a3a3..2c8293dc3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java @@ -2,7 +2,10 @@ import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.index.domain.RRFWeight; +import com.edgechain.lib.index.enums.OrderRRFBy; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.utils.FloatUtils; import com.github.f4b6a3.uuid.UuidCreator; import org.springframework.beans.factory.annotation.Autowired; @@ -23,6 +26,7 @@ public class PostgresClientRepository { public void createTable(PostgresEndpoint postgresEndpoint) { jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector;"); + jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;"); String checkTableQuery = String.format( @@ -51,16 +55,22 @@ public void createTable(PostgresEndpoint postgresEndpoint) { + " (lists = %s);", indexName, postgresEndpoint.getTableName(), vectorOps, postgresEndpoint.getLists()); + String tsvIndexQuery = + String.format( + "CREATE INDEX IF NOT EXISTS %s ON %s USING GIN(tsv);", + postgresEndpoint.getTableName().concat("_tsv_idx"), postgresEndpoint.getTableName()); + if (tableExists == 0) { jdbcTemplate.execute( String.format( "CREATE TABLE IF NOT EXISTS %s (id UUID PRIMARY KEY, " + " raw_text TEXT NOT NULL UNIQUE, embedding vector(%s), timestamp" - + " TIMESTAMP NOT NULL, namespace TEXT, filename VARCHAR(255) );", + + " TIMESTAMP NOT NULL, namespace TEXT, filename VARCHAR(255), tsv TSVECTOR);", postgresEndpoint.getTableName(), postgresEndpoint.getDimensions())); jdbcTemplate.execute(indexQuery); + jdbcTemplate.execute(tsvIndexQuery); } else { @@ -69,11 +79,11 @@ public void createTable(PostgresEndpoint postgresEndpoint) { "SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s';", postgresEndpoint.getTableName(), indexName); - int indexExists = jdbcTemplate.queryForObject(checkIndexQuery, Integer.class); + Integer indexExists = jdbcTemplate.queryForObject(checkIndexQuery, Integer.class); - if (indexExists != 1) + if (indexExists != null && indexExists != 1) throw new RuntimeException( - "No index is specifed therefore use the following SQL:\n" + indexQuery); + "No index is specified therefore use the following SQL:\n" + indexQuery); } } @@ -82,7 +92,8 @@ public List batchUpsertEmbeddings( String tableName, List wordEmbeddingsList, String filename, - String namespace) { + String namespace, + PostgresLanguage language) { Set uuidSet = new HashSet<>(); @@ -92,21 +103,25 @@ public List batchUpsertEmbeddings( if (wordEmbeddings != null && wordEmbeddings.getValues() != null) { float[] floatArray = FloatUtils.toFloatArray(wordEmbeddings.getValues()); + String rawText = wordEmbeddings.getId().replace("'", ""); UUID id = jdbcTemplate.queryForObject( String.format( - "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename)" - + " VALUES ('%s', '%s', '%s', '%s', '%s', '%s') ON CONFLICT (raw_text) DO" - + " UPDATE SET embedding = EXCLUDED.embedding RETURNING id;", + "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename, tsv)" + + " VALUES ('%s', ?, '%s', '%s', '%s', '%s', TO_TSVECTOR('%s', '%s')) ON" + + " CONFLICT (raw_text) DO UPDATE SET embedding = EXCLUDED.embedding" + + " RETURNING id;", tableName, UuidCreator.getTimeOrderedEpoch(), - wordEmbeddings.getId(), Arrays.toString(floatArray), LocalDateTime.now(), namespace, - filename), - UUID.class); + filename, + language.getValue(), + rawText), + UUID.class, + rawText); if (id != null) { uuidSet.add(id.toString()); @@ -119,24 +134,34 @@ public List batchUpsertEmbeddings( @Transactional public String upsertEmbeddings( - String tableName, WordEmbeddings wordEmbeddings, String filename, String namespace) { - - UUID uuid = UuidCreator.getTimeOrderedEpoch(); - - jdbcTemplate.update( - String.format( - "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename) VALUES ('%s'," - + " '%s', '%s', '%s', '%s', '%s') ON CONFLICT (raw_text) DO UPDATE SET embedding =" - + " EXCLUDED.embedding;", - tableName, - uuid, - wordEmbeddings.getId(), - Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())), - LocalDateTime.now(), - namespace, - filename)); - - return uuid.toString(); + String tableName, + WordEmbeddings wordEmbeddings, + String filename, + String namespace, + PostgresLanguage language) { + + float[] floatArray = FloatUtils.toFloatArray(wordEmbeddings.getValues()); + String rawText = wordEmbeddings.getId().replace("'", ""); + + UUID uuid = + jdbcTemplate.queryForObject( + String.format( + "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename, tsv)" + + " VALUES ('%s', ?, '%s', '%s', '%s', '%s', TO_TSVECTOR('%s', '%s')) ON" + + " CONFLICT (raw_text) DO UPDATE SET embedding = EXCLUDED.embedding RETURNING" + + " id;", + tableName, + UuidCreator.getTimeOrderedEpoch(), + Arrays.toString(floatArray), + LocalDateTime.now(), + namespace, + filename, + language.getValue(), + rawText), + UUID.class, + rawText); + + return Objects.requireNonNull(uuid).toString(); } @Transactional(readOnly = true, propagation = Propagation.REQUIRED) @@ -145,54 +170,190 @@ public List> query( String namespace, int probes, PostgresDistanceMetric metric, - List values, + List> values, int topK) { - String embeddings = Arrays.toString(FloatUtils.toFloatArray(values)); - jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); - if (metric.equals(PostgresDistanceMetric.IP)) { - return jdbcTemplate.queryForList( - String.format( - "SELECT id, raw_text, namespace, filename, timestamp, ( embedding <#>" - + " '%s') * -1 AS score FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s'" - + " LIMIT %s;", - embeddings, - tableName, - namespace, - PostgresDistanceMetric.getDistanceMetric(metric), - embeddings, - topK)); - - } else if (metric.equals(PostgresDistanceMetric.COSINE)) { + StringBuilder query = new StringBuilder(); + + for (int i = 0; i < values.size(); i++) { + + String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i))); + + query.append("(").append("SELECT id, raw_text, embedding, namespace, filename, timestamp,"); + + switch (metric) { + case COSINE -> query + .append(String.format("1 - (embedding <=> '%s') AS score ", embeddings)) + .append(" FROM ") + .append(tableName) + .append(" WHERE namespace = ") + .append("'") + .append(namespace) + .append("'") + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT "); + case IP -> query + .append(String.format("(embedding <#> '%s') * -1 AS score ", embeddings)) + .append(" FROM ") + .append(tableName) + .append(" WHERE namespace = ") + .append("'") + .append(namespace) + .append("'") + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT "); + case L2 -> query + .append(String.format("embedding <-> '%s' AS score ", embeddings)) + .append(" FROM ") + .append(tableName) + .append(" WHERE namespace = ") + .append("'") + .append(namespace) + .append("'") + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT "); + default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); + } + query.append(topK).append(")"); + if (i < values.size() - 1) { + query.append(" UNION ALL ").append("\n"); + } + } + + if (values.size() > 1) { return jdbcTemplate.queryForList( - String.format( - "SELECT id, raw_text, namespace, filename, timestamp, 1 - ( embedding" - + " <=> '%s') AS score FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s'" - + " LIMIT %s;", - embeddings, - tableName, - namespace, - PostgresDistanceMetric.getDistanceMetric(metric), - embeddings, - topK)); + String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query)); } else { - return jdbcTemplate.queryForList( - String.format( - "SELECT id, raw_text, namespace, filename, timestamp, (embedding <->" - + " '%s') AS score FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' ASC" - + " LIMIT %s;", - embeddings, - tableName, - namespace, - PostgresDistanceMetric.getDistanceMetric(metric), - embeddings, - topK)); + return jdbcTemplate.queryForList(query.toString()); } } + public List> queryRRF( + String tableName, + String namespace, + String metadataTableName, + List values, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + String searchQuery, + PostgresLanguage language, + int probes, + PostgresDistanceMetric metric, + int upperLimit, + int topK, + OrderRRFBy orderRRFBy) { + + jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); + + String embeddings = Arrays.toString(FloatUtils.toFloatArray(values)); + + StringBuilder query = new StringBuilder(); + query + .append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", + textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", + similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", + dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) + .append("FROM ( ") + .append( + "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," + + " svtm.document_date, svtm.metadata, ") + .append( + String.format( + "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", + language.getValue(), searchQuery)); + + switch (metric) { + case COSINE -> query.append( + String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); + case IP -> query.append( + String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); + case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings)); + default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); + } + + query + .append("CASE ") + .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling + .append( + "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" + + " svtm.document_date) ") + .append("END AS date_rank ") + .append("FROM ") + .append( + String.format( + "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE" + + " namespace = '%s'", + tableName, namespace)); + + switch (metric) { + case COSINE -> query + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(upperLimit); + case IP -> query + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(upperLimit); + case L2 -> query + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(upperLimit); + default -> throw new IllegalArgumentException("Invalid metric: " + metric); + } + query + .append(")") + .append(" sv ") + .append("JOIN ") + .append(tableName.concat("_join_").concat(metadataTableName)) + .append(" jtm ON sv.id = jtm.id ") + .append("JOIN ") + .append(tableName.concat("_").concat(metadataTableName)) + .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") + .append(") subquery "); + + switch (orderRRFBy) { + case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC"); + case SIMILARITY -> query.append("ORDER BY similarity DESC, rrf_score DESC"); + case DATE_RANK -> query.append("ORDER BY date_rank DESC, rrf_score DESC"); + case DEFAULT -> query.append("ORDER BY rrf_score DESC"); + default -> throw new IllegalArgumentException("Invalid orderRRFBy value"); + } + + query.append(" LIMIT ").append(topK).append(";"); + return jdbcTemplate.queryForList(query.toString()); + } + @Transactional(readOnly = true) public List> getAllChunks(PostgresEndpoint endpoint) { return jdbcTemplate.queryForList( diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java index cc3cff381..875bba38c 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java @@ -107,11 +107,10 @@ public JSONObject getBody() { while ((line = reader.readLine()) != null) { jsonContent.append(line); } + return new JSONObject(jsonContent.toString()); } catch (IOException e) { throw new RuntimeException(e); } - - return new JSONObject(jsonContent.toString()); } public Cookie[] getCookies() { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java index ff7d76ee7..e548e7803 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java @@ -36,6 +36,9 @@ public interface PostgresService { @POST(value = "index/postgres/query") Single> query(@Body PostgresEndpoint postgresEndpoint); + @POST(value = "index/postgres/query-rrf") + Single> queryRRF(@Body PostgresEndpoint postgresEndpoint); + @POST(value = "index/postgres/metadata/query") Single> queryWithMetadata(@Body PostgresEndpoint postgresEndpoint); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java index 5bc927061..fdab6cef1 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java @@ -65,6 +65,12 @@ public Single> query( return this.postgresClient.query(postgresEndpoint).toSingle(); } + @PostMapping("/query-rrf") + public Single> queryRRF( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryRRF(postgresEndpoint).toSingle(); + } + @PostMapping("/metadata/query") public Single> queryWithMetadata( @RequestBody PostgresEndpoint postgresEndpoint) { diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java index abf8375a4..2c11a420e 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java @@ -4,6 +4,7 @@ import com.edgechain.lib.endpoint.impl.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import com.edgechain.testutil.PostgresTestContainer; @@ -11,6 +12,8 @@ import com.zaxxer.hikari.HikariConfig; import java.util.List; import java.util.stream.Collectors; + +import io.reactivex.rxjava3.observers.TestObserver; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -61,28 +64,26 @@ void allMethods() { createTable(); createMetadataTable(); + deleteAll(); // check delete before we get foreign keys String uuid1 = upsert(); batchUpsert(); + query_noMeta(); String uuid2 = insertMetadata(); + batchInsertMetadata(); insertIntoJoinTable(uuid1, uuid2); - query_meta(); + query_meta(); getChunks(); getSimilarChunks(); } private void createTable() { - createTable_metric(PostgresDistanceMetric.L2, "testtableL2"); - createTable_metric(PostgresDistanceMetric.COSINE, "testtableCOS"); - createTable_metric(null, "testtable"); - - // create table again - createTable_metric(null, "testtable"); + createTable_metric(PostgresDistanceMetric.COSINE, "t_embedding"); } private void createTable_metric(PostgresDistanceMetric metric, String tableName) { @@ -92,65 +93,71 @@ private void createTable_metric(PostgresDistanceMetric metric, String tableName) when(mockPe.getDimensions()).thenReturn(2); when(mockPe.getMetric()).thenReturn(metric); - final Data data = new Data(); - EdgeChain result = service.createTable(mockPe); - result.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.error = e); - if (data.error != null) { - fail("createTable failed", data.error); + TestObserver test = service.createTable(mockPe).getObservable().test(); + + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - LOGGER.info("createTable (metric={}) response: '{}'", metric, data.val); + test.assertNoErrors(); + LOGGER.info("createTable (metric={}) response: '{}'", metric, tableName); } private void createMetadataTable() { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); - when(mockPe.getMetadataTableNames()).thenReturn(List.of("dogmeta", "catmeta")); - - final Data data = new Data(); - EdgeChain result = service.createMetadataTable(mockPe); - result.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.error = e); - if (data.error != null) { - fail("createMetadataTable failed", data.error); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + + TestObserver test = service.createMetadataTable(mockPe).getObservable().test(); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - LOGGER.info("createMetadataTable response: '{}'", data.val); + LOGGER.info("createMetadataTable response: '{}'", test.values().get(0).getResponse()); } private String upsert() { WordEmbeddings we = new WordEmbeddings(); we.setId("WE1"); - we.setScore("101"); + we.setScore("0.86914713"); we.setValues(List.of(0.25f, 0.5f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); + when(mockPe.getTableName()).thenReturn("t_embedding"); when(mockPe.getWordEmbedding()).thenReturn(we); when(mockPe.getFilename()).thenReturn("readme.pdf"); when(mockPe.getNamespace()).thenReturn("testns"); + when(mockPe.getPostgresLanguage()).thenReturn(PostgresLanguage.ENGLISH); - final Data data = new Data(); - EdgeChain result = service.upsert(mockPe); - result.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.error = e); - if (data.error != null) { - fail("upsert failed", data.error); + TestObserver test = service.upsert(mockPe).getObservable().test(); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - LOGGER.info("upsert response: '{}'", data.val); - return data.val; + + test.assertNoErrors(); + + return test.values().get(0).getResponse(); } private String insertMetadata() { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getMetadataTableNames()).thenReturn(List.of("dogmeta", "catmeta")); - when(mockPe.getMetadata()).thenReturn("''duck''"); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + when(mockPe.getMetadata()).thenReturn("This is a sample text"); when(mockPe.getDocumentDate()).thenReturn("November 11, 2015"); - final Data data = new Data(); - EdgeChain result = service.insertMetadata(mockPe); - result.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.error = e); - if (data.error != null) { - fail("insertMetadata failed", data.error); + TestObserver test = service.insertMetadata(mockPe).getObservable().test(); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - LOGGER.info("insertMetadata response: '{}'", data.val); - return data.val; + test.assertNoErrors(); + return test.values().get(0).getResponse(); } private void batchUpsert() { @@ -165,10 +172,11 @@ private void batchUpsert() { we2.setValues(List.of(0.75f, 0.9f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); + when(mockPe.getTableName()).thenReturn("t_embedding"); when(mockPe.getWordEmbeddingsList()).thenReturn(List.of(we1, we2)); when(mockPe.getFilename()).thenReturn("readme.pdf"); when(mockPe.getNamespace()).thenReturn("testns"); + when(mockPe.getPostgresLanguage()).thenReturn(PostgresLanguage.ENGLISH); final Data data = new Data(); EdgeChain> result = service.batchUpsert(mockPe); @@ -185,8 +193,9 @@ private void batchUpsert() { private void batchInsertMetadata() { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getMetadataTableNames()).thenReturn(List.of("dogmeta", "catmeta")); - when(mockPe.getMetadataList()).thenReturn(List.of("cow", "horse")); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + when(mockPe.getMetadataList()).thenReturn(List.of("text1", "text2")); final Data data = new Data(); EdgeChain> result = service.batchInsertMetadata(mockPe); @@ -203,19 +212,20 @@ private void batchInsertMetadata() { private void insertIntoJoinTable(String uuid1, String uuid2) { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); - when(mockPe.getMetadataTableNames()).thenReturn(List.of("dogmeta", "catmeta")); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); when(mockPe.getId()).thenReturn(uuid1); when(mockPe.getMetadataId()).thenReturn(uuid2); - final Data data = new Data(); + TestObserver test = service.insertIntoJoinTable(mockPe).getObservable().test(); - EdgeChain result = service.insertIntoJoinTable(mockPe); - result.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.error = e); - if (data.error != null) { - fail("insertIntoJoinTable failed", data.error); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - LOGGER.info("insertIntoJoinTable response: '{}'", data.val); + + test.assertNoErrors(); } private void deleteAll() { @@ -226,7 +236,7 @@ private void deleteAll() { private void deleteAll_namespace(String namespace, String expected) { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); + when(mockPe.getTableName()).thenReturn("t_embedding"); when(mockPe.getNamespace()).thenReturn(namespace); final Data data = new Data(); @@ -252,12 +262,12 @@ private void query_noMeta_metric(PostgresDistanceMetric metric) { we1.setValues(List.of(0.25f, 0.5f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); + when(mockPe.getTableName()).thenReturn("t_embedding"); when(mockPe.getNamespace()).thenReturn("testns"); when(mockPe.getProbes()).thenReturn(5); when(mockPe.getMetric()).thenReturn(metric); - when(mockPe.getWordEmbedding()).thenReturn(we1); - when(mockPe.getTopK()).thenReturn(1000); + when(mockPe.getWordEmbeddingsList()).thenReturn(List.of(we1)); + when(mockPe.getTopK()).thenReturn(5); when(mockPe.getMetadataTableNames()).thenReturn(null); final Data data = new Data(); @@ -289,13 +299,13 @@ private void query_meta_metric(PostgresDistanceMetric metric) { we1.setValues(List.of(0.25f, 0.5f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); + when(mockPe.getTableName()).thenReturn("t_embedding"); when(mockPe.getNamespace()).thenReturn("testns"); - when(mockPe.getProbes()).thenReturn(5); + when(mockPe.getProbes()).thenReturn(20); when(mockPe.getMetric()).thenReturn(metric); when(mockPe.getWordEmbedding()).thenReturn(we1); - when(mockPe.getTopK()).thenReturn(1000); - when(mockPe.getMetadataTableNames()).thenReturn(List.of("dogmeta", "catmeta")); + when(mockPe.getTopK()).thenReturn(5); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); final Data data = new Data(); EdgeChain> result = service.queryWithMetadata(mockPe); @@ -315,7 +325,7 @@ private void query_meta_metric(PostgresDistanceMetric metric) { private void getChunks() { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getTableName()).thenReturn("testtable"); + when(mockPe.getTableName()).thenReturn("t_embedding"); when(mockPe.getFilename()).thenReturn("readme.pdf"); final Data data = new Data(); @@ -347,7 +357,8 @@ private void getChunks() { private void getSimilarChunks() { PostgresEndpoint mockPe = mock(PostgresEndpoint.class); - when(mockPe.getMetadataTableNames()).thenReturn(List.of("dogmeta", "catmeta")); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); when(mockPe.getEmbeddingChunk()).thenReturn("how to test this"); final Data data = new Data(); diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java index de1bf36c2..858bb9574 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java @@ -56,7 +56,7 @@ public void testCreateTable_NonEmptyMetadataTableNames() { repository.createTable(postgresEndpoint); // Assert - verify(jdbcTemplate, times(2)).execute(sqlQueryCaptor.capture()); + verify(jdbcTemplate, times(3)).execute(sqlQueryCaptor.capture()); } @Test @@ -71,21 +71,31 @@ public void testCreateTable_EmptyMetadataTableNames() { } @Test - @DisplayName("Insert metadata should return metadata id after getting inserted") - public void testInsertMetadata_ReturnsMetadataId() { + @DisplayName("Insert metadata must throw NullPointerException when metadata ID is null") + public void testInsertMetadata_ThrowsNullPointerException() { + // Arrange + String tablename = "table"; String metadataTableName = "metadata_table"; String metadata = "example_metadata"; String documentDate = "Aug 01, 2023"; - // Act - repository.insertMetadata(metadataTableName, metadata, documentDate); + // Mock jdbcTemplate.queryForObject to return null + when(jdbcTemplate.queryForObject(anyString(), eq(UUID.class), any(Object[].class))) + .thenReturn(null); - // Assert - verify(jdbcTemplate, times(1)).update(sqlQueryCaptor.capture()); + // Act and Assert + assertThrows( + NullPointerException.class, + () -> { + repository.insertMetadata(tablename, metadataTableName, metadata, documentDate); + }); + + // Verify that jdbcTemplate.queryForObject was called with the correct SQL query and arguments + verify(jdbcTemplate, times(1)) + .queryForObject(sqlQueryCaptor.capture(), eq(UUID.class), any(Object[].class)); } - // @Test @DisplayName("Insert entry into the join table") public void testInsertIntoJoinTable() { @@ -112,7 +122,8 @@ public void testInsertIntoJoinTable() { String capturedQuery = sqlQueryCaptor.getValue(); String expectedQuery = String.format( - "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s');", + "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET" + + " metadata_id = EXCLUDED.metadata_id;", joinTable, postgresEndpoint.getId(), postgresEndpoint.getMetadataId()); assertEquals(expectedQuery, capturedQuery); }