From 26de2a5223b45c5527b77381a86b81886b6494ec Mon Sep 17 00:00:00 2001 From: EmadHanif Date: Sat, 30 Sep 2023 16:43:13 +0500 Subject: [PATCH] Embedding <> Database | Airtable Integration | Context Reorder | Restructuring & Fixes. --- Examples/airtable/AirtableExample.java | 133 +++++ .../code-interpreter/CodeInterpreter.java | 6 +- Examples/json/JsonFormat.java | 10 +- Examples/pinecone/PineconeExample.java | 114 ++-- Examples/postgresql/PostgreSQLExample.java | 62 +- .../react-chain/ReactChainApplication.java | 7 +- Examples/redis/RedisExample.java | 75 ++- .../SupabaseMiniLMExample.java | 45 +- Examples/wiki/WikiExample.java | 12 +- Examples/zapier/ZapierExample.java | 276 +++++++++ FlySpring/edgechain-app/pom.xml | 6 + .../com/edgechain/EdgeChainApplication.java | 1 - .../lib/chains/PineconeRetrieval.java | 28 +- .../lib/chains/PostgresRetrieval.java | 32 +- .../edgechain/lib/chains/RedisRetrieval.java | 21 +- .../impl/PostgreSQLHistoryContextClient.java | 2 +- .../impl/RedisHistoryContextClient.java | 2 +- .../PgHistoryContextController.java | 2 +- .../lib/controllers/PostgresController.java | 2 +- .../lib/controllers/RedisController.java | 2 +- .../RedisHistoryContextController.java | 2 +- .../lib/controllers/SupabaseController.java | 2 +- .../lib/embeddings/WordEmbeddings.java | 10 +- .../embeddings/bgeSmall/BgeSmallClient.java | 2 +- .../lib/embeddings/miniLLM/MiniLMClient.java | 2 +- .../lib/endpoint/EmbeddingEndpoint.java | 43 -- .../com/edgechain/lib/endpoint/Endpoint.java | 8 + .../lib/endpoint/impl/PineconeEndpoint.java | 97 --- .../lib/endpoint/impl/PostgresEndpoint.java | 385 ------------ .../lib/endpoint/impl/RedisEndpoint.java | 145 ----- .../PostgreSQLHistoryContextEndpoint.java | 2 +- .../RedisHistoryContextEndpoint.java | 2 +- .../{ => embeddings}/BgeSmallEndpoint.java | 31 +- .../impl/embeddings/EmbeddingEndpoint.java | 69 +++ .../impl/{ => embeddings}/MiniLMEndpoint.java | 30 +- .../embeddings/OpenAiEmbeddingEndpoint.java | 71 +++ .../endpoint/impl/index/PineconeEndpoint.java | 154 +++++ .../endpoint/impl/index/PostgresEndpoint.java | 559 ++++++++++++++++++ .../endpoint/impl/index/RedisEndpoint.java | 185 ++++++ .../impl/integration/AirtableEndpoint.java | 176 ++++++ .../OpenAiChatEndpoint.java} | 132 +++-- .../impl/{ => supabase}/SupabaseEndpoint.java | 2 +- .../impl/{ => wiki}/WikiEndpoint.java | 10 +- .../lib/index/client/impl/PineconeClient.java | 2 +- .../lib/index/client/impl/PostgresClient.java | 12 +- .../lib/index/client/impl/RedisClient.java | 5 +- .../index/domain/PostgresWordEmbeddings.java | 1 + .../PostgresClientMetadataRepository.java | 2 +- .../PostgresClientRepository.java | 174 +++--- .../airtable/client/AirtableClient.java | 188 ++++++ .../airtable/query/AirtableQueryBuilder.java | 127 ++++ .../integration/airtable/query/SortOrder.java | 25 + .../edgechain/lib/jsonnet/JsonnetLoader.java | 16 +- .../lib/jsonnet/XtraSonnetCustomFunc.java | 2 +- .../lib/openai/client/OpenAiClient.java | 255 ++++---- .../providers/OpenAiCompletionProvider.java | 6 +- .../lib/retrofit/AirtableService.java | 33 ++ .../lib/retrofit/BgeSmallService.java | 2 +- .../edgechain/lib/retrofit/MiniLMService.java | 2 +- .../edgechain/lib/retrofit/OpenAiService.java | 11 +- .../lib/retrofit/PineconeService.java | 2 +- .../retrofit/PostgreSQLContextService.java | 2 +- .../lib/retrofit/PostgresService.java | 2 +- .../lib/retrofit/RedisContextService.java | 2 +- .../edgechain/lib/retrofit/RedisService.java | 2 +- .../edgechain/lib/retrofit/WikiService.java | 2 +- .../retrofit/client/OpenAiStreamService.java | 8 +- .../client/RetrofitClientInstance.java | 1 - .../transformer/observable/EdgeChain.java | 8 +- .../edgechain/lib/utils/ContextReorder.java | 47 ++ .../edgechain/lib/wiki/client/WikiClient.java | 2 +- .../bgeSmall/BgeSmallController.java | 2 +- .../PostgreSQLHistoryContextController.java | 2 +- .../RedisHistoryContextController.java | 2 +- .../controllers/index/PineconeController.java | 37 +- .../controllers/index/PostgresController.java | 170 +++--- .../controllers/index/RedisController.java | 56 +- .../integration/AirtableController.java | 42 ++ .../controllers/miniLM/MiniLMController.java | 2 +- .../controllers/openai/OpenAiController.java | 497 ++++++++-------- .../controllers/wiki/WikiController.java | 2 +- .../edgechain/EdgeChainApplicationTest.java | 2 + .../endpoint/impl/BgeSmallEndpointTest.java | 6 +- .../index/client/impl/PostgresClientTest.java | 14 +- .../edgechain/openai/OpenAiClientTest.java | 14 +- .../pinecone/PineconeClientTest.java | 6 +- .../PostgresClientMetadataRepositoryTest.java | 2 +- .../edgechain/testutil/TestConfigSupport.java | 10 + .../com/edgechain/wiki/WikiClientTest.java | 4 +- 89 files changed, 3092 insertions(+), 1676 deletions(-) create mode 100644 Examples/airtable/AirtableExample.java create mode 100644 Examples/zapier/ZapierExample.java delete mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/EmbeddingEndpoint.java delete mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java delete mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java delete mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{ => context}/PostgreSQLHistoryContextEndpoint.java (96%) rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{ => context}/RedisHistoryContextEndpoint.java (96%) rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{ => embeddings}/BgeSmallEndpoint.java (76%) create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{ => embeddings}/MiniLMEndpoint.java (65%) create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{OpenAiEndpoint.java => llm/OpenAiChatEndpoint.java} (73%) rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{ => supabase}/SupabaseEndpoint.java (96%) rename FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/{ => wiki}/WikiEndpoint.java (78%) create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java create mode 100644 FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java diff --git a/Examples/airtable/AirtableExample.java b/Examples/airtable/AirtableExample.java new file mode 100644 index 000000000..2a24d9a83 --- /dev/null +++ b/Examples/airtable/AirtableExample.java @@ -0,0 +1,133 @@ +package com.edgechain; + +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import com.edgechain.lib.integration.airtable.query.AirtableQueryBuilder; +import com.edgechain.lib.integration.airtable.query.SortOrder; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.ArkResponse; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import dev.fuxing.airtable.AirtableRecord; +import dev.fuxing.airtable.formula.AirtableFormula; +import dev.fuxing.airtable.formula.LogicalOperator; +import org.json.JSONObject; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.builder.SpringApplicationBuilder; +import org.springframework.web.bind.annotation.*; + +import java.util.Properties; + +/** + * For the purpose, of this example, create a simple table using AirTable i.e, "Speakers" + * Following are some basic fields: "speaker_name", "designation", "organization", "biography", "speaker_photo", "rating + * To get, your BASE_ID of specific database; use the following API https://api.airtable.com/v0/meta/bases + * --header 'Authorization: Bearer YOUR_PERSONAL_ACCESS_TOKEN' + * You can create, complex tables using Airtable; also define relationships b/w tables via lookups. + */ +@SpringBootApplication +public class AirtableExample { + private static final String AIRTABLE_API_KEY = ""; + private static final String AIRTABLE_BASE_ID = ""; + + private static AirtableEndpoint airtableEndpoint; + + public static void main(String[] args) { + + System.setProperty("server.port", "8080"); + + Properties properties = new Properties(); + properties.setProperty("cors.origins", "http://localhost:4200"); + + new SpringApplicationBuilder(AirtableExample.class).properties(properties).run(args); + + airtableEndpoint = new AirtableEndpoint(AIRTABLE_BASE_ID, AIRTABLE_API_KEY); + } + + @RestController + @RequestMapping("/airtable") + public class AirtableController { + + @GetMapping("/findAll") + public ArkResponse findAll(ArkRequest arkRequest) { + + int pageSize = arkRequest.getIntQueryParam("pageSize"); + String sortSpeakerName = arkRequest.getQueryParam("sortName"); + String offset = arkRequest.getQueryParam("offset"); + + AirtableQueryBuilder queryBuilder = new AirtableQueryBuilder(); + queryBuilder.pageSize(pageSize); // pageSize --> no. of records in each request + queryBuilder.sort("speaker_name", SortOrder.fromValue(sortSpeakerName).getValue()); + queryBuilder.offset(offset); // move to next page by passing offset returned in response; + + // Return only those speakers which have rating Greater Than Eq to 3; + queryBuilder.filterByFormula( + LogicalOperator.GTE, + AirtableFormula.Object.field("rating"), + AirtableFormula.Object.value(3)); + + return new EdgeChain<>(airtableEndpoint.findAll("Speakers", queryBuilder)).getArkResponse(); + } + + @GetMapping("/find") + public ArkResponse findById(ArkRequest arkRequest) { + String id = arkRequest.getQueryParam("id"); + return new EdgeChain<>(airtableEndpoint.findById("Speakers", id)).getArkResponse(); + } + + @PostMapping("/create") + public ArkResponse create(ArkRequest arkRequest) { + + JSONObject body = arkRequest.getBody(); + String speakerName = body.getString("name"); + String designation = body.getString("designation"); + int rating = body.getInt("rating"); + String organization = body.getString("organization"); + String biography = body.getString("biography"); + + // Airtable API doesn't allow to upload blob files directly; therefore, you would require to + // upload it + // to some cloud storage i.e, S3 and then set the URL in Airtable. + + AirtableRecord record = new AirtableRecord(); + record.putField("speaker_name", speakerName); + record.putField("designation", designation); + record.putField("rating", rating); + record.putField("organization", organization); + record.putField("biography", biography); + + return new EdgeChain<>(airtableEndpoint.create("Speakers", record)).getArkResponse(); + } + + @PostMapping("/update") + public ArkResponse update(ArkRequest arkRequest) { + + JSONObject body = arkRequest.getBody(); + String id = body.getString("id"); + String speakerName = body.getString("name"); + String designation = body.getString("designation"); + int rating = body.getInt("rating"); + String organization = body.getString("organization"); + String biography = body.getString("biography"); + + // Airtable API doesn't allow to upload blob files directly; therefore, you would require to + // upload it + // to some cloud storage i.e, S3 and then set the URL in Airtable. + + AirtableRecord record = new AirtableRecord(); + record.setId(id); + record.putField("speaker_name", speakerName); + record.putField("designation", designation); + record.putField("rating", rating); + record.putField("organization", organization); + record.putField("biography", biography); + + return new EdgeChain<>(airtableEndpoint.update("Speakers", record)).getArkResponse(); + } + + @DeleteMapping("/delete") + public ArkResponse delete(ArkRequest arkRequest) { + JSONObject body = arkRequest.getBody(); + String id = body.getString("id"); + return new EdgeChain<>(airtableEndpoint.delete("Speakers", id)).getArkResponse(); + } + } +} diff --git a/Examples/code-interpreter/CodeInterpreter.java b/Examples/code-interpreter/CodeInterpreter.java index 084876a39..cce993b95 100644 --- a/Examples/code-interpreter/CodeInterpreter.java +++ b/Examples/code-interpreter/CodeInterpreter.java @@ -4,6 +4,7 @@ import java.util.concurrent.TimeUnit; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import org.json.JSONException; import org.json.JSONObject; import org.springframework.boot.autoconfigure.SpringBootApplication; @@ -13,7 +14,6 @@ import org.springframework.web.bind.annotation.RestController; import com.edgechain.lib.codeInterpreter.Eval; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -28,7 +28,7 @@ public class CodeInterpreter { private static final String OPENAI_AUTH_KEY = ""; - private static OpenAiEndpoint userChatEndpoint; + private static OpenAiChatEndpoint userChatEndpoint; private static final ObjectMapper objectMapper = new ObjectMapper(); private static JsonnetLoader loader = new FileJsonnetLoader("./code-interpreter/code-interpreter.jsonnet"); @@ -48,7 +48,7 @@ public double interpret(ArkRequest arkRequest) throws JSONException { JSONObject json = arkRequest.getBody(); userChatEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, "gpt-3.5-turbo", diff --git a/Examples/json/JsonFormat.java b/Examples/json/JsonFormat.java index d1d7f5aaf..2e2423138 100644 --- a/Examples/json/JsonFormat.java +++ b/Examples/json/JsonFormat.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import org.json.JSONObject; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; @@ -17,7 +18,6 @@ import org.springframework.web.client.RestTemplate; import com.edgechain.lib.constants.EndpointConstants; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; import com.edgechain.lib.jsonFormat.request.FunctionRequest; import com.edgechain.lib.jsonFormat.request.Message; import com.edgechain.lib.jsonFormat.request.OpenApiFunctionRequest; @@ -41,7 +41,7 @@ public class JsonFormat { // need only for situation endpoint private static final String OPENAI_ORG_ID = ""; - private static OpenAiEndpoint userChatEndpoint; + private static OpenAiChatEndpoint userChatEndpoint; private static JsonnetLoader loader = new FileJsonnetLoader("./json/json-format.jsonnet"); private static JsonnetLoader functionLoader = new FileJsonnetLoader("./json/function.jsonnet"); private static final ObjectMapper objectMapper = new ObjectMapper(); @@ -84,7 +84,7 @@ public class ExampleController { public String extract(ArkRequest arkRequest) { userChatEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, "gpt-3.5-turbo", @@ -129,8 +129,8 @@ public String situation(ArkRequest arkRequest) { JSONObject json = arkRequest.getBody(); - OpenAiEndpoint userChat = - new OpenAiEndpoint( + OpenAiChatEndpoint userChat = + new OpenAiChatEndpoint( EndpointConstants.OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, diff --git a/Examples/pinecone/PineconeExample.java b/Examples/pinecone/PineconeExample.java index bdc92396b..2cf8b9725 100644 --- a/Examples/pinecone/PineconeExample.java +++ b/Examples/pinecone/PineconeExample.java @@ -6,9 +6,11 @@ import com.edgechain.lib.chains.PineconeRetrieval; import com.edgechain.lib.context.domain.HistoryContext; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; + +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -18,7 +20,6 @@ import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.response.ArkResponse; import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; -import com.edgechain.lib.rxjava.retry.impl.FixedDelay; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import java.io.IOException; import java.io.InputStream; @@ -36,21 +37,15 @@ public class PineconeExample { private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID private static final String PINECONE_AUTH_KEY = ""; - private static final String PINECONE_QUERY_API = ""; - private static final String PINECONE_UPSERT_API = ""; - private static final String PINECONE_DELETE = ""; - - private static OpenAiEndpoint ada002Embedding; - private static OpenAiEndpoint gpt3Endpoint; - private static OpenAiEndpoint gpt3StreamEndpoint; + private static final String PINECONE_API = ""; // Only API + private static OpenAiChatEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; - private static PineconeEndpoint upsertPineconeEndpoint; - private static PineconeEndpoint queryPineconeEndpoint; - - private static PineconeEndpoint deletePineconeEndpoint; + private static PineconeEndpoint pineconeEndpoint; private static RedisHistoryContextEndpoint contextEndpoint; + // It's recommended to perform localized instantiation for thread-safe approach. private JsonnetLoader queryLoader = new FileJsonnetLoader("./pinecone/pinecone-query.jsonnet"); private JsonnetLoader chatLoader = new FileJsonnetLoader("./pinecone/pinecone-chat.jsonnet"); @@ -66,7 +61,7 @@ public static void main(String[] args) { // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port", "12285"); + properties.setProperty("redis.port",""); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); @@ -74,23 +69,15 @@ public static void main(String[] args) { // If you want to use PostgreSQL only; then just provide dbHost, dbUsername & dbPassword. // If you haven't specified PostgreSQL, then logs won't be stored. properties.setProperty("postgres.db.host", ""); - properties.setProperty("postgres.db.username", ""); + properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); - new SpringApplicationBuilder(PineconeExample.class).properties(properties).run(args); - // Variables Initialization ==> Endpoints must be intialized in main method... - ada002Embedding = - new OpenAiEndpoint( - OPENAI_EMBEDDINGS_API, - OPENAI_AUTH_KEY, - OPENAI_ORG_ID, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + new SpringApplicationBuilder(PineconeExample.class).properties(properties).run(args); gpt3Endpoint = - new OpenAiEndpoint( - OPENAI_CHAT_COMPLETION_API, + new OpenAiChatEndpoint( + OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, "gpt-3.5-turbo", @@ -99,40 +86,30 @@ public static void main(String[] args) { new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); gpt3StreamEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, "gpt-3.5-turbo", "user", - 0.85, + 0.7, true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - upsertPineconeEndpoint = - new PineconeEndpoint( - PINECONE_UPSERT_API, - PINECONE_AUTH_KEY, - "machine-learning", // Passing namespace; read more on Pinecone documentation. You can - // pass empty string + OpenAiEmbeddingEndpoint ada002 = new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); - queryPineconeEndpoint = + pineconeEndpoint = new PineconeEndpoint( - PINECONE_QUERY_API, + PINECONE_API, PINECONE_AUTH_KEY, - "machine-learning", // Passing namespace; read more on Pinecone documentation. You can - // pass empty string + ada002, new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); - deletePineconeEndpoint = - new PineconeEndpoint( - PINECONE_DELETE, - PINECONE_AUTH_KEY, - "machine-learning", // Passing namespace; read more on Pinecone documentation. You can - // pass empty string - new FixedDelay(4, 5, TimeUnit.SECONDS)); - contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); } @@ -181,17 +158,11 @@ public class PineconeController { // Namespace is optional (if not provided, it will be using Empty String "") @PostMapping("/pinecone/upsert") // /v1/examples/openai/upsert?namespace=machine-learning public void upsertPinecone(ArkRequest arkRequest) throws IOException { + String namespace = arkRequest.getQueryParam("namespace"); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - String[] arr = pdfReader.readByChunkSize(file, 512); - - /** - * Retrieval Class is basically used to generate embeddings & upsert it to VectorDB; If OpenAI - * Embedding Endpoint is not provided; then Doc2Vec constructor is used If the model is not - * provided, then it will emit an error - */ PineconeRetrieval retrieval = - new PineconeRetrieval(arr, ada002Embedding, upsertPineconeEndpoint, arkRequest); + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); retrieval.upsert(); } @@ -201,17 +172,13 @@ public ArkResponse query(ArkRequest arkRequest) { String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); + String namespace = arkRequest.getQueryParam("namespace"); - // Step 1: Chain ==> Get Embeddings From Input & Then Query To Pinecone - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Step 2: Chain ==> Query Embeddings from Pinecone + // Chain 1 ==> Query Embeddings from Pinecone EdgeChain> queryChain = - new EdgeChain<>(queryPineconeEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(pineconeEndpoint.query(query, namespace, topK, arkRequest)); - // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to - // OpenAI + // Chain ===> Our queryFn passes takes list and passes each response with base prompt EdgeChain> gpt3Chain = queryChain.transform(wordEmbeddings -> queryFn(wordEmbeddings, arkRequest)); @@ -231,6 +198,7 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); boolean stream = arkRequest.getBooleanHeader("stream"); + String namespace = arkRequest.getQueryParam("namespace"); // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -246,22 +214,19 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1 ==> Get Embeddings From Input - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from Pinecone & Then concatenate it (preparing for prompt) - // let's say topK=5; then we concatenate List into a string using String.join method + // Chain 1 ==> Query Embeddings from Pinecone & Then concatenate it (preparing for prompt) EdgeChain> pineconeChain = - new EdgeChain<>(queryPineconeEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(pineconeEndpoint.query(query, namespace, topK, arkRequest)); - // Chain 3 ===> Transform String of Queries into List + // Chain 2 ===> Transform String of Queries into List + // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain queryChain = new EdgeChain<>(pineconeChain) .transform( pineconeResponse -> { + List wordEmbeddings = pineconeResponse.get(); List queryList = new ArrayList<>(); - pineconeResponse.get().forEach(q -> queryList.add(q.getId())); + wordEmbeddings.forEach(q -> queryList.add(q.getId())); return String.join("\n", queryList); }); @@ -330,7 +295,8 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { // Namespace is optional (if not provided, it will be using Empty String "") @DeleteMapping("/pinecone/deleteAll") public ArkResponse deletePinecone(ArkRequest arkRequest) { - return new EdgeChain<>(deletePineconeEndpoint.deleteAll()).getArkResponse(); + String namespace = arkRequest.getQueryParam("namespace"); + return new EdgeChain<>(pineconeEndpoint.deleteAll(namespace)).getArkResponse(); } public List queryFn( diff --git a/Examples/postgresql/PostgreSQLExample.java b/Examples/postgresql/PostgreSQLExample.java index 0759abd07..28febf9ca 100644 --- a/Examples/postgresql/PostgreSQLExample.java +++ b/Examples/postgresql/PostgreSQLExample.java @@ -5,8 +5,10 @@ import com.edgechain.lib.chains.PostgresRetrieval; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.*; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; import com.edgechain.lib.index.enums.PostgresLanguage; @@ -37,13 +39,12 @@ public class PostgreSQLExample { private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID - - private static OpenAiEndpoint ada002Embedding; - private static OpenAiEndpoint gpt3Endpoint; - private static OpenAiEndpoint gpt3StreamEndpoint; + private static OpenAiChatEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; private static PostgresEndpoint postgresEndpoint; private static PostgreSQLHistoryContextEndpoint contextEndpoint; + // For thread safe, instantitate it in methods... private JsonnetLoader queryLoader = new FileJsonnetLoader("./postgres/postgres-query.jsonnet"); private JsonnetLoader chatLoader = new FileJsonnetLoader("./postgres/postgres-chat.jsonnet"); @@ -68,17 +69,8 @@ public static void main(String[] args) { new SpringApplicationBuilder(PostgreSQLExample.class).properties(properties).run(args); - // Variables Initialization ==> Endpoints must be intialized in main method... - ada002Embedding = - new OpenAiEndpoint( - OPENAI_EMBEDDINGS_API, - OPENAI_AUTH_KEY, - OPENAI_ORG_ID, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); - gpt3Endpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -88,7 +80,7 @@ public static void main(String[] args) { new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); gpt3StreamEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -98,10 +90,18 @@ public static void main(String[] args) { true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + OpenAiEmbeddingEndpoint adaEmbedding = new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + // Defining tablename and namespace... postgresEndpoint = new PostgresEndpoint( - "pg_vectors", "machine-learning", new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); + "pg_vectors", "machine-learning", adaEmbedding, new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); + contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -160,7 +160,6 @@ public void upsert(ArkRequest arkRequest) throws IOException { PostgresRetrieval retrieval = new PostgresRetrieval( arr, - ada002Embedding, postgresEndpoint, 1536, filename, @@ -183,18 +182,16 @@ public ArkResponse query(ArkRequest arkRequest) { String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - // Chain 1==> Get Embeddings From Input & Then Query To PostgreSQL - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from PostgreSQL + // Chain 1==> Query Embeddings from PostgreSQL EdgeChain> queryChain = new EdgeChain<>( postgresEndpoint.query( - embeddingsChain.get(), + List.of(query), PostgresDistanceMetric.COSINE, topK, - 10)); // defining probes + topK, + 10, + arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -226,24 +223,21 @@ public ArkResponse chat(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - - // Chain 1 ==> Get Embeddings From Input - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - // Chain 2 ==> Query Embeddings from PostgreSQL & Then concatenate it (preparing for prompt) - // let's say topK=5; then we concatenate List into a string using String.join method + EdgeChain> postgresChain = new EdgeChain<>( - postgresEndpoint.query(embeddingsChain.get(), PostgresDistanceMetric.COSINE, topK)); + postgresEndpoint.query(List.of(query), PostgresDistanceMetric.COSINE, topK, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List + // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain queryChain = new EdgeChain<>(postgresChain) .transform( postgresResponse -> { + List postgresWordEmbeddingsList = postgresResponse.get(); List queryList = new ArrayList<>(); - postgresResponse.get().forEach(q -> queryList.add(q.getRawText())); + postgresWordEmbeddingsList.forEach(q -> queryList.add(q.getRawText())); return String.join("\n", queryList); }); diff --git a/Examples/react-chain/ReactChainApplication.java b/Examples/react-chain/ReactChainApplication.java index 01a21ea6b..60c159907 100644 --- a/Examples/react-chain/ReactChainApplication.java +++ b/Examples/react-chain/ReactChainApplication.java @@ -1,6 +1,7 @@ package com.edgechain; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; + +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -24,7 +25,7 @@ public class ReactChainApplication { private static final String OPENAI_AUTH_KEY = ""; private static final String OPENAI_ORG_ID = ""; - private static OpenAiEndpoint userChatEndpoint; + private static OpenAiChatEndpoint userChatEndpoint; private static JsonnetLoader loader = new FileJsonnetLoader("./react-chain/react-chain.jsonnet"); public static void main(String[] args) { @@ -47,7 +48,7 @@ public static void main(String[] args) { new SpringApplicationBuilder(ReactChainApplication.class).properties(properties).run(args); userChatEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, diff --git a/Examples/redis/RedisExample.java b/Examples/redis/RedisExample.java index a9bfdce0c..2185a74a7 100644 --- a/Examples/redis/RedisExample.java +++ b/Examples/redis/RedisExample.java @@ -7,7 +7,10 @@ import com.edgechain.lib.chunk.enums.LangType; import com.edgechain.lib.context.domain.HistoryContext; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.*; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; @@ -35,10 +38,9 @@ public class RedisExample { private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID - private static OpenAiEndpoint ada002Embedding; - private static OpenAiEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3Endpoint; - private static OpenAiEndpoint gpt3StreamEndpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; private static RedisEndpoint redisEndpoint; private static RedisHistoryContextEndpoint contextEndpoint; @@ -59,30 +61,23 @@ public static void main(String[] args) { // If you want to use PostgreSQL only; then just provide dbHost, dbUsername & dbPassword. // If you haven't specified PostgreSQL, then logs won't be stored. - properties.setProperty("postgres.db.host", ""); - properties.setProperty("postgres.db.username", ""); + properties.setProperty( + "postgres.db.host", ""); + properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); + // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port", "12285"); + properties.setProperty("redis.port","12285"); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); new SpringApplicationBuilder(RedisExample.class).properties(properties).run(args); - // Variables Initialization ==> Endpoints must be intialized in main method... - ada002Embedding = - new OpenAiEndpoint( - OPENAI_EMBEDDINGS_API, - OPENAI_AUTH_KEY, - OPENAI_ORG_ID, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); - gpt3Endpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -92,7 +87,7 @@ public static void main(String[] args) { new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); gpt3StreamEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -102,9 +97,18 @@ public static void main(String[] args) { true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + OpenAiEmbeddingEndpoint ada002Endpoint = + new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + redisEndpoint = new RedisEndpoint( - "vector_index", "machine-learning", new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + "vector_index", "machine-learning", ada002Endpoint, new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); } @@ -163,7 +167,7 @@ public void upsert(ArkRequest arkRequest) throws IOException { */ RedisRetrieval retrieval = new RedisRetrieval( - arr, ada002Embedding, redisEndpoint, 1536, RedisDistanceMetric.COSINE, arkRequest); + arr, redisEndpoint, 1536, RedisDistanceMetric.COSINE, arkRequest); retrieval.upsert(); } @@ -179,13 +183,8 @@ public ArkResponse similaritySearch(ArkRequest arkRequest) { String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - // Chain 1 ==> Generate Embeddings Using Ada002 - EdgeChain ada002Chain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Pass those embeddings to Redis & Return Score/values (Similarity search) - EdgeChain> redisQueries = - new EdgeChain<>(redisEndpoint.query(ada002Chain.get(), topK)); + // Chain 1 ==> Pass those embeddings to Redis & Return Score/values (Similarity search) + EdgeChain> redisQueries = new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); return redisQueries.getArkResponse(); } @@ -196,13 +195,8 @@ public ArkResponse queryRedis(ArkRequest arkRequest) { String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - // Chain 1==> Get Embeddings From Input & Then Query To Redis - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from Redis - EdgeChain> queryChain = - new EdgeChain<>(redisEndpoint.query(embeddingsChain.get(), topK)); + // Chain 1 ==> Query Embeddings from Redis + EdgeChain> queryChain = new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -233,22 +227,21 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1 ==> Get Embeddings From Input - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - // Chain 2 ==> Query Embeddings from Redis & Then concatenate it (preparing for prompt) - // let's say topK=5; then we concatenate List into a string using String.join method + // Chain 1==> Query Embeddings from Redis & Then concatenate it (preparing for prompt) + EdgeChain> redisChain = - new EdgeChain<>(redisEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List + // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain queryChain = new EdgeChain<>(redisChain) .transform( redisResponse -> { + List wordEmbeddings = redisResponse.get(); List queryList = new ArrayList<>(); - redisResponse.get().forEach(q -> queryList.add(q.getId())); + wordEmbeddings.forEach(q -> queryList.add(q.getId())); return String.join("\n", queryList); }); diff --git a/Examples/supabase-miniLM/SupabaseMiniLMExample.java b/Examples/supabase-miniLM/SupabaseMiniLMExample.java index de0135ffa..2a227d7a2 100644 --- a/Examples/supabase-miniLM/SupabaseMiniLMExample.java +++ b/Examples/supabase-miniLM/SupabaseMiniLMExample.java @@ -2,9 +2,11 @@ import com.edgechain.lib.chains.PostgresRetrieval; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; -import com.edgechain.lib.endpoint.impl.*; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; import com.edgechain.lib.index.enums.PostgresLanguage; @@ -41,17 +43,15 @@ public class SupabaseMiniLMExample { private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID - private static OpenAiEndpoint gpt3Endpoint; - private static OpenAiEndpoint gpt3StreamEndpoint; + private static OpenAiChatEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; private static PostgresEndpoint postgresEndpoint; private static PostgreSQLHistoryContextEndpoint contextEndpoint; - private static MiniLMEndpoint miniLMEndpoint; - private JsonnetLoader queryLoader = - new FileJsonnetLoader("./supabase-miniLM/postgres-query.jsonnet"); + new FileJsonnetLoader("./supabase-miniLM/postgres-query.jsonnet"); private JsonnetLoader chatLoader = - new FileJsonnetLoader("./supabase-miniLM/postgres-chat.jsonnet"); + new FileJsonnetLoader("./supabase-miniLM/postgres-chat.jsonnet"); public static void main(String[] args) { @@ -74,13 +74,14 @@ public static void main(String[] args) { // For DB config properties.setProperty("postgres.db.host", ""); - properties.setProperty("postgres.db.username", ""); + properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); + new SpringApplicationBuilder(SupabaseMiniLMExample.class).properties(properties).run(args); gpt3Endpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -90,7 +91,7 @@ public static void main(String[] args) { new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); gpt3StreamEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -105,13 +106,13 @@ public static void main(String[] args) { // download on fly. // All the requests will wait until the model is download & loaded once into the application... // As you can see, the model is not download; so it will download on fly... - miniLMEndpoint = new MiniLMEndpoint(MiniLMModel.ALL_MINILM_L12_V2); + MiniLMEndpoint miniLMEndpoint = new MiniLMEndpoint(MiniLMModel.ALL_MINILM_L12_V2); // Creating PostgresEndpoint ==> We create a new table because miniLM supports 384 dimensional // vectors; postgresEndpoint = new PostgresEndpoint( - "minilm_vectors", "minilm-ns", new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); + "minilm_vectors", "minilm-ns", miniLMEndpoint, new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -172,7 +173,6 @@ public void upsert(ArkRequest arkRequest) throws IOException { PostgresRetrieval retrieval = new PostgresRetrieval( arr, - miniLMEndpoint, postgresEndpoint, 384, filename, @@ -196,14 +196,10 @@ public ArkResponse queryPostgres(ArkRequest arkRequest) { String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - // Chain 1==> Get Embeddings From Input using MiniLM & Then Query To PostgreSQL - EdgeChain embeddingsChain = - new EdgeChain<>(miniLMEndpoint.embeddings(query, arkRequest)); - // Chain 2 ==> Query Embeddings from PostgreSQL EdgeChain> queryChain = new EdgeChain<>( - postgresEndpoint.query(embeddingsChain.get(), PostgresDistanceMetric.L2, topK)); + postgresEndpoint.query(List.of(query), PostgresDistanceMetric.L2, topK, topK,arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -237,23 +233,20 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1 ==> Get Embeddings From Input using MiniLM - EdgeChain embeddingsChain = - new EdgeChain<>(miniLMEndpoint.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from PostgreSQL & Then concatenate it (preparing for prompt) + // Chain 1 ==> Query Embeddings from PostgreSQL & Then concatenate it (preparing for prompt) // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain> postgresChain = new EdgeChain<>( - postgresEndpoint.query(embeddingsChain.get(), PostgresDistanceMetric.L2, topK)); + postgresEndpoint.query(List.of(query), PostgresDistanceMetric.L2, topK, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List EdgeChain queryChain = new EdgeChain<>(postgresChain) .transform( postgresResponse -> { + List postgresWordEmbeddingsList = postgresResponse.get(); List queryList = new ArrayList<>(); - postgresResponse.get().forEach(q -> queryList.add(q.getRawText())); + postgresWordEmbeddingsList.forEach(q -> queryList.add(q.getRawText())); return String.join("\n", queryList); }); diff --git a/Examples/wiki/WikiExample.java b/Examples/wiki/WikiExample.java index 3bebb8dbd..feedb984a 100644 --- a/Examples/wiki/WikiExample.java +++ b/Examples/wiki/WikiExample.java @@ -1,7 +1,7 @@ package com.edgechain; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -31,9 +31,9 @@ public class WikiExample { private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID /* Step 3: Create OpenAiEndpoint to communicate with OpenAiServices; */ - private static OpenAiEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3Endpoint; - private static OpenAiEndpoint gpt3StreamEndpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; private static WikiEndpoint wikiEndpoint; @@ -62,7 +62,7 @@ public static void main(String[] args) { wikiEndpoint = new WikiEndpoint(); gpt3Endpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, @@ -72,7 +72,7 @@ public static void main(String[] args) { new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); gpt3StreamEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, diff --git a/Examples/zapier/ZapierExample.java b/Examples/zapier/ZapierExample.java new file mode 100644 index 000000000..29c6f3265 --- /dev/null +++ b/Examples/zapier/ZapierExample.java @@ -0,0 +1,276 @@ +package com.edgechain; + +//DEPS com.amazonaws:aws-java-sdk-s3:1.12.554 + +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; +import com.amazonaws.services.s3.model.ListObjectsV2Request; +import com.amazonaws.services.s3.model.ListObjectsV2Result; +import com.amazonaws.services.s3.model.S3Object; +import com.amazonaws.services.s3.model.S3ObjectSummary; +import com.edgechain.lib.chains.PineconeRetrieval; +import com.edgechain.lib.chunk.Chunker; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.ArkResponse; +import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import me.xuender.unidecode.Unidecode; +import org.apache.commons.io.IOUtils; +import org.json.JSONArray; +import org.json.JSONObject; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.builder.SpringApplicationBuilder; +import org.springframework.context.annotation.Bean; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.util.retry.Retry; +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static com.edgechain.lib.constants.EndpointConstants.OPENAI_EMBEDDINGS_API; + +/** + * Objective (1): + * 1. Use Zapier Webhook to pass list of urls (in our example, we have used Wikipedia Urls") ==> Trigger + * 2. Action: Use 'Web Page Parser by Zapier' to extract entire content (which is later used to upsert in Pinecone) from url + * 3. Action: Stringify the JSON + * 4. Action: Save each json (parsed content from each URL) to Amazon S3 + * (This process is entirely automated by Zapier). You would need to create Zap for it. Zapier Hook is used to trigger the ETL process, + * (Parallelize Hook Requests)... + * ========================================================== + * 5. Then, we extract each file from Amazon S3 + * 6. Upsert the normalized content to Pinecone with a chunkSize of 512... + * You can choose any file storage S3, Dropbox, Google Drive etc.... + */ + +/** + * Objective (2): Extracting PDF via PDF4Me. + * Create a Zapier by using the following steps: + * 1. Trigger ==> Integrate Google Drive Folder; when new file is added (it's not instant; it's scheduled internally by Zapier.) + * (You can also trigger it by Webhook as well) + * 2. Action ==> Extract text from PDF using PDF4Me (Free plan allows 20 API calls) + * 3. Action ==> Use ZapierByCode to stringify the json response from PDF4Me + * 4. Action ==> Save it to Amazon S3 + * Now, from EdgeChains SDK we extract the files from S3 & upsert it to Pinecone via Chunk Size 512. + */ +@SpringBootApplication +public class ZapierExample { + + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID + private static final String ZAPIER_HOOK_URL = ""; // E.g. https://hooks.zapier.com/hooks/catch/18785910/2ia657b + + private static final String PINECONE_AUTH_KEY = ""; + private static final String PINECONE_API = ""; + + + private static PineconeEndpoint pineconeEndpoint; + + public static void main(String[] args) { + + System.setProperty("server.port", "8080"); + + new SpringApplicationBuilder(ZapierExample.class).run(args); + + OpenAiEmbeddingEndpoint adaEmbedding = new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + + pineconeEndpoint = + new PineconeEndpoint( + PINECONE_API, + PINECONE_AUTH_KEY, + adaEmbedding, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + } + + + @Bean + public AmazonS3 s3Client() { + + String accessKey = ""; // YOUR_AWS_S3_ACCESS_KEY + String secretKey = ""; // YOUR_AWS_S3_SECRET_KEY + + BasicAWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey); + return AmazonS3ClientBuilder + .standard() + .withRegion(Regions.fromName("us-east-1")) + .withCredentials(new AWSStaticCredentialsProvider(awsCredentials)) + .build(); + } + + @RestController + public class ZapierController { + + @Autowired + private AmazonS3 s3Client; + + // List of wiki urls triggered in parallel via WebHook. They are automatically parsed, + // transformed to JSON and stored in AWS S3. + /* + Examples: + "https://en.wikipedia.org/wiki/The_Weather_Company", + "https://en.wikipedia.org/wiki/Microsoft_Bing", + "https://en.wikipedia.org/wiki/Quora", + "https://en.wikipedia.org/wiki/Steve_Jobs", + "https://en.wikipedia.org/wiki/Michael_Jordan", + */ + @PostMapping("/etl") + public void performETL(ArkRequest arkRequest) { + + JSONObject body = arkRequest.getBody(); + JSONArray jsonArray = body.getJSONArray("urls"); + + IntStream.range(0, jsonArray.length()) + .parallel() + .forEach( + index -> { + String url = jsonArray.getString(index); + // For Logging + System.out.printf("Url %s: %s\n", index, url); + + // Trigger Zapier WebHook + this.zapWebHook(url); + }); + } + + @PostMapping("/upsert-urls") + public void upsertParsedURLs(ArkRequest arkRequest) throws IOException { + String namespace = arkRequest.getQueryParam("namespace"); + JSONObject body = arkRequest.getBody(); + String bucketName = body.getString("bucketName"); + + // Get all the files from S3 bucket + ListObjectsV2Request listObjectsRequest = new ListObjectsV2Request() + .withBucketName(bucketName); + + ListObjectsV2Result objectListing = s3Client.listObjectsV2(listObjectsRequest); + + for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) { + String key = objectSummary.getKey(); + + if (key.endsWith(".txt")) { + + S3Object object = s3Client.getObject(bucketName, key); + InputStream objectData = object.getObjectContent(); + String content = IOUtils.toString(objectData, StandardCharsets.UTF_8); + + JSONObject jsonObject = new JSONObject(content); + + // These fields are specified from Zapier Action... + System.out.println("Domain: " + jsonObject.getString("domain")); // en.wikipedia.org + System.out.println("Title: " + jsonObject.getString("title")); // Barack Obama + System.out.println("Author: " + jsonObject.getString("author")); // Wikipedia Contributers + System.out.println("Word Count: " + jsonObject.get("word_count")); // 23077 + System.out.println("Date Published: " + jsonObject.getString("date_published")); // Publish date + + // Now, we extract content and Chunk it by 512 size; then upsert it to Pinecone + // Normalize the extracted content.... + + String normalizedText = Unidecode.decode(jsonObject.getString("content")).replaceAll("[\t\n\r]+", " "); + Chunker chunker = new Chunker(normalizedText); + String[] arr = chunker.byChunkSize(512); + + // Upsert to Pinecone: + PineconeRetrieval retrieval = + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); + + retrieval.upsert(); + + System.out.println("File is parsed: " + key); // For Logging + + } + } + + } + + @PostMapping("/upsert-pdfs") + public void upsertPDFs(ArkRequest arkRequest) throws IOException { + String namespace = arkRequest.getQueryParam("namespace"); + JSONObject body = arkRequest.getBody(); + + String bucketName = body.getString("bucketName"); + + // Get all the files from S3 bucket + ListObjectsV2Request listObjectsRequest = new ListObjectsV2Request() + .withBucketName(bucketName); + + ListObjectsV2Result objectListing = s3Client.listObjectsV2(listObjectsRequest); + + for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) { + String key = objectSummary.getKey(); + + if (key.endsWith("-pdf.txt")) { + + S3Object object = s3Client.getObject(bucketName, key); + InputStream objectData = object.getObjectContent(); + String content = IOUtils.toString(objectData, StandardCharsets.UTF_8); + + JSONObject jsonObject = new JSONObject(content); + + // These fields are specified from Zapier Action... + // Now, we extract content and Chunk it by 512 size; then upsert it to Pinecone + // Normalize the extracted content.... + + System.out.println("Filename: " + jsonObject.getString("filename")); // abcd.pdf + System.out.println("Extension: " + jsonObject.getString("file_extension")); // pdf + + String normalizedText = Unidecode.decode(jsonObject.getString("text")).replaceAll("[\t\n\r]+", " "); + Chunker chunker = new Chunker(normalizedText); + String[] arr = chunker.byChunkSize(512); + + // Upsert to Pinecone: + PineconeRetrieval retrieval = + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); + + retrieval.upsert(); + + System.out.println("File is parsed: " + key); // For Logging + + } + } + + } + + @DeleteMapping("/pinecone/deleteAll") + public ArkResponse deletePinecone(ArkRequest arkRequest) { + String namespace = arkRequest.getQueryParam("namespace"); + return new EdgeChain<>(pineconeEndpoint.deleteAll(namespace)).getArkResponse(); + } + + private void zapWebHook(String url) { + + WebClient webClient = WebClient.builder().baseUrl(ZAPIER_HOOK_URL).build(); + + JSONObject json = new JSONObject(); + json.put("url", url); + + webClient + .post() + .contentType(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromValue(json.toString())) + .retrieve() + .bodyToMono(String.class) + .retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(20))) // Using Fixed Delay.. + .block(); + } + + } +} diff --git a/FlySpring/edgechain-app/pom.xml b/FlySpring/edgechain-app/pom.xml index a2c82a0f3..1d66b78e2 100644 --- a/FlySpring/edgechain-app/pom.xml +++ b/FlySpring/edgechain-app/pom.xml @@ -68,6 +68,12 @@ compile + + dev.fuxing + airtable-api + 0.3.2 + + javax.validation validation-api diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java index ca2697cea..efd348e8b 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java @@ -13,7 +13,6 @@ import org.springframework.web.servlet.handler.HandlerMappingIntrospector; @SpringBootApplication -@EnableScheduling public class EdgeChainApplication { private static final Logger logger = LoggerFactory.getLogger(EdgeChainApplication.class); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java index d3b545790..76ff81560 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java @@ -1,11 +1,10 @@ package com.edgechain.lib.chains; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.EmbeddingEndpoint; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.request.ArkRequest; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Observable; @@ -20,28 +19,27 @@ public class PineconeRetrieval { private final PineconeEndpoint pineconeEndpoint; private final ArkRequest arkRequest; - private final EmbeddingEndpoint embeddingEndpoint; - private final String[] arr; - + private String namespace; private int batchSize = 30; public PineconeRetrieval( String[] arr, - EmbeddingEndpoint embeddingEndpoint, PineconeEndpoint pineconeEndpoint, + String namespace, ArkRequest arkRequest) { this.pineconeEndpoint = pineconeEndpoint; - this.embeddingEndpoint = embeddingEndpoint; this.arkRequest = arkRequest; this.arr = arr; + this.namespace = namespace; + Logger logger = LoggerFactory.getLogger(getClass()); - if (embeddingEndpoint instanceof OpenAiEndpoint openAiEndpoint) + if (pineconeEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (embeddingEndpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (pineconeEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); - else if (embeddingEndpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) + else if (pineconeEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); } @@ -64,11 +62,11 @@ public void upsert() { } private WordEmbeddings generateEmbeddings(String input) { - return embeddingEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); + return pineconeEndpoint.getEmbeddingEndpoint().embeddings(input, arkRequest).firstOrError().blockingGet(); } private void executeBatchUpsert(List wordEmbeddingsList) { - pineconeEndpoint.batchUpsert(wordEmbeddingsList); + pineconeEndpoint.batchUpsert(wordEmbeddingsList, this.namespace); } public int getBatchSize() { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java index 8f89a7717..75a4b0769 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java @@ -1,11 +1,10 @@ package com.edgechain.lib.chains; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.EmbeddingEndpoint; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.enums.PostgresDistanceMetric; import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.request.ArkRequest; @@ -36,15 +35,12 @@ public class PostgresRetrieval { private final ArkRequest arkRequest; private final PostgresEndpoint postgresEndpoint; - private final EmbeddingEndpoint embeddingEndpoint; - private final int dimensions; private final PostgresDistanceMetric metric; private final int lists; public PostgresRetrieval( String[] arr, - EmbeddingEndpoint embeddingEndpoint, PostgresEndpoint postgresEndpoint, int dimensions, PostgresDistanceMetric metric, @@ -55,7 +51,6 @@ public PostgresRetrieval( this.arr = arr; this.filename = filename; this.postgresEndpoint = postgresEndpoint; - this.embeddingEndpoint = embeddingEndpoint; this.postgresLanguage = postgresLanguage; this.arkRequest = arkRequest; @@ -63,17 +58,16 @@ public PostgresRetrieval( this.metric = metric; this.lists = lists; - if (embeddingEndpoint instanceof OpenAiEndpoint openAiEndpoint) + if (postgresEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (embeddingEndpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); - else if (embeddingEndpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); } public PostgresRetrieval( String[] arr, - EmbeddingEndpoint embeddingEndpoint, PostgresEndpoint postgresEndpoint, int dimensions, String filename, @@ -81,20 +75,18 @@ public PostgresRetrieval( ArkRequest arkRequest) { this.arr = arr; this.filename = filename; - this.arkRequest = arkRequest; this.postgresLanguage = postgresLanguage; this.postgresEndpoint = postgresEndpoint; - this.embeddingEndpoint = embeddingEndpoint; - this.dimensions = dimensions; this.metric = PostgresDistanceMetric.COSINE; this.lists = 1000; + this.arkRequest = arkRequest; - if (embeddingEndpoint instanceof OpenAiEndpoint openAiEndpoint) + if (postgresEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (embeddingEndpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); - else if (embeddingEndpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); } @@ -126,7 +118,7 @@ public List upsert() { } private WordEmbeddings generateEmbeddings(String input) { - return embeddingEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); + return postgresEndpoint.getEmbeddingEndpoint().embeddings(input, arkRequest).firstOrError().blockingGet(); } private void upsertAndCollectIds( diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java index 26eea4ac7..f9c230a15 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java @@ -1,11 +1,11 @@ package com.edgechain.lib.chains; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.EmbeddingEndpoint; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; import com.edgechain.lib.request.ArkRequest; import io.reactivex.rxjava3.core.Completable; @@ -19,7 +19,6 @@ public class RedisRetrieval { private final RedisEndpoint redisEndpoint; private final ArkRequest arkRequest; - private final EmbeddingEndpoint embeddingEndpoint; private final String[] arr; private final int dimension; private final RedisDistanceMetric metric; @@ -27,24 +26,22 @@ public class RedisRetrieval { public RedisRetrieval( String[] arr, - EmbeddingEndpoint embeddingEndpoint, RedisEndpoint redisEndpoint, int dimension, RedisDistanceMetric metric, ArkRequest arkRequest) { this.redisEndpoint = redisEndpoint; - this.embeddingEndpoint = embeddingEndpoint; this.dimension = dimension; this.metric = metric; this.arkRequest = arkRequest; this.arr = arr; Logger logger = LoggerFactory.getLogger(getClass()); - if (embeddingEndpoint instanceof OpenAiEndpoint openAiEndpoint) + if (redisEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (embeddingEndpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (redisEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); - else if (embeddingEndpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) + else if (redisEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); } @@ -70,7 +67,7 @@ public void upsert() { } private WordEmbeddings generateEmbeddings(String input) { - return embeddingEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); + return redisEndpoint.getEmbeddingEndpoint().embeddings(input, arkRequest).firstOrError().blockingGet(); } private void executeBatchUpsert(List wordEmbeddingsList) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java index 8770f9db3..e5ab71034 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java @@ -3,7 +3,7 @@ import com.edgechain.lib.context.client.HistoryContextClient; import com.edgechain.lib.context.client.repositories.PostgreSQLHistoryContextRepository; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; import java.time.LocalDateTime; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java index 79af3fbfd..c12006cf2 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java @@ -2,7 +2,7 @@ import com.edgechain.lib.context.client.HistoryContextClient; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; import java.time.LocalDateTime; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java index 1a44017c9..470ec6de3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java @@ -1,7 +1,7 @@ package com.edgechain.lib.controllers; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.impl.FixedDelay; import org.json.JSONObject; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java index bae87f46f..0352f04a0 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java @@ -1,6 +1,6 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java index 950855fc2..42658fddb 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java @@ -1,6 +1,6 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; import org.springframework.web.bind.annotation.DeleteMapping; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java index 9f956768a..d9cc61f0a 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java @@ -1,7 +1,7 @@ package com.edgechain.lib.controllers; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.impl.FixedDelay; import org.json.JSONObject; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java index 557cdb394..f81cd3d1a 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java @@ -1,6 +1,6 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.endpoint.impl.SupabaseEndpoint; +import com.edgechain.lib.endpoint.impl.supabase.SupabaseEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.supabase.response.AuthenticatedResponse; import com.edgechain.lib.supabase.response.SupabaseUser; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java index a08949103..46b5f32f9 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java @@ -12,7 +12,7 @@ public class WordEmbeddings implements ArkObject, Serializable { private static final long serialVersionUID = 2210956496609994219L; private String id; private List values; - private String score; + private Double score; public WordEmbeddings() {} @@ -26,13 +26,13 @@ public WordEmbeddings(String id, List values) { this.values = values; } - public WordEmbeddings(String id, List values, String score) { + public WordEmbeddings(String id, List values, Double score) { this.id = id; this.values = values; this.score = score; } - public WordEmbeddings(String id, String score) { + public WordEmbeddings(String id, Double score) { this.id = id; this.score = score; } @@ -49,7 +49,7 @@ public void setValues(List values) { this.values = values; } - public String getScore() { + public Double getScore() { return score; } @@ -57,7 +57,7 @@ public void setId(String id) { this.id = id; } - public void setScore(String score) { + public void setScore(Double score) { this.score = score; } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java index e3e6e68ef..6cc92c672 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java @@ -16,7 +16,7 @@ import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; import java.io.IOException; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java index c64f0bc38..b12ec6af0 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java @@ -9,7 +9,7 @@ import ai.djl.training.util.ProgressBar; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; import com.edgechain.lib.embeddings.miniLLM.response.MiniLMResponse; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; import org.springframework.stereotype.Service; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/EmbeddingEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/EmbeddingEndpoint.java deleted file mode 100644 index fb8b8012e..000000000 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/EmbeddingEndpoint.java +++ /dev/null @@ -1,43 +0,0 @@ -package com.edgechain.lib.endpoint; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.request.ArkRequest; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; - -public abstract class EmbeddingEndpoint extends Endpoint { - - private String rawText; - - public EmbeddingEndpoint() {} - - public EmbeddingEndpoint(RetryPolicy retryPolicy) { - super(retryPolicy); - } - - public EmbeddingEndpoint(String url) { - super(url); - } - - public EmbeddingEndpoint(String url, RetryPolicy retryPolicy) { - super(url, retryPolicy); - } - - public EmbeddingEndpoint(String url, String apiKey) { - super(url, apiKey); - } - - public EmbeddingEndpoint(String url, String apiKey, RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); - } - - public abstract Observable embeddings(String input, ArkRequest arkRequest); - - public void setRawText(String rawText) { - this.rawText = rawText; - } - - public String getRawText() { - return rawText; - } -} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java index 5128f9894..2bdc54901 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java @@ -36,6 +36,14 @@ public Endpoint(String url, String apiKey, RetryPolicy retryPolicy) { this.retryPolicy = retryPolicy; } + public void setUrl(String url) { + this.url = url; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + public String getApiKey() { return this.apiKey; } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java deleted file mode 100644 index 1f573d236..000000000 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java +++ /dev/null @@ -1,97 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.retrofit.PineconeService; -import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; -import retrofit2.Retrofit; - -import java.util.List; - -public class PineconeEndpoint extends Endpoint { - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final PineconeService pineconeService = retrofit.create(PineconeService.class); - - private String namespace; - - // Getters; - private WordEmbeddings wordEmbedding; - - private List wordEmbeddingsList; - - private int topK; - - public PineconeEndpoint() {} - - public PineconeEndpoint(String url, String apiKey) { - super(url, apiKey); - } - - public PineconeEndpoint(String url, String apiKey, RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); - } - - public PineconeEndpoint(String url, String apiKey, String namespace) { - super(url, apiKey); - this.namespace = namespace; - } - - public PineconeEndpoint(String url, String apiKey, String namespace, RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); - this.namespace = namespace; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - // Getters - - public WordEmbeddings getWordEmbedding() { - return wordEmbedding; - } - - public List getWordEmbeddingsList() { - return wordEmbeddingsList; - } - - public void setWordEmbeddings(WordEmbeddings wordEmbedding) { - this.wordEmbedding = wordEmbedding; - } - - public int getTopK() { - return topK; - } - - public void setTopK(int topK) { - this.topK = topK; - } - - public StringResponse upsert(WordEmbeddings wordEmbeddings) { - this.wordEmbedding = wordEmbeddings; - return this.pineconeService.upsert(this).blockingGet(); - } - - public StringResponse batchUpsert(List wordEmbeddingsList) { - this.wordEmbeddingsList = wordEmbeddingsList; - return this.pineconeService.batchUpsert(this).blockingGet(); - } - - public Observable> query(WordEmbeddings wordEmbeddings, int topK) { - this.wordEmbedding = wordEmbeddings; - this.topK = topK; - return Observable.fromSingle(this.pineconeService.query(this)); - } - - public StringResponse deleteAll() { - return this.pineconeService.deleteAll(this).blockingGet(); - } -} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java deleted file mode 100644 index bd9510e8e..000000000 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java +++ /dev/null @@ -1,385 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.index.domain.PostgresWordEmbeddings; -import com.edgechain.lib.index.domain.RRFWeight; -import com.edgechain.lib.index.enums.OrderRRFBy; -import com.edgechain.lib.index.enums.PostgresDistanceMetric; -import com.edgechain.lib.index.enums.PostgresLanguage; -import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.retrofit.PostgresService; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; -import retrofit2.Retrofit; - -import java.util.List; - -public class PostgresEndpoint extends Endpoint { - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final PostgresService postgresService = retrofit.create(PostgresService.class); - private String tableName; - private int lists; - - private String id; - private String namespace; - - private String filename; - - // Getters - private WordEmbeddings wordEmbedding; - - private List wordEmbeddingsList; - - private PostgresDistanceMetric metric; - private int dimensions; - private int topK; - - private int probes; - private String embeddingChunk; - - // Fields for metadata table - private List metadataTableNames; - private String metadata; - private String metadataId; - private List metadataList; - private String documentDate; - - /** RRF * */ - private RRFWeight textWeight; - - private RRFWeight similarityWeight; - private RRFWeight dateWeight; - - private OrderRRFBy orderRRFBy; - private String searchQuery; - - private PostgresLanguage postgresLanguage; - - // Join Table - private List idList; - - public PostgresEndpoint() {} - - public PostgresEndpoint(RetryPolicy retryPolicy) { - super(retryPolicy); - } - - public PostgresEndpoint(String tableName) { - this.tableName = tableName; - } - - public PostgresEndpoint(String tableName, String namespace) { - this.tableName = tableName; - this.namespace = namespace; - } - - public PostgresEndpoint(String tableName, List metadataTableNames) { - this.tableName = tableName; - this.metadataTableNames = metadataTableNames; - } - - public PostgresEndpoint(String tableName, RetryPolicy retryPolicy) { - super(retryPolicy); - this.tableName = tableName; - } - - public PostgresEndpoint(String tableName, String namespace, RetryPolicy retryPolicy) { - super(retryPolicy); - this.tableName = tableName; - this.namespace = namespace; - } - - public String getTableName() { - return tableName; - } - - public String getNamespace() { - return namespace; - } - - public void setTableName(String tableName) { - this.tableName = tableName; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - public void setMetadataTableName(List metadataTableNames) { - this.metadataTableNames = metadataTableNames; - } - - public void setMetadata(String metadata) { - this.metadata = metadata; - } - - public void setId(String id) { - this.id = id; - } - - public void setMetadataId(String metadataId) { - this.metadataId = metadataId; - } - - public void setMetadataList(List metadataList) { - this.metadataList = metadataList; - } - - public void setEmbeddingChunk(String embeddingChunk) { - this.embeddingChunk = embeddingChunk; - } - - // Getters - - public WordEmbeddings getWordEmbedding() { - return wordEmbedding; - } - - public int getDimensions() { - return dimensions; - } - - public int getTopK() { - return topK; - } - - public String getFilename() { - return filename; - } - - public List getWordEmbeddingsList() { - return wordEmbeddingsList; - } - - public PostgresDistanceMetric getMetric() { - return metric; - } - - public int getLists() { - return lists; - } - - public int getProbes() { - return probes; - } - - public List getMetadataTableNames() { - return metadataTableNames; - } - - public String getMetadata() { - return metadata; - } - - public String getMetadataId() { - return metadataId; - } - - public String getId() { - return id; - } - - public List getMetadataList() { - return metadataList; - } - - public String getEmbeddingChunk() { - return embeddingChunk; - } - - public String getDocumentDate() { - return documentDate; - } - - public RRFWeight getTextWeight() { - return textWeight; - } - - public RRFWeight getSimilarityWeight() { - return similarityWeight; - } - - public RRFWeight getDateWeight() { - return dateWeight; - } - - public OrderRRFBy getOrderRRFBy() { - return orderRRFBy; - } - - public String getSearchQuery() { - return searchQuery; - } - - public PostgresLanguage getPostgresLanguage() { - return postgresLanguage; - } - - public List getIdList() { - return idList; - } - - public StringResponse upsert( - WordEmbeddings wordEmbeddings, - String filename, - int dimension, - PostgresDistanceMetric metric) { - this.wordEmbedding = wordEmbeddings; - this.dimensions = dimension; - this.filename = filename; - this.metric = metric; - return this.postgresService.upsert(this).blockingGet(); - } - - public StringResponse createTable(int dimensions, PostgresDistanceMetric metric, int lists) { - this.dimensions = dimensions; - this.metric = metric; - this.lists = lists; - return this.postgresService.createTable(this).blockingGet(); - } - - public StringResponse createMetadataTable(String metadataTableName) { - this.metadataTableNames = List.of(metadataTableName); - return this.postgresService.createMetadataTable(this).blockingGet(); - } - - public List upsert( - List wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) { - this.wordEmbeddingsList = wordEmbeddingsList; - this.filename = filename; - this.postgresLanguage = postgresLanguage; - return this.postgresService.batchUpsert(this).blockingGet(); - } - - public StringResponse insertMetadata( - String metadataTableName, String metadata, String documentDate) { - this.metadata = metadata; - this.documentDate = documentDate; - this.metadataTableNames = List.of(metadataTableName); - return this.postgresService.insertMetadata(this).blockingGet(); - } - - public List batchInsertMetadata(List metadataList) { - this.metadataList = metadataList; - return this.postgresService.batchInsertMetadata(this).blockingGet(); - } - - public StringResponse insertIntoJoinTable( - String metadataTableName, String id, String metadataId) { - this.id = id; - this.metadataId = metadataId; - this.metadataTableNames = List.of(metadataTableName); - return this.postgresService.insertIntoJoinTable(this).blockingGet(); - } - - public StringResponse batchInsertIntoJoinTable( - String metadataTableName, List idList, String metadataId) { - this.idList = idList; - this.metadataId = metadataId; - this.metadataTableNames = List.of(metadataTableName); - return this.postgresService.batchInsertIntoJoinTable(this).blockingGet(); - } - - public Observable> query( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) { - this.wordEmbeddingsList = List.of(wordEmbeddings); - this.topK = topK; - this.metric = metric; - this.probes = 1; - return Observable.fromSingle(this.postgresService.query(this)); - } - - public Observable> query( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK, int probes) { - this.wordEmbeddingsList = List.of(wordEmbeddings); - this.topK = topK; - this.metric = metric; - this.probes = probes; - return Observable.fromSingle(this.postgresService.query(this)); - } - - public Observable> query( - List wordEmbeddingsList, - PostgresDistanceMetric metric, - int topK, - int probes) { - this.wordEmbeddingsList = wordEmbeddingsList; - this.metric = metric; - this.probes = probes; - this.topK = topK; - return Observable.fromSingle(this.postgresService.query(this)); - } - - public Observable> queryRRF( - String metadataTable, - List wordEmbeddingsList, - RRFWeight textWeight, - RRFWeight similarityWeight, - RRFWeight dateWeight, - OrderRRFBy orderRRFBy, - String searchQuery, - PostgresLanguage postgresLanguage, - int probes, - PostgresDistanceMetric metric, - int topK) { - this.metadataTableNames = List.of(metadataTable); - this.wordEmbeddingsList = wordEmbeddingsList; - this.textWeight = textWeight; - this.similarityWeight = similarityWeight; - this.dateWeight = dateWeight; - this.orderRRFBy = orderRRFBy; - this.searchQuery = searchQuery; - this.postgresLanguage = postgresLanguage; - this.probes = probes; - this.metric = metric; - this.topK = topK; - return Observable.fromSingle(this.postgresService.queryRRF(this)); - } - - public Observable> queryWithMetadata( - List metadataTableNames, - WordEmbeddings wordEmbeddings, - PostgresDistanceMetric metric, - int topK) { - this.metadataTableNames = metadataTableNames; - this.wordEmbedding = wordEmbeddings; - this.topK = topK; - this.metric = metric; - this.probes = 1; - return Observable.fromSingle(this.postgresService.queryWithMetadata(this)); - } - - public Observable> queryWithMetadata( - List metadataTableNames, - WordEmbeddings wordEmbeddings, - PostgresDistanceMetric metric, - int topK, - int probes) { - this.metadataTableNames = metadataTableNames; - this.wordEmbedding = wordEmbeddings; - this.topK = topK; - this.metric = metric; - this.probes = probes; - return Observable.fromSingle(this.postgresService.queryWithMetadata(this)); - } - - public Observable> getSimilarMetadataChunk(String embeddingChunk) { - this.embeddingChunk = embeddingChunk; - return Observable.fromSingle(this.postgresService.getSimilarMetadataChunk(this)); - } - - public Observable> getAllChunks(String tableName, String filename) { - this.tableName = tableName; - this.filename = filename; - return Observable.fromSingle(this.postgresService.getAllChunks(this)); - } - - public StringResponse deleteAll(String tableName, String namespace) { - this.tableName = tableName; - this.namespace = namespace; - return this.postgresService.deleteAll(this).blockingGet(); - } -} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java deleted file mode 100644 index 57fcab9d5..000000000 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java +++ /dev/null @@ -1,145 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.retrofit.RedisService; -import com.edgechain.lib.index.enums.RedisDistanceMetric; -import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; -import retrofit2.Retrofit; -import java.util.List; - -public class RedisEndpoint extends Endpoint { - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final RedisService redisService = retrofit.create(RedisService.class); - - private String indexName; - private String namespace; - - // Getters; - private WordEmbeddings wordEmbedding; - private List wordEmbeddingsList; - - private int dimensions; - - private RedisDistanceMetric metric; - - private int topK; - - private String pattern; - - public RedisEndpoint() {} - - public RedisEndpoint(RetryPolicy retryPolicy) { - super(retryPolicy); - } - - public RedisEndpoint(String indexName) { - this.indexName = indexName; - } - - public RedisEndpoint(String indexName, RetryPolicy retryPolicy) { - super(retryPolicy); - this.indexName = indexName; - } - - public RedisEndpoint(String indexName, String namespace) { - this.indexName = indexName; - this.namespace = namespace; - } - - public RedisEndpoint(String indexName, String namespace, RetryPolicy retryPolicy) { - super(retryPolicy); - this.indexName = indexName; - this.namespace = namespace; - } - - public String getIndexName() { - return indexName; - } - - public void setIndexName(String indexName) { - this.indexName = indexName; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - // Getters - public WordEmbeddings getWordEmbedding() { - return wordEmbedding; - } - - public void setWordEmbeddings(WordEmbeddings wordEmbedding) { - this.wordEmbedding = wordEmbedding; - } - - public int getDimensions() { - return dimensions; - } - - public void setDimensions(int dimensions) { - this.dimensions = dimensions; - } - - public RedisDistanceMetric getMetric() { - return metric; - } - - public List getWordEmbeddingsList() { - return wordEmbeddingsList; - } - - public void setMetric(RedisDistanceMetric metric) { - this.metric = metric; - } - - public int getTopK() { - return topK; - } - - public void setTopK(int topK) { - this.topK = topK; - } - - public String getPattern() { - return pattern; - } - - // Convenience Methods - public StringResponse createIndex(String namespace, int dimension, RedisDistanceMetric metric) { - this.dimensions = dimension; - this.metric = metric; - this.namespace = namespace; - return this.redisService.createIndex(this).blockingGet(); - } - - public void batchUpsert(List wordEmbeddingsList) { - this.wordEmbeddingsList = wordEmbeddingsList; - this.redisService.batchUpsert(this).ignoreElement().blockingAwait(); - } - - public StringResponse upsert(WordEmbeddings wordEmbeddings) { - this.wordEmbedding = wordEmbeddings; - return this.redisService.upsert(this).blockingGet(); - } - - public Observable> query(WordEmbeddings embeddings, int topK) { - this.topK = topK; - this.wordEmbedding = embeddings; - return Observable.fromSingle(this.redisService.query(this)); - } - - public void delete(String patternName) { - this.pattern = patternName; - this.redisService.deleteByPattern(this).blockingAwait(); - } -} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgreSQLHistoryContextEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/PostgreSQLHistoryContextEndpoint.java similarity index 96% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgreSQLHistoryContextEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/PostgreSQLHistoryContextEndpoint.java index 2fb095412..c99a331da 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgreSQLHistoryContextEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/PostgreSQLHistoryContextEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.context; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisHistoryContextEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/RedisHistoryContextEndpoint.java similarity index 96% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisHistoryContextEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/RedisHistoryContextEndpoint.java index 1a420722c..9dcfcfb75 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisHistoryContextEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/RedisHistoryContextEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.context; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java similarity index 76% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java index a26b5034e..630d85d10 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java @@ -1,7 +1,7 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.embeddings; +import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.EmbeddingEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.BgeSmallService; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; @@ -15,6 +15,8 @@ import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; import java.util.Objects; + +import org.modelmapper.ModelMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,14 +27,14 @@ public class BgeSmallEndpoint extends EmbeddingEndpoint { private final BgeSmallService bgeSmallService = RetrofitClientInstance.getInstance().create(BgeSmallService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String modelUrl; private String tokenizerUrl; - private String callIdentifier; - public static final String MODEL_FOLDER = "./model"; - static final String MODEL_PATH = MODEL_FOLDER + "/model.onnx"; - static final String TOKENIZER_PATH = MODEL_FOLDER + "/tokenizer.json"; + public static final String MODEL_PATH = MODEL_FOLDER + "/model.onnx"; + public static final String TOKENIZER_PATH = MODEL_FOLDER + "/tokenizer.json"; public BgeSmallEndpoint() {} @@ -69,8 +71,12 @@ public String getTokenizerUrl() { return tokenizerUrl; } - public String getCallIdentifier() { - return callIdentifier; + public void setModelUrl(String modelUrl) { + this.modelUrl = modelUrl; + } + + public void setTokenizerUrl(String tokenizerUrl) { + this.tokenizerUrl = tokenizerUrl; } public BgeSmallEndpoint(RetryPolicy retryPolicy, String modelUrl, String tokenizerUrl) { @@ -81,13 +87,14 @@ public BgeSmallEndpoint(RetryPolicy retryPolicy, String modelUrl, String tokeniz @Override public Observable embeddings(String input, ArkRequest arkRequest) { - setRawText(input); + BgeSmallEndpoint mapper = modelMapper.map(this, BgeSmallEndpoint.class); + mapper.setRawText(input); - if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); - else this.callIdentifier = "URI wasn't provided"; + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); return Observable.fromSingle( - bgeSmallService.embeddings(this).map(m -> new WordEmbeddings(input, m.getEmbedding()))); + bgeSmallService.embeddings(mapper).map(m -> new WordEmbeddings(input, m.getEmbedding()))); } private void downloadFile(String urlStr, String path) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java new file mode 100644 index 000000000..7d9cf991d --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java @@ -0,0 +1,69 @@ +package com.edgechain.lib.endpoint.impl.embeddings; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.reactivex.rxjava3.core.Observable; + +import java.io.Serializable; + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type" +) +@JsonSubTypes({ + @JsonSubTypes.Type(value = OpenAiEmbeddingEndpoint.class, name = "type1"), + @JsonSubTypes.Type(value = MiniLMEndpoint.class, name = "type2"), + @JsonSubTypes.Type(value = BgeSmallEndpoint.class, name = "type3"), +}) +public abstract class EmbeddingEndpoint extends Endpoint implements Serializable { + + private static final long serialVersionUID = 4201794264326630184L; + private String callIdentifier; + private String rawText; + + public EmbeddingEndpoint() {} + public EmbeddingEndpoint(RetryPolicy retryPolicy) { + super(retryPolicy); + } + + public EmbeddingEndpoint(String url) { + super(url); + } + + public EmbeddingEndpoint(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + } + + public EmbeddingEndpoint(String url, String apiKey) { + super(url, apiKey); + } + + public EmbeddingEndpoint(String url, String apiKey, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + } + + public abstract Observable embeddings(String input, ArkRequest arkRequest); + + public void setRawText(String rawText) { + this.rawText = rawText; + } + + public void setCallIdentifier(String callIdentifier) { + this.callIdentifier = callIdentifier; + } + + public String getRawText() { + return rawText; + } + + public String getCallIdentifier() { + return callIdentifier; + } + +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java similarity index 65% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java index 8876cbfff..878a89af3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java @@ -1,8 +1,8 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.embeddings; +import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; -import com.edgechain.lib.endpoint.EmbeddingEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.MiniLMService; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; @@ -10,33 +10,31 @@ import java.util.Objects; import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import retrofit2.Retrofit; public class MiniLMEndpoint extends EmbeddingEndpoint { - private Logger logger = LoggerFactory.getLogger(MiniLMEndpoint.class); - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); private final MiniLMService miniLMService = retrofit.create(MiniLMService.class); + private ModelMapper modelMapper = new ModelMapper(); private MiniLMModel miniLMModel; - private String callIdentifier; - public MiniLMEndpoint() {} public MiniLMEndpoint(MiniLMModel miniLMModel) { this.miniLMModel = miniLMModel; } - public MiniLMModel getMiniLMModel() { - return miniLMModel; + public void setMiniLMModel(MiniLMModel miniLMModel) { + this.miniLMModel = miniLMModel; } - public String getCallIdentifier() { - return callIdentifier; + public MiniLMModel getMiniLMModel() { + return miniLMModel; } public MiniLMEndpoint(RetryPolicy retryPolicy, MiniLMModel miniLMModel) { @@ -46,16 +44,14 @@ public MiniLMEndpoint(RetryPolicy retryPolicy, MiniLMModel miniLMModel) { @Override public Observable embeddings(String input, ArkRequest arkRequest) { - setRawText(input); - if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); - else this.callIdentifier = "URI wasn't provided"; + MiniLMEndpoint mapper = modelMapper.map(this, MiniLMEndpoint.class); + mapper.setRawText(input); - if (Objects.nonNull(arkRequest)) { - this.callIdentifier = arkRequest.getRequestURI(); - } + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); return Observable.fromSingle( - miniLMService.embeddings(this).map(m -> new WordEmbeddings(input, m.getEmbedding()))); + miniLMService.embeddings(mapper).map(m -> new WordEmbeddings(input, m.getEmbedding()))); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java new file mode 100644 index 000000000..cb3ed95ed --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java @@ -0,0 +1,71 @@ +package com.edgechain.lib.endpoint.impl.embeddings; + +import com.edgechain.lib.configuration.context.ApplicationContextHolder; +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.OpenAiService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.Objects; + +public class OpenAiEmbeddingEndpoint extends EmbeddingEndpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final OpenAiService openAiService = retrofit.create(OpenAiService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String orgId; + private String model; + + public OpenAiEmbeddingEndpoint() {} + + public OpenAiEmbeddingEndpoint(String url, String apiKey, String orgId, String model) { + super(url, apiKey); + this.orgId = orgId; + this.model = model; + } + + public OpenAiEmbeddingEndpoint(String url, String apiKey,String orgId, String model, RetryPolicy retryPolicy ) { + super(url, apiKey, retryPolicy); + this.orgId = orgId; + this.model = model; + } + + public String getModel() { + return model; + } + + public String getOrgId() { + return orgId; + } + + public void setOrgId(String orgId) { + this.orgId = orgId; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Observable embeddings(String input, ArkRequest arkRequest) { + + OpenAiEmbeddingEndpoint mapper = modelMapper.map(this, OpenAiEmbeddingEndpoint.class); + mapper.setRawText(input); + + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); + + return Observable.fromSingle( + openAiService + .embeddings(mapper) + .map( + embeddingResponse -> + new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java new file mode 100644 index 000000000..ecd86ad9f --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java @@ -0,0 +1,154 @@ +package com.edgechain.lib.endpoint.impl.index; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.PineconeService; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; + +public class PineconeEndpoint extends Endpoint { + + private static final String QUERY_API = "/query"; + private static final String UPSERT_API = "/vectors/upsert"; + private static final String DELETE_API = "/vectors/delete"; + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final PineconeService pineconeService = retrofit.create(PineconeService.class); + private ModelMapper modelMapper = new ModelMapper(); + + private String originalUrl; + private String namespace; + + // Getters; + private WordEmbeddings wordEmbedding; + + private List wordEmbeddingsList; + + private int topK; + + private EmbeddingEndpoint embeddingEndpoint; + + public PineconeEndpoint() {} + + + public PineconeEndpoint(String url, String apiKey, EmbeddingEndpoint embeddingEndpoint) { + super(url, apiKey); + this.originalUrl = url; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PineconeEndpoint(String url, String apiKey, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + this.embeddingEndpoint = embeddingEndpoint; + this.originalUrl = url; + } + + public PineconeEndpoint(String url, String apiKey, String namespace, EmbeddingEndpoint embeddingEndpoint) { + super(url, apiKey); + this.originalUrl = url; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PineconeEndpoint(String url, String apiKey, String namespace, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + this.originalUrl = url; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public String getNamespace() { + return namespace; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + // Getters + + public void setOriginalUrl(String originalUrl) { + this.originalUrl = originalUrl; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } + + private void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + private void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + public int getTopK() { + return topK; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public String getOriginalUrl() { + return originalUrl; + } + + private void setTopK(int topK) { + this.topK = topK; + } + + public StringResponse upsert(WordEmbeddings wordEmbedding, String namespace) { + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setWordEmbedding(wordEmbedding); + mapper.setUrl(mapper.getOriginalUrl().concat(UPSERT_API)); + mapper.setNamespace(namespace); + + return this.pineconeService.upsert(mapper).blockingGet(); + } + + public StringResponse batchUpsert(List wordEmbeddingsList, String namespace) { + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + mapper.setUrl(mapper.getOriginalUrl().concat(UPSERT_API)); + mapper.setNamespace(namespace); + + return this.pineconeService.batchUpsert(mapper).blockingGet(); + } + + public Observable> query(String query, String namespace, int topK, ArkRequest arkRequest) { + WordEmbeddings wordEmbeddings = new EdgeChain<>(getEmbeddingEndpoint().embeddings(query, arkRequest)).get(); + + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setUrl(mapper.getOriginalUrl().concat(QUERY_API)); + mapper.setNamespace(namespace); + mapper.setTopK(topK); + return Observable.fromSingle(this.pineconeService.query(mapper)); + } + + public StringResponse deleteAll(String namespace) { + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setUrl(mapper.getOriginalUrl().concat(DELETE_API)); + mapper.setNamespace(namespace); + return this.pineconeService.deleteAll(mapper).blockingGet(); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java new file mode 100644 index 000000000..0a1805557 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java @@ -0,0 +1,559 @@ +package com.edgechain.lib.endpoint.impl.index; + +import com.edgechain.lib.configuration.context.ApplicationContextHolder; +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.index.domain.PostgresWordEmbeddings; +import com.edgechain.lib.index.domain.RRFWeight; +import com.edgechain.lib.index.enums.OrderRRFBy; +import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.retrofit.PostgresService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.schedulers.Schedulers; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; + +public class PostgresEndpoint extends Endpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final PostgresService postgresService = retrofit.create(PostgresService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String tableName; + private int lists; + + private String id; + private String namespace; + + private String filename; + + // Getters + private WordEmbeddings wordEmbedding; + + private List wordEmbeddingsList; + + private PostgresDistanceMetric metric; + private int dimensions; + private int topK; + private int upperLimit; + + private int probes; + private String embeddingChunk; + + // Fields for metadata table + private List metadataTableNames; + private String metadata; + private String metadataId; + private List metadataList; + private String documentDate; + + /** + * RRF * + */ + private RRFWeight textWeight; + + private RRFWeight similarityWeight; + private RRFWeight dateWeight; + + private OrderRRFBy orderRRFBy; + private String searchQuery; + + private PostgresLanguage postgresLanguage; + + // Join Table + private List idList; + + private EmbeddingEndpoint embeddingEndpoint; + + public PostgresEndpoint() { + } + + public PostgresEndpoint(RetryPolicy retryPolicy) { + super(retryPolicy); + } + + public PostgresEndpoint(String tableName, EmbeddingEndpoint embeddingEndpoint) { + this.tableName = tableName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint(String tableName, String namespace, EmbeddingEndpoint embeddingEndpoint) { + this.tableName = tableName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint(String tableName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint(String tableName, String namespace, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + + public String getTableName() { + return tableName; + } + + public String getNamespace() { + return namespace; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + public void setMetadata(String metadata) { + this.metadata = metadata; + } + + public void setId(String id) { + this.id = id; + } + + public void setMetadataId(String metadataId) { + this.metadataId = metadataId; + } + + public void setMetadataList(List metadataList) { + this.metadataList = metadataList; + } + + public void setEmbeddingChunk(String embeddingChunk) { + this.embeddingChunk = embeddingChunk; + } + + public int getUpperLimit() { + return upperLimit; + } + + public void setUpperLimit(int upperLimit) { + this.upperLimit = upperLimit; + } + + private void setLists(int lists) { + this.lists = lists; + } + + private void setFilename(String filename) { + this.filename = filename; + } + + private void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + private void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + private void setMetric(PostgresDistanceMetric metric) { + this.metric = metric; + } + + private void setDimensions(int dimensions) { + this.dimensions = dimensions; + } + + private void setTopK(int topK) { + this.topK = topK; + } + + private void setProbes(int probes) { + this.probes = probes; + } + + private void setMetadataTableNames(List metadataTableNames) { + this.metadataTableNames = metadataTableNames; + } + + private void setDocumentDate(String documentDate) { + this.documentDate = documentDate; + } + + private void setTextWeight(RRFWeight textWeight) { + this.textWeight = textWeight; + } + + private void setSimilarityWeight(RRFWeight similarityWeight) { + this.similarityWeight = similarityWeight; + } + + private void setDateWeight(RRFWeight dateWeight) { + this.dateWeight = dateWeight; + } + + private void setOrderRRFBy(OrderRRFBy orderRRFBy) { + this.orderRRFBy = orderRRFBy; + } + + private void setSearchQuery(String searchQuery) { + this.searchQuery = searchQuery; + } + + private void setPostgresLanguage(PostgresLanguage postgresLanguage) { + this.postgresLanguage = postgresLanguage; + } + + private void setIdList(List idList) { + this.idList = idList; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } +// Getters + + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public int getDimensions() { + return dimensions; + } + + public int getTopK() { + return topK; + } + + public String getFilename() { + return filename; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } + + public PostgresDistanceMetric getMetric() { + return metric; + } + + public int getLists() { + return lists; + } + + public int getProbes() { + return probes; + } + + public List getMetadataTableNames() { + return metadataTableNames; + } + + public String getMetadata() { + return metadata; + } + + public String getMetadataId() { + return metadataId; + } + + public String getId() { + return id; + } + + public List getMetadataList() { + return metadataList; + } + + public String getEmbeddingChunk() { + return embeddingChunk; + } + + public String getDocumentDate() { + return documentDate; + } + + public RRFWeight getTextWeight() { + return textWeight; + } + + public RRFWeight getSimilarityWeight() { + return similarityWeight; + } + + public RRFWeight getDateWeight() { + return dateWeight; + } + + public OrderRRFBy getOrderRRFBy() { + return orderRRFBy; + } + + public String getSearchQuery() { + return searchQuery; + } + + + public PostgresLanguage getPostgresLanguage() { + return postgresLanguage; + } + + public List getIdList() { + return idList; + } + + public StringResponse upsert( + WordEmbeddings wordEmbeddings, + String filename, + int dimension, + PostgresDistanceMetric metric) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setFilename(filename); + mapper.setDimensions(dimension); + mapper.setMetric(metric); + return this.postgresService.upsert(mapper).blockingGet(); + } + + public StringResponse createTable(int dimensions, PostgresDistanceMetric metric, int lists) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setDimensions(dimensions); + mapper.setMetric(metric); + mapper.setLists(lists); + + return this.postgresService.createTable(mapper).blockingGet(); + } + + public StringResponse createMetadataTable(String metadataTableName) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.createMetadataTable(mapper).blockingGet(); + } + + public List upsert( + List wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + mapper.setFilename(filename); + mapper.setPostgresLanguage(postgresLanguage); + + return this.postgresService.batchUpsert(mapper).blockingGet(); + } + + public StringResponse insertMetadata( + String metadataTableName, String metadata, String documentDate) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadata(metadata); + mapper.setDocumentDate(documentDate); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.insertMetadata(mapper).blockingGet(); + } + + public List batchInsertMetadata(List metadataList) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataList(metadataList); + return this.postgresService.batchInsertMetadata(mapper).blockingGet(); + } + + public StringResponse insertIntoJoinTable( + String metadataTableName, String id, String metadataId) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setId(id); + mapper.setMetadataId(metadataId); + mapper.setMetadataTableNames(List.of(metadataTableName)); + + return this.postgresService.insertIntoJoinTable(mapper).blockingGet(); + } + + public StringResponse batchInsertIntoJoinTable( + String metadataTableName, List idList, String metadataId) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setIdList(idList); + mapper.setMetadataId(metadataId); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.batchInsertIntoJoinTable(mapper).blockingGet(); + } + + public Observable> query( + List inputList, PostgresDistanceMetric metric, int topK, int upperLimit, ArkRequest arkRequest) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) + .flatMap( + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.query(mapper)); + } + + public Observable> query( + List inputList, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + int probes, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) + .flatMap( + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setMetric(metric); + mapper.setProbes(probes); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + return Observable.fromSingle(this.postgresService.query(mapper)); + } + + public Observable> queryRRF( + String metadataTable, + List inputList, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + OrderRRFBy orderRRFBy, + String searchQuery, + PostgresLanguage postgresLanguage, + int probes, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(List.of(metadataTable)); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) + .flatMap( + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setTextWeight(textWeight); + mapper.setSimilarityWeight(similarityWeight); + mapper.setDateWeight(dateWeight); + mapper.setOrderRRFBy(orderRRFBy); + mapper.setSearchQuery(searchQuery); + mapper.setPostgresLanguage(postgresLanguage); + mapper.setProbes(probes); + mapper.setMetric(metric); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + return Observable.fromSingle(this.postgresService.queryRRF(mapper)); + } + + public Observable> queryWithMetadata( + List metadataTableNames, + WordEmbeddings wordEmbeddings, + PostgresDistanceMetric metric, + int topK) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(metadataTableNames); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setTopK(topK); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); + } + + public Observable> queryWithMetadata( + List metadataTableNames, + String input, + PostgresDistanceMetric metric, + int topK, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + WordEmbeddings wordEmbeddings = new EdgeChain<>(embeddingEndpoint.embeddings(input, arkRequest)).get(); + + mapper.setMetadataTableNames(metadataTableNames); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setTopK(topK); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); + } + + public Observable> getSimilarMetadataChunk(String embeddingChunk) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setEmbeddingChunk(embeddingChunk); + return Observable.fromSingle(this.postgresService.getSimilarMetadataChunk(mapper)); + } + + public Observable> getAllChunks(String tableName, String filename) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setTableName(tableName); + mapper.setFilename(filename); + + return Observable.fromSingle(this.postgresService.getAllChunks(this)); + } + + public StringResponse deleteAll(String tableName, String namespace) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setTableName(tableName); + mapper.setNamespace(namespace); + return this.postgresService.deleteAll(mapper).blockingGet(); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java new file mode 100644 index 000000000..2cead0c22 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java @@ -0,0 +1,185 @@ +package com.edgechain.lib.endpoint.impl.index; + +import com.edgechain.lib.configuration.context.ApplicationContextHolder; +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.RedisService; +import com.edgechain.lib.index.enums.RedisDistanceMetric; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; +import java.util.List; + +public class RedisEndpoint extends Endpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final RedisService redisService = retrofit.create(RedisService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String indexName; + private String namespace; + + // Getters; + private WordEmbeddings wordEmbedding; + private List wordEmbeddingsList; + + private int dimensions; + + private RedisDistanceMetric metric; + + private int topK; + + private String pattern; + + private EmbeddingEndpoint embeddingEndpoint; + + public RedisEndpoint() {} + + public RedisEndpoint(RetryPolicy retryPolicy) { + super(retryPolicy); + } + + public RedisEndpoint(String indexName, EmbeddingEndpoint embeddingEndpoint) { + this.indexName = indexName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public RedisEndpoint(String indexName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.indexName = indexName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public RedisEndpoint(String indexName, String namespace, EmbeddingEndpoint embeddingEndpoint) { + this.indexName = indexName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public RedisEndpoint(String indexName, String namespace, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.indexName = indexName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + + public void setPattern(String pattern) { + this.pattern = pattern; + } + + public String getIndexName() { + return indexName; + } + + public void setIndexName(String indexName) { + this.indexName = indexName; + } + + public String getNamespace() { + return namespace; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + // Getters + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + public int getDimensions() { + return dimensions; + } + + public void setDimensions(int dimensions) { + this.dimensions = dimensions; + } + + public void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + public RedisDistanceMetric getMetric() { + return metric; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } + + public void setMetric(RedisDistanceMetric metric) { + this.metric = metric; + } + + public int getTopK() { + return topK; + } + + public void setTopK(int topK) { + this.topK = topK; + } + + public String getPattern() { + return pattern; + } + + // Convenience Methods + public StringResponse createIndex(String namespace, int dimension, RedisDistanceMetric metric) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setDimensions(dimension); + mapper.setMetric(metric); + mapper.setNamespace(namespace); + + return this.redisService.createIndex(mapper).blockingGet(); + } + + public void batchUpsert(List wordEmbeddingsList) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + + this.redisService.batchUpsert(mapper).ignoreElement().blockingAwait(); + } + + public StringResponse upsert(WordEmbeddings wordEmbedding) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setWordEmbedding(wordEmbedding); + + return this.redisService.upsert(mapper).blockingGet(); + } + + public Observable> query(String input, int topK, ArkRequest arkRequest) { + + WordEmbeddings wordEmbedding = new EdgeChain<>(embeddingEndpoint.embeddings(input,arkRequest)).get(); + + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setTopK(topK); + mapper.setWordEmbedding(wordEmbedding); + return Observable.fromSingle(this.redisService.query(mapper)); + } + + public void delete(String patternName) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setPattern(patternName); + this.redisService.deleteByPattern(mapper).blockingAwait(); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java new file mode 100644 index 000000000..a68ca723c --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java @@ -0,0 +1,176 @@ +package com.edgechain.lib.endpoint.impl.integration; + +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.integration.airtable.query.AirtableQueryBuilder; +import com.edgechain.lib.retrofit.AirtableService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import dev.fuxing.airtable.AirtableRecord; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; +import java.util.Map; + +public class AirtableEndpoint extends Endpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final AirtableService airtableService = retrofit.create(AirtableService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String baseId; + private String apiKey; + + private List ids; + private String tableName; + private List airtableRecordList; + private boolean typecast = false; + + private AirtableQueryBuilder airtableQueryBuilder; + + public AirtableEndpoint() {} + + public AirtableEndpoint(String baseId, String apiKey) { + this.baseId = baseId; + this.apiKey = apiKey; + } + + public void setIds(List ids) { + this.ids = ids; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public void setAirtableRecordList(List airtableRecordList) { + this.airtableRecordList = airtableRecordList; + } + + public void setTypecast(boolean typecast) { + this.typecast = typecast; + } + + public void setAirtableQueryBuilder(AirtableQueryBuilder airtableQueryBuilder) { + this.airtableQueryBuilder = airtableQueryBuilder; + } + + public void setBaseId(String baseId) { + this.baseId = baseId; + } + + @Override + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseId() { + return baseId; + } + + public String getApiKey() { + return apiKey; + } + + public String getTableName() { + return tableName; + } + + public List getIds() { + return ids; + } + + public List getAirtableRecordList() { + return airtableRecordList; + } + + public boolean isTypecast() { + return typecast; + } + + public AirtableQueryBuilder getAirtableQueryBuilder() { + return airtableQueryBuilder; + } + + public Observable> findAll(String tableName, AirtableQueryBuilder builder){ + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class);; + mapper.setTableName(tableName); + mapper.setAirtableQueryBuilder(builder); + return Observable.fromSingle(this.airtableService.findAll(mapper)); + } + public Observable> findAll(String tableName){ + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setTableName(tableName); + mapper.setAirtableQueryBuilder(new AirtableQueryBuilder()); + + return Observable.fromSingle(this.airtableService.findAll(mapper)); + } + + public Observable findById(String tableName, String id){ + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setTableName(tableName); + mapper.setIds(List.of(id)); + return Observable.fromSingle(this.airtableService.findById(mapper)); + } + + public Observable> create(String tableName, List airtableRecordList) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> create(String tableName, List airtableRecordList, boolean typecast) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + mapper.setTypecast(typecast); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> create(String tableName, AirtableRecord airtableRecord) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(List.of(airtableRecord)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> update(String tableName, List airtableRecordList) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> update(String tableName, List airtableRecordList, boolean typecast) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + mapper.setTypecast(typecast); + + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> update(String tableName, AirtableRecord airtableRecord) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(List.of(airtableRecord)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> delete(String tableName, List ids) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setIds(ids); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.delete(mapper)); + } + + public Observable> delete(String tableName, String id) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setIds(List.of(id)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.delete(mapper)); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java similarity index 73% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java index 831a95636..50350f4e8 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java @@ -1,10 +1,10 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.llm; import com.edgechain.lib.configuration.context.ApplicationContextHolder; -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.EmbeddingEndpoint; +import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.openai.request.ChatMessage; +import com.edgechain.lib.openai.response.CompletionResponse; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.client.OpenAiStreamService; import com.edgechain.lib.retrofit.OpenAiService; @@ -12,13 +12,14 @@ import com.edgechain.lib.retrofit.client.RetrofitClientInstance; import com.edgechain.lib.rxjava.retry.RetryPolicy; import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; import retrofit2.Retrofit; import java.util.List; import java.util.Map; import java.util.Objects; -public class OpenAiEndpoint extends EmbeddingEndpoint { +public class OpenAiChatEndpoint extends Endpoint { private final OpenAiStreamService openAiStreamService = ApplicationContextHolder.getContext().getBean(OpenAiStreamService.class); @@ -26,6 +27,8 @@ public class OpenAiEndpoint extends EmbeddingEndpoint { private final Retrofit retrofit = RetrofitClientInstance.getInstance(); private final OpenAiService openAiService = retrofit.create(OpenAiService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String orgId; private String model; @@ -42,6 +45,8 @@ public class OpenAiEndpoint extends EmbeddingEndpoint { private String role; + private String input; + /** Log fields * */ private String chainName; @@ -49,22 +54,23 @@ public class OpenAiEndpoint extends EmbeddingEndpoint { private JsonnetLoader jsonnetLoader; - public OpenAiEndpoint() {} + public OpenAiChatEndpoint() {} - public OpenAiEndpoint(String url, String apiKey, String model) { + + public OpenAiChatEndpoint(String url, String apiKey, String model) { super(url, apiKey, null); this.model = model; } // For Embeddings.... - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String orgId, String model, RetryPolicy retryPolicy) { super(url, apiKey, retryPolicy); this.orgId = orgId; this.model = model; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String model, @@ -77,7 +83,7 @@ public OpenAiEndpoint( this.temperature = temperature; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String model, String role, Double temperature, Boolean stream) { super(url, apiKey, null); this.model = model; @@ -86,7 +92,7 @@ public OpenAiEndpoint( this.stream = stream; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String orgId, @@ -101,22 +107,7 @@ public OpenAiEndpoint( this.orgId = orgId; } - public OpenAiEndpoint( - String url, - String apiKey, - String model, - String role, - Double temperature, - Boolean stream, - RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); - this.model = model; - this.role = role; - this.temperature = temperature; - this.stream = stream; - } - - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String orgId, @@ -132,7 +123,7 @@ public OpenAiEndpoint( this.stream = stream; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String orgId, @@ -149,6 +140,18 @@ public OpenAiEndpoint( this.stream = stream; } + public String getRole() { + return role; + } + + public void setRole(String role) { + this.role = role; + } + + public void setInput(String input) { + this.input = input; + } + public String getModel() { return model; } @@ -201,6 +204,10 @@ public List getStop() { return stop; } + public String getInput() { + return input; + } + public void setStop(List stop) { this.stop = stop; } @@ -245,6 +252,10 @@ public List getChatMessages() { return chatMessages; } + public void setChatMessages(List chatMessages) { + this.chatMessages = chatMessages; + } + public void setJsonnetLoader(JsonnetLoader jsonnetLoader) { this.jsonnetLoader = jsonnetLoader; } @@ -267,24 +278,31 @@ public String getCallIdentifier() { public Observable chatCompletion( String input, String chainName, ArkRequest arkRequest) { - this.chatMessages = List.of(new ChatMessage(this.role, input)); - this.chainName = chainName; - return chatCompletion(arkRequest); + + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(List.of(new ChatMessage(this.role, input))); + mapper.setChainName(chainName); + + return chatCompletion(mapper,arkRequest); } public Observable chatCompletion( String input, String chainName, JsonnetLoader loader, ArkRequest arkRequest) { - this.chatMessages = List.of(new ChatMessage(this.role, input)); - this.chainName = chainName; - this.jsonnetLoader = loader; - return chatCompletion(arkRequest); + + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(List.of(new ChatMessage(this.role, input))); + mapper.setChainName(chainName); + mapper.setJsonnetLoader(loader); + + return chatCompletion(mapper,arkRequest); } public Observable chatCompletion( List chatMessages, String chainName, ArkRequest arkRequest) { - this.chainName = chainName; - this.chatMessages = chatMessages; - return chatCompletion(arkRequest); + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(chatMessages); + mapper.setChainName(chainName); + return chatCompletion(mapper,arkRequest); } public Observable chatCompletion( @@ -292,35 +310,23 @@ public Observable chatCompletion( String chainName, JsonnetLoader loader, ArkRequest arkRequest) { - this.chainName = chainName; - this.chatMessages = chatMessages; - this.jsonnetLoader = loader; - return chatCompletion(arkRequest); - } - - @Override - public Observable embeddings(String input, ArkRequest arkRequest) { - setRawText(input); - if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); - else this.callIdentifier = "URI wasn't provided"; + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(chatMessages); + mapper.setChainName(chainName); + mapper.setJsonnetLoader(loader); - return Observable.fromSingle( - openAiService - .embeddings(this) - .map( - embeddingResponse -> - new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); + return chatCompletion(mapper, arkRequest); } - private Observable chatCompletion(ArkRequest arkRequest) { + private Observable chatCompletion(OpenAiChatEndpoint mapper, ArkRequest arkRequest) { - if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); - else this.callIdentifier = "URI wasn't provided"; + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); if (Objects.nonNull(getStream()) && getStream()) return this.openAiStreamService - .chatCompletion(this) + .chatCompletion(mapper) .map( chatResponse -> { if (!Objects.isNull(chatResponse.getChoices().get(0).getFinishReason())) { @@ -328,6 +334,14 @@ private Observable chatCompletion(ArkRequest arkRequest) return chatResponse; } else return chatResponse; }); - else return Observable.fromSingle(this.openAiService.chatCompletion(this)); + else return Observable.fromSingle(this.openAiService.chatCompletion(mapper)); + } + + public Observable completion(String input, ArkRequest arkRequest) { + if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); + else this.callIdentifier = "URI wasn't provided"; + + this.input = input; + return Observable.fromSingle(this.openAiService.completion(this)); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/SupabaseEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/supabase/SupabaseEndpoint.java similarity index 96% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/SupabaseEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/supabase/SupabaseEndpoint.java index 7ff111733..22acc7921 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/SupabaseEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/supabase/SupabaseEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.supabase; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.retrofit.SupabaseService; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/WikiEndpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/wiki/WikiEndpoint.java similarity index 78% rename from FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/WikiEndpoint.java rename to FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/wiki/WikiEndpoint.java index a3f69d79c..f03cf381d 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/WikiEndpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/wiki/WikiEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.wiki; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.retrofit.WikiService; @@ -6,6 +6,7 @@ import com.edgechain.lib.rxjava.retry.RetryPolicy; import com.edgechain.lib.wiki.response.WikiResponse; import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; import retrofit2.Retrofit; public class WikiEndpoint extends Endpoint { @@ -13,6 +14,8 @@ public class WikiEndpoint extends Endpoint { private final Retrofit retrofit = RetrofitClientInstance.getInstance(); private final WikiService wikiService = retrofit.create(WikiService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String input; public WikiEndpoint() {} @@ -30,7 +33,8 @@ public void setInput(String input) { } public Observable getPageContent(String input) { - this.input = input; - return Observable.fromSingle(this.wikiService.getPageContent(this)); + WikiEndpoint mapper = modelMapper.map(this, WikiEndpoint.class); + mapper.setInput(input); + return Observable.fromSingle(this.wikiService.getPageContent(mapper)); } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java index c5b98a50b..6b5d84f75 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java @@ -1,6 +1,6 @@ package com.edgechain.lib.index.client.impl; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.index.request.pinecone.PineconeUpsert; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.response.StringResponse; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java index 810c99626..815834b7c 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java @@ -2,7 +2,7 @@ import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.repositories.PostgresClientMetadataRepository; import com.edgechain.lib.index.repositories.PostgresClientRepository; @@ -239,7 +239,8 @@ public EdgeChain> query(PostgresEndpoint postgresEn postgresEndpoint.getProbes(), postgresEndpoint.getMetric(), embeddings, - postgresEndpoint.getTopK()); + postgresEndpoint.getTopK(), + postgresEndpoint.getUpperLimit()); for (Map row : rows) { @@ -284,9 +285,9 @@ public EdgeChain> queryRRF(PostgresEndpoint postgre try { List wordEmbeddingsList = new ArrayList<>(); List> embeddings = - postgresEndpoint.getWordEmbeddingsList().stream() - .map(WordEmbeddings::getValues) - .toList(); + postgresEndpoint.getWordEmbeddingsList().stream() + .map(WordEmbeddings::getValues) + .toList(); List> rows = this.repository.queryRRF( @@ -302,6 +303,7 @@ public EdgeChain> queryRRF(PostgresEndpoint postgre postgresEndpoint.getProbes(), postgresEndpoint.getMetric(), postgresEndpoint.getTopK(), + postgresEndpoint.getUpperLimit(), postgresEndpoint.getOrderRRFBy()); for (Map row : rows) { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java index a8cc856fd..c436e1033 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java @@ -1,7 +1,7 @@ package com.edgechain.lib.index.client.impl; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; import com.edgechain.lib.index.responses.RedisDocument; import com.edgechain.lib.index.responses.RedisProperty; @@ -142,8 +142,7 @@ public EdgeChain> query(RedisEndpoint endpoint) { ArrayList properties = iterator.next().getProperties(); words2VecList.add( new WordEmbeddings( - properties.get(1).getId(), - String.valueOf(properties.get(0).get__values_score()))); + properties.get(1).getId(), properties.get(0).get__values_score())); } emitter.onNext(words2VecList); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java index cd2263fe8..080b176fe 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java @@ -6,6 +6,7 @@ import java.time.LocalDateTime; import java.util.List; +import java.util.StringJoiner; public class PostgresWordEmbeddings implements ArkObject { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java index 38bb4e456..5f340aed3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java @@ -1,6 +1,6 @@ package com.edgechain.lib.index.repositories; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.enums.PostgresDistanceMetric; import com.edgechain.lib.utils.FloatUtils; import com.github.f4b6a3.uuid.UuidCreator; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java index 7ff270b86..c8b08b0dc 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java @@ -1,7 +1,7 @@ package com.edgechain.lib.index.repositories; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.domain.RRFWeight; import com.edgechain.lib.index.enums.OrderRRFBy; import com.edgechain.lib.index.enums.PostgresDistanceMetric; @@ -171,7 +171,8 @@ public List> query( int probes, PostgresDistanceMetric metric, List> values, - int topK) { + int topK, + int upperLimit) { jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); @@ -233,118 +234,117 @@ public List> query( } if (values.size() > 1) { - return jdbcTemplate.queryForList( - String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query)); + return jdbcTemplate.queryForList(String.format( + "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER BY score DESC LIMIT %s;", + query, upperLimit)); } else { return jdbcTemplate.queryForList(query.toString()); } } public List> queryRRF( - String tableName, - String namespace, - String metadataTableName, - List> values, - RRFWeight textWeight, - RRFWeight similarityWeight, - RRFWeight dateWeight, - String searchQuery, - PostgresLanguage language, - int probes, - PostgresDistanceMetric metric, - int topK, - OrderRRFBy orderRRFBy) { + String tableName, + String namespace, + String metadataTableName, + List> values, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + String searchQuery, + PostgresLanguage language, + int probes, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + OrderRRFBy orderRRFBy) { jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); StringBuilder query = new StringBuilder(); - for (int i = 0; i < values.size(); i++) { + for(int i = 0; i < values.size(); i++) { String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i))); - query - .append("(") - .append( - "SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", - textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", - similarityWeight.getBaseWeight().getValue(), - similarityWeight.getFineTuneWeight())) - .append( - String.format( - "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", - dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) - .append("FROM ( ") - .append( - "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," - + " svtm.document_date, svtm.metadata, ") - .append( - String.format( - "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", - language.getValue(), searchQuery)); + query .append("(") + .append("SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", + textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", + similarityWeight.getBaseWeight().getValue(), similarityWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", + dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) + .append("FROM ( ") + .append( + "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," + + " svtm.document_date, svtm.metadata, ") + .append( + String.format( + "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", + language.getValue(), searchQuery)); switch (metric) { case COSINE -> query.append( - String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); + String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); case IP -> query.append( - String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); + String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings)); default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); } query - .append("CASE ") - .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling - .append( - "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" - + " svtm.document_date) ") - .append("END AS date_rank ") - .append("FROM ") - .append( - String.format( - "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s" - + " WHERE namespace = '%s'", - tableName, namespace)); + .append("CASE ") + .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling + .append( + "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" + + " svtm.document_date) ") + .append("END AS date_rank ") + .append("FROM ") + .append( + String.format( + "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s WHERE" + + " namespace = '%s'", + tableName, namespace)); switch (metric) { case COSINE -> query - .append(" ORDER BY embedding <=> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); case IP -> query - .append(" ORDER BY embedding <#> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); case L2 -> query - .append(" ORDER BY embedding <-> ") - .append("'") - .append(embeddings) - .append("'") - .append(" LIMIT ") - .append(topK); + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); default -> throw new IllegalArgumentException("Invalid metric: " + metric); } query - .append(")") - .append(" sv ") - .append("JOIN ") - .append(tableName.concat("_join_").concat(metadataTableName)) - .append(" jtm ON sv.id = jtm.id ") - .append("JOIN ") - .append(tableName.concat("_").concat(metadataTableName)) - .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") - .append(") subquery "); + .append(")") + .append(" sv ") + .append("JOIN ") + .append(tableName.concat("_join_").concat(metadataTableName)) + .append(" jtm ON sv.id = jtm.id ") + .append("JOIN ") + .append(tableName.concat("_").concat(metadataTableName)) + .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") + .append(") subquery "); switch (orderRRFBy) { case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC"); @@ -358,11 +358,13 @@ public List> queryRRF( if (i < values.size() - 1) { query.append(" UNION ALL ").append("\n"); } + } if (values.size() > 1) { - return jdbcTemplate.queryForList( - String.format("SELECT DISTINCT ON (result.id) *\n" + "FROM ( %s ) result;", query)); + return jdbcTemplate.queryForList(String.format( + "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER BY rrf_score DESC LIMIT %s;", + query, upperLimit)); } else { return jdbcTemplate.queryForList(query.toString()); } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java new file mode 100644 index 000000000..166d7c320 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java @@ -0,0 +1,188 @@ +package com.edgechain.lib.integration.airtable.client; + +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import com.edgechain.lib.integration.airtable.query.AirtableQueryBuilder; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import dev.fuxing.airtable.AirtableApi; +import dev.fuxing.airtable.AirtableRecord; +import dev.fuxing.airtable.AirtableTable; +import io.reactivex.rxjava3.core.Observable; +import org.springframework.stereotype.Service; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +@Service +public class AirtableClient { + + public EdgeChain> findAll(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + AirtableQueryBuilder airtableQueryBuilder = endpoint.getAirtableQueryBuilder(); + + AirtableTable.PaginationList list = + table.list( + querySpec -> { + int maxRecords = airtableQueryBuilder.getMaxRecords(); + int pageSize = airtableQueryBuilder.getPageSize(); + String sortField = airtableQueryBuilder.getSortField(); + String sortDirection = airtableQueryBuilder.getSortDirection(); + String offset = airtableQueryBuilder.getOffset(); + List fields = airtableQueryBuilder.getFields(); + String filterByFormula = airtableQueryBuilder.getFilterByFormula(); + String view = airtableQueryBuilder.getView(); + String cellFormat = airtableQueryBuilder.getCellFormat(); + String timeZone = airtableQueryBuilder.getTimeZone(); + String userLocale = airtableQueryBuilder.getUserLocale(); + + querySpec.maxRecords(maxRecords); + querySpec.pageSize(pageSize); + + if (sortField != null && sortDirection != null) { + querySpec.sort(sortField, sortDirection); + } + + if (offset != null) { + querySpec.offset(offset); + } + + if (fields != null) { + querySpec.fields(fields); + } + + if (filterByFormula != null) { + querySpec.filterByFormula(filterByFormula); + } + + if (view != null) { + querySpec.view(view); + } + + if (cellFormat != null) { + querySpec.cellFormat(cellFormat); + } + + if (timeZone != null) { + querySpec.timeZone(timeZone); + } + + if (userLocale != null) { + querySpec.userLocale(userLocale); + } + }); + + Map mapper = new HashMap<>(); + mapper.put("data", list); + mapper.put("offset", list.getOffset()); + + emitter.onNext(mapper); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain findById(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + String id = endpoint.getIds().get(0); + + if (Objects.isNull(id) || id.isEmpty()) + throw new RuntimeException("Id cannot be null"); + + AirtableRecord record = table.get(id); + + if (Objects.isNull(record)) + throw new RuntimeException("Id: " + id + " does not exist"); + + emitter.onNext(record); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain> create(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + List airtableRecordList = + table.post(endpoint.getAirtableRecordList(), endpoint.isTypecast()); + + emitter.onNext(airtableRecordList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain> update(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + List airtableRecordList = + table.put(endpoint.getAirtableRecordList(), endpoint.isTypecast()); + + emitter.onNext(airtableRecordList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain> delete(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + List deleteIdsList = table.delete(endpoint.getIds()); + + emitter.onNext(deleteIdsList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java new file mode 100644 index 000000000..377660442 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java @@ -0,0 +1,127 @@ +package com.edgechain.lib.integration.airtable.query; + +import dev.fuxing.airtable.formula.AirtableFormula; +import dev.fuxing.airtable.formula.AirtableFunction; +import dev.fuxing.airtable.formula.AirtableOperator; + +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +public class AirtableQueryBuilder { + private String offset; + private List fields; + private String filterByFormula; + private int maxRecords = 100; + private int pageSize = 100; + private String sortField; + private String sortDirection; + private String view; + private String cellFormat; + private String timeZone; + private String userLocale; + + public void offset(String offset) { + this.offset = offset; + } + + public void fields(String... fields) { + this.fields = Arrays.asList(fields); + } + + public void filterByFormula(String formula) { + this.filterByFormula = formula; + } + + public void filterByFormula(AirtableFunction function, AirtableFormula.Object... objects){ + this.filterByFormula = function.apply(objects); + } + + public void filterByFormula(AirtableOperator operator, AirtableFormula.Object left, AirtableFormula.Object right, AirtableFormula.Object... others) { + this.filterByFormula = operator.apply(left, right, others); + } + + public void maxRecords(int maxRecords) { + this.maxRecords = maxRecords; + } + + public void pageSize(int pageSize) { + this.pageSize = pageSize; + } + + public void sort(String field, String direction) { + this.sortField = field; + this.sortDirection = direction; + } + + public void view(String view) { + this.view = view; + } + + public void cellFormat(String cellFormat) { + this.cellFormat = cellFormat; + } + + public void timeZone(String timeZone) { + this.timeZone = timeZone; + } + + public void timeZone(ZoneId zoneId) { + this.timeZone = zoneId.getId(); + } + + public void userLocale(String userLocale){ + this.userLocale = userLocale; + } + + public void userLocale(Locale locale) { + this.userLocale = locale.toLanguageTag().toLowerCase(); + } + + // Getters for QuerySpec fields + public String getOffset() { + return offset; + } + + public List getFields() { + return fields; + } + + public String getFilterByFormula() { + return filterByFormula; + } + + public int getMaxRecords() { + return maxRecords; + } + + public int getPageSize() { + return pageSize; + } + + public String getSortField() { + return sortField; + } + + public String getSortDirection() { + return sortDirection; + } + + public String getView() { + return view; + } + + public String getCellFormat() { + return cellFormat; + } + + public String getTimeZone() { + return timeZone; + } + + public String getUserLocale() { + return userLocale; + } + +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java new file mode 100644 index 000000000..7dadae115 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java @@ -0,0 +1,25 @@ +package com.edgechain.lib.integration.airtable.query; + +public enum SortOrder { + ASC("asc"), + DESC("desc"); + + private final String value; + + SortOrder(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static SortOrder fromValue(String value) { + for (SortOrder sortOrder : SortOrder.values()) { + if (sortOrder.value.equalsIgnoreCase(value)) { + return sortOrder; + } + } + throw new IllegalArgumentException("Invalid SortOrder value: " + value); + } +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java index c9fcf20eb..7dbb647d0 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java @@ -47,7 +47,7 @@ public JsonnetLoader(int threshold, String f1, String f2) { if (threshold >= 1 && threshold < 100) { this.threshold = threshold; this.splitSize = - String.valueOf(threshold).concat("-").concat(String.valueOf((100 - threshold))); + String.valueOf(threshold).concat("-").concat(String.valueOf((100 - threshold))); } else throw new RuntimeException("Threshold has to be b/w 1 and 100"); } @@ -70,7 +70,7 @@ public void load(InputStream inputStream) { // Create Temp File With Unique Name String filename = - RandomStringUtils.randomAlphanumeric(12) + "_" + System.currentTimeMillis() + ".jsonnet"; + RandomStringUtils.randomAlphanumeric(12) + "_" + System.currentTimeMillis() + ".jsonnet"; File file = new File(System.getProperty("java.io.tmpdir") + File.separator + filename); BufferedReader br = new BufferedReader(new InputStreamReader(inputStream)); @@ -98,16 +98,16 @@ public void load(InputStream inputStream) { xtraArgsMap.put(entry.getKey(), entry.getValue().getVal().replaceAll(regex, "")); } else if (entry.getValue().getDataType().equals(DataType.INTEGER) - || entry.getValue().getDataType().equals(DataType.BOOLEAN)) { + || entry.getValue().getDataType().equals(DataType.BOOLEAN)) { xtraArgsMap.put(entry.getKey(), entry.getValue().getVal()); } } var res = - Transformer.builder(text) - .withLibrary(new XtraSonnetCustomFunc()) - .build() - .transform(serializeXtraArgs(xtraArgsMap)); + Transformer.builder(text) + .withLibrary(new XtraSonnetCustomFunc()) + .build() + .transform(serializeXtraArgs(xtraArgsMap)); // Get the String Output & Transform it into JsonnetSchema this.metadata = res; @@ -191,4 +191,4 @@ public int getThreshold() { public String getSplitSize() { return splitSize; } -} +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java index ca3d0665b..33120fc99 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java @@ -1,6 +1,6 @@ package com.edgechain.lib.jsonnet; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.github.jam01.xtrasonnet.DataFormatService; import io.github.jam01.xtrasonnet.header.Header; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java index 3fa15aee3..56e0975f4 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java @@ -3,7 +3,8 @@ import com.edgechain.lib.constants.EndpointConstants; import com.edgechain.lib.embeddings.request.OpenAiEmbeddingRequest; import com.edgechain.lib.embeddings.response.OpenAiEmbeddingResponse; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.request.ChatCompletionRequest; import com.edgechain.lib.openai.request.CompletionRequest; import com.edgechain.lib.openai.response.ChatCompletionResponse; @@ -24,130 +25,130 @@ @Service public class OpenAiClient { - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final RestTemplate restTemplate = new RestTemplate(); - - public EdgeChain createChatCompletion( - ChatCompletionRequest request, OpenAiEndpoint endpoint) { - - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - - logger.info("Logging ChatCompletion...."); - - // Create headers - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setBearerAuth(endpoint.getApiKey()); - - if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { - headers.set("OpenAI-Organization", endpoint.getOrgId()); - } - HttpEntity entity = new HttpEntity<>(request, headers); - - logger.info(String.valueOf(entity.getBody())); - - // Send the POST request - ResponseEntity response = - restTemplate.exchange( - endpoint.getUrl(), HttpMethod.POST, entity, ChatCompletionResponse.class); - - emitter.onNext(Objects.requireNonNull(response.getBody())); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } - - public EdgeChain createChatCompletionStream( - ChatCompletionRequest request, OpenAiEndpoint endpoint) { - - try { - logger.info("Logging ChatCompletion Stream...."); - logger.info(request.toString()); - - return new EdgeChain<>( - RxJava3Adapter.fluxToObservable( - WebClient.builder() - .build() - .post() - .uri(EndpointConstants.OPENAI_CHAT_COMPLETION_API) - .accept(MediaType.TEXT_EVENT_STREAM) - .headers( - httpHeaders -> { - httpHeaders.setContentType(MediaType.APPLICATION_JSON); - httpHeaders.setBearerAuth(endpoint.getApiKey()); - if (Objects.nonNull(endpoint.getOrgId()) - && !endpoint.getOrgId().isEmpty()) { - httpHeaders.set("OpenAI-Organization", endpoint.getOrgId()); - } - }) - .bodyValue(new ObjectMapper().writeValueAsString(request)) - .retrieve() - .bodyToFlux(ChatCompletionResponse.class)), - endpoint); - } catch (final Exception e) { - throw new RuntimeException(e); + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); + + public EdgeChain createChatCompletion( + ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + logger.info("Logging ChatCompletion...."); + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(endpoint.getApiKey()); + + if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { + headers.set("OpenAI-Organization", endpoint.getOrgId()); + } + HttpEntity entity = new HttpEntity<>(request, headers); + + logger.info(String.valueOf(entity.getBody())); + + // Send the POST request + ResponseEntity response = + restTemplate.exchange( + endpoint.getUrl(), HttpMethod.POST, entity, ChatCompletionResponse.class); + + emitter.onNext(Objects.requireNonNull(response.getBody())); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + + public EdgeChain createChatCompletionStream( + ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { + + try { + logger.info("Logging ChatCompletion Stream...."); + logger.info(request.toString()); + + return new EdgeChain<>( + RxJava3Adapter.fluxToObservable( + WebClient.builder() + .build() + .post() + .uri(EndpointConstants.OPENAI_CHAT_COMPLETION_API) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers( + httpHeaders -> { + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.setBearerAuth(endpoint.getApiKey()); + if (Objects.nonNull(endpoint.getOrgId()) + && !endpoint.getOrgId().isEmpty()) { + httpHeaders.set("OpenAI-Organization", endpoint.getOrgId()); + } + }) + .bodyValue(new ObjectMapper().writeValueAsString(request)) + .retrieve() + .bodyToFlux(ChatCompletionResponse.class)), + endpoint); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + public EdgeChain createCompletion( + CompletionRequest request, OpenAiChatEndpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(endpoint.getApiKey()); + if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { + headers.set("OpenAI-Organization", endpoint.getOrgId()); + } + HttpEntity entity = new HttpEntity<>(request, headers); + + ResponseEntity response = + this.restTemplate.exchange( + endpoint.getUrl(), HttpMethod.POST, entity, CompletionResponse.class); + emitter.onNext(Objects.requireNonNull(response.getBody())); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + + public EdgeChain createEmbeddings( + OpenAiEmbeddingRequest request, OpenAiEmbeddingEndpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(endpoint.getApiKey()); + if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { + headers.set("OpenAI-Organization", endpoint.getOrgId()); + } + HttpEntity entity = new HttpEntity<>(request, headers); + + ResponseEntity response = + this.restTemplate.exchange( + endpoint.getUrl(), HttpMethod.POST, entity, OpenAiEmbeddingResponse.class); + + emitter.onNext(Objects.requireNonNull(response.getBody())); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); } - } - - public EdgeChain createCompletion( - CompletionRequest request, OpenAiEndpoint endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setBearerAuth(endpoint.getApiKey()); - if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { - headers.set("OpenAI-Organization", endpoint.getOrgId()); - } - HttpEntity entity = new HttpEntity<>(request, headers); - - ResponseEntity response = - this.restTemplate.exchange( - endpoint.getUrl(), HttpMethod.POST, entity, CompletionResponse.class); - emitter.onNext(Objects.requireNonNull(response.getBody())); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } - - public EdgeChain createEmbeddings( - OpenAiEmbeddingRequest request, OpenAiEndpoint endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setBearerAuth(endpoint.getApiKey()); - if (Objects.nonNull(endpoint.getOrgId()) && !endpoint.getOrgId().isEmpty()) { - headers.set("OpenAI-Organization", endpoint.getOrgId()); - } - HttpEntity entity = new HttpEntity<>(request, headers); - - ResponseEntity response = - this.restTemplate.exchange( - endpoint.getUrl(), HttpMethod.POST, entity, OpenAiEmbeddingResponse.class); - - emitter.onNext(Objects.requireNonNull(response.getBody())); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } -} +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java index e304a4232..52353f076 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java @@ -1,15 +1,15 @@ package com.edgechain.lib.openai.providers; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.request.CompletionRequest; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; /** Going to be removed * */ public class OpenAiCompletionProvider { - private final OpenAiEndpoint endpoint; + private final OpenAiChatEndpoint endpoint; - public OpenAiCompletionProvider(OpenAiEndpoint endpoint) { + public OpenAiCompletionProvider(OpenAiChatEndpoint endpoint) { this.endpoint = endpoint; } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java new file mode 100644 index 000000000..c3a8db88a --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java @@ -0,0 +1,33 @@ +package com.edgechain.lib.retrofit; + +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import dev.fuxing.airtable.AirtableRecord; +import io.reactivex.rxjava3.core.Single; +import retrofit2.http.Body; +import retrofit2.http.HTTP; +import retrofit2.http.POST; + +import java.util.List; +import java.util.Map; + +public interface AirtableService { + + @POST("airtable/findAll") + Single> findAll(@Body AirtableEndpoint endpoint); + + @POST("airtable/findById") + Single findById(@Body AirtableEndpoint endpoint); + + @POST("airtable/create") + Single> create(@Body AirtableEndpoint endpoint); + + @POST("airtable/update") + Single> update(@Body AirtableEndpoint endpoint); + + @HTTP(method = "DELETE", path = "airtable/delete", hasBody = true) + Single> delete(@Body AirtableEndpoint endpoint); + + +} + + diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java index 3f8302378..ad7526fe5 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java @@ -1,7 +1,7 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.POST; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java index a0cfaff77..96e09ba1b 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java @@ -2,7 +2,7 @@ import com.edgechain.lib.embeddings.miniLLM.response.MiniLMResponse; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.POST; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java index c1bd2e116..814b815a1 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java @@ -1,9 +1,10 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.response.OpenAiEmbeddingResponse; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.response.ChatCompletionResponse; -import io.reactivex.rxjava3.core.Completable; +import com.edgechain.lib.openai.response.CompletionResponse; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.POST; @@ -11,11 +12,11 @@ public interface OpenAiService { @POST(value = "openai/chat-completion") - Single chatCompletion(@Body OpenAiEndpoint openAiEndpoint); + Single chatCompletion(@Body OpenAiChatEndpoint OpenAiChatEndpoint); @POST(value = "openai/completion") - Single completion(@Body OpenAiEndpoint openAiEndpoint); + Single completion(@Body OpenAiChatEndpoint openAiChatEndpoint); @POST(value = "openai/embeddings") - Single embeddings(@Body OpenAiEndpoint openAiEndpoint); + Single embeddings(@Body OpenAiEmbeddingEndpoint openAiEmbeddingEndpoint); } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java index bdfd3e6a9..0a74c27dc 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java @@ -1,7 +1,7 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java index b1b8325d9..d4b28f3dc 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java @@ -2,7 +2,7 @@ import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import retrofit2.http.*; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java index d46a07f67..a8701d952 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java @@ -1,6 +1,6 @@ package com.edgechain.lib.retrofit; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Single; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java index bc327e916..faa4843e0 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java @@ -3,7 +3,7 @@ import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import retrofit2.http.*; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java index ff6ce3259..b29ba59f7 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java @@ -1,7 +1,7 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java index d5b2d33bf..6b24c965f 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java @@ -1,6 +1,6 @@ package com.edgechain.lib.retrofit; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.wiki.response.WikiResponse; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java index 9e8a93677..a1366dd2f 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java @@ -2,7 +2,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.configuration.domain.SecurityUUID; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.response.ChatCompletionResponse; import com.edgechain.lib.utils.JsonUtils; import io.reactivex.rxjava3.core.Observable; @@ -21,7 +21,7 @@ public class OpenAiStreamService { @Autowired private SecurityUUID securityUUID; - public Observable chatCompletion(OpenAiEndpoint openAiEndpoint) { + public Observable chatCompletion(OpenAiChatEndpoint endpoint) { return RxJava3Adapter.fluxToObservable( WebClient.builder() @@ -39,8 +39,10 @@ public Observable chatCompletion(OpenAiEndpoint openAiEn httpHeaders.set("stream", "true"); httpHeaders.set("Authorization", securityUUID.getAuthKey()); }) - .bodyValue(JsonUtils.convertToString(openAiEndpoint)) + .bodyValue(JsonUtils.convertToString(endpoint)) .retrieve() .bodyToFlux(ChatCompletionResponse.class)); } + + } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java index 2d3127d22..b0fa2921d 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java @@ -114,7 +114,6 @@ private static JacksonConverterFactory createJacksonFactory() { objectMapper.registerModule(new Jdk8Module()); objectMapper.registerModule(new PageJacksonModule()); objectMapper.registerModule(new SortJacksonModule()); - objectMapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); return JacksonConverterFactory.create(objectMapper); } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java index 7e3f42cb6..e5f02d3c3 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java @@ -172,10 +172,12 @@ public Single toSingle() { if (RetryUtils.available(endpoint)) return this.observable - .subscribeOn(Schedulers.io()) - .retryWhen(endpoint.getRetryPolicy()) - .firstOrError(); + .subscribeOn(Schedulers.io()) + .retryWhen(endpoint.getRetryPolicy()) + .firstOrError(); + else return this.observable.subscribeOn(Schedulers.io()).firstOrError(); + } public Single toSingleWithoutScheduler() { diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java new file mode 100644 index 000000000..4f441b9de --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java @@ -0,0 +1,47 @@ +package com.edgechain.lib.utils; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.index.domain.PostgresWordEmbeddings; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +@Component +public class ContextReorder { + + public List reorderWordEmbeddings(List wordEmbeddingsList) { + + wordEmbeddingsList.sort(Comparator.comparingDouble(WordEmbeddings::getScore).reversed()); + + int mid = wordEmbeddingsList.size() / 2; + + List modifiedList = new ArrayList<>(wordEmbeddingsList.subList(0, mid)); + + List secondHalfList = wordEmbeddingsList.subList(mid, wordEmbeddingsList.size()); + secondHalfList.sort(Comparator.comparingDouble(WordEmbeddings::getScore)); + + modifiedList.addAll(secondHalfList); + + return modifiedList; + } + + public List reorderPostgresWordEmbeddings(List postgresWordEmbeddings) { + + postgresWordEmbeddings.sort(Comparator.comparingDouble(PostgresWordEmbeddings::getScore).reversed()); + + int mid = postgresWordEmbeddings.size() / 2; + + List modifiedList = new ArrayList<>( postgresWordEmbeddings.subList(0, mid)); + + List secondHalfList = postgresWordEmbeddings.subList(mid, postgresWordEmbeddings.size()); + secondHalfList.sort(Comparator.comparingDouble(PostgresWordEmbeddings::getScore)); + + modifiedList.addAll(secondHalfList); + + return modifiedList; + } + +} + diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java index bd647cd72..677b543c4 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java @@ -1,6 +1,6 @@ package com.edgechain.lib.wiki.client; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import com.edgechain.lib.wiki.response.WikiResponse; import com.fasterxml.jackson.databind.JsonNode; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java index 6d2820c49..57e23b21a 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java @@ -3,7 +3,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.bgeSmall.BgeSmallClient; import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.services.EmbeddingLogService; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java index cf5d8d16b..fce8831c0 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java @@ -4,7 +4,7 @@ import com.edgechain.lib.context.client.impl.PostgreSQLHistoryContextClient; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import org.springframework.beans.factory.annotation.Autowired; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java index ce97d689a..1a4c4a402 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java @@ -4,7 +4,7 @@ import com.edgechain.lib.context.client.impl.RedisHistoryContextClient; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import org.springframework.beans.factory.annotation.Autowired; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java index 6e7b9107c..e488f7ed6 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java @@ -2,7 +2,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.index.client.impl.PineconeClient; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Single; @@ -15,25 +15,26 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/index/pinecone") public class PineconeController { - @Autowired private PineconeClient pineconeClient; + @Autowired + private PineconeClient pineconeClient; - @PostMapping("/upsert") - public Single upsert(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.upsert(pineconeEndpoint).toSingle(); - } + @PostMapping("/upsert") + public Single upsert(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.upsert(pineconeEndpoint).toSingle(); + } - @PostMapping("/batch-upsert") - public Single batchUpsert(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.batchUpsert(pineconeEndpoint).toSingleWithoutScheduler(); - } + @PostMapping("/batch-upsert") + public Single batchUpsert(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.batchUpsert(pineconeEndpoint).toSingleWithoutScheduler(); + } - @PostMapping("/query") - public Single> query(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.query(pineconeEndpoint).toSingle(); - } + @PostMapping("/query") + public Single> query(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.query(pineconeEndpoint).toSingle(); + } - @DeleteMapping("/deleteAll") - public Single deleteAll(@RequestBody PineconeEndpoint pineconeEndpoint) { - return pineconeClient.deleteAll(pineconeEndpoint).toSingle(); - } + @DeleteMapping("/deleteAll") + public Single deleteAll(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.deleteAll(pineconeEndpoint).toSingle(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java index babbd7a90..5b034468d 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java @@ -1,7 +1,7 @@ package com.edgechain.service.controllers.index; import com.edgechain.lib.configuration.WebConfiguration; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.client.impl.PostgresClient; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.response.StringResponse; @@ -18,87 +18,89 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/index/postgres") public class PostgresController { - @Autowired @Lazy private PostgresClient postgresClient; - - @PostMapping("/create-table") - public Single createTable(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.createTable(postgresEndpoint).toSingle(); - } - - @PostMapping("/metadata/create-table") - public Single createMetadataTable( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.createMetadataTable(postgresEndpoint).toSingle(); - } - - @PostMapping("/upsert") - public Single upsert(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.upsert(postgresEndpoint).toSingle(); - } - - @PostMapping("/batch-upsert") - public Single> batchUpsert(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.batchUpsert(postgresEndpoint).toSingleWithoutScheduler(); - } - - @PostMapping("/metadata/insert") - public Single insertMetadata(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.insertMetadata(postgresEndpoint).toSingle(); - } - - @PostMapping("/metadata/batch-insert") - public Single> batchInsertMetadata( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.batchInsertMetadata(postgresEndpoint).toSingle(); - } - - @PostMapping("/join/insert") - public Single insertIntoJoinTable( - @RequestBody PostgresEndpoint postgresEndpoint) { - EdgeChain edgeChain = this.postgresClient.insertIntoJoinTable(postgresEndpoint); - return edgeChain.toSingle(); - } - - @PostMapping("/join/batch-insert") - public Single batchInsertIntoJoinTable( - @RequestBody PostgresEndpoint postgresEndpoint) { - EdgeChain edgeChain = - this.postgresClient.batchInsertIntoJoinTable(postgresEndpoint); - return edgeChain.toSingle(); - } - - @PostMapping("/query") - public Single> query( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.query(postgresEndpoint).toSingle(); - } - - @PostMapping("/query-rrf") - public Single> queryRRF( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.queryRRF(postgresEndpoint).toSingle(); - } - - @PostMapping("/metadata/query") - public Single> queryWithMetadata( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.queryWithMetadata(postgresEndpoint).toSingle(); - } - - @PostMapping("/chunks") - public Single> getAllChunks( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.getAllChunks(postgresEndpoint).toSingle(); - } - - @PostMapping("/similarity-metadata") - public Single> getSimilarMetadataChunk( - @RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.getSimilarMetadataChunk(postgresEndpoint).toSingle(); - } - - @DeleteMapping("/deleteAll") - public Single deleteAll(@RequestBody PostgresEndpoint postgresEndpoint) { - return this.postgresClient.deleteAll(postgresEndpoint).toSingle(); - } + @Autowired + @Lazy + private PostgresClient postgresClient; + + @PostMapping("/create-table") + public Single createTable(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.createTable(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/create-table") + public Single createMetadataTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.createMetadataTable(postgresEndpoint).toSingle(); + } + + @PostMapping("/upsert") + public Single upsert(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.upsert(postgresEndpoint).toSingle(); + } + + @PostMapping("/batch-upsert") + public Single> batchUpsert(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.batchUpsert(postgresEndpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/metadata/insert") + public Single insertMetadata(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.insertMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/batch-insert") + public Single> batchInsertMetadata( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.batchInsertMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/join/insert") + public Single insertIntoJoinTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + EdgeChain edgeChain = this.postgresClient.insertIntoJoinTable(postgresEndpoint); + return edgeChain.toSingle(); + } + + @PostMapping("/join/batch-insert") + public Single batchInsertIntoJoinTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + EdgeChain edgeChain = + this.postgresClient.batchInsertIntoJoinTable(postgresEndpoint); + return edgeChain.toSingle(); + } + + @PostMapping("/query") + public Single> query( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.query(postgresEndpoint).toSingle(); + } + + @PostMapping("/query-rrf") + public Single> queryRRF( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryRRF(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/query") + public Single> queryWithMetadata( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryWithMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/chunks") + public Single> getAllChunks( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.getAllChunks(postgresEndpoint).toSingle(); + } + + @PostMapping("/similarity-metadata") + public Single> getSimilarMetadataChunk( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.getSimilarMetadataChunk(postgresEndpoint).toSingle(); + } + + @DeleteMapping("/deleteAll") + public Single deleteAll(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.deleteAll(postgresEndpoint).toSingle(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java index 2ca1873cd..4090164b7 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java @@ -2,7 +2,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.client.impl.RedisClient; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Completable; @@ -17,30 +17,32 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/index/redis") public class RedisController { - @Autowired @Lazy private RedisClient redisClient; - - @PostMapping("/create-index") - public Single createIndex(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.createIndex(redisEndpoint).toSingle(); - } - - @PostMapping("/upsert") - public Single upsert(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.upsert(redisEndpoint).toSingle(); - } - - @PostMapping("/batch-upsert") - public Single batchUpsert(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.batchUpsert(redisEndpoint).toSingleWithoutScheduler(); - } - - @PostMapping("/query") - public Single> query(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.query(redisEndpoint).toSingle(); - } - - @DeleteMapping("/delete") - public Completable deleteByPattern(@RequestBody RedisEndpoint redisEndpoint) { - return this.redisClient.deleteByPattern(redisEndpoint).await(); - } + @Autowired + @Lazy + private RedisClient redisClient; + + @PostMapping("/create-index") + public Single createIndex(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.createIndex(redisEndpoint).toSingle(); + } + + @PostMapping("/upsert") + public Single upsert(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.upsert(redisEndpoint).toSingle(); + } + + @PostMapping("/batch-upsert") + public Single batchUpsert(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.batchUpsert(redisEndpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/query") + public Single> query(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.query(redisEndpoint).toSingle(); + } + + @DeleteMapping("/delete") + public Completable deleteByPattern(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.deleteByPattern(redisEndpoint).await(); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java new file mode 100644 index 000000000..835dea9f0 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java @@ -0,0 +1,42 @@ +package com.edgechain.service.controllers.integration; + +import com.edgechain.lib.configuration.WebConfiguration; +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import com.edgechain.lib.integration.airtable.client.AirtableClient; +import dev.fuxing.airtable.AirtableRecord; +import io.reactivex.rxjava3.core.Single; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; + +import java.util.List; +import java.util.Map; + +@RestController("Service AirtableController") +@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/airtable") +public class AirtableController { + + @Autowired private AirtableClient airtableClient; + + @PostMapping("/findAll") + public Single> findAll(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.findAll(endpoint).toSingleWithoutScheduler(); + } + @PostMapping("/findById") + public Single findById(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.findById(endpoint).toSingleWithoutScheduler(); + } + @PostMapping("/create") + public Single> create(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.create(endpoint).toSingleWithoutScheduler(); + } + @PostMapping("/update") + public Single> update(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.update(endpoint).toSingleWithoutScheduler(); + } + @DeleteMapping("/delete") + public Single> delete(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.delete(endpoint).toSingleWithoutScheduler(); + } + + +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java index 59ce390d8..9ab3ea75e 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java @@ -3,7 +3,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.miniLLM.MiniLMClient; import com.edgechain.lib.embeddings.miniLLM.response.MiniLMResponse; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.services.EmbeddingLogService; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java index e92d97727..d9fa9d1fe 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java @@ -3,7 +3,8 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.request.OpenAiEmbeddingRequest; import com.edgechain.lib.embeddings.response.OpenAiEmbeddingResponse; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.logger.entities.ChatCompletionLog; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.entities.JsonnetLog; @@ -42,259 +43,259 @@ @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/openai") public class OpenAiController { - @Autowired private ChatCompletionLogService chatCompletionLogService; - @Autowired private EmbeddingLogService embeddingLogService; - @Autowired private JsonnetLogService jsonnetLogService; - - @Autowired private Environment env; - @Autowired private OpenAiClient openAiClient; - - @PostMapping(value = "/chat-completion") - public Single chatCompletion(@RequestBody OpenAiEndpoint openAiEndpoint) { - - ChatCompletionRequest chatCompletionRequest = - ChatCompletionRequest.builder() - .model(openAiEndpoint.getModel()) - .temperature(openAiEndpoint.getTemperature()) - .messages(openAiEndpoint.getChatMessages()) - .stream(false) - .topP(openAiEndpoint.getTopP()) - .n(openAiEndpoint.getN()) - .stop(openAiEndpoint.getStop()) - .presencePenalty(openAiEndpoint.getPresencePenalty()) - .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) - .logitBias(openAiEndpoint.getLogitBias()) - .user(openAiEndpoint.getUser()) - .build(); - - EdgeChain edgeChain = - openAiClient.createChatCompletion(chatCompletionRequest, openAiEndpoint); - - if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - - ChatCompletionLog chatLog = new ChatCompletionLog(); - chatLog.setName(openAiEndpoint.getChainName()); - chatLog.setCreatedAt(LocalDateTime.now()); - chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); - chatLog.setModel(chatCompletionRequest.getModel()); - - chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); - chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); - chatLog.setTopP(chatCompletionRequest.getTopP()); - chatLog.setN(chatCompletionRequest.getN()); - chatLog.setTemperature(chatCompletionRequest.getTemperature()); - - return edgeChain - .doOnNext( - c -> { - chatLog.setPromptTokens(c.getUsage().getPrompt_tokens()); - chatLog.setTotalTokens(c.getUsage().getTotal_tokens()); - chatLog.setContent(c.getChoices().get(0).getMessage().getContent()); - chatLog.setType(c.getObject()); - - chatLog.setCompletedAt(LocalDateTime.now()); - - Duration duration = - Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); - chatLog.setLatency(duration.toMillis()); - - chatCompletionLogService.saveOrUpdate(chatLog); - - if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) - && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { - JsonnetLog jsonnetLog = new JsonnetLog(); - jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); - jsonnetLog.setContent(c.getChoices().get(0).getMessage().getContent()); - jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); - jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); - jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); - jsonnetLog.setCreatedAt(LocalDateTime.now()); - jsonnetLog.setSelectedFile(openAiEndpoint.getJsonnetLoader().getSelectedFile()); - jsonnetLogService.saveOrUpdate(jsonnetLog); - } - }) - .toSingle(); - - } else return edgeChain.toSingle(); - } - - @PostMapping( - value = "/chat-completion-stream", - consumes = {MediaType.APPLICATION_JSON_VALUE}) - public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoint) { - - ChatCompletionRequest chatCompletionRequest = - ChatCompletionRequest.builder() - .model(openAiEndpoint.getModel()) - .temperature(openAiEndpoint.getTemperature()) - .messages(openAiEndpoint.getChatMessages()) - .stream(true) - .topP(openAiEndpoint.getTopP()) - .n(openAiEndpoint.getN()) - .stop(openAiEndpoint.getStop()) - .presencePenalty(openAiEndpoint.getPresencePenalty()) - .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) - .logitBias(openAiEndpoint.getLogitBias()) - .user(openAiEndpoint.getUser()) - .build(); - SseEmitter emitter = new SseEmitter(); - ExecutorService executorService = Executors.newSingleThreadExecutor(); - - executorService.execute( - () -> { - try { - EdgeChain edgeChain = - openAiClient.createChatCompletionStream(chatCompletionRequest, openAiEndpoint); - - AtomInteger chunks = AtomInteger.of(0); - - if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - - ChatCompletionLog chatLog = new ChatCompletionLog(); - chatLog.setName(openAiEndpoint.getChainName()); - chatLog.setCreatedAt(LocalDateTime.now()); - chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); - chatLog.setModel(chatCompletionRequest.getModel()); - - chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); - chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); - chatLog.setTopP(chatCompletionRequest.getTopP()); - chatLog.setN(chatCompletionRequest.getN()); - chatLog.setTemperature(chatCompletionRequest.getTemperature()); - - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append("<|im_start|>"); - - for (ChatMessage chatMessage : openAiEndpoint.getChatMessages()) { - stringBuilder.append(chatMessage.getContent()); - } - EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); - Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); - - chatLog.setPromptTokens((long) enc.countTokens(stringBuilder.toString())); - - StringBuilder content = new StringBuilder(); - - Observable obs = edgeChain.getScheduledObservable(); - - obs.subscribe( - res -> { - try { + @Autowired private ChatCompletionLogService chatCompletionLogService; + @Autowired private EmbeddingLogService embeddingLogService; + @Autowired private JsonnetLogService jsonnetLogService; + + @Autowired private Environment env; + @Autowired private OpenAiClient openAiClient; + + @PostMapping(value = "/chat-completion") + public Single chatCompletion(@RequestBody OpenAiChatEndpoint openAiEndpoint) { + + ChatCompletionRequest chatCompletionRequest = + ChatCompletionRequest.builder() + .model(openAiEndpoint.getModel()) + .temperature(openAiEndpoint.getTemperature()) + .messages(openAiEndpoint.getChatMessages()) + .stream(false) + .topP(openAiEndpoint.getTopP()) + .n(openAiEndpoint.getN()) + .stop(openAiEndpoint.getStop()) + .presencePenalty(openAiEndpoint.getPresencePenalty()) + .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) + .logitBias(openAiEndpoint.getLogitBias()) + .user(openAiEndpoint.getUser()) + .build(); + + EdgeChain edgeChain = + openAiClient.createChatCompletion(chatCompletionRequest, openAiEndpoint); + + if (Objects.nonNull(env.getProperty("postgres.db.host"))) { + + ChatCompletionLog chatLog = new ChatCompletionLog(); + chatLog.setName(openAiEndpoint.getChainName()); + chatLog.setCreatedAt(LocalDateTime.now()); + chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); + chatLog.setModel(chatCompletionRequest.getModel()); + + chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); + chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); + chatLog.setTopP(chatCompletionRequest.getTopP()); + chatLog.setN(chatCompletionRequest.getN()); + chatLog.setTemperature(chatCompletionRequest.getTemperature()); + + return edgeChain + .doOnNext( + c -> { + chatLog.setPromptTokens(c.getUsage().getPrompt_tokens()); + chatLog.setTotalTokens(c.getUsage().getTotal_tokens()); + chatLog.setContent(c.getChoices().get(0).getMessage().getContent()); + chatLog.setType(c.getObject()); + + chatLog.setCompletedAt(LocalDateTime.now()); + + Duration duration = + Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); + chatLog.setLatency(duration.toMillis()); + + chatCompletionLogService.saveOrUpdate(chatLog); + + if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) + && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { + JsonnetLog jsonnetLog = new JsonnetLog(); + jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); + jsonnetLog.setContent(c.getChoices().get(0).getMessage().getContent()); + jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); + jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); + jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); + jsonnetLog.setCreatedAt(LocalDateTime.now()); + jsonnetLog.setSelectedFile(openAiEndpoint.getJsonnetLoader().getSelectedFile()); + jsonnetLogService.saveOrUpdate(jsonnetLog); + } + }) + .toSingle(); + + } else return edgeChain.toSingle(); + } - emitter.send(res); - - chunks.incrementAndGet(); - content.append(res.getChoices().get(0).getMessage().getContent()); - - if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { - - emitter.complete(); - chatLog.setType(res.getObject()); - chatLog.setContent(content.toString()); - chatLog.setCompletedAt(LocalDateTime.now()); - chatLog.setTotalTokens(chunks.get() + chatLog.getPromptTokens()); - - Duration duration = - Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); - chatLog.setLatency(duration.toMillis()); - - chatCompletionLogService.saveOrUpdate(chatLog); - - if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) - && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { - JsonnetLog jsonnetLog = new JsonnetLog(); - jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); - jsonnetLog.setContent(content.toString()); - jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); - jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); - jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); - jsonnetLog.setCreatedAt(LocalDateTime.now()); - jsonnetLog.setSelectedFile( - openAiEndpoint.getJsonnetLoader().getSelectedFile()); - jsonnetLogService.saveOrUpdate(jsonnetLog); + @PostMapping( + value = "/chat-completion-stream", + consumes = {MediaType.APPLICATION_JSON_VALUE}) + public SseEmitter chatCompletionStream(@RequestBody OpenAiChatEndpoint openAiEndpoint) { + + ChatCompletionRequest chatCompletionRequest = + ChatCompletionRequest.builder() + .model(openAiEndpoint.getModel()) + .temperature(openAiEndpoint.getTemperature()) + .messages(openAiEndpoint.getChatMessages()) + .stream(true) + .topP(openAiEndpoint.getTopP()) + .n(openAiEndpoint.getN()) + .stop(openAiEndpoint.getStop()) + .presencePenalty(openAiEndpoint.getPresencePenalty()) + .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) + .logitBias(openAiEndpoint.getLogitBias()) + .user(openAiEndpoint.getUser()) + .build(); + SseEmitter emitter = new SseEmitter(); + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + executorService.execute( + () -> { + try { + EdgeChain edgeChain = + openAiClient.createChatCompletionStream(chatCompletionRequest, openAiEndpoint); + + AtomInteger chunks = AtomInteger.of(0); + + if (Objects.nonNull(env.getProperty("postgres.db.host"))) { + + ChatCompletionLog chatLog = new ChatCompletionLog(); + chatLog.setName(openAiEndpoint.getChainName()); + chatLog.setCreatedAt(LocalDateTime.now()); + chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); + chatLog.setModel(chatCompletionRequest.getModel()); + + chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); + chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); + chatLog.setTopP(chatCompletionRequest.getTopP()); + chatLog.setN(chatCompletionRequest.getN()); + chatLog.setTemperature(chatCompletionRequest.getTemperature()); + + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("<|im_start|>"); + + for (ChatMessage chatMessage : openAiEndpoint.getChatMessages()) { + stringBuilder.append(chatMessage.getContent()); + } + EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); + Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); + + chatLog.setPromptTokens((long) enc.countTokens(stringBuilder.toString())); + + StringBuilder content = new StringBuilder(); + + Observable obs = edgeChain.getScheduledObservable(); + + obs.subscribe( + res -> { + try { + + emitter.send(res); + + chunks.incrementAndGet(); + content.append(res.getChoices().get(0).getMessage().getContent()); + + if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { + + emitter.complete(); + chatLog.setType(res.getObject()); + chatLog.setContent(content.toString()); + chatLog.setCompletedAt(LocalDateTime.now()); + chatLog.setTotalTokens(chunks.get() + chatLog.getPromptTokens()); + + Duration duration = + Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); + chatLog.setLatency(duration.toMillis()); + + chatCompletionLogService.saveOrUpdate(chatLog); + + if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) + && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { + JsonnetLog jsonnetLog = new JsonnetLog(); + jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); + jsonnetLog.setContent(content.toString()); + jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); + jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); + jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); + jsonnetLog.setCreatedAt(LocalDateTime.now()); + jsonnetLog.setSelectedFile( + openAiEndpoint.getJsonnetLoader().getSelectedFile()); + jsonnetLogService.saveOrUpdate(jsonnetLog); + } + } + + } catch (final Exception e) { + emitter.completeWithError(e); + } + }); + } else { + + Observable obs = edgeChain.getScheduledObservable(); + obs.subscribe( + res -> { + try { + emitter.send(res); + if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { + emitter.complete(); + } + + } catch (final Exception e) { + emitter.completeWithError(e); + } + }); } - } } catch (final Exception e) { - emitter.completeWithError(e); + emitter.completeWithError(e); } - }); - } else { + }); - Observable obs = edgeChain.getScheduledObservable(); - obs.subscribe( - res -> { - try { - emitter.send(res); - if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { - emitter.complete(); - } + executorService.shutdown(); + return emitter; + } - } catch (final Exception e) { - emitter.completeWithError(e); - } - }); - } - - } catch (final Exception e) { - emitter.completeWithError(e); - } - }); - - executorService.shutdown(); - return emitter; - } - - @PostMapping("/completion") - public Single completion(@RequestBody OpenAiEndpoint openAiEndpoint) { - - CompletionRequest completionRequest = - CompletionRequest.builder() - .prompt(openAiEndpoint.getRawText()) - .model(openAiEndpoint.getModel()) - .temperature(openAiEndpoint.getTemperature()) - .build(); - - EdgeChain edgeChain = - openAiClient.createCompletion(completionRequest, openAiEndpoint); - - return edgeChain.toSingle(); - } - - @PostMapping("/embeddings") - public Single embeddings(@RequestBody OpenAiEndpoint openAiEndpoint) - throws SQLException { - - EdgeChain edgeChain = - openAiClient.createEmbeddings( - new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()), - openAiEndpoint); - - if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - - EmbeddingLog embeddingLog = new EmbeddingLog(); - embeddingLog.setCreatedAt(LocalDateTime.now()); - embeddingLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - embeddingLog.setModel(openAiEndpoint.getModel()); - - return edgeChain - .doOnNext( - e -> { - embeddingLog.setPromptTokens(e.getUsage().getPrompt_tokens()); - embeddingLog.setCompletedAt(LocalDateTime.now()); - embeddingLog.setTotalTokens(e.getUsage().getTotal_tokens()); - - Duration duration = - Duration.between(embeddingLog.getCreatedAt(), embeddingLog.getCompletedAt()); - embeddingLog.setLatency(duration.toMillis()); - - embeddingLogService.saveOrUpdate(embeddingLog); - }) - .toSingleWithoutScheduler(); + @PostMapping("/completion") + public Single completion(@RequestBody OpenAiChatEndpoint openAiEndpoint) { + + CompletionRequest completionRequest = + CompletionRequest.builder() + .prompt(openAiEndpoint.getInput()) + .model(openAiEndpoint.getModel()) + .temperature(openAiEndpoint.getTemperature()) + .build(); + + EdgeChain edgeChain = + openAiClient.createCompletion(completionRequest, openAiEndpoint); + + return edgeChain.toSingle(); } - return edgeChain.toSingleWithoutScheduler(); - } -} + @PostMapping("/embeddings") + public Single embeddings(@RequestBody OpenAiEmbeddingEndpoint openAiEndpoint) + throws SQLException { + + EdgeChain edgeChain = + openAiClient.createEmbeddings( + new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()), + openAiEndpoint); + + if (Objects.nonNull(env.getProperty("postgres.db.host"))) { + + EmbeddingLog embeddingLog = new EmbeddingLog(); + embeddingLog.setCreatedAt(LocalDateTime.now()); + embeddingLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + embeddingLog.setModel(openAiEndpoint.getModel()); + + return edgeChain + .doOnNext( + e -> { + embeddingLog.setPromptTokens(e.getUsage().getPrompt_tokens()); + embeddingLog.setCompletedAt(LocalDateTime.now()); + embeddingLog.setTotalTokens(e.getUsage().getTotal_tokens()); + + Duration duration = + Duration.between(embeddingLog.getCreatedAt(), embeddingLog.getCompletedAt()); + embeddingLog.setLatency(duration.toMillis()); + + embeddingLogService.saveOrUpdate(embeddingLog); + }) + .toSingleWithoutScheduler(); + } + + return edgeChain.toSingleWithoutScheduler(); + } +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java index f7ef9660b..f1d394e8b 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java @@ -1,7 +1,7 @@ package com.edgechain.service.controllers.wiki; import com.edgechain.lib.configuration.WebConfiguration; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import com.edgechain.lib.wiki.client.WikiClient; import com.edgechain.lib.wiki.response.WikiResponse; diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java index 978b58089..a9db8fe88 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java @@ -1,7 +1,9 @@ package com.edgechain; import org.junit.jupiter.api.Test; +import org.modelmapper.ModelMapper; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; @SpringBootTest class EdgeChainApplicationTest { diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java index 931cd7f9f..8fa7e4c3f 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java @@ -1,8 +1,10 @@ package com.edgechain.lib.endpoint.impl; import com.edgechain.lib.configuration.domain.SecurityUUID; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; import java.io.File; + import org.junit.jupiter.api.Test; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.util.ReflectionTestUtils; @@ -13,6 +15,7 @@ class BgeSmallEndpointTest { @Test @DirtiesContext + void downloadFiles() { // Retrofit needs a port System.setProperty("server.port", "8888"); @@ -42,6 +45,7 @@ void downloadFiles() { ReflectionTestUtils.setField(RetrofitClientInstance.class, "securityUUID", null); ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); + deleteFiles(); // make sure we clean up files afterwards } } @@ -55,4 +59,4 @@ private static void deleteFiles() { File tokenizerFile = new File(BgeSmallEndpoint.TOKENIZER_PATH); tokenizerFile.delete(); } -} +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java index 2c11a420e..5178bfd63 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java @@ -1,7 +1,7 @@ package com.edgechain.lib.index.client.impl; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; import com.edgechain.lib.index.enums.PostgresLanguage; @@ -121,7 +121,7 @@ private void createMetadataTable() { private String upsert() { WordEmbeddings we = new WordEmbeddings(); we.setId("WE1"); - we.setScore("0.86914713"); + we.setScore(0.86914713); we.setValues(List.of(0.25f, 0.5f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); @@ -163,12 +163,12 @@ private String insertMetadata() { private void batchUpsert() { WordEmbeddings we1 = new WordEmbeddings(); we1.setId("WE1"); - we1.setScore("101"); + we1.setScore(1.05689); we1.setValues(List.of(0.25f, 0.5f)); WordEmbeddings we2 = new WordEmbeddings(); we2.setId("WE2"); - we2.setScore("202"); + we2.setScore(2.02689); we2.setValues(List.of(0.75f, 0.9f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); @@ -258,7 +258,7 @@ private void query_noMeta() { private void query_noMeta_metric(PostgresDistanceMetric metric) { WordEmbeddings we1 = new WordEmbeddings(); we1.setId("WEQUERY"); - we1.setScore("104"); + we1.setScore(1.05589); we1.setValues(List.of(0.25f, 0.5f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); @@ -268,6 +268,7 @@ private void query_noMeta_metric(PostgresDistanceMetric metric) { when(mockPe.getMetric()).thenReturn(metric); when(mockPe.getWordEmbeddingsList()).thenReturn(List.of(we1)); when(mockPe.getTopK()).thenReturn(5); + when(mockPe.getUpperLimit()).thenReturn(5); when(mockPe.getMetadataTableNames()).thenReturn(null); final Data data = new Data(); @@ -295,7 +296,7 @@ private void query_meta() { private void query_meta_metric(PostgresDistanceMetric metric) { WordEmbeddings we1 = new WordEmbeddings(); we1.setId("WEQUERY"); - we1.setScore("104"); + we1.setScore(1.258); we1.setValues(List.of(0.25f, 0.5f)); PostgresEndpoint mockPe = mock(PostgresEndpoint.class); @@ -305,6 +306,7 @@ private void query_meta_metric(PostgresDistanceMetric metric) { when(mockPe.getMetric()).thenReturn(metric); when(mockPe.getWordEmbedding()).thenReturn(we1); when(mockPe.getTopK()).thenReturn(5); + when(mockPe.getUpperLimit()).thenReturn(5); when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); final Data data = new Data(); diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java index cf8c8be36..9b8ac08aa 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java @@ -1,6 +1,6 @@ package com.edgechain.openai; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.request.ChatCompletionRequest; import com.edgechain.lib.openai.request.ChatMessage; import com.edgechain.lib.openai.response.ChatCompletionResponse; @@ -88,15 +88,15 @@ public void testOpenAiClient_ChatCompletionResponse_ShouldMappedToPOJO(Class } @Test - @DisplayName("Test OpenAiEndpoint With Retry Mechanism") + @DisplayName("Test OpenAiChatEndpoint With Retry Mechanism") @Order(3) public void testOpenAiClient_WithRetryMechanism_ShouldThrowExceptionWithRetry(TestInfo testInfo) throws InterruptedException { System.out.println("======== " + testInfo.getDisplayName() + " ========"); - OpenAiEndpoint endpoint = - new OpenAiEndpoint( + OpenAiChatEndpoint endpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, "", // apiKey "", // orgId @@ -120,7 +120,7 @@ public void testOpenAiClient_WithRetryMechanism_ShouldThrowExceptionWithRetry(Te } @Test - @DisplayName("Test OpenAiEndpoint With No Retry Mechanism") + @DisplayName("Test OpenAiChatEndpoint With No Retry Mechanism") @Order(4) public void testOpenAiClient_WithNoRetryMechanism_ShouldThrowExceptionWithNoRetry( TestInfo testInfo) throws InterruptedException { @@ -128,8 +128,8 @@ public void testOpenAiClient_WithNoRetryMechanism_ShouldThrowExceptionWithNoRetr System.out.println("======== " + testInfo.getDisplayName() + " ========"); // Step 1 : Create OpenAi Endpoint - OpenAiEndpoint endpoint = - new OpenAiEndpoint( + OpenAiChatEndpoint endpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, "", // apiKey "", // orgId diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java index 616192713..f16c5d69a 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java @@ -1,6 +1,6 @@ package com.edgechain.pinecone; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.index.client.impl.PineconeClient; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; @@ -30,7 +30,7 @@ public class PineconeClientTest { @BeforeEach void setUp() { System.setProperty("server.port", String.valueOf(port)); - pineconeEndpoint = new PineconeEndpoint("https://arakoo.ai", "apiKey", "Pinecone"); + pineconeEndpoint = new PineconeEndpoint("https://arakoo.ai", "apiKey", "Pinecone",null); } @Test @@ -104,4 +104,4 @@ public void test_GetNamespace() { String nullNamespace = pineconeClient.getNamespace(pineconeEndpoint); assertEquals("", nullNamespace); } -} +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java index 858bb9574..012c7dc16 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java @@ -1,7 +1,7 @@ package com.edgechain.postgres; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.enums.PostgresDistanceMetric; import com.edgechain.lib.index.repositories.PostgresClientMetadataRepository; import org.junit.jupiter.api.BeforeEach; diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java index 1e4559bdf..d1f947021 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java @@ -2,6 +2,7 @@ import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import org.modelmapper.ModelMapper; import org.springframework.context.ApplicationContext; import org.springframework.test.util.ReflectionTestUtils; import retrofit2.Retrofit; @@ -53,6 +54,15 @@ public Retrofit setupRetrofit() { return mockRetrofit; } + private ModelMapper setupModelMapper() { + ModelMapper mockModelMapper = mock(ModelMapper.class); + ReflectionTestUtils.setField(ModelMapper.class,"modelMapper", mockModelMapper); + // Retrofit needs a valid port + prevServerPort = System.getProperty("server.port"); + System.setProperty("server.port", "8888"); + return mockModelMapper; + } + /** * Erases the current Retrofit instance so it can be recreated. Call this once in an @AfterEach * teardown method. diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java index e587d4c56..56777ee4d 100644 --- a/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java @@ -1,7 +1,7 @@ package com.edgechain.wiki; import com.edgechain.lib.configuration.domain.SecurityUUID; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; import com.edgechain.lib.wiki.response.WikiResponse; import io.reactivex.rxjava3.observers.TestObserver; @@ -64,4 +64,4 @@ void wikiControllerTest_TestWikiContentMethod_HandlesException(TestInfo testInfo ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); } } -} +} \ No newline at end of file