From 4cdea9c2bf597497c435b15c6ba99fa728ffff38 Mon Sep 17 00:00:00 2001 From: github-actions <> Date: Sat, 30 Sep 2023 12:39:37 +0000 Subject: [PATCH] Google Java Format --- Examples/airtable/AirtableExample.java | 23 +- Examples/pinecone/PineconeExample.java | 8 +- Examples/postgresql/PostgreSQLExample.java | 28 +- .../react-chain/ReactChainApplication.java | 1 - Examples/redis/RedisExample.java | 33 +- .../SupabaseMiniLMExample.java | 26 +- Examples/zapier/ZapierExample.java | 118 +- .../com/edgechain/EdgeChainApplication.java | 1 - .../lib/chains/PineconeRetrieval.java | 12 +- .../lib/chains/PostgresRetrieval.java | 6 +- .../edgechain/lib/chains/RedisRetrieval.java | 7 +- .../impl/embeddings/BgeSmallEndpoint.java | 1 - .../impl/embeddings/EmbeddingEndpoint.java | 15 +- .../impl/embeddings/MiniLMEndpoint.java | 3 - .../embeddings/OpenAiEmbeddingEndpoint.java | 84 +- .../endpoint/impl/index/PineconeEndpoint.java | 20 +- .../endpoint/impl/index/PostgresEndpoint.java | 1011 +++++++++-------- .../endpoint/impl/index/RedisEndpoint.java | 13 +- .../impl/integration/AirtableEndpoint.java | 324 +++--- .../endpoint/impl/llm/OpenAiChatEndpoint.java | 12 +- .../lib/index/client/impl/PostgresClient.java | 8 +- .../index/domain/PostgresWordEmbeddings.java | 1 - .../PostgresClientRepository.java | 172 +-- .../airtable/query/AirtableQueryBuilder.java | 233 ++-- .../integration/airtable/query/SortOrder.java | 34 +- .../edgechain/lib/jsonnet/JsonnetLoader.java | 16 +- .../lib/openai/client/OpenAiClient.java | 252 ++-- .../lib/retrofit/AirtableService.java | 24 +- .../retrofit/client/OpenAiStreamService.java | 2 - .../transformer/observable/EdgeChain.java | 8 +- .../edgechain/lib/utils/ContextReorder.java | 43 +- .../controllers/index/PineconeController.java | 43 +- .../controllers/index/PostgresController.java | 168 ++- .../controllers/index/RedisController.java | 54 +- .../integration/AirtableController.java | 46 +- .../controllers/openai/OpenAiController.java | 495 ++++---- .../edgechain/EdgeChainApplicationTest.java | 2 - .../endpoint/impl/BgeSmallEndpointTest.java | 4 +- .../pinecone/PineconeClientTest.java | 10 +- .../edgechain/testutil/TestConfigSupport.java | 2 +- .../com/edgechain/wiki/WikiClientTest.java | 2 +- 41 files changed, 1684 insertions(+), 1681 deletions(-) diff --git a/Examples/airtable/AirtableExample.java b/Examples/airtable/AirtableExample.java index 2a24d9a83..6cd5c2cdf 100644 --- a/Examples/airtable/AirtableExample.java +++ b/Examples/airtable/AirtableExample.java @@ -17,11 +17,12 @@ import java.util.Properties; /** - * For the purpose, of this example, create a simple table using AirTable i.e, "Speakers" - * Following are some basic fields: "speaker_name", "designation", "organization", "biography", "speaker_photo", "rating - * To get, your BASE_ID of specific database; use the following API https://api.airtable.com/v0/meta/bases - * --header 'Authorization: Bearer YOUR_PERSONAL_ACCESS_TOKEN' - * You can create, complex tables using Airtable; also define relationships b/w tables via lookups. + * For the purpose, of this example, create a simple table using AirTable i.e, "Speakers" Following + * are some basic fields: "speaker_name", "designation", "organization", "biography", + * "speaker_photo", "rating To get, your BASE_ID of specific database; use the following API + * https://api.airtable.com/v0/meta/bases --header 'Authorization: Bearer + * YOUR_PERSONAL_ACCESS_TOKEN' You can create, complex tables using Airtable; also define + * relationships b/w tables via lookups. */ @SpringBootApplication public class AirtableExample { @@ -123,11 +124,11 @@ public ArkResponse update(ArkRequest arkRequest) { return new EdgeChain<>(airtableEndpoint.update("Speakers", record)).getArkResponse(); } - @DeleteMapping("/delete") - public ArkResponse delete(ArkRequest arkRequest) { - JSONObject body = arkRequest.getBody(); - String id = body.getString("id"); - return new EdgeChain<>(airtableEndpoint.delete("Speakers", id)).getArkResponse(); - } + @DeleteMapping("/delete") + public ArkResponse delete(ArkRequest arkRequest) { + JSONObject body = arkRequest.getBody(); + String id = body.getString("id"); + return new EdgeChain<>(airtableEndpoint.delete("Speakers", id)).getArkResponse(); + } } } diff --git a/Examples/pinecone/PineconeExample.java b/Examples/pinecone/PineconeExample.java index 2cf8b9725..fda4e6163 100644 --- a/Examples/pinecone/PineconeExample.java +++ b/Examples/pinecone/PineconeExample.java @@ -61,7 +61,7 @@ public static void main(String[] args) { // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port",""); + properties.setProperty("redis.port", ""); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); @@ -72,12 +72,11 @@ public static void main(String[] args) { properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); - new SpringApplicationBuilder(PineconeExample.class).properties(properties).run(args); gpt3Endpoint = new OpenAiChatEndpoint( - OPENAI_CHAT_COMPLETION_API, + OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, "gpt-3.5-turbo", @@ -96,7 +95,8 @@ public static void main(String[] args) { true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - OpenAiEmbeddingEndpoint ada002 = new OpenAiEmbeddingEndpoint( + OpenAiEmbeddingEndpoint ada002 = + new OpenAiEmbeddingEndpoint( OPENAI_EMBEDDINGS_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, diff --git a/Examples/postgresql/PostgreSQLExample.java b/Examples/postgresql/PostgreSQLExample.java index 28febf9ca..7b1782e3c 100644 --- a/Examples/postgresql/PostgreSQLExample.java +++ b/Examples/postgresql/PostgreSQLExample.java @@ -90,7 +90,8 @@ public static void main(String[] args) { true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - OpenAiEmbeddingEndpoint adaEmbedding = new OpenAiEmbeddingEndpoint( + OpenAiEmbeddingEndpoint adaEmbedding = + new OpenAiEmbeddingEndpoint( OPENAI_EMBEDDINGS_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -100,7 +101,10 @@ public static void main(String[] args) { // Defining tablename and namespace... postgresEndpoint = new PostgresEndpoint( - "pg_vectors", "machine-learning", adaEmbedding, new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); + "pg_vectors", + "machine-learning", + adaEmbedding, + new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -159,12 +163,7 @@ public void upsert(ArkRequest arkRequest) throws IOException { PostgresRetrieval retrieval = new PostgresRetrieval( - arr, - postgresEndpoint, - 1536, - filename, - PostgresLanguage.ENGLISH, - arkRequest); + arr, postgresEndpoint, 1536, filename, PostgresLanguage.ENGLISH, arkRequest); // retrieval.setBatchSize(50); // Modifying batchSize....(Default is 30) @@ -186,12 +185,7 @@ public ArkResponse query(ArkRequest arkRequest) { EdgeChain> queryChain = new EdgeChain<>( postgresEndpoint.query( - List.of(query), - PostgresDistanceMetric.COSINE, - topK, - topK, - 10, - arkRequest)); + List.of(query), PostgresDistanceMetric.COSINE, topK, topK, 10, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -227,7 +221,8 @@ public ArkResponse chat(ArkRequest arkRequest) { EdgeChain> postgresChain = new EdgeChain<>( - postgresEndpoint.query(List.of(query), PostgresDistanceMetric.COSINE, topK, topK, arkRequest)); + postgresEndpoint.query( + List.of(query), PostgresDistanceMetric.COSINE, topK, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List // let's say topK=5; then we concatenate List into a string using String.join method @@ -235,7 +230,8 @@ public ArkResponse chat(ArkRequest arkRequest) { new EdgeChain<>(postgresChain) .transform( postgresResponse -> { - List postgresWordEmbeddingsList = postgresResponse.get(); + List postgresWordEmbeddingsList = + postgresResponse.get(); List queryList = new ArrayList<>(); postgresWordEmbeddingsList.forEach(q -> queryList.add(q.getRawText())); return String.join("\n", queryList); diff --git a/Examples/react-chain/ReactChainApplication.java b/Examples/react-chain/ReactChainApplication.java index 60c159907..d3e612155 100644 --- a/Examples/react-chain/ReactChainApplication.java +++ b/Examples/react-chain/ReactChainApplication.java @@ -1,6 +1,5 @@ package com.edgechain; - import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; diff --git a/Examples/redis/RedisExample.java b/Examples/redis/RedisExample.java index 2185a74a7..919b4cfa7 100644 --- a/Examples/redis/RedisExample.java +++ b/Examples/redis/RedisExample.java @@ -61,15 +61,13 @@ public static void main(String[] args) { // If you want to use PostgreSQL only; then just provide dbHost, dbUsername & dbPassword. // If you haven't specified PostgreSQL, then logs won't be stored. - properties.setProperty( - "postgres.db.host", ""); + properties.setProperty("postgres.db.host", ""); properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); - // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port","12285"); + properties.setProperty("redis.port", "12285"); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); @@ -98,16 +96,19 @@ public static void main(String[] args) { new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); OpenAiEmbeddingEndpoint ada002Endpoint = - new OpenAiEmbeddingEndpoint( - OPENAI_EMBEDDINGS_API, - OPENAI_AUTH_KEY, - OPENAI_ORG_ID, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); redisEndpoint = new RedisEndpoint( - "vector_index", "machine-learning", ada002Endpoint, new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + "vector_index", + "machine-learning", + ada002Endpoint, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); @@ -166,8 +167,7 @@ public void upsert(ArkRequest arkRequest) throws IOException { * asynchronously...; */ RedisRetrieval retrieval = - new RedisRetrieval( - arr, redisEndpoint, 1536, RedisDistanceMetric.COSINE, arkRequest); + new RedisRetrieval(arr, redisEndpoint, 1536, RedisDistanceMetric.COSINE, arkRequest); retrieval.upsert(); } @@ -184,7 +184,8 @@ public ArkResponse similaritySearch(ArkRequest arkRequest) { int topK = arkRequest.getIntQueryParam("topK"); // Chain 1 ==> Pass those embeddings to Redis & Return Score/values (Similarity search) - EdgeChain> redisQueries = new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); + EdgeChain> redisQueries = + new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); return redisQueries.getArkResponse(); } @@ -196,7 +197,8 @@ public ArkResponse queryRedis(ArkRequest arkRequest) { int topK = arkRequest.getIntQueryParam("topK"); // Chain 1 ==> Query Embeddings from Redis - EdgeChain> queryChain = new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); + EdgeChain> queryChain = + new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -227,7 +229,6 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1==> Query Embeddings from Redis & Then concatenate it (preparing for prompt) EdgeChain> redisChain = diff --git a/Examples/supabase-miniLM/SupabaseMiniLMExample.java b/Examples/supabase-miniLM/SupabaseMiniLMExample.java index 2a227d7a2..dba35663e 100644 --- a/Examples/supabase-miniLM/SupabaseMiniLMExample.java +++ b/Examples/supabase-miniLM/SupabaseMiniLMExample.java @@ -49,9 +49,9 @@ public class SupabaseMiniLMExample { private static PostgreSQLHistoryContextEndpoint contextEndpoint; private JsonnetLoader queryLoader = - new FileJsonnetLoader("./supabase-miniLM/postgres-query.jsonnet"); + new FileJsonnetLoader("./supabase-miniLM/postgres-query.jsonnet"); private JsonnetLoader chatLoader = - new FileJsonnetLoader("./supabase-miniLM/postgres-chat.jsonnet"); + new FileJsonnetLoader("./supabase-miniLM/postgres-chat.jsonnet"); public static void main(String[] args) { @@ -77,7 +77,6 @@ public static void main(String[] args) { properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); - new SpringApplicationBuilder(SupabaseMiniLMExample.class).properties(properties).run(args); gpt3Endpoint = @@ -112,7 +111,10 @@ public static void main(String[] args) { // vectors; postgresEndpoint = new PostgresEndpoint( - "minilm_vectors", "minilm-ns", miniLMEndpoint, new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); + "minilm_vectors", + "minilm-ns", + miniLMEndpoint, + new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -172,12 +174,7 @@ public void upsert(ArkRequest arkRequest) throws IOException { PostgresRetrieval retrieval = new PostgresRetrieval( - arr, - postgresEndpoint, - 384, - filename, - PostgresLanguage.ENGLISH, - arkRequest); + arr, postgresEndpoint, 384, filename, PostgresLanguage.ENGLISH, arkRequest); // retrieval.setBatchSize(50); // Modifying batchSize.... @@ -199,7 +196,8 @@ public ArkResponse queryPostgres(ArkRequest arkRequest) { // Chain 2 ==> Query Embeddings from PostgreSQL EdgeChain> queryChain = new EdgeChain<>( - postgresEndpoint.query(List.of(query), PostgresDistanceMetric.L2, topK, topK,arkRequest)); + postgresEndpoint.query( + List.of(query), PostgresDistanceMetric.L2, topK, topK, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -237,14 +235,16 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain> postgresChain = new EdgeChain<>( - postgresEndpoint.query(List.of(query), PostgresDistanceMetric.L2, topK, topK, arkRequest)); + postgresEndpoint.query( + List.of(query), PostgresDistanceMetric.L2, topK, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List EdgeChain queryChain = new EdgeChain<>(postgresChain) .transform( postgresResponse -> { - List postgresWordEmbeddingsList = postgresResponse.get(); + List postgresWordEmbeddingsList = + postgresResponse.get(); List queryList = new ArrayList<>(); postgresWordEmbeddingsList.forEach(q -> queryList.add(q.getRawText())); return String.join("\n", queryList); diff --git a/Examples/zapier/ZapierExample.java b/Examples/zapier/ZapierExample.java index 29c6f3265..f50c93f40 100644 --- a/Examples/zapier/ZapierExample.java +++ b/Examples/zapier/ZapierExample.java @@ -1,6 +1,6 @@ package com.edgechain; -//DEPS com.amazonaws:aws-java-sdk-s3:1.12.554 +// DEPS com.amazonaws:aws-java-sdk-s3:1.12.554 import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.BasicAWSCredentials; @@ -36,7 +36,6 @@ import reactor.util.retry.Retry; import java.io.*; import java.nio.charset.StandardCharsets; -import java.security.GeneralSecurityException; import java.time.Duration; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; @@ -44,40 +43,35 @@ import static com.edgechain.lib.constants.EndpointConstants.OPENAI_EMBEDDINGS_API; /** - * Objective (1): - * 1. Use Zapier Webhook to pass list of urls (in our example, we have used Wikipedia Urls") ==> Trigger - * 2. Action: Use 'Web Page Parser by Zapier' to extract entire content (which is later used to upsert in Pinecone) from url - * 3. Action: Stringify the JSON - * 4. Action: Save each json (parsed content from each URL) to Amazon S3 - * (This process is entirely automated by Zapier). You would need to create Zap for it. Zapier Hook is used to trigger the ETL process, - * (Parallelize Hook Requests)... - * ========================================================== - * 5. Then, we extract each file from Amazon S3 - * 6. Upsert the normalized content to Pinecone with a chunkSize of 512... + * Objective (1): 1. Use Zapier Webhook to pass list of urls (in our example, we have used Wikipedia + * Urls") ==> Trigger 2. Action: Use 'Web Page Parser by Zapier' to extract entire content (which is + * later used to upsert in Pinecone) from url 3. Action: Stringify the JSON 4. Action: Save each + * json (parsed content from each URL) to Amazon S3 (This process is entirely automated by Zapier). + * You would need to create Zap for it. Zapier Hook is used to trigger the ETL process, (Parallelize + * Hook Requests)... ========================================================== 5. Then, we extract + * each file from Amazon S3 6. Upsert the normalized content to Pinecone with a chunkSize of 512... * You can choose any file storage S3, Dropbox, Google Drive etc.... */ /** - * Objective (2): Extracting PDF via PDF4Me. - * Create a Zapier by using the following steps: - * 1. Trigger ==> Integrate Google Drive Folder; when new file is added (it's not instant; it's scheduled internally by Zapier.) - * (You can also trigger it by Webhook as well) - * 2. Action ==> Extract text from PDF using PDF4Me (Free plan allows 20 API calls) - * 3. Action ==> Use ZapierByCode to stringify the json response from PDF4Me - * 4. Action ==> Save it to Amazon S3 - * Now, from EdgeChains SDK we extract the files from S3 & upsert it to Pinecone via Chunk Size 512. + * Objective (2): Extracting PDF via PDF4Me. Create a Zapier by using the following steps: 1. + * Trigger ==> Integrate Google Drive Folder; when new file is added (it's not instant; it's + * scheduled internally by Zapier.) (You can also trigger it by Webhook as well) 2. Action ==> + * Extract text from PDF using PDF4Me (Free plan allows 20 API calls) 3. Action ==> Use ZapierByCode + * to stringify the json response from PDF4Me 4. Action ==> Save it to Amazon S3 Now, from + * EdgeChains SDK we extract the files from S3 & upsert it to Pinecone via Chunk Size 512. */ @SpringBootApplication public class ZapierExample { private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID - private static final String ZAPIER_HOOK_URL = ""; // E.g. https://hooks.zapier.com/hooks/catch/18785910/2ia657b + private static final String ZAPIER_HOOK_URL = + ""; // E.g. https://hooks.zapier.com/hooks/catch/18785910/2ia657b private static final String PINECONE_AUTH_KEY = ""; private static final String PINECONE_API = ""; - private static PineconeEndpoint pineconeEndpoint; public static void main(String[] args) { @@ -86,7 +80,8 @@ public static void main(String[] args) { new SpringApplicationBuilder(ZapierExample.class).run(args); - OpenAiEmbeddingEndpoint adaEmbedding = new OpenAiEmbeddingEndpoint( + OpenAiEmbeddingEndpoint adaEmbedding = + new OpenAiEmbeddingEndpoint( OPENAI_EMBEDDINGS_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -94,44 +89,41 @@ public static void main(String[] args) { new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); pineconeEndpoint = - new PineconeEndpoint( - PINECONE_API, - PINECONE_AUTH_KEY, - adaEmbedding, - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + new PineconeEndpoint( + PINECONE_API, + PINECONE_AUTH_KEY, + adaEmbedding, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); } - @Bean public AmazonS3 s3Client() { String accessKey = ""; // YOUR_AWS_S3_ACCESS_KEY - String secretKey = ""; // YOUR_AWS_S3_SECRET_KEY + String secretKey = ""; // YOUR_AWS_S3_SECRET_KEY BasicAWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey); - return AmazonS3ClientBuilder - .standard() - .withRegion(Regions.fromName("us-east-1")) - .withCredentials(new AWSStaticCredentialsProvider(awsCredentials)) - .build(); + return AmazonS3ClientBuilder.standard() + .withRegion(Regions.fromName("us-east-1")) + .withCredentials(new AWSStaticCredentialsProvider(awsCredentials)) + .build(); } @RestController public class ZapierController { - @Autowired - private AmazonS3 s3Client; + @Autowired private AmazonS3 s3Client; // List of wiki urls triggered in parallel via WebHook. They are automatically parsed, // transformed to JSON and stored in AWS S3. /* - Examples: - "https://en.wikipedia.org/wiki/The_Weather_Company", - "https://en.wikipedia.org/wiki/Microsoft_Bing", - "https://en.wikipedia.org/wiki/Quora", - "https://en.wikipedia.org/wiki/Steve_Jobs", - "https://en.wikipedia.org/wiki/Michael_Jordan", - */ + Examples: + "https://en.wikipedia.org/wiki/The_Weather_Company", + "https://en.wikipedia.org/wiki/Microsoft_Bing", + "https://en.wikipedia.org/wiki/Quora", + "https://en.wikipedia.org/wiki/Steve_Jobs", + "https://en.wikipedia.org/wiki/Michael_Jordan", + */ @PostMapping("/etl") public void performETL(ArkRequest arkRequest) { @@ -153,13 +145,13 @@ public void performETL(ArkRequest arkRequest) { @PostMapping("/upsert-urls") public void upsertParsedURLs(ArkRequest arkRequest) throws IOException { - String namespace = arkRequest.getQueryParam("namespace"); - JSONObject body = arkRequest.getBody(); - String bucketName = body.getString("bucketName"); + String namespace = arkRequest.getQueryParam("namespace"); + JSONObject body = arkRequest.getBody(); + String bucketName = body.getString("bucketName"); - // Get all the files from S3 bucket - ListObjectsV2Request listObjectsRequest = new ListObjectsV2Request() - .withBucketName(bucketName); + // Get all the files from S3 bucket + ListObjectsV2Request listObjectsRequest = + new ListObjectsV2Request().withBucketName(bucketName); ListObjectsV2Result objectListing = s3Client.listObjectsV2(listObjectsRequest); @@ -179,26 +171,26 @@ public void upsertParsedURLs(ArkRequest arkRequest) throws IOException { System.out.println("Title: " + jsonObject.getString("title")); // Barack Obama System.out.println("Author: " + jsonObject.getString("author")); // Wikipedia Contributers System.out.println("Word Count: " + jsonObject.get("word_count")); // 23077 - System.out.println("Date Published: " + jsonObject.getString("date_published")); // Publish date + System.out.println( + "Date Published: " + jsonObject.getString("date_published")); // Publish date // Now, we extract content and Chunk it by 512 size; then upsert it to Pinecone // Normalize the extracted content.... - String normalizedText = Unidecode.decode(jsonObject.getString("content")).replaceAll("[\t\n\r]+", " "); + String normalizedText = + Unidecode.decode(jsonObject.getString("content")).replaceAll("[\t\n\r]+", " "); Chunker chunker = new Chunker(normalizedText); String[] arr = chunker.byChunkSize(512); // Upsert to Pinecone: PineconeRetrieval retrieval = - new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); retrieval.upsert(); - System.out.println("File is parsed: " + key); // For Logging - + System.out.println("File is parsed: " + key); // For Logging } } - } @PostMapping("/upsert-pdfs") @@ -209,8 +201,8 @@ public void upsertPDFs(ArkRequest arkRequest) throws IOException { String bucketName = body.getString("bucketName"); // Get all the files from S3 bucket - ListObjectsV2Request listObjectsRequest = new ListObjectsV2Request() - .withBucketName(bucketName); + ListObjectsV2Request listObjectsRequest = + new ListObjectsV2Request().withBucketName(bucketName); ListObjectsV2Result objectListing = s3Client.listObjectsV2(listObjectsRequest); @@ -232,21 +224,20 @@ public void upsertPDFs(ArkRequest arkRequest) throws IOException { System.out.println("Filename: " + jsonObject.getString("filename")); // abcd.pdf System.out.println("Extension: " + jsonObject.getString("file_extension")); // pdf - String normalizedText = Unidecode.decode(jsonObject.getString("text")).replaceAll("[\t\n\r]+", " "); + String normalizedText = + Unidecode.decode(jsonObject.getString("text")).replaceAll("[\t\n\r]+", " "); Chunker chunker = new Chunker(normalizedText); String[] arr = chunker.byChunkSize(512); // Upsert to Pinecone: PineconeRetrieval retrieval = - new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); retrieval.upsert(); - System.out.println("File is parsed: " + key); // For Logging - + System.out.println("File is parsed: " + key); // For Logging } } - } @DeleteMapping("/pinecone/deleteAll") @@ -271,6 +262,5 @@ private void zapWebHook(String url) { .retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(20))) // Using Fixed Delay.. .block(); } - } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java index efd348e8b..23714c735 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java @@ -9,7 +9,6 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.context.annotation.Bean; -import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; @SpringBootApplication diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java index 76ff81560..e2d19d6df 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java @@ -24,16 +24,12 @@ public class PineconeRetrieval { private int batchSize = 30; public PineconeRetrieval( - String[] arr, - PineconeEndpoint pineconeEndpoint, - String namespace, - ArkRequest arkRequest) { + String[] arr, PineconeEndpoint pineconeEndpoint, String namespace, ArkRequest arkRequest) { this.pineconeEndpoint = pineconeEndpoint; this.arkRequest = arkRequest; this.arr = arr; this.namespace = namespace; - Logger logger = LoggerFactory.getLogger(getClass()); if (pineconeEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); @@ -62,7 +58,11 @@ public void upsert() { } private WordEmbeddings generateEmbeddings(String input) { - return pineconeEndpoint.getEmbeddingEndpoint().embeddings(input, arkRequest).firstOrError().blockingGet(); + return pineconeEndpoint + .getEmbeddingEndpoint() + .embeddings(input, arkRequest) + .firstOrError() + .blockingGet(); } private void executeBatchUpsert(List wordEmbeddingsList) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java index 75a4b0769..48f82c6fe 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java @@ -118,7 +118,11 @@ public List upsert() { } private WordEmbeddings generateEmbeddings(String input) { - return postgresEndpoint.getEmbeddingEndpoint().embeddings(input, arkRequest).firstOrError().blockingGet(); + return postgresEndpoint + .getEmbeddingEndpoint() + .embeddings(input, arkRequest) + .firstOrError() + .blockingGet(); } private void upsertAndCollectIds( diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java index f9c230a15..35f5d1732 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java @@ -3,7 +3,6 @@ import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; @@ -67,7 +66,11 @@ public void upsert() { } private WordEmbeddings generateEmbeddings(String input) { - return redisEndpoint.getEmbeddingEndpoint().embeddings(input, arkRequest).firstOrError().blockingGet(); + return redisEndpoint + .getEmbeddingEndpoint() + .embeddings(input, arkRequest) + .firstOrError() + .blockingGet(); } private void executeBatchUpsert(List wordEmbeddingsList) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java index 630d85d10..7d115898e 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java @@ -1,6 +1,5 @@ package com.edgechain.lib.endpoint.impl.embeddings; -import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.BgeSmallService; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java index 7d9cf991d..609332c4b 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java @@ -2,7 +2,6 @@ import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.RetryPolicy; import com.fasterxml.jackson.annotation.JsonSubTypes; @@ -11,15 +10,11 @@ import java.io.Serializable; -@JsonTypeInfo( - use = JsonTypeInfo.Id.NAME, - include = JsonTypeInfo.As.PROPERTY, - property = "type" -) +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") @JsonSubTypes({ - @JsonSubTypes.Type(value = OpenAiEmbeddingEndpoint.class, name = "type1"), - @JsonSubTypes.Type(value = MiniLMEndpoint.class, name = "type2"), - @JsonSubTypes.Type(value = BgeSmallEndpoint.class, name = "type3"), + @JsonSubTypes.Type(value = OpenAiEmbeddingEndpoint.class, name = "type1"), + @JsonSubTypes.Type(value = MiniLMEndpoint.class, name = "type2"), + @JsonSubTypes.Type(value = BgeSmallEndpoint.class, name = "type3"), }) public abstract class EmbeddingEndpoint extends Endpoint implements Serializable { @@ -28,6 +23,7 @@ public abstract class EmbeddingEndpoint extends Endpoint implements Serializable private String rawText; public EmbeddingEndpoint() {} + public EmbeddingEndpoint(RetryPolicy retryPolicy) { super(retryPolicy); } @@ -65,5 +61,4 @@ public String getRawText() { public String getCallIdentifier() { return callIdentifier; } - } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java index 878a89af3..840cfac2c 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java @@ -1,6 +1,5 @@ package com.edgechain.lib.endpoint.impl.embeddings; -import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; import com.edgechain.lib.request.ArkRequest; @@ -11,8 +10,6 @@ import io.reactivex.rxjava3.core.Observable; import org.modelmapper.ModelMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import retrofit2.Retrofit; public class MiniLMEndpoint extends EmbeddingEndpoint { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java index cb3ed95ed..6dc9dea4f 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java @@ -1,6 +1,5 @@ package com.edgechain.lib.endpoint.impl.embeddings; -import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.OpenAiService; @@ -14,58 +13,59 @@ public class OpenAiEmbeddingEndpoint extends EmbeddingEndpoint { - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final OpenAiService openAiService = retrofit.create(OpenAiService.class); + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final OpenAiService openAiService = retrofit.create(OpenAiService.class); - private ModelMapper modelMapper = new ModelMapper(); + private ModelMapper modelMapper = new ModelMapper(); - private String orgId; - private String model; + private String orgId; + private String model; - public OpenAiEmbeddingEndpoint() {} + public OpenAiEmbeddingEndpoint() {} - public OpenAiEmbeddingEndpoint(String url, String apiKey, String orgId, String model) { - super(url, apiKey); - this.orgId = orgId; - this.model = model; - } + public OpenAiEmbeddingEndpoint(String url, String apiKey, String orgId, String model) { + super(url, apiKey); + this.orgId = orgId; + this.model = model; + } - public OpenAiEmbeddingEndpoint(String url, String apiKey,String orgId, String model, RetryPolicy retryPolicy ) { - super(url, apiKey, retryPolicy); - this.orgId = orgId; - this.model = model; - } + public OpenAiEmbeddingEndpoint( + String url, String apiKey, String orgId, String model, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + this.orgId = orgId; + this.model = model; + } - public String getModel() { - return model; - } + public String getModel() { + return model; + } - public String getOrgId() { - return orgId; - } + public String getOrgId() { + return orgId; + } - public void setOrgId(String orgId) { - this.orgId = orgId; - } + public void setOrgId(String orgId) { + this.orgId = orgId; + } - public void setModel(String model) { - this.model = model; - } + public void setModel(String model) { + this.model = model; + } - @Override - public Observable embeddings(String input, ArkRequest arkRequest) { + @Override + public Observable embeddings(String input, ArkRequest arkRequest) { - OpenAiEmbeddingEndpoint mapper = modelMapper.map(this, OpenAiEmbeddingEndpoint.class); - mapper.setRawText(input); + OpenAiEmbeddingEndpoint mapper = modelMapper.map(this, OpenAiEmbeddingEndpoint.class); + mapper.setRawText(input); - if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); - else mapper.setCallIdentifier("URI wasn't provided"); + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); - return Observable.fromSingle( - openAiService - .embeddings(mapper) - .map( - embeddingResponse -> - new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); - } + return Observable.fromSingle( + openAiService + .embeddings(mapper) + .map( + embeddingResponse -> + new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java index ecd86ad9f..366a06992 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java @@ -39,27 +39,33 @@ public class PineconeEndpoint extends Endpoint { public PineconeEndpoint() {} - public PineconeEndpoint(String url, String apiKey, EmbeddingEndpoint embeddingEndpoint) { super(url, apiKey); this.originalUrl = url; this.embeddingEndpoint = embeddingEndpoint; } - public PineconeEndpoint(String url, String apiKey, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + public PineconeEndpoint( + String url, String apiKey, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { super(url, apiKey, retryPolicy); this.embeddingEndpoint = embeddingEndpoint; this.originalUrl = url; } - public PineconeEndpoint(String url, String apiKey, String namespace, EmbeddingEndpoint embeddingEndpoint) { + public PineconeEndpoint( + String url, String apiKey, String namespace, EmbeddingEndpoint embeddingEndpoint) { super(url, apiKey); this.originalUrl = url; this.namespace = namespace; this.embeddingEndpoint = embeddingEndpoint; } - public PineconeEndpoint(String url, String apiKey, String namespace, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + public PineconeEndpoint( + String url, + String apiKey, + String namespace, + EmbeddingEndpoint embeddingEndpoint, + RetryPolicy retryPolicy) { super(url, apiKey, retryPolicy); this.originalUrl = url; this.namespace = namespace; @@ -134,8 +140,10 @@ public StringResponse batchUpsert(List wordEmbeddingsList, Strin return this.pineconeService.batchUpsert(mapper).blockingGet(); } - public Observable> query(String query, String namespace, int topK, ArkRequest arkRequest) { - WordEmbeddings wordEmbeddings = new EdgeChain<>(getEmbeddingEndpoint().embeddings(query, arkRequest)).get(); + public Observable> query( + String query, String namespace, int topK, ArkRequest arkRequest) { + WordEmbeddings wordEmbeddings = + new EdgeChain<>(getEmbeddingEndpoint().embeddings(query, arkRequest)).get(); PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); mapper.setWordEmbedding(wordEmbeddings); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java index 0a1805557..8fcee3b52 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java @@ -1,6 +1,5 @@ package com.edgechain.lib.endpoint.impl.index; -import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; @@ -24,536 +23,542 @@ public class PostgresEndpoint extends Endpoint { - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final PostgresService postgresService = retrofit.create(PostgresService.class); - private ModelMapper modelMapper = new ModelMapper(); - private String tableName; - private int lists; + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final PostgresService postgresService = retrofit.create(PostgresService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String tableName; + private int lists; - private String id; - private String namespace; + private String id; + private String namespace; - private String filename; + private String filename; - // Getters - private WordEmbeddings wordEmbedding; + // Getters + private WordEmbeddings wordEmbedding; - private List wordEmbeddingsList; + private List wordEmbeddingsList; - private PostgresDistanceMetric metric; - private int dimensions; - private int topK; - private int upperLimit; + private PostgresDistanceMetric metric; + private int dimensions; + private int topK; + private int upperLimit; - private int probes; - private String embeddingChunk; + private int probes; + private String embeddingChunk; - // Fields for metadata table - private List metadataTableNames; - private String metadata; - private String metadataId; - private List metadataList; - private String documentDate; + // Fields for metadata table + private List metadataTableNames; + private String metadata; + private String metadataId; + private List metadataList; + private String documentDate; - /** - * RRF * - */ - private RRFWeight textWeight; + /** RRF * */ + private RRFWeight textWeight; - private RRFWeight similarityWeight; - private RRFWeight dateWeight; + private RRFWeight similarityWeight; + private RRFWeight dateWeight; - private OrderRRFBy orderRRFBy; - private String searchQuery; + private OrderRRFBy orderRRFBy; + private String searchQuery; - private PostgresLanguage postgresLanguage; + private PostgresLanguage postgresLanguage; - // Join Table - private List idList; + // Join Table + private List idList; - private EmbeddingEndpoint embeddingEndpoint; + private EmbeddingEndpoint embeddingEndpoint; - public PostgresEndpoint() { - } + public PostgresEndpoint() {} public PostgresEndpoint(RetryPolicy retryPolicy) { super(retryPolicy); } public PostgresEndpoint(String tableName, EmbeddingEndpoint embeddingEndpoint) { - this.tableName = tableName; - this.embeddingEndpoint = embeddingEndpoint; - } - - public PostgresEndpoint(String tableName, String namespace, EmbeddingEndpoint embeddingEndpoint) { - this.tableName = tableName; - this.namespace = namespace; - this.embeddingEndpoint = embeddingEndpoint; - } - - public PostgresEndpoint(String tableName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { - super(retryPolicy); - this.tableName = tableName; - this.embeddingEndpoint = embeddingEndpoint; - } - - public PostgresEndpoint(String tableName, String namespace, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { - super(retryPolicy); - this.tableName = tableName; - this.namespace = namespace; - this.embeddingEndpoint = embeddingEndpoint; - } - - - public String getTableName() { - return tableName; - } - - public String getNamespace() { - return namespace; - } - - public EmbeddingEndpoint getEmbeddingEndpoint() { - return embeddingEndpoint; - } - - public void setTableName(String tableName) { - this.tableName = tableName; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - public void setMetadata(String metadata) { - this.metadata = metadata; - } - - public void setId(String id) { - this.id = id; - } - - public void setMetadataId(String metadataId) { - this.metadataId = metadataId; - } - - public void setMetadataList(List metadataList) { - this.metadataList = metadataList; - } - - public void setEmbeddingChunk(String embeddingChunk) { - this.embeddingChunk = embeddingChunk; - } - - public int getUpperLimit() { - return upperLimit; - } - - public void setUpperLimit(int upperLimit) { - this.upperLimit = upperLimit; - } - - private void setLists(int lists) { - this.lists = lists; - } - - private void setFilename(String filename) { - this.filename = filename; - } - - private void setWordEmbedding(WordEmbeddings wordEmbedding) { - this.wordEmbedding = wordEmbedding; - } - - private void setWordEmbeddingsList(List wordEmbeddingsList) { - this.wordEmbeddingsList = wordEmbeddingsList; - } - - private void setMetric(PostgresDistanceMetric metric) { - this.metric = metric; - } - - private void setDimensions(int dimensions) { - this.dimensions = dimensions; - } - - private void setTopK(int topK) { - this.topK = topK; - } - - private void setProbes(int probes) { - this.probes = probes; - } - - private void setMetadataTableNames(List metadataTableNames) { - this.metadataTableNames = metadataTableNames; - } - - private void setDocumentDate(String documentDate) { - this.documentDate = documentDate; - } - - private void setTextWeight(RRFWeight textWeight) { - this.textWeight = textWeight; - } + this.tableName = tableName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint(String tableName, String namespace, EmbeddingEndpoint embeddingEndpoint) { + this.tableName = tableName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint( + String tableName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint( + String tableName, + String namespace, + EmbeddingEndpoint embeddingEndpoint, + RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public String getTableName() { + return tableName; + } + + public String getNamespace() { + return namespace; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + public void setMetadata(String metadata) { + this.metadata = metadata; + } + + public void setId(String id) { + this.id = id; + } + + public void setMetadataId(String metadataId) { + this.metadataId = metadataId; + } + + public void setMetadataList(List metadataList) { + this.metadataList = metadataList; + } + + public void setEmbeddingChunk(String embeddingChunk) { + this.embeddingChunk = embeddingChunk; + } + + public int getUpperLimit() { + return upperLimit; + } + + public void setUpperLimit(int upperLimit) { + this.upperLimit = upperLimit; + } + + private void setLists(int lists) { + this.lists = lists; + } + + private void setFilename(String filename) { + this.filename = filename; + } - private void setSimilarityWeight(RRFWeight similarityWeight) { - this.similarityWeight = similarityWeight; - } + private void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + private void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + private void setMetric(PostgresDistanceMetric metric) { + this.metric = metric; + } - private void setDateWeight(RRFWeight dateWeight) { - this.dateWeight = dateWeight; - } + private void setDimensions(int dimensions) { + this.dimensions = dimensions; + } - private void setOrderRRFBy(OrderRRFBy orderRRFBy) { - this.orderRRFBy = orderRRFBy; - } + private void setTopK(int topK) { + this.topK = topK; + } - private void setSearchQuery(String searchQuery) { - this.searchQuery = searchQuery; - } + private void setProbes(int probes) { + this.probes = probes; + } + + private void setMetadataTableNames(List metadataTableNames) { + this.metadataTableNames = metadataTableNames; + } - private void setPostgresLanguage(PostgresLanguage postgresLanguage) { - this.postgresLanguage = postgresLanguage; - } + private void setDocumentDate(String documentDate) { + this.documentDate = documentDate; + } + + private void setTextWeight(RRFWeight textWeight) { + this.textWeight = textWeight; + } + + private void setSimilarityWeight(RRFWeight similarityWeight) { + this.similarityWeight = similarityWeight; + } + + private void setDateWeight(RRFWeight dateWeight) { + this.dateWeight = dateWeight; + } + + private void setOrderRRFBy(OrderRRFBy orderRRFBy) { + this.orderRRFBy = orderRRFBy; + } - private void setIdList(List idList) { - this.idList = idList; - } + private void setSearchQuery(String searchQuery) { + this.searchQuery = searchQuery; + } + + private void setPostgresLanguage(PostgresLanguage postgresLanguage) { + this.postgresLanguage = postgresLanguage; + } + + private void setIdList(List idList) { + this.idList = idList; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + + // Getters + + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public int getDimensions() { + return dimensions; + } + + public int getTopK() { + return topK; + } + + public String getFilename() { + return filename; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } - public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { - this.embeddingEndpoint = embeddingEndpoint; - } -// Getters - - public WordEmbeddings getWordEmbedding() { - return wordEmbedding; - } - - public int getDimensions() { - return dimensions; - } - - public int getTopK() { - return topK; - } - - public String getFilename() { - return filename; - } - - public List getWordEmbeddingsList() { - return wordEmbeddingsList; - } - - public PostgresDistanceMetric getMetric() { - return metric; - } - - public int getLists() { - return lists; - } - - public int getProbes() { - return probes; - } - - public List getMetadataTableNames() { - return metadataTableNames; - } - - public String getMetadata() { - return metadata; - } - - public String getMetadataId() { - return metadataId; - } - - public String getId() { - return id; - } - - public List getMetadataList() { - return metadataList; - } - - public String getEmbeddingChunk() { - return embeddingChunk; - } - - public String getDocumentDate() { - return documentDate; - } - - public RRFWeight getTextWeight() { - return textWeight; - } - - public RRFWeight getSimilarityWeight() { - return similarityWeight; - } - - public RRFWeight getDateWeight() { - return dateWeight; - } - - public OrderRRFBy getOrderRRFBy() { - return orderRRFBy; - } - - public String getSearchQuery() { - return searchQuery; - } - - - public PostgresLanguage getPostgresLanguage() { - return postgresLanguage; - } - - public List getIdList() { - return idList; - } - - public StringResponse upsert( - WordEmbeddings wordEmbeddings, - String filename, - int dimension, - PostgresDistanceMetric metric) { - - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setWordEmbedding(wordEmbeddings); - mapper.setFilename(filename); - mapper.setDimensions(dimension); - mapper.setMetric(metric); - return this.postgresService.upsert(mapper).blockingGet(); - } - - public StringResponse createTable(int dimensions, PostgresDistanceMetric metric, int lists) { - - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setDimensions(dimensions); - mapper.setMetric(metric); - mapper.setLists(lists); - - return this.postgresService.createTable(mapper).blockingGet(); - } - - public StringResponse createMetadataTable(String metadataTableName) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setMetadataTableNames(List.of(metadataTableName)); - return this.postgresService.createMetadataTable(mapper).blockingGet(); - } - - public List upsert( - List wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) { - - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setWordEmbeddingsList(wordEmbeddingsList); - mapper.setFilename(filename); - mapper.setPostgresLanguage(postgresLanguage); - - return this.postgresService.batchUpsert(mapper).blockingGet(); - } - - public StringResponse insertMetadata( - String metadataTableName, String metadata, String documentDate) { - - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setMetadata(metadata); - mapper.setDocumentDate(documentDate); - mapper.setMetadataTableNames(List.of(metadataTableName)); - return this.postgresService.insertMetadata(mapper).blockingGet(); - } - - public List batchInsertMetadata(List metadataList) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setMetadataList(metadataList); - return this.postgresService.batchInsertMetadata(mapper).blockingGet(); - } - - public StringResponse insertIntoJoinTable( - String metadataTableName, String id, String metadataId) { - - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setId(id); - mapper.setMetadataId(metadataId); - mapper.setMetadataTableNames(List.of(metadataTableName)); - - return this.postgresService.insertIntoJoinTable(mapper).blockingGet(); - } - - public StringResponse batchInsertIntoJoinTable( - String metadataTableName, List idList, String metadataId) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setIdList(idList); - mapper.setMetadataId(metadataId); - mapper.setMetadataTableNames(List.of(metadataTableName)); - return this.postgresService.batchInsertIntoJoinTable(mapper).blockingGet(); - } - - public Observable> query( - List inputList, PostgresDistanceMetric metric, int topK, int upperLimit, ArkRequest arkRequest) { - - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - - List endpointEmbeddingList = - Observable.fromIterable(inputList) - .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + public PostgresDistanceMetric getMetric() { + return metric; + } + + public int getLists() { + return lists; + } + + public int getProbes() { + return probes; + } + + public List getMetadataTableNames() { + return metadataTableNames; + } + + public String getMetadata() { + return metadata; + } + + public String getMetadataId() { + return metadataId; + } + + public String getId() { + return id; + } + + public List getMetadataList() { + return metadataList; + } + + public String getEmbeddingChunk() { + return embeddingChunk; + } + + public String getDocumentDate() { + return documentDate; + } + + public RRFWeight getTextWeight() { + return textWeight; + } + + public RRFWeight getSimilarityWeight() { + return similarityWeight; + } + + public RRFWeight getDateWeight() { + return dateWeight; + } + + public OrderRRFBy getOrderRRFBy() { + return orderRRFBy; + } + + public String getSearchQuery() { + return searchQuery; + } + + public PostgresLanguage getPostgresLanguage() { + return postgresLanguage; + } + + public List getIdList() { + return idList; + } + + public StringResponse upsert( + WordEmbeddings wordEmbeddings, + String filename, + int dimension, + PostgresDistanceMetric metric) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setFilename(filename); + mapper.setDimensions(dimension); + mapper.setMetric(metric); + return this.postgresService.upsert(mapper).blockingGet(); + } + + public StringResponse createTable(int dimensions, PostgresDistanceMetric metric, int lists) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setDimensions(dimensions); + mapper.setMetric(metric); + mapper.setLists(lists); + + return this.postgresService.createTable(mapper).blockingGet(); + } + + public StringResponse createMetadataTable(String metadataTableName) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.createMetadataTable(mapper).blockingGet(); + } + + public List upsert( + List wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + mapper.setFilename(filename); + mapper.setPostgresLanguage(postgresLanguage); + + return this.postgresService.batchUpsert(mapper).blockingGet(); + } + + public StringResponse insertMetadata( + String metadataTableName, String metadata, String documentDate) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadata(metadata); + mapper.setDocumentDate(documentDate); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.insertMetadata(mapper).blockingGet(); + } + + public List batchInsertMetadata(List metadataList) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataList(metadataList); + return this.postgresService.batchInsertMetadata(mapper).blockingGet(); + } + + public StringResponse insertIntoJoinTable( + String metadataTableName, String id, String metadataId) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setId(id); + mapper.setMetadataId(metadataId); + mapper.setMetadataTableNames(List.of(metadataTableName)); + + return this.postgresService.insertIntoJoinTable(mapper).blockingGet(); + } + + public StringResponse batchInsertIntoJoinTable( + String metadataTableName, List idList, String metadataId) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setIdList(idList); + mapper.setMetadataId(metadataId); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.batchInsertIntoJoinTable(mapper).blockingGet(); + } + + public Observable> query( + List inputList, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + ArkRequest arkRequest) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) .flatMap( - bufferedList -> - Observable.fromIterable(bufferedList) - .flatMap( - res -> - Observable.fromCallable( - () -> - new EdgeChain<>( - embeddingEndpoint.embeddings(res, arkRequest)) - .get()) - .subscribeOn(Schedulers.io()))) - .toList() - .blockingGet(); - - mapper.setWordEmbeddingsList(endpointEmbeddingList); - mapper.setTopK(topK); - mapper.setUpperLimit(upperLimit); - mapper.setMetric(metric); - mapper.setProbes(1); - return Observable.fromSingle(this.postgresService.query(mapper)); - } - - public Observable> query( - List inputList, - PostgresDistanceMetric metric, - int topK, - int upperLimit, - int probes, - ArkRequest arkRequest) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - - List endpointEmbeddingList = - Observable.fromIterable(inputList) - .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.query(mapper)); + } + + public Observable> query( + List inputList, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + int probes, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) .flatMap( - bufferedList -> - Observable.fromIterable(bufferedList) - .flatMap( - res -> - Observable.fromCallable( - () -> - new EdgeChain<>( - embeddingEndpoint.embeddings(res, arkRequest)) - .get()) - .subscribeOn(Schedulers.io()))) - .toList() - .blockingGet(); - - mapper.setWordEmbeddingsList(endpointEmbeddingList); - mapper.setMetric(metric); - mapper.setProbes(probes); - mapper.setTopK(topK); - mapper.setUpperLimit(upperLimit); - return Observable.fromSingle(this.postgresService.query(mapper)); - } - - public Observable> queryRRF( - String metadataTable, - List inputList, - RRFWeight textWeight, - RRFWeight similarityWeight, - RRFWeight dateWeight, - OrderRRFBy orderRRFBy, - String searchQuery, - PostgresLanguage postgresLanguage, - int probes, - PostgresDistanceMetric metric, - int topK, - int upperLimit, - ArkRequest arkRequest) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setMetadataTableNames(List.of(metadataTable)); - - List endpointEmbeddingList = - Observable.fromIterable(inputList) - .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setMetric(metric); + mapper.setProbes(probes); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + return Observable.fromSingle(this.postgresService.query(mapper)); + } + + public Observable> queryRRF( + String metadataTable, + List inputList, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + OrderRRFBy orderRRFBy, + String searchQuery, + PostgresLanguage postgresLanguage, + int probes, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(List.of(metadataTable)); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) .flatMap( - bufferedList -> - Observable.fromIterable(bufferedList) - .flatMap( - res -> - Observable.fromCallable( - () -> - new EdgeChain<>( - embeddingEndpoint.embeddings(res, arkRequest)) - .get()) - .subscribeOn(Schedulers.io()))) - .toList() - .blockingGet(); - - mapper.setWordEmbeddingsList(endpointEmbeddingList); - mapper.setTextWeight(textWeight); - mapper.setSimilarityWeight(similarityWeight); - mapper.setDateWeight(dateWeight); - mapper.setOrderRRFBy(orderRRFBy); - mapper.setSearchQuery(searchQuery); - mapper.setPostgresLanguage(postgresLanguage); - mapper.setProbes(probes); - mapper.setMetric(metric); - mapper.setTopK(topK); - mapper.setUpperLimit(upperLimit); - return Observable.fromSingle(this.postgresService.queryRRF(mapper)); - } - - public Observable> queryWithMetadata( - List metadataTableNames, - WordEmbeddings wordEmbeddings, - PostgresDistanceMetric metric, - int topK) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setMetadataTableNames(metadataTableNames); - mapper.setWordEmbedding(wordEmbeddings); - mapper.setTopK(topK); - mapper.setMetric(metric); - mapper.setProbes(1); - return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); - } - - public Observable> queryWithMetadata( - List metadataTableNames, - String input, - PostgresDistanceMetric metric, - int topK, - ArkRequest arkRequest) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - - WordEmbeddings wordEmbeddings = new EdgeChain<>(embeddingEndpoint.embeddings(input, arkRequest)).get(); - - mapper.setMetadataTableNames(metadataTableNames); - mapper.setWordEmbedding(wordEmbeddings); - mapper.setTopK(topK); - mapper.setMetric(metric); - mapper.setProbes(1); - return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); - } - - public Observable> getSimilarMetadataChunk(String embeddingChunk) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setEmbeddingChunk(embeddingChunk); - return Observable.fromSingle(this.postgresService.getSimilarMetadataChunk(mapper)); - } - - public Observable> getAllChunks(String tableName, String filename) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setTableName(tableName); - mapper.setFilename(filename); - - return Observable.fromSingle(this.postgresService.getAllChunks(this)); - } - - public StringResponse deleteAll(String tableName, String namespace) { - PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); - mapper.setTableName(tableName); - mapper.setNamespace(namespace); - return this.postgresService.deleteAll(mapper).blockingGet(); - } + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setTextWeight(textWeight); + mapper.setSimilarityWeight(similarityWeight); + mapper.setDateWeight(dateWeight); + mapper.setOrderRRFBy(orderRRFBy); + mapper.setSearchQuery(searchQuery); + mapper.setPostgresLanguage(postgresLanguage); + mapper.setProbes(probes); + mapper.setMetric(metric); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + return Observable.fromSingle(this.postgresService.queryRRF(mapper)); + } + + public Observable> queryWithMetadata( + List metadataTableNames, + WordEmbeddings wordEmbeddings, + PostgresDistanceMetric metric, + int topK) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(metadataTableNames); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setTopK(topK); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); + } + + public Observable> queryWithMetadata( + List metadataTableNames, + String input, + PostgresDistanceMetric metric, + int topK, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + WordEmbeddings wordEmbeddings = + new EdgeChain<>(embeddingEndpoint.embeddings(input, arkRequest)).get(); + + mapper.setMetadataTableNames(metadataTableNames); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setTopK(topK); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); + } + + public Observable> getSimilarMetadataChunk(String embeddingChunk) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setEmbeddingChunk(embeddingChunk); + return Observable.fromSingle(this.postgresService.getSimilarMetadataChunk(mapper)); + } + + public Observable> getAllChunks(String tableName, String filename) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setTableName(tableName); + mapper.setFilename(filename); + + return Observable.fromSingle(this.postgresService.getAllChunks(this)); + } + + public StringResponse deleteAll(String tableName, String namespace) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setTableName(tableName); + mapper.setNamespace(namespace); + return this.postgresService.deleteAll(mapper).blockingGet(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java index 2cead0c22..dbc8941f8 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java @@ -1,6 +1,5 @@ package com.edgechain.lib.endpoint.impl.index; -import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; @@ -51,7 +50,8 @@ public RedisEndpoint(String indexName, EmbeddingEndpoint embeddingEndpoint) { this.embeddingEndpoint = embeddingEndpoint; } - public RedisEndpoint(String indexName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + public RedisEndpoint( + String indexName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { super(retryPolicy); this.indexName = indexName; this.embeddingEndpoint = embeddingEndpoint; @@ -63,7 +63,11 @@ public RedisEndpoint(String indexName, String namespace, EmbeddingEndpoint embed this.embeddingEndpoint = embeddingEndpoint; } - public RedisEndpoint(String indexName, String namespace, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + public RedisEndpoint( + String indexName, + String namespace, + EmbeddingEndpoint embeddingEndpoint, + RetryPolicy retryPolicy) { super(retryPolicy); this.indexName = indexName; this.namespace = namespace; @@ -169,7 +173,8 @@ public StringResponse upsert(WordEmbeddings wordEmbedding) { public Observable> query(String input, int topK, ArkRequest arkRequest) { - WordEmbeddings wordEmbedding = new EdgeChain<>(embeddingEndpoint.embeddings(input,arkRequest)).get(); + WordEmbeddings wordEmbedding = + new EdgeChain<>(embeddingEndpoint.embeddings(input, arkRequest)).get(); RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); mapper.setTopK(topK); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java index a68ca723c..1ca7f4943 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java @@ -14,163 +14,169 @@ public class AirtableEndpoint extends Endpoint { - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final AirtableService airtableService = retrofit.create(AirtableService.class); - - private ModelMapper modelMapper = new ModelMapper(); - - private String baseId; - private String apiKey; - - private List ids; - private String tableName; - private List airtableRecordList; - private boolean typecast = false; - - private AirtableQueryBuilder airtableQueryBuilder; - - public AirtableEndpoint() {} - - public AirtableEndpoint(String baseId, String apiKey) { - this.baseId = baseId; - this.apiKey = apiKey; - } - - public void setIds(List ids) { - this.ids = ids; - } - - public void setTableName(String tableName) { - this.tableName = tableName; - } - - public void setAirtableRecordList(List airtableRecordList) { - this.airtableRecordList = airtableRecordList; - } - - public void setTypecast(boolean typecast) { - this.typecast = typecast; - } - - public void setAirtableQueryBuilder(AirtableQueryBuilder airtableQueryBuilder) { - this.airtableQueryBuilder = airtableQueryBuilder; - } - - public void setBaseId(String baseId) { - this.baseId = baseId; - } - - @Override - public void setApiKey(String apiKey) { - this.apiKey = apiKey; - } - - public String getBaseId() { - return baseId; - } - - public String getApiKey() { - return apiKey; - } - - public String getTableName() { - return tableName; - } - - public List getIds() { - return ids; - } - - public List getAirtableRecordList() { - return airtableRecordList; - } - - public boolean isTypecast() { - return typecast; - } - - public AirtableQueryBuilder getAirtableQueryBuilder() { - return airtableQueryBuilder; - } - - public Observable> findAll(String tableName, AirtableQueryBuilder builder){ - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class);; - mapper.setTableName(tableName); - mapper.setAirtableQueryBuilder(builder); - return Observable.fromSingle(this.airtableService.findAll(mapper)); - } - public Observable> findAll(String tableName){ - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setTableName(tableName); - mapper.setAirtableQueryBuilder(new AirtableQueryBuilder()); - - return Observable.fromSingle(this.airtableService.findAll(mapper)); - } - - public Observable findById(String tableName, String id){ - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setTableName(tableName); - mapper.setIds(List.of(id)); - return Observable.fromSingle(this.airtableService.findById(mapper)); - } - - public Observable> create(String tableName, List airtableRecordList) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setAirtableRecordList(airtableRecordList); - mapper.setTableName(tableName); - return Observable.fromSingle(this.airtableService.create(mapper)); - } - - public Observable> create(String tableName, List airtableRecordList, boolean typecast) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setAirtableRecordList(airtableRecordList); - mapper.setTableName(tableName); - mapper.setTypecast(typecast); - return Observable.fromSingle(this.airtableService.create(mapper)); - } - - public Observable> create(String tableName, AirtableRecord airtableRecord) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setAirtableRecordList(List.of(airtableRecord)); - mapper.setTableName(tableName); - return Observable.fromSingle(this.airtableService.create(mapper)); - } - - public Observable> update(String tableName, List airtableRecordList) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setAirtableRecordList(airtableRecordList); - mapper.setTableName(tableName); - - return Observable.fromSingle(this.airtableService.update(mapper)); - } - - public Observable> update(String tableName, List airtableRecordList, boolean typecast) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setAirtableRecordList(airtableRecordList); - mapper.setTableName(tableName); - mapper.setTypecast(typecast); - - return Observable.fromSingle(this.airtableService.update(mapper)); - } - - public Observable> update(String tableName, AirtableRecord airtableRecord) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setAirtableRecordList(List.of(airtableRecord)); - mapper.setTableName(tableName); - return Observable.fromSingle(this.airtableService.update(mapper)); - } - - public Observable> delete(String tableName, List ids) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setIds(ids); - mapper.setTableName(tableName); - return Observable.fromSingle(this.airtableService.delete(mapper)); - } - - public Observable> delete(String tableName, String id) { - AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); - mapper.setIds(List.of(id)); - mapper.setTableName(tableName); - return Observable.fromSingle(this.airtableService.delete(mapper)); - } + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final AirtableService airtableService = retrofit.create(AirtableService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String baseId; + private String apiKey; + + private List ids; + private String tableName; + private List airtableRecordList; + private boolean typecast = false; + + private AirtableQueryBuilder airtableQueryBuilder; + + public AirtableEndpoint() {} + + public AirtableEndpoint(String baseId, String apiKey) { + this.baseId = baseId; + this.apiKey = apiKey; + } + + public void setIds(List ids) { + this.ids = ids; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public void setAirtableRecordList(List airtableRecordList) { + this.airtableRecordList = airtableRecordList; + } + + public void setTypecast(boolean typecast) { + this.typecast = typecast; + } + + public void setAirtableQueryBuilder(AirtableQueryBuilder airtableQueryBuilder) { + this.airtableQueryBuilder = airtableQueryBuilder; + } + + public void setBaseId(String baseId) { + this.baseId = baseId; + } + + @Override + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseId() { + return baseId; + } + + public String getApiKey() { + return apiKey; + } + + public String getTableName() { + return tableName; + } + + public List getIds() { + return ids; + } + + public List getAirtableRecordList() { + return airtableRecordList; + } + + public boolean isTypecast() { + return typecast; + } + + public AirtableQueryBuilder getAirtableQueryBuilder() { + return airtableQueryBuilder; + } + + public Observable> findAll(String tableName, AirtableQueryBuilder builder) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + ; + mapper.setTableName(tableName); + mapper.setAirtableQueryBuilder(builder); + return Observable.fromSingle(this.airtableService.findAll(mapper)); + } + + public Observable> findAll(String tableName) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setTableName(tableName); + mapper.setAirtableQueryBuilder(new AirtableQueryBuilder()); + + return Observable.fromSingle(this.airtableService.findAll(mapper)); + } + + public Observable findById(String tableName, String id) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setTableName(tableName); + mapper.setIds(List.of(id)); + return Observable.fromSingle(this.airtableService.findById(mapper)); + } + + public Observable> create( + String tableName, List airtableRecordList) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> create( + String tableName, List airtableRecordList, boolean typecast) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + mapper.setTypecast(typecast); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> create(String tableName, AirtableRecord airtableRecord) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(List.of(airtableRecord)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> update( + String tableName, List airtableRecordList) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> update( + String tableName, List airtableRecordList, boolean typecast) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + mapper.setTypecast(typecast); + + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> update(String tableName, AirtableRecord airtableRecord) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(List.of(airtableRecord)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> delete(String tableName, List ids) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setIds(ids); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.delete(mapper)); + } + + public Observable> delete(String tableName, String id) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setIds(List.of(id)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.delete(mapper)); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java index 50350f4e8..19d571ef8 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java @@ -56,7 +56,6 @@ public class OpenAiChatEndpoint extends Endpoint { public OpenAiChatEndpoint() {} - public OpenAiChatEndpoint(String url, String apiKey, String model) { super(url, apiKey, null); this.model = model; @@ -283,7 +282,7 @@ public Observable chatCompletion( mapper.setChatMessages(List.of(new ChatMessage(this.role, input))); mapper.setChainName(chainName); - return chatCompletion(mapper,arkRequest); + return chatCompletion(mapper, arkRequest); } public Observable chatCompletion( @@ -294,7 +293,7 @@ public Observable chatCompletion( mapper.setChainName(chainName); mapper.setJsonnetLoader(loader); - return chatCompletion(mapper,arkRequest); + return chatCompletion(mapper, arkRequest); } public Observable chatCompletion( @@ -302,7 +301,7 @@ public Observable chatCompletion( OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); mapper.setChatMessages(chatMessages); mapper.setChainName(chainName); - return chatCompletion(mapper,arkRequest); + return chatCompletion(mapper, arkRequest); } public Observable chatCompletion( @@ -319,7 +318,8 @@ public Observable chatCompletion( return chatCompletion(mapper, arkRequest); } - private Observable chatCompletion(OpenAiChatEndpoint mapper, ArkRequest arkRequest) { + private Observable chatCompletion( + OpenAiChatEndpoint mapper, ArkRequest arkRequest) { if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); else mapper.setCallIdentifier("URI wasn't provided"); @@ -341,7 +341,7 @@ public Observable completion(String input, ArkRequest arkReq if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); else this.callIdentifier = "URI wasn't provided"; - this.input = input; + this.input = input; return Observable.fromSingle(this.openAiService.completion(this)); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java index 815834b7c..073ba6769 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java @@ -240,7 +240,7 @@ public EdgeChain> query(PostgresEndpoint postgresEn postgresEndpoint.getMetric(), embeddings, postgresEndpoint.getTopK(), - postgresEndpoint.getUpperLimit()); + postgresEndpoint.getUpperLimit()); for (Map row : rows) { @@ -285,9 +285,9 @@ public EdgeChain> queryRRF(PostgresEndpoint postgre try { List wordEmbeddingsList = new ArrayList<>(); List> embeddings = - postgresEndpoint.getWordEmbeddingsList().stream() - .map(WordEmbeddings::getValues) - .toList(); + postgresEndpoint.getWordEmbeddingsList().stream() + .map(WordEmbeddings::getValues) + .toList(); List> rows = this.repository.queryRRF( diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java index 080b176fe..cd2263fe8 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java @@ -6,7 +6,6 @@ import java.time.LocalDateTime; import java.util.List; -import java.util.StringJoiner; public class PostgresWordEmbeddings implements ArkObject { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java index c8b08b0dc..a626c5741 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java @@ -234,8 +234,10 @@ public List> query( } if (values.size() > 1) { - return jdbcTemplate.queryForList(String.format( - "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER BY score DESC LIMIT %s;", + return jdbcTemplate.queryForList( + String.format( + "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER" + + " BY score DESC LIMIT %s;", query, upperLimit)); } else { return jdbcTemplate.queryForList(query.toString()); @@ -243,108 +245,111 @@ public List> query( } public List> queryRRF( - String tableName, - String namespace, - String metadataTableName, - List> values, - RRFWeight textWeight, - RRFWeight similarityWeight, - RRFWeight dateWeight, - String searchQuery, - PostgresLanguage language, - int probes, - PostgresDistanceMetric metric, - int topK, - int upperLimit, - OrderRRFBy orderRRFBy) { + String tableName, + String namespace, + String metadataTableName, + List> values, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + String searchQuery, + PostgresLanguage language, + int probes, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + OrderRRFBy orderRRFBy) { jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); StringBuilder query = new StringBuilder(); - for(int i = 0; i < values.size(); i++) { + for (int i = 0; i < values.size(); i++) { String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i))); - query .append("(") - .append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", - textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", - similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight())) - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", - dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) - .append("FROM ( ") - .append( - "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," - + " svtm.document_date, svtm.metadata, ") - .append( - String.format( - "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", - language.getValue(), searchQuery)); + query + .append("(") + .append( + "SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", + textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", + similarityWeight.getBaseWeight().getValue(), + similarityWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", + dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) + .append("FROM ( ") + .append( + "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," + + " svtm.document_date, svtm.metadata, ") + .append( + String.format( + "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", + language.getValue(), searchQuery)); switch (metric) { case COSINE -> query.append( - String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); + String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); case IP -> query.append( - String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); + String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings)); default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); } query - .append("CASE ") - .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling - .append( - "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" - + " svtm.document_date) ") - .append("END AS date_rank ") - .append("FROM ") - .append( - String.format( - "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE" - + " namespace = '%s'", - tableName, namespace)); + .append("CASE ") + .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling + .append( + "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" + + " svtm.document_date) ") + .append("END AS date_rank ") + .append("FROM ") + .append( + String.format( + "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s" + + " WHERE namespace = '%s'", + tableName, namespace)); switch (metric) { case COSINE -> query - .append(" ORDER BY embedding <=> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); case IP -> query - .append(" ORDER BY embedding <#> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); case L2 -> query - .append(" ORDER BY embedding <-> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); default -> throw new IllegalArgumentException("Invalid metric: " + metric); } query - .append(")") - .append(" sv ") - .append("JOIN ") - .append(tableName.concat("_join_").concat(metadataTableName)) - .append(" jtm ON sv.id = jtm.id ") - .append("JOIN ") - .append(tableName.concat("_").concat(metadataTableName)) - .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") - .append(") subquery "); + .append(")") + .append(" sv ") + .append("JOIN ") + .append(tableName.concat("_join_").concat(metadataTableName)) + .append(" jtm ON sv.id = jtm.id ") + .append("JOIN ") + .append(tableName.concat("_").concat(metadataTableName)) + .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") + .append(") subquery "); switch (orderRRFBy) { case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC"); @@ -358,12 +363,13 @@ public List> queryRRF( if (i < values.size() - 1) { query.append(" UNION ALL ").append("\n"); } - } if (values.size() > 1) { - return jdbcTemplate.queryForList(String.format( - "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER BY rrf_score DESC LIMIT %s;", + return jdbcTemplate.queryForList( + String.format( + "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER" + + " BY rrf_score DESC LIMIT %s;", query, upperLimit)); } else { return jdbcTemplate.queryForList(query.toString()); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java index 377660442..7e6945568 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java @@ -10,118 +10,121 @@ import java.util.Locale; public class AirtableQueryBuilder { - private String offset; - private List fields; - private String filterByFormula; - private int maxRecords = 100; - private int pageSize = 100; - private String sortField; - private String sortDirection; - private String view; - private String cellFormat; - private String timeZone; - private String userLocale; - - public void offset(String offset) { - this.offset = offset; - } - - public void fields(String... fields) { - this.fields = Arrays.asList(fields); - } - - public void filterByFormula(String formula) { - this.filterByFormula = formula; - } - - public void filterByFormula(AirtableFunction function, AirtableFormula.Object... objects){ - this.filterByFormula = function.apply(objects); - } - - public void filterByFormula(AirtableOperator operator, AirtableFormula.Object left, AirtableFormula.Object right, AirtableFormula.Object... others) { - this.filterByFormula = operator.apply(left, right, others); - } - - public void maxRecords(int maxRecords) { - this.maxRecords = maxRecords; - } - - public void pageSize(int pageSize) { - this.pageSize = pageSize; - } - - public void sort(String field, String direction) { - this.sortField = field; - this.sortDirection = direction; - } - - public void view(String view) { - this.view = view; - } - - public void cellFormat(String cellFormat) { - this.cellFormat = cellFormat; - } - - public void timeZone(String timeZone) { - this.timeZone = timeZone; - } - - public void timeZone(ZoneId zoneId) { - this.timeZone = zoneId.getId(); - } - - public void userLocale(String userLocale){ - this.userLocale = userLocale; - } - - public void userLocale(Locale locale) { - this.userLocale = locale.toLanguageTag().toLowerCase(); - } - - // Getters for QuerySpec fields - public String getOffset() { - return offset; - } - - public List getFields() { - return fields; - } - - public String getFilterByFormula() { - return filterByFormula; - } - - public int getMaxRecords() { - return maxRecords; - } - - public int getPageSize() { - return pageSize; - } - - public String getSortField() { - return sortField; - } - - public String getSortDirection() { - return sortDirection; - } - - public String getView() { - return view; - } - - public String getCellFormat() { - return cellFormat; - } - - public String getTimeZone() { - return timeZone; - } - - public String getUserLocale() { - return userLocale; - } - -} \ No newline at end of file + private String offset; + private List fields; + private String filterByFormula; + private int maxRecords = 100; + private int pageSize = 100; + private String sortField; + private String sortDirection; + private String view; + private String cellFormat; + private String timeZone; + private String userLocale; + + public void offset(String offset) { + this.offset = offset; + } + + public void fields(String... fields) { + this.fields = Arrays.asList(fields); + } + + public void filterByFormula(String formula) { + this.filterByFormula = formula; + } + + public void filterByFormula(AirtableFunction function, AirtableFormula.Object... objects) { + this.filterByFormula = function.apply(objects); + } + + public void filterByFormula( + AirtableOperator operator, + AirtableFormula.Object left, + AirtableFormula.Object right, + AirtableFormula.Object... others) { + this.filterByFormula = operator.apply(left, right, others); + } + + public void maxRecords(int maxRecords) { + this.maxRecords = maxRecords; + } + + public void pageSize(int pageSize) { + this.pageSize = pageSize; + } + + public void sort(String field, String direction) { + this.sortField = field; + this.sortDirection = direction; + } + + public void view(String view) { + this.view = view; + } + + public void cellFormat(String cellFormat) { + this.cellFormat = cellFormat; + } + + public void timeZone(String timeZone) { + this.timeZone = timeZone; + } + + public void timeZone(ZoneId zoneId) { + this.timeZone = zoneId.getId(); + } + + public void userLocale(String userLocale) { + this.userLocale = userLocale; + } + + public void userLocale(Locale locale) { + this.userLocale = locale.toLanguageTag().toLowerCase(); + } + + // Getters for QuerySpec fields + public String getOffset() { + return offset; + } + + public List getFields() { + return fields; + } + + public String getFilterByFormula() { + return filterByFormula; + } + + public int getMaxRecords() { + return maxRecords; + } + + public int getPageSize() { + return pageSize; + } + + public String getSortField() { + return sortField; + } + + public String getSortDirection() { + return sortDirection; + } + + public String getView() { + return view; + } + + public String getCellFormat() { + return cellFormat; + } + + public String getTimeZone() { + return timeZone; + } + + public String getUserLocale() { + return userLocale; + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java index 7dadae115..9b203f024 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java @@ -1,25 +1,25 @@ package com.edgechain.lib.integration.airtable.query; public enum SortOrder { - ASC("asc"), - DESC("desc"); + ASC("asc"), + DESC("desc"); - private final String value; + private final String value; - SortOrder(String value) { - this.value = value; - } + SortOrder(String value) { + this.value = value; + } - public String getValue() { - return value; - } + public String getValue() { + return value; + } - public static SortOrder fromValue(String value) { - for (SortOrder sortOrder : SortOrder.values()) { - if (sortOrder.value.equalsIgnoreCase(value)) { - return sortOrder; - } - } - throw new IllegalArgumentException("Invalid SortOrder value: " + value); + public static SortOrder fromValue(String value) { + for (SortOrder sortOrder : SortOrder.values()) { + if (sortOrder.value.equalsIgnoreCase(value)) { + return sortOrder; + } } -} \ No newline at end of file + throw new IllegalArgumentException("Invalid SortOrder value: " + value); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java index 7dbb647d0..c9fcf20eb 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java @@ -47,7 +47,7 @@ public JsonnetLoader(int threshold, String f1, String f2) { if (threshold >= 1 && threshold < 100) { this.threshold = threshold; this.splitSize = - String.valueOf(threshold).concat("-").concat(String.valueOf((100 - threshold))); + String.valueOf(threshold).concat("-").concat(String.valueOf((100 - threshold))); } else throw new RuntimeException("Threshold has to be b/w 1 and 100"); } @@ -70,7 +70,7 @@ public void load(InputStream inputStream) { // Create Temp File With Unique Name String filename = - RandomStringUtils.randomAlphanumeric(12) + "_" + System.currentTimeMillis() + ".jsonnet"; + RandomStringUtils.randomAlphanumeric(12) + "_" + System.currentTimeMillis() + ".jsonnet"; File file = new File(System.getProperty("java.io.tmpdir") + File.separator + filename); BufferedReader br = new BufferedReader(new InputStreamReader(inputStream)); @@ -98,16 +98,16 @@ public void load(InputStream inputStream) { xtraArgsMap.put(entry.getKey(), entry.getValue().getVal().replaceAll(regex, "")); } else if (entry.getValue().getDataType().equals(DataType.INTEGER) - || entry.getValue().getDataType().equals(DataType.BOOLEAN)) { + || entry.getValue().getDataType().equals(DataType.BOOLEAN)) { xtraArgsMap.put(entry.getKey(), entry.getValue().getVal()); } } var res = - Transformer.builder(text) - .withLibrary(new XtraSonnetCustomFunc()) - .build() - .transform(serializeXtraArgs(xtraArgsMap)); + Transformer.builder(text) + .withLibrary(new XtraSonnetCustomFunc()) + .build() + .transform(serializeXtraArgs(xtraArgsMap)); // Get the String Output & Transform it into JsonnetSchema this.metadata = res; @@ -191,4 +191,4 @@ public int getThreshold() { public String getSplitSize() { return splitSize; } -} \ No newline at end of file +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java index 56e0975f4..fffd19dab 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java @@ -25,130 +25,130 @@ @Service public class OpenAiClient { - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final RestTemplate restTemplate = new RestTemplate(); - - public EdgeChain createChatCompletion( - ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { - - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - - logger.info("Logging ChatCompletion...."); - - // Create headers - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setBearerAuth(endpoint.getApiKey()); - - if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { - headers.set("OpenAI-Organization", endpoint.getOrgId()); - } - HttpEntity entity = new HttpEntity<>(request, headers); - - logger.info(String.valueOf(entity.getBody())); - - // Send the POST request - ResponseEntity response = - restTemplate.exchange( - endpoint.getUrl(), HttpMethod.POST, entity, ChatCompletionResponse.class); - - emitter.onNext(Objects.requireNonNull(response.getBody())); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } - - public EdgeChain createChatCompletionStream( - ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { - - try { - logger.info("Logging ChatCompletion Stream...."); - logger.info(request.toString()); - - return new EdgeChain<>( - RxJava3Adapter.fluxToObservable( - WebClient.builder() - .build() - .post() - .uri(EndpointConstants.OPENAI_CHAT_COMPLETION_API) - .accept(MediaType.TEXT_EVENT_STREAM) - .headers( - httpHeaders -> { - httpHeaders.setContentType(MediaType.APPLICATION_JSON); - httpHeaders.setBearerAuth(endpoint.getApiKey()); - if (Objects.nonNull(endpoint.getOrgId()) - && !endpoint.getOrgId().isEmpty()) { - httpHeaders.set("OpenAI-Organization", endpoint.getOrgId()); - } - }) - .bodyValue(new ObjectMapper().writeValueAsString(request)) - .retrieve() - .bodyToFlux(ChatCompletionResponse.class)), - endpoint); - } catch (final Exception e) { - throw new RuntimeException(e); - } - } - - public EdgeChain createCompletion( - CompletionRequest request, OpenAiChatEndpoint endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setBearerAuth(endpoint.getApiKey()); - if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { - headers.set("OpenAI-Organization", endpoint.getOrgId()); - } - HttpEntity entity = new HttpEntity<>(request, headers); - - ResponseEntity response = - this.restTemplate.exchange( - endpoint.getUrl(), HttpMethod.POST, entity, CompletionResponse.class); - emitter.onNext(Objects.requireNonNull(response.getBody())); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } - - public EdgeChain createEmbeddings( - OpenAiEmbeddingRequest request, OpenAiEmbeddingEndpoint endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setBearerAuth(endpoint.getApiKey()); - if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { - headers.set("OpenAI-Organization", endpoint.getOrgId()); - } - HttpEntity entity = new HttpEntity<>(request, headers); - - ResponseEntity response = - this.restTemplate.exchange( - endpoint.getUrl(), HttpMethod.POST, entity, OpenAiEmbeddingResponse.class); - - emitter.onNext(Objects.requireNonNull(response.getBody())); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); + + public EdgeChain createChatCompletion( + ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + logger.info("Logging ChatCompletion...."); + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(endpoint.getApiKey()); + + if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { + headers.set("OpenAI-Organization", endpoint.getOrgId()); + } + HttpEntity entity = new HttpEntity<>(request, headers); + + logger.info(String.valueOf(entity.getBody())); + + // Send the POST request + ResponseEntity response = + restTemplate.exchange( + endpoint.getUrl(), HttpMethod.POST, entity, ChatCompletionResponse.class); + + emitter.onNext(Objects.requireNonNull(response.getBody())); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + + public EdgeChain createChatCompletionStream( + ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { + + try { + logger.info("Logging ChatCompletion Stream...."); + logger.info(request.toString()); + + return new EdgeChain<>( + RxJava3Adapter.fluxToObservable( + WebClient.builder() + .build() + .post() + .uri(EndpointConstants.OPENAI_CHAT_COMPLETION_API) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers( + httpHeaders -> { + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.setBearerAuth(endpoint.getApiKey()); + if (Objects.nonNull(endpoint.getOrgId()) + && !endpoint.getOrgId().isEmpty()) { + httpHeaders.set("OpenAI-Organization", endpoint.getOrgId()); + } + }) + .bodyValue(new ObjectMapper().writeValueAsString(request)) + .retrieve() + .bodyToFlux(ChatCompletionResponse.class)), + endpoint); + } catch (final Exception e) { + throw new RuntimeException(e); } -} \ No newline at end of file + } + + public EdgeChain createCompletion( + CompletionRequest request, OpenAiChatEndpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(endpoint.getApiKey()); + if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { + headers.set("OpenAI-Organization", endpoint.getOrgId()); + } + HttpEntity entity = new HttpEntity<>(request, headers); + + ResponseEntity response = + this.restTemplate.exchange( + endpoint.getUrl(), HttpMethod.POST, entity, CompletionResponse.class); + emitter.onNext(Objects.requireNonNull(response.getBody())); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + + public EdgeChain createEmbeddings( + OpenAiEmbeddingRequest request, OpenAiEmbeddingEndpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(endpoint.getApiKey()); + if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { + headers.set("OpenAI-Organization", endpoint.getOrgId()); + } + HttpEntity entity = new HttpEntity<>(request, headers); + + ResponseEntity response = + this.restTemplate.exchange( + endpoint.getUrl(), HttpMethod.POST, entity, OpenAiEmbeddingResponse.class); + + emitter.onNext(Objects.requireNonNull(response.getBody())); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java index c3a8db88a..dc5744a9d 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java @@ -12,22 +12,18 @@ public interface AirtableService { - @POST("airtable/findAll") - Single> findAll(@Body AirtableEndpoint endpoint); + @POST("airtable/findAll") + Single> findAll(@Body AirtableEndpoint endpoint); - @POST("airtable/findById") - Single findById(@Body AirtableEndpoint endpoint); + @POST("airtable/findById") + Single findById(@Body AirtableEndpoint endpoint); - @POST("airtable/create") - Single> create(@Body AirtableEndpoint endpoint); - - @POST("airtable/update") - Single> update(@Body AirtableEndpoint endpoint); - - @HTTP(method = "DELETE", path = "airtable/delete", hasBody = true) - Single> delete(@Body AirtableEndpoint endpoint); + @POST("airtable/create") + Single> create(@Body AirtableEndpoint endpoint); + @POST("airtable/update") + Single> update(@Body AirtableEndpoint endpoint); + @HTTP(method = "DELETE", path = "airtable/delete", hasBody = true) + Single> delete(@Body AirtableEndpoint endpoint); } - - diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java index a1366dd2f..4392bb348 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java @@ -43,6 +43,4 @@ public Observable chatCompletion(OpenAiChatEndpoint endp .retrieve() .bodyToFlux(ChatCompletionResponse.class)); } - - } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java index e5f02d3c3..7e3f42cb6 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java @@ -172,12 +172,10 @@ public Single toSingle() { if (RetryUtils.available(endpoint)) return this.observable - .subscribeOn(Schedulers.io()) - .retryWhen(endpoint.getRetryPolicy()) - .firstOrError(); - + .subscribeOn(Schedulers.io()) + .retryWhen(endpoint.getRetryPolicy()) + .firstOrError(); else return this.observable.subscribeOn(Schedulers.io()).firstOrError(); - } public Single toSingleWithoutScheduler() { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java index 4f441b9de..a3fb83282 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java @@ -11,37 +11,40 @@ @Component public class ContextReorder { - public List reorderWordEmbeddings(List wordEmbeddingsList) { + public List reorderWordEmbeddings(List wordEmbeddingsList) { - wordEmbeddingsList.sort(Comparator.comparingDouble(WordEmbeddings::getScore).reversed()); + wordEmbeddingsList.sort(Comparator.comparingDouble(WordEmbeddings::getScore).reversed()); - int mid = wordEmbeddingsList.size() / 2; + int mid = wordEmbeddingsList.size() / 2; - List modifiedList = new ArrayList<>(wordEmbeddingsList.subList(0, mid)); + List modifiedList = new ArrayList<>(wordEmbeddingsList.subList(0, mid)); - List secondHalfList = wordEmbeddingsList.subList(mid, wordEmbeddingsList.size()); - secondHalfList.sort(Comparator.comparingDouble(WordEmbeddings::getScore)); + List secondHalfList = + wordEmbeddingsList.subList(mid, wordEmbeddingsList.size()); + secondHalfList.sort(Comparator.comparingDouble(WordEmbeddings::getScore)); - modifiedList.addAll(secondHalfList); + modifiedList.addAll(secondHalfList); - return modifiedList; - } + return modifiedList; + } - public List reorderPostgresWordEmbeddings(List postgresWordEmbeddings) { + public List reorderPostgresWordEmbeddings( + List postgresWordEmbeddings) { - postgresWordEmbeddings.sort(Comparator.comparingDouble(PostgresWordEmbeddings::getScore).reversed()); + postgresWordEmbeddings.sort( + Comparator.comparingDouble(PostgresWordEmbeddings::getScore).reversed()); - int mid = postgresWordEmbeddings.size() / 2; + int mid = postgresWordEmbeddings.size() / 2; - List modifiedList = new ArrayList<>( postgresWordEmbeddings.subList(0, mid)); + List modifiedList = + new ArrayList<>(postgresWordEmbeddings.subList(0, mid)); - List secondHalfList = postgresWordEmbeddings.subList(mid, postgresWordEmbeddings.size()); - secondHalfList.sort(Comparator.comparingDouble(PostgresWordEmbeddings::getScore)); + List secondHalfList = + postgresWordEmbeddings.subList(mid, postgresWordEmbeddings.size()); + secondHalfList.sort(Comparator.comparingDouble(PostgresWordEmbeddings::getScore)); - modifiedList.addAll(secondHalfList); - - return modifiedList; - } + modifiedList.addAll(secondHalfList); + return modifiedList; + } } - diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java index e488f7ed6..98d3cd9f3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java @@ -15,26 +15,25 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/index/pinecone") public class PineconeController { - @Autowired - private PineconeClient pineconeClient; - - @PostMapping("/upsert") - public Single upsert(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.upsert(pineconeEndpoint).toSingle(); - } - - @PostMapping("/batch-upsert") - public Single batchUpsert(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.batchUpsert(pineconeEndpoint).toSingleWithoutScheduler(); - } - - @PostMapping("/query") - public Single> query(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.query(pineconeEndpoint).toSingle(); - } - - @DeleteMapping("/deleteAll") - public Single deleteAll(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.deleteAll(pineconeEndpoint).toSingle(); - } + @Autowired private PineconeClient pineconeClient; + + @PostMapping("/upsert") + public Single upsert(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.upsert(pineconeEndpoint).toSingle(); + } + + @PostMapping("/batch-upsert") + public Single batchUpsert(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.batchUpsert(pineconeEndpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/query") + public Single> query(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.query(pineconeEndpoint).toSingle(); + } + + @DeleteMapping("/deleteAll") + public Single deleteAll(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.deleteAll(pineconeEndpoint).toSingle(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java index 5b034468d..a83f48db3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java @@ -18,89 +18,87 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/index/postgres") public class PostgresController { - @Autowired - @Lazy - private PostgresClient postgresClient; - - @PostMapping("/create-table") - public Single createTable(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.createTable(postgresEndpoint).toSingle(); - } - - @PostMapping("/metadata/create-table") - public Single createMetadataTable( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.createMetadataTable(postgresEndpoint).toSingle(); - } - - @PostMapping("/upsert") - public Single upsert(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.upsert(postgresEndpoint).toSingle(); - } - - @PostMapping("/batch-upsert") - public Single> batchUpsert(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.batchUpsert(postgresEndpoint).toSingleWithoutScheduler(); - } - - @PostMapping("/metadata/insert") - public Single insertMetadata(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.insertMetadata(postgresEndpoint).toSingle(); - } - - @PostMapping("/metadata/batch-insert") - public Single> batchInsertMetadata( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.batchInsertMetadata(postgresEndpoint).toSingle(); - } - - @PostMapping("/join/insert") - public Single insertIntoJoinTable( - @RequestBody PostgresEndpoint postgresEndpoint) { - EdgeChain edgeChain = this.postgresClient.insertIntoJoinTable(postgresEndpoint); - return edgeChain.toSingle(); - } - - @PostMapping("/join/batch-insert") - public Single batchInsertIntoJoinTable( - @RequestBody PostgresEndpoint postgresEndpoint) { - EdgeChain edgeChain = - this.postgresClient.batchInsertIntoJoinTable(postgresEndpoint); - return edgeChain.toSingle(); - } - - @PostMapping("/query") - public Single> query( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.query(postgresEndpoint).toSingle(); - } - - @PostMapping("/query-rrf") - public Single> queryRRF( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.queryRRF(postgresEndpoint).toSingle(); - } - - @PostMapping("/metadata/query") - public Single> queryWithMetadata( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.queryWithMetadata(postgresEndpoint).toSingle(); - } - - @PostMapping("/chunks") - public Single> getAllChunks( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.getAllChunks(postgresEndpoint).toSingle(); - } - - @PostMapping("/similarity-metadata") - public Single> getSimilarMetadataChunk( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.getSimilarMetadataChunk(postgresEndpoint).toSingle(); - } - - @DeleteMapping("/deleteAll") - public Single deleteAll(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.deleteAll(postgresEndpoint).toSingle(); - } + @Autowired @Lazy private PostgresClient postgresClient; + + @PostMapping("/create-table") + public Single createTable(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.createTable(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/create-table") + public Single createMetadataTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.createMetadataTable(postgresEndpoint).toSingle(); + } + + @PostMapping("/upsert") + public Single upsert(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.upsert(postgresEndpoint).toSingle(); + } + + @PostMapping("/batch-upsert") + public Single> batchUpsert(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.batchUpsert(postgresEndpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/metadata/insert") + public Single insertMetadata(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.insertMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/batch-insert") + public Single> batchInsertMetadata( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.batchInsertMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/join/insert") + public Single insertIntoJoinTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + EdgeChain edgeChain = this.postgresClient.insertIntoJoinTable(postgresEndpoint); + return edgeChain.toSingle(); + } + + @PostMapping("/join/batch-insert") + public Single batchInsertIntoJoinTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + EdgeChain edgeChain = + this.postgresClient.batchInsertIntoJoinTable(postgresEndpoint); + return edgeChain.toSingle(); + } + + @PostMapping("/query") + public Single> query( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.query(postgresEndpoint).toSingle(); + } + + @PostMapping("/query-rrf") + public Single> queryRRF( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryRRF(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/query") + public Single> queryWithMetadata( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryWithMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/chunks") + public Single> getAllChunks( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.getAllChunks(postgresEndpoint).toSingle(); + } + + @PostMapping("/similarity-metadata") + public Single> getSimilarMetadataChunk( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.getSimilarMetadataChunk(postgresEndpoint).toSingle(); + } + + @DeleteMapping("/deleteAll") + public Single deleteAll(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.deleteAll(postgresEndpoint).toSingle(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java index 4090164b7..1d4916f38 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java @@ -17,32 +17,30 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/index/redis") public class RedisController { - @Autowired - @Lazy - private RedisClient redisClient; - - @PostMapping("/create-index") - public Single createIndex(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.createIndex(redisEndpoint).toSingle(); - } - - @PostMapping("/upsert") - public Single upsert(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.upsert(redisEndpoint).toSingle(); - } - - @PostMapping("/batch-upsert") - public Single batchUpsert(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.batchUpsert(redisEndpoint).toSingleWithoutScheduler(); - } - - @PostMapping("/query") - public Single> query(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.query(redisEndpoint).toSingle(); - } - - @DeleteMapping("/delete") - public Completable deleteByPattern(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.deleteByPattern(redisEndpoint).await(); - } + @Autowired @Lazy private RedisClient redisClient; + + @PostMapping("/create-index") + public Single createIndex(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.createIndex(redisEndpoint).toSingle(); + } + + @PostMapping("/upsert") + public Single upsert(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.upsert(redisEndpoint).toSingle(); + } + + @PostMapping("/batch-upsert") + public Single batchUpsert(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.batchUpsert(redisEndpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/query") + public Single> query(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.query(redisEndpoint).toSingle(); + } + + @DeleteMapping("/delete") + public Completable deleteByPattern(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.deleteByPattern(redisEndpoint).await(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java index 835dea9f0..741fe2362 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java @@ -15,28 +15,30 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/airtable") public class AirtableController { - @Autowired private AirtableClient airtableClient; - - @PostMapping("/findAll") - public Single> findAll(@RequestBody AirtableEndpoint endpoint) { - return airtableClient.findAll(endpoint).toSingleWithoutScheduler(); - } - @PostMapping("/findById") - public Single findById(@RequestBody AirtableEndpoint endpoint) { - return airtableClient.findById(endpoint).toSingleWithoutScheduler(); - } - @PostMapping("/create") - public Single> create(@RequestBody AirtableEndpoint endpoint) { - return airtableClient.create(endpoint).toSingleWithoutScheduler(); - } - @PostMapping("/update") - public Single> update(@RequestBody AirtableEndpoint endpoint) { - return airtableClient.update(endpoint).toSingleWithoutScheduler(); - } - @DeleteMapping("/delete") - public Single> delete(@RequestBody AirtableEndpoint endpoint) { - return airtableClient.delete(endpoint).toSingleWithoutScheduler(); - } + @Autowired private AirtableClient airtableClient; + @PostMapping("/findAll") + public Single> findAll(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.findAll(endpoint).toSingleWithoutScheduler(); + } + @PostMapping("/findById") + public Single findById(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.findById(endpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/create") + public Single> create(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.create(endpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/update") + public Single> update(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.update(endpoint).toSingleWithoutScheduler(); + } + + @DeleteMapping("/delete") + public Single> delete(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.delete(endpoint).toSingleWithoutScheduler(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java index d9fa9d1fe..6d3c6f8a1 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java @@ -43,259 +43,260 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/openai") public class OpenAiController { - @Autowired private ChatCompletionLogService chatCompletionLogService; - @Autowired private EmbeddingLogService embeddingLogService; - @Autowired private JsonnetLogService jsonnetLogService; - - @Autowired private Environment env; - @Autowired private OpenAiClient openAiClient; - - @PostMapping(value = "/chat-completion") - public Single chatCompletion(@RequestBody OpenAiChatEndpoint openAiEndpoint) { - - ChatCompletionRequest chatCompletionRequest = - ChatCompletionRequest.builder() - .model(openAiEndpoint.getModel()) - .temperature(openAiEndpoint.getTemperature()) - .messages(openAiEndpoint.getChatMessages()) - .stream(false) - .topP(openAiEndpoint.getTopP()) - .n(openAiEndpoint.getN()) - .stop(openAiEndpoint.getStop()) - .presencePenalty(openAiEndpoint.getPresencePenalty()) - .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) - .logitBias(openAiEndpoint.getLogitBias()) - .user(openAiEndpoint.getUser()) - .build(); - - EdgeChain edgeChain = - openAiClient.createChatCompletion(chatCompletionRequest, openAiEndpoint); - - if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - - ChatCompletionLog chatLog = new ChatCompletionLog(); - chatLog.setName(openAiEndpoint.getChainName()); - chatLog.setCreatedAt(LocalDateTime.now()); - chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); - chatLog.setModel(chatCompletionRequest.getModel()); - - chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); - chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); - chatLog.setTopP(chatCompletionRequest.getTopP()); - chatLog.setN(chatCompletionRequest.getN()); - chatLog.setTemperature(chatCompletionRequest.getTemperature()); - - return edgeChain - .doOnNext( - c -> { - chatLog.setPromptTokens(c.getUsage().getPrompt_tokens()); - chatLog.setTotalTokens(c.getUsage().getTotal_tokens()); - chatLog.setContent(c.getChoices().get(0).getMessage().getContent()); - chatLog.setType(c.getObject()); - - chatLog.setCompletedAt(LocalDateTime.now()); - - Duration duration = - Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); - chatLog.setLatency(duration.toMillis()); - - chatCompletionLogService.saveOrUpdate(chatLog); - - if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) - && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { - JsonnetLog jsonnetLog = new JsonnetLog(); - jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); - jsonnetLog.setContent(c.getChoices().get(0).getMessage().getContent()); - jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); - jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); - jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); - jsonnetLog.setCreatedAt(LocalDateTime.now()); - jsonnetLog.setSelectedFile(openAiEndpoint.getJsonnetLoader().getSelectedFile()); - jsonnetLogService.saveOrUpdate(jsonnetLog); - } - }) - .toSingle(); - - } else return edgeChain.toSingle(); - } - - @PostMapping( - value = "/chat-completion-stream", - consumes = {MediaType.APPLICATION_JSON_VALUE}) - public SseEmitter chatCompletionStream(@RequestBody OpenAiChatEndpoint openAiEndpoint) { - - ChatCompletionRequest chatCompletionRequest = - ChatCompletionRequest.builder() - .model(openAiEndpoint.getModel()) - .temperature(openAiEndpoint.getTemperature()) - .messages(openAiEndpoint.getChatMessages()) - .stream(true) - .topP(openAiEndpoint.getTopP()) - .n(openAiEndpoint.getN()) - .stop(openAiEndpoint.getStop()) - .presencePenalty(openAiEndpoint.getPresencePenalty()) - .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) - .logitBias(openAiEndpoint.getLogitBias()) - .user(openAiEndpoint.getUser()) - .build(); - SseEmitter emitter = new SseEmitter(); - ExecutorService executorService = Executors.newSingleThreadExecutor(); - - executorService.execute( - () -> { + @Autowired private ChatCompletionLogService chatCompletionLogService; + @Autowired private EmbeddingLogService embeddingLogService; + @Autowired private JsonnetLogService jsonnetLogService; + + @Autowired private Environment env; + @Autowired private OpenAiClient openAiClient; + + @PostMapping(value = "/chat-completion") + public Single chatCompletion( + @RequestBody OpenAiChatEndpoint openAiEndpoint) { + + ChatCompletionRequest chatCompletionRequest = + ChatCompletionRequest.builder() + .model(openAiEndpoint.getModel()) + .temperature(openAiEndpoint.getTemperature()) + .messages(openAiEndpoint.getChatMessages()) + .stream(false) + .topP(openAiEndpoint.getTopP()) + .n(openAiEndpoint.getN()) + .stop(openAiEndpoint.getStop()) + .presencePenalty(openAiEndpoint.getPresencePenalty()) + .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) + .logitBias(openAiEndpoint.getLogitBias()) + .user(openAiEndpoint.getUser()) + .build(); + + EdgeChain edgeChain = + openAiClient.createChatCompletion(chatCompletionRequest, openAiEndpoint); + + if (Objects.nonNull(env.getProperty("postgres.db.host"))) { + + ChatCompletionLog chatLog = new ChatCompletionLog(); + chatLog.setName(openAiEndpoint.getChainName()); + chatLog.setCreatedAt(LocalDateTime.now()); + chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); + chatLog.setModel(chatCompletionRequest.getModel()); + + chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); + chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); + chatLog.setTopP(chatCompletionRequest.getTopP()); + chatLog.setN(chatCompletionRequest.getN()); + chatLog.setTemperature(chatCompletionRequest.getTemperature()); + + return edgeChain + .doOnNext( + c -> { + chatLog.setPromptTokens(c.getUsage().getPrompt_tokens()); + chatLog.setTotalTokens(c.getUsage().getTotal_tokens()); + chatLog.setContent(c.getChoices().get(0).getMessage().getContent()); + chatLog.setType(c.getObject()); + + chatLog.setCompletedAt(LocalDateTime.now()); + + Duration duration = + Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); + chatLog.setLatency(duration.toMillis()); + + chatCompletionLogService.saveOrUpdate(chatLog); + + if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) + && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { + JsonnetLog jsonnetLog = new JsonnetLog(); + jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); + jsonnetLog.setContent(c.getChoices().get(0).getMessage().getContent()); + jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); + jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); + jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); + jsonnetLog.setCreatedAt(LocalDateTime.now()); + jsonnetLog.setSelectedFile(openAiEndpoint.getJsonnetLoader().getSelectedFile()); + jsonnetLogService.saveOrUpdate(jsonnetLog); + } + }) + .toSingle(); + + } else return edgeChain.toSingle(); + } + + @PostMapping( + value = "/chat-completion-stream", + consumes = {MediaType.APPLICATION_JSON_VALUE}) + public SseEmitter chatCompletionStream(@RequestBody OpenAiChatEndpoint openAiEndpoint) { + + ChatCompletionRequest chatCompletionRequest = + ChatCompletionRequest.builder() + .model(openAiEndpoint.getModel()) + .temperature(openAiEndpoint.getTemperature()) + .messages(openAiEndpoint.getChatMessages()) + .stream(true) + .topP(openAiEndpoint.getTopP()) + .n(openAiEndpoint.getN()) + .stop(openAiEndpoint.getStop()) + .presencePenalty(openAiEndpoint.getPresencePenalty()) + .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) + .logitBias(openAiEndpoint.getLogitBias()) + .user(openAiEndpoint.getUser()) + .build(); + SseEmitter emitter = new SseEmitter(); + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + executorService.execute( + () -> { + try { + EdgeChain edgeChain = + openAiClient.createChatCompletionStream(chatCompletionRequest, openAiEndpoint); + + AtomInteger chunks = AtomInteger.of(0); + + if (Objects.nonNull(env.getProperty("postgres.db.host"))) { + + ChatCompletionLog chatLog = new ChatCompletionLog(); + chatLog.setName(openAiEndpoint.getChainName()); + chatLog.setCreatedAt(LocalDateTime.now()); + chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); + chatLog.setModel(chatCompletionRequest.getModel()); + + chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); + chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); + chatLog.setTopP(chatCompletionRequest.getTopP()); + chatLog.setN(chatCompletionRequest.getN()); + chatLog.setTemperature(chatCompletionRequest.getTemperature()); + + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("<|im_start|>"); + + for (ChatMessage chatMessage : openAiEndpoint.getChatMessages()) { + stringBuilder.append(chatMessage.getContent()); + } + EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); + Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); + + chatLog.setPromptTokens((long) enc.countTokens(stringBuilder.toString())); + + StringBuilder content = new StringBuilder(); + + Observable obs = edgeChain.getScheduledObservable(); + + obs.subscribe( + res -> { try { - EdgeChain edgeChain = - openAiClient.createChatCompletionStream(chatCompletionRequest, openAiEndpoint); - - AtomInteger chunks = AtomInteger.of(0); - - if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - - ChatCompletionLog chatLog = new ChatCompletionLog(); - chatLog.setName(openAiEndpoint.getChainName()); - chatLog.setCreatedAt(LocalDateTime.now()); - chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); - chatLog.setModel(chatCompletionRequest.getModel()); - - chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); - chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); - chatLog.setTopP(chatCompletionRequest.getTopP()); - chatLog.setN(chatCompletionRequest.getN()); - chatLog.setTemperature(chatCompletionRequest.getTemperature()); - - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append("<|im_start|>"); - - for (ChatMessage chatMessage : openAiEndpoint.getChatMessages()) { - stringBuilder.append(chatMessage.getContent()); - } - EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); - Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); - - chatLog.setPromptTokens((long) enc.countTokens(stringBuilder.toString())); - - StringBuilder content = new StringBuilder(); - - Observable obs = edgeChain.getScheduledObservable(); - - obs.subscribe( - res -> { - try { - - emitter.send(res); - - chunks.incrementAndGet(); - content.append(res.getChoices().get(0).getMessage().getContent()); - - if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { - - emitter.complete(); - chatLog.setType(res.getObject()); - chatLog.setContent(content.toString()); - chatLog.setCompletedAt(LocalDateTime.now()); - chatLog.setTotalTokens(chunks.get() + chatLog.getPromptTokens()); - - Duration duration = - Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); - chatLog.setLatency(duration.toMillis()); - - chatCompletionLogService.saveOrUpdate(chatLog); - - if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) - && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { - JsonnetLog jsonnetLog = new JsonnetLog(); - jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); - jsonnetLog.setContent(content.toString()); - jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); - jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); - jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); - jsonnetLog.setCreatedAt(LocalDateTime.now()); - jsonnetLog.setSelectedFile( - openAiEndpoint.getJsonnetLoader().getSelectedFile()); - jsonnetLogService.saveOrUpdate(jsonnetLog); - } - } - - } catch (final Exception e) { - emitter.completeWithError(e); - } - }); - } else { - - Observable obs = edgeChain.getScheduledObservable(); - obs.subscribe( - res -> { - try { - emitter.send(res); - if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { - emitter.complete(); - } - - } catch (final Exception e) { - emitter.completeWithError(e); - } - }); + + emitter.send(res); + + chunks.incrementAndGet(); + content.append(res.getChoices().get(0).getMessage().getContent()); + + if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { + + emitter.complete(); + chatLog.setType(res.getObject()); + chatLog.setContent(content.toString()); + chatLog.setCompletedAt(LocalDateTime.now()); + chatLog.setTotalTokens(chunks.get() + chatLog.getPromptTokens()); + + Duration duration = + Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); + chatLog.setLatency(duration.toMillis()); + + chatCompletionLogService.saveOrUpdate(chatLog); + + if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) + && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { + JsonnetLog jsonnetLog = new JsonnetLog(); + jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); + jsonnetLog.setContent(content.toString()); + jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); + jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); + jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); + jsonnetLog.setCreatedAt(LocalDateTime.now()); + jsonnetLog.setSelectedFile( + openAiEndpoint.getJsonnetLoader().getSelectedFile()); + jsonnetLogService.saveOrUpdate(jsonnetLog); } + } } catch (final Exception e) { - emitter.completeWithError(e); + emitter.completeWithError(e); } - }); - - executorService.shutdown(); - return emitter; - } + }); + } else { - @PostMapping("/completion") - public Single completion(@RequestBody OpenAiChatEndpoint openAiEndpoint) { - - CompletionRequest completionRequest = - CompletionRequest.builder() - .prompt(openAiEndpoint.getInput()) - .model(openAiEndpoint.getModel()) - .temperature(openAiEndpoint.getTemperature()) - .build(); - - EdgeChain edgeChain = - openAiClient.createCompletion(completionRequest, openAiEndpoint); + Observable obs = edgeChain.getScheduledObservable(); + obs.subscribe( + res -> { + try { + emitter.send(res); + if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { + emitter.complete(); + } - return edgeChain.toSingle(); + } catch (final Exception e) { + emitter.completeWithError(e); + } + }); + } + + } catch (final Exception e) { + emitter.completeWithError(e); + } + }); + + executorService.shutdown(); + return emitter; + } + + @PostMapping("/completion") + public Single completion(@RequestBody OpenAiChatEndpoint openAiEndpoint) { + + CompletionRequest completionRequest = + CompletionRequest.builder() + .prompt(openAiEndpoint.getInput()) + .model(openAiEndpoint.getModel()) + .temperature(openAiEndpoint.getTemperature()) + .build(); + + EdgeChain edgeChain = + openAiClient.createCompletion(completionRequest, openAiEndpoint); + + return edgeChain.toSingle(); + } + + @PostMapping("/embeddings") + public Single embeddings( + @RequestBody OpenAiEmbeddingEndpoint openAiEndpoint) throws SQLException { + + EdgeChain edgeChain = + openAiClient.createEmbeddings( + new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()), + openAiEndpoint); + + if (Objects.nonNull(env.getProperty("postgres.db.host"))) { + + EmbeddingLog embeddingLog = new EmbeddingLog(); + embeddingLog.setCreatedAt(LocalDateTime.now()); + embeddingLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + embeddingLog.setModel(openAiEndpoint.getModel()); + + return edgeChain + .doOnNext( + e -> { + embeddingLog.setPromptTokens(e.getUsage().getPrompt_tokens()); + embeddingLog.setCompletedAt(LocalDateTime.now()); + embeddingLog.setTotalTokens(e.getUsage().getTotal_tokens()); + + Duration duration = + Duration.between(embeddingLog.getCreatedAt(), embeddingLog.getCompletedAt()); + embeddingLog.setLatency(duration.toMillis()); + + embeddingLogService.saveOrUpdate(embeddingLog); + }) + .toSingleWithoutScheduler(); } - @PostMapping("/embeddings") - public Single embeddings(@RequestBody OpenAiEmbeddingEndpoint openAiEndpoint) - throws SQLException { - - EdgeChain edgeChain = - openAiClient.createEmbeddings( - new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()), - openAiEndpoint); - - if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - - EmbeddingLog embeddingLog = new EmbeddingLog(); - embeddingLog.setCreatedAt(LocalDateTime.now()); - embeddingLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - embeddingLog.setModel(openAiEndpoint.getModel()); - - return edgeChain - .doOnNext( - e -> { - embeddingLog.setPromptTokens(e.getUsage().getPrompt_tokens()); - embeddingLog.setCompletedAt(LocalDateTime.now()); - embeddingLog.setTotalTokens(e.getUsage().getTotal_tokens()); - - Duration duration = - Duration.between(embeddingLog.getCreatedAt(), embeddingLog.getCompletedAt()); - embeddingLog.setLatency(duration.toMillis()); - - embeddingLogService.saveOrUpdate(embeddingLog); - }) - .toSingleWithoutScheduler(); - } - - return edgeChain.toSingleWithoutScheduler(); - } -} \ No newline at end of file + return edgeChain.toSingleWithoutScheduler(); + } +} diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java index a9db8fe88..978b58089 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java @@ -1,9 +1,7 @@ package com.edgechain; import org.junit.jupiter.api.Test; -import org.modelmapper.ModelMapper; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.annotation.Bean; @SpringBootTest class EdgeChainApplicationTest { diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java index 8fa7e4c3f..b7cae1606 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java @@ -15,7 +15,6 @@ class BgeSmallEndpointTest { @Test @DirtiesContext - void downloadFiles() { // Retrofit needs a port System.setProperty("server.port", "8888"); @@ -45,7 +44,6 @@ void downloadFiles() { ReflectionTestUtils.setField(RetrofitClientInstance.class, "securityUUID", null); ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); - deleteFiles(); // make sure we clean up files afterwards } } @@ -59,4 +57,4 @@ private static void deleteFiles() { File tokenizerFile = new File(BgeSmallEndpoint.TOKENIZER_PATH); tokenizerFile.delete(); } -} \ No newline at end of file +} diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java index 13e86702e..f300ac6b8 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java @@ -21,18 +21,16 @@ @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) public class PineconeClientTest { - @LocalServerPort - private int port; + @LocalServerPort private int port; - @Autowired - private PineconeClient pineconeClient; + @Autowired private PineconeClient pineconeClient; private PineconeEndpoint pineconeEndpoint; @BeforeEach void setUp() { System.setProperty("server.port", String.valueOf(port)); - pineconeEndpoint = new PineconeEndpoint("https://arakoo.ai", "apiKey", "Pinecone",null); + pineconeEndpoint = new PineconeEndpoint("https://arakoo.ai", "apiKey", "Pinecone", null); } @Test @@ -106,4 +104,4 @@ public void test_GetNamespace() { String nullNamespace = pineconeClient.getNamespace(pineconeEndpoint); assertEquals("", nullNamespace); } -} \ No newline at end of file +} diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java index d1f947021..78b605896 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java @@ -56,7 +56,7 @@ public Retrofit setupRetrofit() { private ModelMapper setupModelMapper() { ModelMapper mockModelMapper = mock(ModelMapper.class); - ReflectionTestUtils.setField(ModelMapper.class,"modelMapper", mockModelMapper); + ReflectionTestUtils.setField(ModelMapper.class, "modelMapper", mockModelMapper); // Retrofit needs a valid port prevServerPort = System.getProperty("server.port"); System.setProperty("server.port", "8888"); diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java index 56777ee4d..ec556317b 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java @@ -64,4 +64,4 @@ void wikiControllerTest_TestWikiContentMethod_HandlesException(TestInfo testInfo ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); } } -} \ No newline at end of file +}