diff --git a/.DS_Store b/.DS_Store index bf4f94f08..23e40cb48 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/JS/.DS_Store b/JS/.DS_Store index c3843fb64..8b74a4062 100644 Binary files a/JS/.DS_Store and b/JS/.DS_Store differ diff --git a/Java/.DS_Store b/Java/.DS_Store new file mode 100644 index 000000000..d2f94edf4 Binary files /dev/null and b/Java/.DS_Store differ diff --git a/Java/Examples/.DS_Store b/Java/Examples/.DS_Store index 07e27bb69..e6c911b09 100644 Binary files a/Java/Examples/.DS_Store and b/Java/Examples/.DS_Store differ diff --git a/Java/Examples/airtable/AirtableExample.java b/Java/Examples/airtable/AirtableExample.java new file mode 100644 index 000000000..6cd5c2cdf --- /dev/null +++ b/Java/Examples/airtable/AirtableExample.java @@ -0,0 +1,134 @@ +package com.edgechain; + +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import com.edgechain.lib.integration.airtable.query.AirtableQueryBuilder; +import com.edgechain.lib.integration.airtable.query.SortOrder; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.ArkResponse; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import dev.fuxing.airtable.AirtableRecord; +import dev.fuxing.airtable.formula.AirtableFormula; +import dev.fuxing.airtable.formula.LogicalOperator; +import org.json.JSONObject; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.builder.SpringApplicationBuilder; +import org.springframework.web.bind.annotation.*; + +import java.util.Properties; + +/** + * For the purpose, of this example, create a simple table using AirTable i.e, "Speakers" Following + * are some basic fields: "speaker_name", "designation", "organization", "biography", + * "speaker_photo", "rating To get, your BASE_ID of specific database; use the following API + * https://api.airtable.com/v0/meta/bases --header 'Authorization: Bearer + * YOUR_PERSONAL_ACCESS_TOKEN' You can create, complex tables using Airtable; also define + * relationships b/w tables via lookups. + */ +@SpringBootApplication +public class AirtableExample { + private static final String AIRTABLE_API_KEY = ""; + private static final String AIRTABLE_BASE_ID = ""; + + private static AirtableEndpoint airtableEndpoint; + + public static void main(String[] args) { + + System.setProperty("server.port", "8080"); + + Properties properties = new Properties(); + properties.setProperty("cors.origins", "http://localhost:4200"); + + new SpringApplicationBuilder(AirtableExample.class).properties(properties).run(args); + + airtableEndpoint = new AirtableEndpoint(AIRTABLE_BASE_ID, AIRTABLE_API_KEY); + } + + @RestController + @RequestMapping("/airtable") + public class AirtableController { + + @GetMapping("/findAll") + public ArkResponse findAll(ArkRequest arkRequest) { + + int pageSize = arkRequest.getIntQueryParam("pageSize"); + String sortSpeakerName = arkRequest.getQueryParam("sortName"); + String offset = arkRequest.getQueryParam("offset"); + + AirtableQueryBuilder queryBuilder = new AirtableQueryBuilder(); + queryBuilder.pageSize(pageSize); // pageSize --> no. of records in each request + queryBuilder.sort("speaker_name", SortOrder.fromValue(sortSpeakerName).getValue()); + queryBuilder.offset(offset); // move to next page by passing offset returned in response; + + // Return only those speakers which have rating Greater Than Eq to 3; + queryBuilder.filterByFormula( + LogicalOperator.GTE, + AirtableFormula.Object.field("rating"), + AirtableFormula.Object.value(3)); + + return new EdgeChain<>(airtableEndpoint.findAll("Speakers", queryBuilder)).getArkResponse(); + } + + @GetMapping("/find") + public ArkResponse findById(ArkRequest arkRequest) { + String id = arkRequest.getQueryParam("id"); + return new EdgeChain<>(airtableEndpoint.findById("Speakers", id)).getArkResponse(); + } + + @PostMapping("/create") + public ArkResponse create(ArkRequest arkRequest) { + + JSONObject body = arkRequest.getBody(); + String speakerName = body.getString("name"); + String designation = body.getString("designation"); + int rating = body.getInt("rating"); + String organization = body.getString("organization"); + String biography = body.getString("biography"); + + // Airtable API doesn't allow to upload blob files directly; therefore, you would require to + // upload it + // to some cloud storage i.e, S3 and then set the URL in Airtable. + + AirtableRecord record = new AirtableRecord(); + record.putField("speaker_name", speakerName); + record.putField("designation", designation); + record.putField("rating", rating); + record.putField("organization", organization); + record.putField("biography", biography); + + return new EdgeChain<>(airtableEndpoint.create("Speakers", record)).getArkResponse(); + } + + @PostMapping("/update") + public ArkResponse update(ArkRequest arkRequest) { + + JSONObject body = arkRequest.getBody(); + String id = body.getString("id"); + String speakerName = body.getString("name"); + String designation = body.getString("designation"); + int rating = body.getInt("rating"); + String organization = body.getString("organization"); + String biography = body.getString("biography"); + + // Airtable API doesn't allow to upload blob files directly; therefore, you would require to + // upload it + // to some cloud storage i.e, S3 and then set the URL in Airtable. + + AirtableRecord record = new AirtableRecord(); + record.setId(id); + record.putField("speaker_name", speakerName); + record.putField("designation", designation); + record.putField("rating", rating); + record.putField("organization", organization); + record.putField("biography", biography); + + return new EdgeChain<>(airtableEndpoint.update("Speakers", record)).getArkResponse(); + } + + @DeleteMapping("/delete") + public ArkResponse delete(ArkRequest arkRequest) { + JSONObject body = arkRequest.getBody(); + String id = body.getString("id"); + return new EdgeChain<>(airtableEndpoint.delete("Speakers", id)).getArkResponse(); + } + } +} diff --git a/Java/Examples/code-interpreter/CodeInterpreter.java b/Java/Examples/code-interpreter/CodeInterpreter.java index 084876a39..cce993b95 100644 --- a/Java/Examples/code-interpreter/CodeInterpreter.java +++ b/Java/Examples/code-interpreter/CodeInterpreter.java @@ -4,6 +4,7 @@ import java.util.concurrent.TimeUnit; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import org.json.JSONException; import org.json.JSONObject; import org.springframework.boot.autoconfigure.SpringBootApplication; @@ -13,7 +14,6 @@ import org.springframework.web.bind.annotation.RestController; import com.edgechain.lib.codeInterpreter.Eval; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -28,7 +28,7 @@ public class CodeInterpreter { private static final String OPENAI_AUTH_KEY = ""; - private static OpenAiEndpoint userChatEndpoint; + private static OpenAiChatEndpoint userChatEndpoint; private static final ObjectMapper objectMapper = new ObjectMapper(); private static JsonnetLoader loader = new FileJsonnetLoader("./code-interpreter/code-interpreter.jsonnet"); @@ -48,7 +48,7 @@ public double interpret(ArkRequest arkRequest) throws JSONException { JSONObject json = arkRequest.getBody(); userChatEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, "gpt-3.5-turbo", diff --git a/Java/Examples/htmx-ui-demo/ChatMessage.java b/Java/Examples/htmx-ui-demo/ChatMessage.java index 2f77db2c5..8d0c4b71c 100644 --- a/Java/Examples/htmx-ui-demo/ChatMessage.java +++ b/Java/Examples/htmx-ui-demo/ChatMessage.java @@ -1,30 +1,30 @@ -package com.edgechain; - -public class ChatMessage { - String role; - String content; - - public ChatMessage(String role, String content) { - this.role = role; - this.content = content; - } - - public ChatMessage() {} - - public String getRole() { - return role; - } - - public String getContent() { - return content; - } - - public void setContent(String content) { - this.content = content; - } - - @Override - public String toString() { - return "ChatMessage{" + "role='" + role + '\'' + ", content='" + content + '\'' + '}'; - } -} +package com.edgechain; + +public class ChatMessage { + String role; + String content; + + public ChatMessage(String role, String content) { + this.role = role; + this.content = content; + } + + public ChatMessage() {} + + public String getRole() { + return role; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + @Override + public String toString() { + return "ChatMessage{" + "role='" + role + '\'' + ", content='" + content + '\'' + '}'; + } +} diff --git a/Java/Examples/htmx-ui-demo/User.java b/Java/Examples/htmx-ui-demo/User.java index a0a76a979..d2ea16502 100644 --- a/Java/Examples/htmx-ui-demo/User.java +++ b/Java/Examples/htmx-ui-demo/User.java @@ -1,5 +1,4 @@ -package com.edgechain; -public class User { +class User { public String email; public String password; diff --git a/Java/Examples/json/JsonFormat.java b/Java/Examples/json/JsonFormat.java index d1d7f5aaf..2e2423138 100644 --- a/Java/Examples/json/JsonFormat.java +++ b/Java/Examples/json/JsonFormat.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import org.json.JSONObject; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; @@ -17,7 +18,6 @@ import org.springframework.web.client.RestTemplate; import com.edgechain.lib.constants.EndpointConstants; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; import com.edgechain.lib.jsonFormat.request.FunctionRequest; import com.edgechain.lib.jsonFormat.request.Message; import com.edgechain.lib.jsonFormat.request.OpenApiFunctionRequest; @@ -41,7 +41,7 @@ public class JsonFormat { // need only for situation endpoint private static final String OPENAI_ORG_ID = ""; - private static OpenAiEndpoint userChatEndpoint; + private static OpenAiChatEndpoint userChatEndpoint; private static JsonnetLoader loader = new FileJsonnetLoader("./json/json-format.jsonnet"); private static JsonnetLoader functionLoader = new FileJsonnetLoader("./json/function.jsonnet"); private static final ObjectMapper objectMapper = new ObjectMapper(); @@ -84,7 +84,7 @@ public class ExampleController { public String extract(ArkRequest arkRequest) { userChatEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, "gpt-3.5-turbo", @@ -129,8 +129,8 @@ public String situation(ArkRequest arkRequest) { JSONObject json = arkRequest.getBody(); - OpenAiEndpoint userChat = - new OpenAiEndpoint( + OpenAiChatEndpoint userChat = + new OpenAiChatEndpoint( EndpointConstants.OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, OPENAI_ORG_ID, diff --git a/Java/Examples/pinecone/PineconeExample.java b/Java/Examples/pinecone/PineconeExample.java index 8f98764e6..fda4e6163 100644 --- a/Java/Examples/pinecone/PineconeExample.java +++ b/Java/Examples/pinecone/PineconeExample.java @@ -4,12 +4,13 @@ import static com.edgechain.lib.constants.EndpointConstants.OPENAI_EMBEDDINGS_API; import com.edgechain.lib.chains.PineconeRetrieval; -import com.edgechain.lib.chains.Retrieval; import com.edgechain.lib.context.domain.HistoryContext; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; + +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -19,13 +20,11 @@ import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.response.ArkResponse; import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; -import com.edgechain.lib.rxjava.retry.impl.FixedDelay; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import java.io.IOException; import java.io.InputStream; import java.util.*; import java.util.concurrent.TimeUnit; -import java.util.stream.IntStream; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; @@ -34,22 +33,19 @@ @SpringBootApplication public class PineconeExample { - private static final String OPENAI_AUTH_KEY = ""; - private static final String PINECONE_AUTH_KEY = ""; - private static final String PINECONE_QUERY_API = ""; - private static final String PINECONE_UPSERT_API = ""; - private static final String PINECONE_DELETE = ""; - - private static OpenAiEndpoint ada002Embedding; - private static OpenAiEndpoint gpt3Endpoint; + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID - private static PineconeEndpoint upsertPineconeEndpoint; - private static PineconeEndpoint queryPineconeEndpoint; + private static final String PINECONE_AUTH_KEY = ""; + private static final String PINECONE_API = ""; // Only API + private static OpenAiChatEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; - private static PineconeEndpoint deletePineconeEndpoint; + private static PineconeEndpoint pineconeEndpoint; private static RedisHistoryContextEndpoint contextEndpoint; + // It's recommended to perform localized instantiation for thread-safe approach. private JsonnetLoader queryLoader = new FileJsonnetLoader("./pinecone/pinecone-query.jsonnet"); private JsonnetLoader chatLoader = new FileJsonnetLoader("./pinecone/pinecone-chat.jsonnet"); @@ -65,7 +61,7 @@ public static void main(String[] args) { // Redis Configuration properties.setProperty("redis.url", ""); - properties.setProperty("redis.port", "12285"); + properties.setProperty("redis.port", ""); properties.setProperty("redis.username", "default"); properties.setProperty("redis.password", ""); properties.setProperty("redis.ttl", "3600"); @@ -78,36 +74,41 @@ public static void main(String[] args) { new SpringApplicationBuilder(PineconeExample.class).properties(properties).run(args); - // Variables Initialization ==> Endpoints must be intialized in main method... - ada002Embedding = - new OpenAiEndpoint( - OPENAI_EMBEDDINGS_API, + gpt3Endpoint = + new OpenAiChatEndpoint( + OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - gpt3Endpoint = - new OpenAiEndpoint( + gpt3StreamEndpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, + OPENAI_ORG_ID, "gpt-3.5-turbo", "user", 0.7, + true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - upsertPineconeEndpoint = - new PineconeEndpoint( - PINECONE_UPSERT_API, - PINECONE_AUTH_KEY, + OpenAiEmbeddingEndpoint ada002 = + new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); - queryPineconeEndpoint = + pineconeEndpoint = new PineconeEndpoint( - PINECONE_QUERY_API, PINECONE_AUTH_KEY, new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); - - deletePineconeEndpoint = - new PineconeEndpoint( - PINECONE_DELETE, PINECONE_AUTH_KEY, new FixedDelay(4, 5, TimeUnit.SECONDS)); + PINECONE_API, + PINECONE_AUTH_KEY, + ada002, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); @@ -157,46 +158,27 @@ public class PineconeController { // Namespace is optional (if not provided, it will be using Empty String "") @PostMapping("/pinecone/upsert") // /v1/examples/openai/upsert?namespace=machine-learning public void upsertPinecone(ArkRequest arkRequest) throws IOException { - String namespace = arkRequest.getQueryParam("namespace"); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - - // Configure Pinecone - upsertPineconeEndpoint.setNamespace(namespace); - String[] arr = pdfReader.readByChunkSize(file, 512); + PineconeRetrieval retrieval = + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); - /** - * Retrieval Class is basically used to generate embeddings & upsert it to VectorDB; If OpenAI - * Embedding Endpoint is not provided; then Doc2Vec constructor is used If the model is not - * provided, then it will emit an error - */ - Retrieval retrieval = - new PineconeRetrieval(upsertPineconeEndpoint, ada002Embedding, arkRequest); - - IntStream.range(0, arr.length).parallel().forEach(i -> retrieval.upsert(arr[i])); + retrieval.upsert(); } @PostMapping(value = "/pinecone/query") public ArkResponse query(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); + String namespace = arkRequest.getQueryParam("namespace"); - // Configure Pinecone - queryPineconeEndpoint.setNamespace(namespace); - - // Step 1: Chain ==> Get Embeddings From Input & Then Query To Pinecone - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Step 2: Chain ==> Query Embeddings from Pinecone + // Chain 1 ==> Query Embeddings from Pinecone EdgeChain> queryChain = - new EdgeChain<>(queryPineconeEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(pineconeEndpoint.query(query, namespace, topK, arkRequest)); - // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to - // OpenAI + // Chain ===> Our queryFn passes takes list and passes each response with base prompt EdgeChain> gpt3Chain = queryChain.transform(wordEmbeddings -> queryFn(wordEmbeddings, arkRequest)); @@ -215,14 +197,8 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - - // Configure Pinecone - queryPineconeEndpoint.setNamespace(namespace); - - // Configure GPT3endpoint - gpt3Endpoint.setStream(stream); + String namespace = arkRequest.getQueryParam("namespace"); // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -238,22 +214,19 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1 ==> Get Embeddings From Input - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from Pinecone & Then concatenate it (preparing for prompt) - // let's say topK=5; then we concatenate List into a string using String.join method + // Chain 1 ==> Query Embeddings from Pinecone & Then concatenate it (preparing for prompt) EdgeChain> pineconeChain = - new EdgeChain<>(queryPineconeEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(pineconeEndpoint.query(query, namespace, topK, arkRequest)); - // Chain 3 ===> Transform String of Queries into List + // Chain 2 ===> Transform String of Queries into List + // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain queryChain = new EdgeChain<>(pineconeChain) .transform( pineconeResponse -> { + List wordEmbeddings = pineconeResponse.get(); List queryList = new ArrayList<>(); - pineconeResponse.get().forEach(q -> queryList.add(q.getId())); + wordEmbeddings.forEach(q -> queryList.add(q.getId())); return String.join("\n", queryList); }); @@ -261,16 +234,16 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "PineconeChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -287,8 +260,13 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response - and once GPT chain indicates that it is finished, we will save the following into Postgres + and once GPT chain indicates that it is finished, we will save the following into Redis Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory */ @@ -318,8 +296,7 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) { @DeleteMapping("/pinecone/deleteAll") public ArkResponse deletePinecone(ArkRequest arkRequest) { String namespace = arkRequest.getQueryParam("namespace"); - deletePineconeEndpoint.setNamespace(namespace); - return new EdgeChain<>(deletePineconeEndpoint.deleteAll()).getArkResponse(); + return new EdgeChain<>(pineconeEndpoint.deleteAll(namespace)).getArkResponse(); } public List queryFn( diff --git a/Java/Examples/postgresql/PostgreSQLExample.java b/Java/Examples/postgresql/PostgreSQLExample.java index 2c5e0b36c..7b1782e3c 100644 --- a/Java/Examples/postgresql/PostgreSQLExample.java +++ b/Java/Examples/postgresql/PostgreSQLExample.java @@ -4,12 +4,14 @@ import static com.edgechain.lib.constants.EndpointConstants.OPENAI_EMBEDDINGS_API; import com.edgechain.lib.chains.PostgresRetrieval; -import com.edgechain.lib.chains.Retrieval; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.*; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -26,7 +28,6 @@ import java.io.InputStream; import java.util.*; import java.util.concurrent.TimeUnit; -import java.util.stream.IntStream; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.SpringBootApplication; @@ -36,13 +37,14 @@ @SpringBootApplication public class PostgreSQLExample { - private static final String OPENAI_AUTH_KEY = ""; - - private static OpenAiEndpoint ada002Embedding; - private static OpenAiEndpoint gpt3Endpoint; + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID + private static OpenAiChatEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; private static PostgresEndpoint postgresEndpoint; private static PostgreSQLHistoryContextEndpoint contextEndpoint; + // For thread safe, instantitate it in methods... private JsonnetLoader queryLoader = new FileJsonnetLoader("./postgres/postgres-query.jsonnet"); private JsonnetLoader chatLoader = new FileJsonnetLoader("./postgres/postgres-chat.jsonnet"); @@ -62,30 +64,48 @@ public static void main(String[] args) { // If you want to use PostgreSQL only; then just provide dbHost, dbUsername & dbPassword. // If you haven't specified PostgreSQL, then logs won't be stored. properties.setProperty("postgres.db.host", ""); - properties.setProperty("postgres.db.username", "postgres"); + properties.setProperty("postgres.db.username", ""); properties.setProperty("postgres.db.password", ""); new SpringApplicationBuilder(PostgreSQLExample.class).properties(properties).run(args); - // Variables Initialization ==> Endpoints must be intialized in main method... - ada002Embedding = - new OpenAiEndpoint( - OPENAI_EMBEDDINGS_API, + gpt3Endpoint = + new OpenAiChatEndpoint( + OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - gpt3Endpoint = - new OpenAiEndpoint( + gpt3StreamEndpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, + OPENAI_ORG_ID, "gpt-3.5-turbo", "user", - 0.7, + 0.85, + true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + OpenAiEmbeddingEndpoint adaEmbedding = + new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + + // Defining tablename and namespace... postgresEndpoint = - new PostgresEndpoint("spring_vectors", new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); + new PostgresEndpoint( + "pg_vectors", + "machine-learning", + adaEmbedding, + new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS)); + contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -136,39 +156,36 @@ public class PostgreSQLController { */ @PostMapping("/postgres/upsert") public void upsert(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); String filename = arkRequest.getMultiPart("file").getSubmittedFileName(); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - postgresEndpoint.setNamespace(namespace); - String[] arr = pdfReader.readByChunkSize(file, 512); - final Retrieval retrieval = - new PostgresRetrieval(postgresEndpoint, filename, 1536, ada002Embedding, arkRequest); + PostgresRetrieval retrieval = + new PostgresRetrieval( + arr, postgresEndpoint, 1536, filename, PostgresLanguage.ENGLISH, arkRequest); + + // retrieval.setBatchSize(50); // Modifying batchSize....(Default is 30) - IntStream.range(0, arr.length).parallel().forEach(i -> retrieval.upsert(arr[i])); + // Getting ids from upsertion... Internally, it automatically parallelizes the operation... + List ids = retrieval.upsert(); + + ids.forEach(System.out::println); + + System.out.println("Size: " + ids.size()); // Printing the UUIDs } @PostMapping(value = "/postgres/query") public ArkResponse query(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - postgresEndpoint.setNamespace(namespace); - - // Chain 1==> Get Embeddings From Input & Then Query To PostgreSQL - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from PostgreSQL + // Chain 1==> Query Embeddings from PostgreSQL EdgeChain> queryChain = new EdgeChain<>( postgresEndpoint.query( - embeddingsChain.get(), PostgresDistanceMetric.IP, topK, 10)); // defining probes + List.of(query), PostgresDistanceMetric.COSINE, topK, topK, 10, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -184,15 +201,9 @@ public ArkResponse chat(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure PostgresEndpoint - postgresEndpoint.setNamespace(namespace); - - gpt3Endpoint.setStream(stream); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -206,24 +217,23 @@ public ArkResponse chat(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - - // Chain 1 ==> Get Embeddings From Input - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - // Chain 2 ==> Query Embeddings from PostgreSQL & Then concatenate it (preparing for prompt) - // let's say topK=5; then we concatenate List into a string using String.join method + EdgeChain> postgresChain = new EdgeChain<>( - postgresEndpoint.query(embeddingsChain.get(), PostgresDistanceMetric.L2, topK)); + postgresEndpoint.query( + List.of(query), PostgresDistanceMetric.COSINE, topK, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List + // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain queryChain = new EdgeChain<>(postgresChain) .transform( postgresResponse -> { + List postgresWordEmbeddingsList = + postgresResponse.get(); List queryList = new ArrayList<>(); - postgresResponse.get().forEach(q -> queryList.add(q.getRawText())); + postgresWordEmbeddingsList.forEach(q -> queryList.add(q.getRawText())); return String.join("\n", queryList); }); @@ -231,16 +241,16 @@ public ArkResponse chat(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -257,6 +267,12 @@ public ArkResponse chat(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion( + promptChain.get(), "PostgresChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response and once GPT chain indicates that it is finished, we will save the following into Postgres Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory diff --git a/Java/Examples/react-chain/ReactChainApplication.java b/Java/Examples/react-chain/ReactChainApplication.java index b8c848f25..d3e612155 100644 --- a/Java/Examples/react-chain/ReactChainApplication.java +++ b/Java/Examples/react-chain/ReactChainApplication.java @@ -1,6 +1,6 @@ package com.edgechain; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -9,9 +9,10 @@ import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; -import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import java.util.Properties; import java.util.concurrent.TimeUnit; @@ -22,8 +23,9 @@ public class ReactChainApplication { private static final String OPENAI_AUTH_KEY = ""; - - private static OpenAiEndpoint userChatEndpoint; + private static final String OPENAI_ORG_ID = ""; + private static OpenAiChatEndpoint userChatEndpoint; + private static JsonnetLoader loader = new FileJsonnetLoader("./react-chain/react-chain.jsonnet"); public static void main(String[] args) { System.setProperty("server.port", "8080"); @@ -45,64 +47,83 @@ public static void main(String[] args) { new SpringApplicationBuilder(ReactChainApplication.class).properties(properties).run(args); userChatEndpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, + OPENAI_ORG_ID, "gpt-3.5-turbo", "user", 0.7, + false, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); } @RestController @RequestMapping("/v1/examples") - public class ExampleController { + public class ReactChainController { - @GetMapping(value = "/react-chain") + @PostMapping(value = "/react-chain") public String reactChain(ArkRequest arkRequest) { String prompt = (String) arkRequest.getBody().get("prompt"); - StringBuilder context = new StringBuilder(); - JsonnetLoader loader = - new FileJsonnetLoader("./react-chain.jsonnet") - .put("context", new JsonnetArgs(DataType.STRING, "This is context")) - .put("gptResponse", new JsonnetArgs(DataType.STRING, "")) - .loadOrReload(); - String preset = loader.get("preset"); - - prompt = preset + " \nQuestion: " + prompt; - - String gptResponse = - userChatEndpoint - .chatCompletion(prompt, "React-Chain", arkRequest) - .blockingFirst() - .getChoices() - .get(0) - .getMessage() - .getContent(); - context.append(prompt); - loader.put("context", new JsonnetArgs(DataType.STRING, context.toString())); + + loader.put("context", new JsonnetArgs(DataType.STRING, "This is context")); + loader.put("gptResponse", new JsonnetArgs(DataType.STRING, "")); + loader.put("question", new JsonnetArgs(DataType.STRING, prompt)); + loader.put("text", new JsonnetArgs(DataType.STRING, "")); + + try { + loader.loadOrReload(); + } catch (Exception e) { + e.printStackTrace(); + return "Please broaden the search query!"; + } + prompt = loader.get("initialPrompt"); + + String gptResponse = gptFn(prompt, arkRequest); + + loader.put("context", new JsonnetArgs(DataType.STRING, prompt)); loader.put("gptResponse", new JsonnetArgs(DataType.STRING, gptResponse)); while (!checkIfFinished(gptResponse)) { - loader.loadOrReload(); + try { + loader.loadOrReload(); + } catch (Exception e) { + return "Please broaden the search query or try again!"; + } + + String observation = loader.get("observation"); + if (observation.isEmpty()) + return "No info found on Wiki! Please broaden the search query or try again!"; + prompt = loader.get("prompt"); - gptResponse = - userChatEndpoint - .chatCompletion(prompt, "React-Chain", arkRequest) - .blockingFirst() - .getChoices() - .get(0) - .getMessage() - .getContent(); - context.append("\n" + prompt); - loader.put("context", new JsonnetArgs(DataType.STRING, context.toString())); + gptResponse = gptFn(prompt, arkRequest); + + loader.put("context", new JsonnetArgs(DataType.STRING, prompt)); loader.put("gptResponse", new JsonnetArgs(DataType.STRING, gptResponse)); } - return gptResponse.substring(gptResponse.indexOf("Finish[") + 7, gptResponse.indexOf("]")); + + // Extracting the final answer + loader.put("text", new JsonnetArgs(DataType.STRING, gptResponse)); + + try { + loader.loadOrReload(); + return loader.get("finalAns"); + } catch (Exception e) { + return "Please broaden the search query or try again!"; + } } private boolean checkIfFinished(String gptResponse) { return gptResponse.contains("Finish"); } + + private String gptFn(String prompt, ArkRequest arkRequest) { + return new EdgeChain<>(userChatEndpoint.chatCompletion(prompt, "React-Chain", arkRequest)) + .get() + .getChoices() + .get(0) + .getMessage() + .getContent(); + } } } diff --git a/Java/Examples/react-chain/react-chain.jsonnet b/Java/Examples/react-chain/react-chain.jsonnet index 5f2c25d20..39f3d5da5 100644 --- a/Java/Examples/react-chain/react-chain.jsonnet +++ b/Java/Examples/react-chain/react-chain.jsonnet @@ -45,6 +45,7 @@ local preset = ||| Action 3: Finish[yes] **ALL THE OBSERVATIONS WILL BE PROVIDED BY THE USER, YOU DON'T HAVE TO PROVIDE ANY OBSERVATION** + Question: {} |||; //To extract action from the response @@ -57,14 +58,27 @@ local extractThought(str) = local thought = xtr.strings.substringAfter(xtr.strings.substringBefore(str, "Action"), ":"); thought; +//Replace the {} in the preset with the question +local updateQueryPrompt(question) = + local updatedPrompt = xtr.replace(preset, '{}', question + "\n"); + updatedPrompt; + +//Extract the final answer +local extractFinalAns(text) = + local finalAns = xtr.strings.substringAfter(xtr.strings.substringBeforeLast(xtr.strings.substringAfter(text, "Finish["), "]"), "["); + finalAns; + +local initialPrompt = updateQueryPrompt(payload.question); local gptResponse = payload.gptResponse; //this will be populated from the java code after the prompt is submitted to gpt local action = extractAction(gptResponse); local thought = extractThought(gptResponse); -local searchResponse = std.substr(callFunction("search")(action), 0, 200); //extract action from response and insert here +local searchResponse = std.substr(callFunction("search")(action), 0, 400); //extract action from response and insert here local observation = xtr.join(["Observation:", searchResponse], ''); -local context = payload.context + "\n" + gptResponse + "\n" + observation; -local prompt = xtr.strings.appendIfMissing(context, "\n" + observation); +local context = payload.context; +local prompt = xtr.join([context, gptResponse, observation], '\n'); +local finalAns = extractFinalAns(payload.text); { + initialPrompt: initialPrompt, observation: observation, thought: thought, action: action, @@ -72,6 +86,7 @@ local prompt = xtr.strings.appendIfMissing(context, "\n" + observation); prompt: prompt, context: context, searchResponse: searchResponse, - gptResponse: gptResponse + gptResponse: gptResponse, + finalAns: finalAns } diff --git a/Java/Examples/redis/RedisExample.java b/Java/Examples/redis/RedisExample.java index 603970750..919b4cfa7 100644 --- a/Java/Examples/redis/RedisExample.java +++ b/Java/Examples/redis/RedisExample.java @@ -4,11 +4,13 @@ import static com.edgechain.lib.constants.EndpointConstants.OPENAI_EMBEDDINGS_API; import com.edgechain.lib.chains.RedisRetrieval; -import com.edgechain.lib.chains.Retrieval; import com.edgechain.lib.chunk.enums.LangType; import com.edgechain.lib.context.domain.HistoryContext; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.*; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; @@ -25,7 +27,6 @@ import java.io.InputStream; import java.util.*; import java.util.concurrent.TimeUnit; -import java.util.stream.IntStream; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.SpringBootApplication; @@ -35,9 +36,12 @@ @SpringBootApplication public class RedisExample { - private static final String OPENAI_AUTH_KEY = ""; - private static OpenAiEndpoint ada002Embedding; - private static OpenAiEndpoint gpt3Endpoint; + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID + private static OpenAiChatEndpoint gpt3Endpoint; + + private static OpenAiChatEndpoint gpt3StreamEndpoint; + private static RedisEndpoint redisEndpoint; private static RedisHistoryContextEndpoint contextEndpoint; @@ -70,25 +74,42 @@ public static void main(String[] args) { new SpringApplicationBuilder(RedisExample.class).properties(properties).run(args); - // Variables Initialization ==> Endpoints must be intialized in main method... - ada002Embedding = - new OpenAiEndpoint( - OPENAI_EMBEDDINGS_API, + gpt3Endpoint = + new OpenAiChatEndpoint( + OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, - "text-embedding-ada-002", - new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); - gpt3Endpoint = - new OpenAiEndpoint( + gpt3StreamEndpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, + OPENAI_ORG_ID, "gpt-3.5-turbo", "user", - 0.7, + 0.85, + true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + OpenAiEmbeddingEndpoint ada002Endpoint = + new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + redisEndpoint = - new RedisEndpoint("vector_index", new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + new RedisEndpoint( + "vector_index", + "machine-learning", + ada002Endpoint, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + contextEndpoint = new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS)); } @@ -123,20 +144,15 @@ public class RedisController { /********************** REDIS WITH OPENAI ****************************/ // Namespace is optional (if not provided, it will be using namespace will be "knowledge") + /** + * Both IndexName & namespace are integral for upsert & performing similarity search; If you are + * creating different namespace; recommended to use different index_name because filtering is + * done by index_name * + */ @PostMapping("/redis/upsert") // /v1/examples/openai/upsert?namespace=machine-learning public void upsert(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - /** - * Both IndexName & namespace are integral for upsert & performing similarity search; If you - * are creating different namespace; recommended to use different index_name because filtering - * is done by index_name * - */ - // Configure RedisEndpoint - redisEndpoint.setNamespace(namespace); - /** * We have two implementation for Read By Sentence: a) readBySentence(LangType, Your File) * EdgeChains sdk has predefined support to chunk by sentences w.r.t to 5 languages (english, @@ -147,14 +163,12 @@ public void upsert(ArkRequest arkRequest) throws IOException { String[] arr = pdfReader.readBySentence(LangType.EN, file); /** - * Retrieval Class is basically used to generate embeddings & upsert it to VectorDB; If OpenAI - * Embedding Endpoint is not provided; then Doc2Vec constructor is used If the model is not - * provided, then it will emit an error + * Retrieval Class is basically used to generate embeddings & upsert it to VectorDB + * asynchronously...; */ - Retrieval retrieval = - new RedisRetrieval( - redisEndpoint, ada002Embedding, 1536, RedisDistanceMetric.COSINE, arkRequest); - IntStream.range(0, arr.length).parallel().forEach(i -> retrieval.upsert(arr[i])); + RedisRetrieval retrieval = + new RedisRetrieval(arr, redisEndpoint, 1536, RedisDistanceMetric.COSINE, arkRequest); + retrieval.upsert(); } /** @@ -166,19 +180,12 @@ public void upsert(ArkRequest arkRequest) throws IOException { @PostMapping(value = "/redis/similarity-search") public ArkResponse similaritySearch(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - redisEndpoint.setNamespace(namespace); - - // Chain 1 ==> Generate Embeddings Using Ada002 - EdgeChain ada002Chain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Pass those embeddings to Redis & Return Score/values (Similarity search) + // Chain 1 ==> Pass those embeddings to Redis & Return Score/values (Similarity search) EdgeChain> redisQueries = - new EdgeChain<>(redisEndpoint.query(ada002Chain.get(), topK)); + new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); return redisQueries.getArkResponse(); } @@ -186,19 +193,12 @@ public ArkResponse similaritySearch(ArkRequest arkRequest) { @PostMapping(value = "/redis/query") public ArkResponse queryRedis(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - redisEndpoint.setNamespace(namespace); - - // Chain 1==> Get Embeddings From Input & Then Query To Redis - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from Redis + // Chain 1 ==> Query Embeddings from Redis EdgeChain> queryChain = - new EdgeChain<>(redisEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -213,14 +213,8 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // configure GPT3Endpoint - gpt3Endpoint.setStream(stream); - - redisEndpoint.setNamespace(namespace); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -235,22 +229,20 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1 ==> Get Embeddings From Input - EdgeChain embeddingsChain = - new EdgeChain<>(ada002Embedding.embeddings(query, arkRequest)); + // Chain 1==> Query Embeddings from Redis & Then concatenate it (preparing for prompt) - // Chain 2 ==> Query Embeddings from Redis & Then concatenate it (preparing for prompt) - // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain> redisChain = - new EdgeChain<>(redisEndpoint.query(embeddingsChain.get(), topK)); + new EdgeChain<>(redisEndpoint.query(query, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List + // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain queryChain = new EdgeChain<>(redisChain) .transform( redisResponse -> { + List wordEmbeddings = redisResponse.get(); List queryList = new ArrayList<>(); - redisResponse.get().forEach(q -> queryList.add(q.getId())); + wordEmbeddings.forEach(q -> queryList.add(q.getId())); return String.join("\n", queryList); }); @@ -258,16 +250,16 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -284,6 +276,11 @@ public ArkResponse chatWithRedis(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response and once GPT chain indicates that it is finished, we will save the following into Redis Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory diff --git a/Java/Examples/supabase-miniLM/SupabaseMiniLMExample.java b/Java/Examples/supabase-miniLM/SupabaseMiniLMExample.java index a58cdb94c..dba35663e 100644 --- a/Java/Examples/supabase-miniLM/SupabaseMiniLMExample.java +++ b/Java/Examples/supabase-miniLM/SupabaseMiniLMExample.java @@ -1,13 +1,15 @@ package com.edgechain; import com.edgechain.lib.chains.PostgresRetrieval; -import com.edgechain.lib.chains.Retrieval; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; -import com.edgechain.lib.endpoint.impl.*; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -33,20 +35,19 @@ import java.util.Objects; import java.util.Properties; import java.util.concurrent.TimeUnit; -import java.util.stream.IntStream; import static com.edgechain.lib.constants.EndpointConstants.OPENAI_CHAT_COMPLETION_API; @SpringBootApplication public class SupabaseMiniLMExample { - private static final String OPENAI_AUTH_KEY = ""; + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID - private static OpenAiEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3Endpoint; + private static OpenAiChatEndpoint gpt3StreamEndpoint; private static PostgresEndpoint postgresEndpoint; private static PostgreSQLHistoryContextEndpoint contextEndpoint; - private static MiniLMEndpoint miniLMEndpoint; - private JsonnetLoader queryLoader = new FileJsonnetLoader("./supabase-miniLM/postgres-query.jsonnet"); private JsonnetLoader chatLoader = @@ -61,6 +62,9 @@ public static void main(String[] args) { properties.setProperty("supabase.url", ""); properties.setProperty("supabase.annon.key", ""); + // For JWT decode + properties.setProperty("jwt.secret", ""); + // Adding Cors ==> You can configure multiple cors w.r.t your urls.; properties.setProperty("cors.origins", "http://localhost:4200"); @@ -73,18 +77,27 @@ public static void main(String[] args) { properties.setProperty("postgres.db.username", "postgres"); properties.setProperty("postgres.db.password", ""); - // For JWT decode - properties.setProperty("jwt.secret", ""); - new SpringApplicationBuilder(SupabaseMiniLMExample.class).properties(properties).run(args); gpt3Endpoint = - new OpenAiEndpoint( + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, + OPENAI_ORG_ID, "gpt-3.5-turbo", "user", - 0.7, + 0.85, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + + gpt3StreamEndpoint = + new OpenAiChatEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.85, + true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); // Creating MiniLM Endpoint @@ -92,12 +105,16 @@ public static void main(String[] args) { // download on fly. // All the requests will wait until the model is download & loaded once into the application... // As you can see, the model is not download; so it will download on fly... - miniLMEndpoint = new MiniLMEndpoint(MiniLMModel.ALL_MINILM_L12_V2); + MiniLMEndpoint miniLMEndpoint = new MiniLMEndpoint(MiniLMModel.ALL_MINILM_L12_V2); // Creating PostgresEndpoint ==> We create a new table because miniLM supports 384 dimensional // vectors; postgresEndpoint = - new PostgresEndpoint("minilm_vectors", new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); + new PostgresEndpoint( + "minilm_vectors", + "minilm-ns", + miniLMEndpoint, + new ExponentialDelay(2, 3, 2, TimeUnit.SECONDS)); contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS)); } @@ -150,39 +167,37 @@ public class SupabaseController { @PostMapping("/miniLM/upsert") @PreAuthorize("hasAnyAuthority('authenticated')") public void upsert(ArkRequest arkRequest) throws IOException { - - String namespace = arkRequest.getQueryParam("namespace"); String filename = arkRequest.getMultiPart("file").getSubmittedFileName(); InputStream file = arkRequest.getMultiPart("file").getInputStream(); - postgresEndpoint.setNamespace(namespace); - String[] arr = pdfReader.readByChunkSize(file, 512); - Retrieval retrieval = - new PostgresRetrieval(postgresEndpoint, filename, 384, miniLMEndpoint, arkRequest); + PostgresRetrieval retrieval = + new PostgresRetrieval( + arr, postgresEndpoint, 384, filename, PostgresLanguage.ENGLISH, arkRequest); - IntStream.range(0, arr.length).parallel().forEach(i -> retrieval.upsert(arr[i])); + // retrieval.setBatchSize(50); // Modifying batchSize.... + + // Getting ids from upsertion... Internally, it automatically parallelizes the operation... + List ids = retrieval.upsert(); + + ids.forEach(System.out::println); + + System.out.println("Size: " + ids.size()); // Printing the UUIDs } @PostMapping(value = "/miniLM/query") @PreAuthorize("hasAnyAuthority('authenticated')") public ArkResponse queryPostgres(ArkRequest arkRequest) { - String namespace = arkRequest.getQueryParam("namespace"); String query = arkRequest.getBody().getString("query"); int topK = arkRequest.getIntQueryParam("topK"); - postgresEndpoint.setNamespace(namespace); - - // Chain 1==> Get Embeddings From Input using MiniLM & Then Query To PostgreSQL - EdgeChain embeddingsChain = - new EdgeChain<>(miniLMEndpoint.embeddings(query, arkRequest)); - // Chain 2 ==> Query Embeddings from PostgreSQL EdgeChain> queryChain = new EdgeChain<>( - postgresEndpoint.query(embeddingsChain.get(), PostgresDistanceMetric.L2, topK)); + postgresEndpoint.query( + List.of(query), PostgresDistanceMetric.L2, topK, topK, arkRequest)); // Chain 3 ===> Our queryFn passes takes list and passes each response with base prompt to // OpenAI @@ -199,15 +214,9 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { String contextId = arkRequest.getQueryParam("id"); String query = arkRequest.getBody().getString("query"); - String namespace = arkRequest.getQueryParam("namespace"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure PostgresEndpoint - postgresEndpoint.setNamespace(namespace); - - gpt3Endpoint.setStream(stream); - // Get HistoryContext HistoryContext historyContext = contextEndpoint.get(contextId); @@ -222,23 +231,22 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { // Extract topK value from JsonnetLoader; int topK = chatLoader.getInt("topK"); - // Chain 1 ==> Get Embeddings From Input using MiniLM - EdgeChain embeddingsChain = - new EdgeChain<>(miniLMEndpoint.embeddings(query, arkRequest)); - - // Chain 2 ==> Query Embeddings from PostgreSQL & Then concatenate it (preparing for prompt) + // Chain 1 ==> Query Embeddings from PostgreSQL & Then concatenate it (preparing for prompt) // let's say topK=5; then we concatenate List into a string using String.join method EdgeChain> postgresChain = new EdgeChain<>( - postgresEndpoint.query(embeddingsChain.get(), PostgresDistanceMetric.L2, topK)); + postgresEndpoint.query( + List.of(query), PostgresDistanceMetric.L2, topK, topK, arkRequest)); // Chain 3 ===> Transform String of Queries into List EdgeChain queryChain = new EdgeChain<>(postgresChain) .transform( postgresResponse -> { + List postgresWordEmbeddingsList = + postgresResponse.get(); List queryList = new ArrayList<>(); - postgresResponse.get().forEach(q -> queryList.add(q.getRawText())); + postgresWordEmbeddingsList.forEach(q -> queryList.add(q.getRawText())); return String.join("\n", queryList); }); @@ -246,17 +254,17 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { EdgeChain promptChain = queryChain.transform(queries -> chatFn(historyContext.getResponse(), queries)); - // Chain 5 ==> Pass the Prompt To Gpt3 - EdgeChain gpt3Chain = - new EdgeChain<>( - gpt3Endpoint.chatCompletion( - promptChain.get(), "MiniLMPostgresChatChain", arkRequest)); - // (FOR NON STREAMING) // If it's not stream ==> // Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory if (!stream) { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion( + promptChain.get(), "MiniLMPostgresChatChain", arkRequest)); + // Chain 6 EdgeChain historyUpdatedChain = gpt3Chain.doOnNext( @@ -273,6 +281,12 @@ public ArkResponse chatWithPostgres(ArkRequest arkRequest) { // For STREAMING Version else { + // Chain 5 ==> Pass the Prompt To Gpt3 + EdgeChain gpt3Chain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion( + promptChain.get(), "MiniLMPostgresChatChain", arkRequest)); + /* As the response is in stream, so we will use StringBuilder to append the response and once GPT chain indicates that it is finished, we will save the following into Postgres Query(What is the collect stage for data maturity) + OpenAiResponse + Prev. ChatHistory diff --git a/Java/Examples/wiki/WikiExample.java b/Java/Examples/wiki/WikiExample.java index 7c3d8cab1..feedb984a 100644 --- a/Java/Examples/wiki/WikiExample.java +++ b/Java/Examples/wiki/WikiExample.java @@ -1,7 +1,7 @@ package com.edgechain; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.jsonnet.JsonnetArgs; import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.enums.DataType; @@ -27,13 +27,19 @@ @SpringBootApplication public class WikiExample { - private static final String OPENAI_AUTH_KEY = ""; + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID /* Step 3: Create OpenAiEndpoint to communicate with OpenAiServices; */ - private static OpenAiEndpoint gpt4Endpoint; + private static OpenAiChatEndpoint gpt3Endpoint; + + private static OpenAiChatEndpoint gpt3StreamEndpoint; + private static WikiEndpoint wikiEndpoint; - private final JsonnetLoader loader = new FileJsonnetLoader("./wiki/wiki.jsonnet"); + // There is a 70% chance that file1 is executed; 30% chance file2 is executed.... + private final JsonnetLoader loader = + new FileJsonnetLoader(70, "./wiki/wiki1.jsonnet", "./wiki/wiki2.jsonnet"); public static void main(String[] args) { System.setProperty("server.port", "8080"); @@ -48,20 +54,32 @@ public static void main(String[] args) { properties.setProperty("spring.jpa.properties.hibernate.format_sql", "true"); properties.setProperty("postgres.db.host", ""); - properties.setProperty("postgres.db.username", "postgres"); + properties.setProperty("postgres.db.username", ""); properties.setProperty("postgres.db.password", ""); new SpringApplicationBuilder(WikiExample.class).properties(properties).run(args); wikiEndpoint = new WikiEndpoint(); - gpt4Endpoint = - new OpenAiEndpoint( + gpt3Endpoint = + new OpenAiChatEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "gpt-3.5-turbo", + "user", + 0.7, + new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); + + gpt3StreamEndpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, OPENAI_AUTH_KEY, - "gpt-4", + OPENAI_ORG_ID, + "gpt-3.5-turbo", "user", 0.7, + true, new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS)); } @@ -80,19 +98,11 @@ public ArkResponse wikiSummary(ArkRequest arkRequest) { String query = arkRequest.getQueryParam("query"); boolean stream = arkRequest.getBooleanHeader("stream"); - // Configure GPT4Endpoint - gpt4Endpoint.setStream(stream); - // Chain 1 ==> WikiChain EdgeChain wikiChain = new EdgeChain<>(wikiEndpoint.getPageContent(query)); // Chain 2 ===> Creating Prompt Chain & Return ChatCompletion EdgeChain promptChain = wikiChain.transform(this::fn); - - // Chain 3 ==> Pass Prompt to ChatCompletion API & Return ArkResponseObservable - EdgeChain openAiChain = - new EdgeChain<>(gpt4Endpoint.chatCompletion(promptChain.get(), "WikiChain", arkRequest)); - /** * The best part is flexibility with just one method EdgeChainsSDK will return response either * in json or stream; The real magic happens here. Streaming happens only if your logic allows @@ -101,8 +111,25 @@ public ArkResponse wikiSummary(ArkRequest arkRequest) { // Note: When you call getArkResponse() or getArkStreamResponse() ==> Only then your streams // are executed... - if (stream) return openAiChain.getArkStreamResponse(); - else return openAiChain.getArkResponse(); + if (stream) { + + // Chain 3 ==> Pass Prompt to ChatCompletion API & Return ArkResponseObservable + EdgeChain openAiChain = + new EdgeChain<>( + gpt3StreamEndpoint.chatCompletion( + promptChain.get(), "WikiChain", loader, arkRequest)); + + return openAiChain.getArkStreamResponse(); + + } else { + + // Chain 3 ==> Pass Prompt to ChatCompletion API & Return ArkResponseObservable + EdgeChain openAiChain = + new EdgeChain<>( + gpt3Endpoint.chatCompletion(promptChain.get(), "WikiChain", loader, arkRequest)); + + return openAiChain.getArkResponse(); + } } private String fn(WikiResponse wiki) { diff --git a/Java/Examples/wiki/wiki.jsonnet b/Java/Examples/wiki/wiki1.jsonnet similarity index 93% rename from Java/Examples/wiki/wiki.jsonnet rename to Java/Examples/wiki/wiki1.jsonnet index 051d088ce..347054fdb 100644 --- a/Java/Examples/wiki/wiki.jsonnet +++ b/Java/Examples/wiki/wiki1.jsonnet @@ -12,7 +12,7 @@ local preset = ||| 2. - ... ``` - Now, given the data, create a 30-bullet point summary of: + Now, given the data, create a 5-bullet point summary of: |||; local keepContext = payload.keepContext; local context = if keepContext == "true" then payload.context else ""; diff --git a/Java/Examples/wiki/wiki2.jsonnet b/Java/Examples/wiki/wiki2.jsonnet new file mode 100644 index 000000000..a0c9e064c --- /dev/null +++ b/Java/Examples/wiki/wiki2.jsonnet @@ -0,0 +1,23 @@ +local keepMaxTokens = payload.keepMaxTokens; +local maxTokens = if keepMaxTokens == "true" then payload.maxTokens else 5120; + +local preset = ||| + Just consider yourself as a summary generator bot. You should detect the language and the characters the user is writing in, and reply in the same character set and language. + You should follow the following template while answering the user: + ``` + 1. - + 2. - + ... + ``` + Now, given the data, create a 15-bullet point summary of: + |||; +local keepContext = payload.keepContext; +local context = if keepContext == "true" then payload.context else ""; +local prompt = std.join("\n", [preset, context]); +{ + "maxTokens": maxTokens, + "typeOfKeepContext": xtr.type(keepContext), + "preset" : preset, + "context": context, + "prompt": if(std.length(prompt) > xtr.parseNum(maxTokens)) then std.substr(prompt, 0, xtr.parseNum(maxTokens)) else prompt +} \ No newline at end of file diff --git a/Java/Examples/zapier/ZapierExample.java b/Java/Examples/zapier/ZapierExample.java new file mode 100644 index 000000000..f50c93f40 --- /dev/null +++ b/Java/Examples/zapier/ZapierExample.java @@ -0,0 +1,266 @@ +package com.edgechain; + +// DEPS com.amazonaws:aws-java-sdk-s3:1.12.554 + +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; +import com.amazonaws.services.s3.model.ListObjectsV2Request; +import com.amazonaws.services.s3.model.ListObjectsV2Result; +import com.amazonaws.services.s3.model.S3Object; +import com.amazonaws.services.s3.model.S3ObjectSummary; +import com.edgechain.lib.chains.PineconeRetrieval; +import com.edgechain.lib.chunk.Chunker; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.ArkResponse; +import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import me.xuender.unidecode.Unidecode; +import org.apache.commons.io.IOUtils; +import org.json.JSONArray; +import org.json.JSONObject; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.builder.SpringApplicationBuilder; +import org.springframework.context.annotation.Bean; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.util.retry.Retry; +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static com.edgechain.lib.constants.EndpointConstants.OPENAI_EMBEDDINGS_API; + +/** + * Objective (1): 1. Use Zapier Webhook to pass list of urls (in our example, we have used Wikipedia + * Urls") ==> Trigger 2. Action: Use 'Web Page Parser by Zapier' to extract entire content (which is + * later used to upsert in Pinecone) from url 3. Action: Stringify the JSON 4. Action: Save each + * json (parsed content from each URL) to Amazon S3 (This process is entirely automated by Zapier). + * You would need to create Zap for it. Zapier Hook is used to trigger the ETL process, (Parallelize + * Hook Requests)... ========================================================== 5. Then, we extract + * each file from Amazon S3 6. Upsert the normalized content to Pinecone with a chunkSize of 512... + * You can choose any file storage S3, Dropbox, Google Drive etc.... + */ + +/** + * Objective (2): Extracting PDF via PDF4Me. Create a Zapier by using the following steps: 1. + * Trigger ==> Integrate Google Drive Folder; when new file is added (it's not instant; it's + * scheduled internally by Zapier.) (You can also trigger it by Webhook as well) 2. Action ==> + * Extract text from PDF using PDF4Me (Free plan allows 20 API calls) 3. Action ==> Use ZapierByCode + * to stringify the json response from PDF4Me 4. Action ==> Save it to Amazon S3 Now, from + * EdgeChains SDK we extract the files from S3 & upsert it to Pinecone via Chunk Size 512. + */ +@SpringBootApplication +public class ZapierExample { + + private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY + private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID + private static final String ZAPIER_HOOK_URL = + ""; // E.g. https://hooks.zapier.com/hooks/catch/18785910/2ia657b + + private static final String PINECONE_AUTH_KEY = ""; + private static final String PINECONE_API = ""; + + private static PineconeEndpoint pineconeEndpoint; + + public static void main(String[] args) { + + System.setProperty("server.port", "8080"); + + new SpringApplicationBuilder(ZapierExample.class).run(args); + + OpenAiEmbeddingEndpoint adaEmbedding = + new OpenAiEmbeddingEndpoint( + OPENAI_EMBEDDINGS_API, + OPENAI_AUTH_KEY, + OPENAI_ORG_ID, + "text-embedding-ada-002", + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + + pineconeEndpoint = + new PineconeEndpoint( + PINECONE_API, + PINECONE_AUTH_KEY, + adaEmbedding, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + } + + @Bean + public AmazonS3 s3Client() { + + String accessKey = ""; // YOUR_AWS_S3_ACCESS_KEY + String secretKey = ""; // YOUR_AWS_S3_SECRET_KEY + + BasicAWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey); + return AmazonS3ClientBuilder.standard() + .withRegion(Regions.fromName("us-east-1")) + .withCredentials(new AWSStaticCredentialsProvider(awsCredentials)) + .build(); + } + + @RestController + public class ZapierController { + + @Autowired private AmazonS3 s3Client; + + // List of wiki urls triggered in parallel via WebHook. They are automatically parsed, + // transformed to JSON and stored in AWS S3. + /* + Examples: + "https://en.wikipedia.org/wiki/The_Weather_Company", + "https://en.wikipedia.org/wiki/Microsoft_Bing", + "https://en.wikipedia.org/wiki/Quora", + "https://en.wikipedia.org/wiki/Steve_Jobs", + "https://en.wikipedia.org/wiki/Michael_Jordan", + */ + @PostMapping("/etl") + public void performETL(ArkRequest arkRequest) { + + JSONObject body = arkRequest.getBody(); + JSONArray jsonArray = body.getJSONArray("urls"); + + IntStream.range(0, jsonArray.length()) + .parallel() + .forEach( + index -> { + String url = jsonArray.getString(index); + // For Logging + System.out.printf("Url %s: %s\n", index, url); + + // Trigger Zapier WebHook + this.zapWebHook(url); + }); + } + + @PostMapping("/upsert-urls") + public void upsertParsedURLs(ArkRequest arkRequest) throws IOException { + String namespace = arkRequest.getQueryParam("namespace"); + JSONObject body = arkRequest.getBody(); + String bucketName = body.getString("bucketName"); + + // Get all the files from S3 bucket + ListObjectsV2Request listObjectsRequest = + new ListObjectsV2Request().withBucketName(bucketName); + + ListObjectsV2Result objectListing = s3Client.listObjectsV2(listObjectsRequest); + + for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) { + String key = objectSummary.getKey(); + + if (key.endsWith(".txt")) { + + S3Object object = s3Client.getObject(bucketName, key); + InputStream objectData = object.getObjectContent(); + String content = IOUtils.toString(objectData, StandardCharsets.UTF_8); + + JSONObject jsonObject = new JSONObject(content); + + // These fields are specified from Zapier Action... + System.out.println("Domain: " + jsonObject.getString("domain")); // en.wikipedia.org + System.out.println("Title: " + jsonObject.getString("title")); // Barack Obama + System.out.println("Author: " + jsonObject.getString("author")); // Wikipedia Contributers + System.out.println("Word Count: " + jsonObject.get("word_count")); // 23077 + System.out.println( + "Date Published: " + jsonObject.getString("date_published")); // Publish date + + // Now, we extract content and Chunk it by 512 size; then upsert it to Pinecone + // Normalize the extracted content.... + + String normalizedText = + Unidecode.decode(jsonObject.getString("content")).replaceAll("[\t\n\r]+", " "); + Chunker chunker = new Chunker(normalizedText); + String[] arr = chunker.byChunkSize(512); + + // Upsert to Pinecone: + PineconeRetrieval retrieval = + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); + + retrieval.upsert(); + + System.out.println("File is parsed: " + key); // For Logging + } + } + } + + @PostMapping("/upsert-pdfs") + public void upsertPDFs(ArkRequest arkRequest) throws IOException { + String namespace = arkRequest.getQueryParam("namespace"); + JSONObject body = arkRequest.getBody(); + + String bucketName = body.getString("bucketName"); + + // Get all the files from S3 bucket + ListObjectsV2Request listObjectsRequest = + new ListObjectsV2Request().withBucketName(bucketName); + + ListObjectsV2Result objectListing = s3Client.listObjectsV2(listObjectsRequest); + + for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) { + String key = objectSummary.getKey(); + + if (key.endsWith("-pdf.txt")) { + + S3Object object = s3Client.getObject(bucketName, key); + InputStream objectData = object.getObjectContent(); + String content = IOUtils.toString(objectData, StandardCharsets.UTF_8); + + JSONObject jsonObject = new JSONObject(content); + + // These fields are specified from Zapier Action... + // Now, we extract content and Chunk it by 512 size; then upsert it to Pinecone + // Normalize the extracted content.... + + System.out.println("Filename: " + jsonObject.getString("filename")); // abcd.pdf + System.out.println("Extension: " + jsonObject.getString("file_extension")); // pdf + + String normalizedText = + Unidecode.decode(jsonObject.getString("text")).replaceAll("[\t\n\r]+", " "); + Chunker chunker = new Chunker(normalizedText); + String[] arr = chunker.byChunkSize(512); + + // Upsert to Pinecone: + PineconeRetrieval retrieval = + new PineconeRetrieval(arr, pineconeEndpoint, namespace, arkRequest); + + retrieval.upsert(); + + System.out.println("File is parsed: " + key); // For Logging + } + } + } + + @DeleteMapping("/pinecone/deleteAll") + public ArkResponse deletePinecone(ArkRequest arkRequest) { + String namespace = arkRequest.getQueryParam("namespace"); + return new EdgeChain<>(pineconeEndpoint.deleteAll(namespace)).getArkResponse(); + } + + private void zapWebHook(String url) { + + WebClient webClient = WebClient.builder().baseUrl(ZAPIER_HOOK_URL).build(); + + JSONObject json = new JSONObject(); + json.put("url", url); + + webClient + .post() + .contentType(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromValue(json.toString())) + .retrieve() + .bodyToMono(String.class) + .retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(20))) // Using Fixed Delay.. + .block(); + } + } +} diff --git a/Java/FlySpring/.DS_Store b/Java/FlySpring/.DS_Store new file mode 100644 index 000000000..b945340c8 Binary files /dev/null and b/Java/FlySpring/.DS_Store differ diff --git a/Java/FlySpring/autoroute/dependency-reduced-pom.xml b/Java/FlySpring/autoroute/dependency-reduced-pom.xml index 8bde02dcf..d186c8cb2 100644 --- a/Java/FlySpring/autoroute/dependency-reduced-pom.xml +++ b/Java/FlySpring/autoroute/dependency-reduced-pom.xml @@ -1,205 +1,115 @@ - - - - spring-boot-starter-parent - org.springframework.boot - 3.0.0 - pom.xml - - 4.0.0 - com.flyspring - autoroute-spring-boot-starter - autoroute - 0.1.1-SNAPSHOT - AutoRoute java functions to routes - - - - maven-shade-plugin - 3.4.1 - - - package - - shade - - - - - - - - - - gofly - - - - maven-antrun-plugin - 3.1.0 - - - package - - run - - - - - - - - - - - - - - - - org.junit.jupiter - junit-jupiter-api - 5.9.2 - test - - - opentest4j - org.opentest4j - - - junit-platform-commons - org.junit.platform - - - apiguardian-api - org.apiguardian - - - - - org.mockito - mockito-junit-jupiter - 5.3.0 - test - - - mockito-core - org.mockito - - - - - junit - junit - 4.4 - test - - - org.springframework - spring-mock - 2.0.8 - test - - - commons-io - commons-io - 2.6 - provided - - - org.springframework.boot - spring-boot-starter-webflux - 3.0.5 - provided - - - org.reflections - reflections - 0.9.12 - provided - - - org.springframework.boot - spring-boot-starter-oauth2-resource-server - 3.0.0 - provided - - - org.springframework.security - spring-security-oauth2-jose - 6.0.0 - provided - - - org.glowroot - glowroot-agent-api - 0.13.6 - provided - - - org.projectlombok - lombok - 1.18.24 - provided - true - - - io.reactivex.rxjava3 - rxjava - 3.1.6 - provided - - - io.reactivex - rxjava-reactive-streams - 1.2.1 - provided - - - io.projectreactor.addons - reactor-adapter - 3.5.1 - provided - - - org.apache.pdfbox - pdfbox - 2.0.28 - provided - - - com.squareup.okhttp3 - okhttp - 4.10.0 - provided - - - commons-fileupload - commons-fileupload - 1.5 - provided - - - me.xuender - unidecode - 0.0.7 - provided - - - org.apache.tika - tika-core - 2.7.0 - provided - - - org.apache.tika - tika-parsers-standard-package - 2.7.0 - provided - - - - 17 - 17 - - + + + + edgechain-parent + com.flyspring + 0.0.1-SNAPSHOT + + 4.0.0 + autoroute + autoroute + 0.1.1-SNAPSHOT + AutoRoute java functions to routes + + clean install + + + maven-shade-plugin + ${maven-shade.version} + + + package + + shade + + + + + + + + + gofly + + + + maven-antrun-plugin + 3.1.0 + + + package + + run + + + + + + + + + + + + + + + + org.junit.jupiter + junit-jupiter-api + 5.9.3 + test + + + opentest4j + org.opentest4j + + + junit-platform-commons + org.junit.platform + + + apiguardian-api + org.apiguardian + + + + + org.mockito + mockito-junit-jupiter + 5.3.1 + test + + + mockito-core + org.mockito + + + + + junit + junit + 4.13.2 + test + + + hamcrest-core + org.hamcrest + + + + + org.springframework + spring-mock + 2.0.8 + test + + + + 1.5 + 2.0.28 + 0.13.6 + 0.9.12 + 2.0.8 + + diff --git a/Java/FlySpring/autoroute/pom.xml b/Java/FlySpring/autoroute/pom.xml index 8612af0e4..5694a6799 100644 --- a/Java/FlySpring/autoroute/pom.xml +++ b/Java/FlySpring/autoroute/pom.xml @@ -1,76 +1,84 @@ - + 4.0.0 - com.flyspring - autoroute-spring-boot-starter + + com.flyspring + edgechain-parent + 0.0.1-SNAPSHOT + + + autoroute 0.1.1-SNAPSHOT autoroute AutoRoute java functions to routes - - org.springframework.boot - spring-boot-starter-parent - 3.0.0 - - - 17 - 17 + 1.5 + 0.13.6 + 0.9.12 + 2.0.28 + 2.0.8 org.junit.jupiter junit-jupiter-api - 5.9.2 test + org.mockito mockito-junit-jupiter - 5.3.0 test + junit junit - 4.4 test + org.springframework spring-mock - 2.0.8 + ${spring-mock.version} test + commons-io commons-io - 2.6 + org.springframework.boot spring-boot-starter-webflux - 3.0.5 + org.reflections - reflections - 0.9.12 + reflections + ${org-reflections.version} + org.springframework.boot spring-boot-starter-oauth2-resource-server + org.springframework.security spring-security-oauth2-jose + org.glowroot glowroot-agent-api - 0.13.6 + ${glowroot.version} @@ -82,104 +90,104 @@ io.reactivex.rxjava3 rxjava - 3.1.6 io.reactivex rxjava-reactive-streams - 1.2.1 + ${rxjava-reactive-streams.version} io.projectreactor.addons reactor-adapter - 3.5.1 org.apache.pdfbox pdfbox - 2.0.28 + ${pdfbox.version} com.squareup.okhttp3 okhttp - 4.10.0 - commons-fileupload - commons-fileupload - 1.5 - + commons-fileupload + commons-fileupload + ${commons-fileupload.version} + me.xuender unidecode - 0.0.7 + ${unidecode.version} org.apache.tika tika-core - 2.7.0 + ${apache-tika.version} org.apache.tika tika-parsers-standard-package - 2.7.0 + ${apache-tika.version} - - + - - gofly - - - - maven-antrun-plugin - 3.1.0 - - - package - - - - - - - run - - - - - - - - + + gofly + + true + + + + + maven-antrun-plugin + ${maven-antrun.version} + + + package + + + + + + + run + + + + + + + + + - - - org.apache.maven.plugins - maven-shade-plugin - 3.4.1 - - - - - - package - - shade - - - - - - - + clean install + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade.version} + + + package + + shade + + + + + + + \ No newline at end of file diff --git a/Java/FlySpring/edgechain-app/.DS_Store b/Java/FlySpring/edgechain-app/.DS_Store new file mode 100644 index 000000000..b0fd969a0 Binary files /dev/null and b/Java/FlySpring/edgechain-app/.DS_Store differ diff --git a/Java/FlySpring/edgechain-app/.gitignore b/Java/FlySpring/edgechain-app/.gitignore index 848775e32..3f969a6bf 100644 --- a/Java/FlySpring/edgechain-app/.gitignore +++ b/Java/FlySpring/edgechain-app/.gitignore @@ -45,3 +45,4 @@ build/ /src/main/java/com/edgechain/HydeExample.java /model/ +/src/main/java/com/edgechain/SupabaseMiniLMExample.java diff --git a/Java/FlySpring/edgechain-app/chat/README.md b/Java/FlySpring/edgechain-app/chat/README.md new file mode 100644 index 000000000..cebd9efee --- /dev/null +++ b/Java/FlySpring/edgechain-app/chat/README.md @@ -0,0 +1,21 @@ +# Simple Chat test + +- Edit SimpleApp.java and set your OpenAI key + +- Run the server using `./run.sh` or enter `java -jar ../target/edgechain.jar jbang SimpleApp.java` + +- Wait for server to start + +- In a separate terminal call the server using `./callserver.sh` or enter + +```bash +curl --location 'localhost:8080/v1/examples/gpt/ask' \ +--header 'Content-Type: application/json' \ +--data '{ + "prompt": "Who was Nikola Tesla?" +}' +``` + + - After a short time text should appear similar to `Ah, my dear interlocutor, allow me to regale you with the tale of Nikola Tesla! Born in 1856, this remarkable gentleman ` ... + +- Close the server terminal using `ctrl+c` diff --git a/Java/FlySpring/edgechain-app/chat/SimpleApp.java b/Java/FlySpring/edgechain-app/chat/SimpleApp.java new file mode 100644 index 000000000..a7a2dacc2 --- /dev/null +++ b/Java/FlySpring/edgechain-app/chat/SimpleApp.java @@ -0,0 +1,83 @@ +package com.edgechain; + +import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.openai.client.OpenAiClient; +import com.edgechain.lib.openai.request.ChatCompletionRequest; +import com.edgechain.lib.openai.request.ChatMessage; +import com.edgechain.lib.openai.response.ChatCompletionResponse; +import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static com.edgechain.lib.constants.EndpointConstants.OPENAI_CHAT_COMPLETION_API; + +@SpringBootApplication +public class SimpleApp { + + private final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI KEY + + public static void main(String[] args) { + System.setProperty("server.port", "8080"); + SpringApplication.run(SimpleApp.class, args); + } + + @RestController + @RequestMapping("/v1/examples") + public class Conversation { + + private List messages; + + public Conversation() { + messages = new ArrayList<>(); + messages.add( + new ChatMessage( + "system", + "You are a helpful, polite, old English assistant. Answer the user prompt with a bit" + + " of humor.")); + } + + @PostMapping("/gpt/ask") + public ResponseEntity ask(@RequestBody String prompt) { + updateMessageList("user", prompt); + String model = "gpt-3.5-turbo"; + ChatCompletionRequest chatCompletionRequest = + new ChatCompletionRequest( + model, 0.7, // temperature + messages, false, null, null, null, null, null, null, null); + OpenAiClient openAiClient = new OpenAiClient(); + OpenAiEndpoint chatEndpoint = + new OpenAiEndpoint( + OPENAI_CHAT_COMPLETION_API, + OPENAI_AUTH_KEY, + model, + "user", + 0.7, + false, + new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS)); + openAiClient.setEndpoint(chatEndpoint); + EdgeChain chatCompletion = + openAiClient.createChatCompletion(chatCompletionRequest); + String response = chatCompletion.get().getChoices().get(0).getMessage().getContent(); + System.out.println(response); + updateMessageList("assistant", response); + return new ResponseEntity<>(response, HttpStatus.OK); + } + + private void updateMessageList(String role, String content) { + messages.add(new ChatMessage(role, content)); + + if (messages.size() > 20) { + messages.remove(0); + } + } + } +} diff --git a/Java/FlySpring/edgechain-app/chat/callserver.sh b/Java/FlySpring/edgechain-app/chat/callserver.sh new file mode 100755 index 000000000..ec4d2ffa7 --- /dev/null +++ b/Java/FlySpring/edgechain-app/chat/callserver.sh @@ -0,0 +1,5 @@ +curl --location 'localhost:8080/v1/examples/gpt/ask' \ +--header 'Content-Type: application/json' \ +--data '{ + "prompt": "Who was Nikola Tesla?" +}' diff --git a/Java/FlySpring/edgechain-app/pom.xml b/Java/FlySpring/edgechain-app/pom.xml index 27b8f282b..1d66b78e2 100644 --- a/Java/FlySpring/edgechain-app/pom.xml +++ b/Java/FlySpring/edgechain-app/pom.xml @@ -1,378 +1,402 @@ - - 4.0.0 - com.edgechain - edgechain-app - 1.0.0 - edgechain - EdgeChains SDK. - jar - - 17 - 17 - 17 - 0.23.0 - - - - - org.springframework.boot - spring-boot-starter-data-redis - - - - org.springframework.boot - spring-boot-starter-web - - - - org.springframework.boot - spring-boot-starter-security - - - - org.springframework.boot - spring-boot-starter-webflux - - - - org.springframework.boot - spring-boot-starter-data-jpa - - - - - redis.clients - jedis - 4.3.1 - - - - org.postgresql - postgresql - runtime - - - - javax.validation - validation-api - 2.0.1.Final - - - - com.github.f4b6a3 - uuid-creator - 5.2.0 - - - - org.hibernate.validator - hibernate-validator - 6.1.6.Final - - - - org.modelmapper - modelmapper - 3.1.1 - - - - io.jsonwebtoken - jjwt - 0.9.1 - - - - javax.xml.bind - jaxb-api - 2.3.1 - - - - - com.squareup.retrofit2 - retrofit - 2.9.0 - - - - com.squareup.retrofit2 - adapter-rxjava3 - 2.9.0 - - - - com.squareup.retrofit2 - converter-jackson - 2.9.0 - - - - org.apache.opennlp - opennlp-tools - 2.2.0 - - - - io.reactivex.rxjava3 - rxjava - 3.1.6 - - - - io.reactivex - rxjava-reactive-streams - 1.2.1 - - - - io.projectreactor.addons - reactor-adapter - 3.5.1 - - - - me.xuender - unidecode - 0.0.7 - - - - org.apache.tika - tika-core - 2.7.0 - - - - org.apache.tika - tika-parsers-standard-package - 2.7.0 - - - - ai.djl - api - ${djl.version} - - - - ai.djl - basicdataset - ${djl.version} - - - - ai.djl.huggingface - tokenizers - 0.23.0 - - - - ai.djl - model-zoo - ${djl.version} - - - - - - ai.djl.pytorch - pytorch-engine - ${djl.version} - - - - ai.djl.pytorch - pytorch-model-zoo - ${djl.version} - - - - - io.github.jam01 - xtrasonnet - 0.5.3 - - - - com.knuddels - jtokkit - 0.6.1 - - - - net.objecthunter - exp4j - 0.4.8 - - - - - info.picocli - picocli-spring-boot-starter - 4.7.0 - - - - net.lingala.zip4j - zip4j - 2.11.3 - - - - org.zeroturnaround - zt-exec - 1.12 - - - - org.testcontainers - testcontainers - 1.17.6 - - - - org.testcontainers - postgresql - - - - org.springframework.boot - spring-boot-starter-test - test - - - - ai.djl.onnxruntime - onnxruntime-engine - 0.23.0 - runtime - - - - - - - - - - org.springframework.boot - spring-boot-dependencies - 3.1.0 - pom - import - - - - - - - - - - - - org.springframework.boot - spring-boot-maven-plugin - - - - org.apache.maven.plugins - maven-shade-plugin - 3.3.0 - - - org.springframework.boot - spring-boot-maven-plugin - 2.7.0 - - - - false - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - edgechain - - - META-INF/spring.handlers - META-INF/spring.schemas + + 4.0.0 + + com.flyspring + edgechain-parent + 0.0.1-SNAPSHOT + + + edgechain-app + 1.0.0 + edgechain + EdgeChains SDK. + jar + + + 4.4.0 + 0.23.0 + 0.4.8 + 6.0.23.Final + 2.3.1 + 0.9.1 + 0.6.1 + 3.1.1 + 2.2.0 + 2.9.0 + 5.2.0 + 2.0.1.Final + 0.5.3 + + + + + org.springframework.boot + spring-boot-starter-data-redis + + + + org.springframework.boot + spring-boot-starter-web + + + + org.springframework.boot + spring-boot-starter-security + + + + org.springframework.boot + spring-boot-starter-webflux + + + + org.springframework.boot + spring-boot-starter-data-jpa + + + + redis.clients + jedis + + + + org.postgresql + postgresql + compile + + + + dev.fuxing + airtable-api + 0.3.2 + + + + javax.validation + validation-api + ${validation-api.version} + + + + com.github.f4b6a3 + uuid-creator + ${uuid-creator.version} + + + + org.hibernate.validator + hibernate-validator + ${hibernate-validator} + + + + org.modelmapper + modelmapper + ${modelmapper.version} + + + + io.jsonwebtoken + jjwt + ${jsonwebtoken.version} + + + + javax.xml.bind + jaxb-api + ${jaxb-api.version} + + + + com.squareup.retrofit2 + retrofit + ${retrofit2.version} + + + + com.squareup.retrofit2 + adapter-rxjava3 + ${retrofit2.version} + + + + com.squareup.retrofit2 + converter-jackson + ${retrofit2.version} + + + + org.apache.opennlp + opennlp-tools + ${opennlp.version} + + + + io.reactivex.rxjava3 + rxjava + + + + io.reactivex + rxjava-reactive-streams + ${rxjava-reactive-streams.version} + + + + io.projectreactor.addons + reactor-adapter + + + + me.xuender + unidecode + ${unidecode.version} + + + + org.apache.tika + tika-core + ${apache-tika.version} + + + + org.apache.tika + tika-parsers-standard-package + ${apache-tika.version} + + + xml-apis + xml-apis + + + + + + ai.djl + api + ${djl.version} + + + + ai.djl + basicdataset + ${djl.version} + + + + ai.djl.huggingface + tokenizers + ${djl.version} + + + + ai.djl + model-zoo + ${djl.version} + + + + + + ai.djl.pytorch + pytorch-engine + ${djl.version} + + + + ai.djl.pytorch + pytorch-model-zoo + ${djl.version} + + + + + io.github.jam01 + xtrasonnet + ${xtrasonnet.version} + + + + com.knuddels + jtokkit + ${jtokkit.version} + + + + net.objecthunter + exp4j + ${exp4j.version} + + + + + info.picocli + picocli-spring-boot-starter + ${picocli.version} + + + + net.lingala.zip4j + zip4j + ${zip4j.version} + + + + org.zeroturnaround + zt-exec + ${zeroturnaround.version} + + + + org.testcontainers + testcontainers + + + + org.testcontainers + postgresql + + + + org.springframework.boot + spring-boot-starter-test + + + com.vaadin.external.google + android-json + + + test + + + + ai.djl.onnxruntime + onnxruntime-engine + ${djl.version} + runtime + + + + org.testcontainers + junit-jupiter + test + + + + com.auth0 + java-jwt + ${auth0-jwt.version} + test + + + + + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade.version} + + + + org.springframework.boot + spring-boot-maven-plugin + ${spring-boot.version} + + + + + false + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + shade-jar-with-dependencies + package + + shade + + + + edgechain + + + META-INF/spring.handlers + META-INF/spring.schemas - - META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports - - - - - com.edgechain.EdgeChainApplication - - - - - - - - - - org.apache.maven.plugins - maven-antrun-plugin - 1.8 - - - download-and-unpack-jbang - generate-resources - - run - - - - - - - - - - - - - - - - - - - - - - - - - - + + META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports + + + + + com.edgechain.EdgeChainApplication + + + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + ${maven-antrun.version} + + + download-and-unpack-jbang + generate-resources + + run + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java index 1e00dfc6c..23714c735 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/EdgeChainApplication.java @@ -1,5 +1,7 @@ package com.edgechain; +import java.net.URL; +import java.nio.file.Paths; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.SpringApplication; @@ -7,14 +9,9 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.context.annotation.Bean; -import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; -import java.net.URL; -import java.nio.file.Paths; - @SpringBootApplication -@EnableScheduling public class EdgeChainApplication { private static final Logger logger = LoggerFactory.getLogger(EdgeChainApplication.class); @@ -22,8 +19,9 @@ public class EdgeChainApplication { public static void main(String[] args) { logger.info("Please avoid special symbols such as space in naming the directory."); - System.setProperty("jar.name", getJarFileName(EdgeChainApplication.class)); - logger.info("Executed jar file: " + System.getProperty("jar.name")); + String jarFileName = getJarFileName(EdgeChainApplication.class); + System.setProperty("jar.name", jarFileName); + logger.info("Executed jar file: {}", jarFileName); SpringApplication springApplication = new SpringApplicationBuilder() @@ -35,7 +33,7 @@ public static void main(String[] args) { } @Bean(name = "mvcHandlerMappingIntrospector") - public HandlerMappingIntrospector mvcHandlerMappingIntrospector() { + HandlerMappingIntrospector mvcHandlerMappingIntrospector() { return new HandlerMappingIntrospector(); } @@ -44,13 +42,15 @@ private static String getJarFileName(Class clazz) { if (classResource == null) { throw new RuntimeException("class resource is null"); } + String url = classResource.toString(); + logger.info("class url: {}", url); if (url.startsWith("jar:file:")) { String path = url.replaceAll("^jar:(file:.*[.]jar)!/.*", "$1"); try { return Paths.get(new URL(path).toURI()).toString(); } catch (Exception e) { - throw new RuntimeException("Invalid jar file"); + throw new RuntimeException("Invalid jar file", e); } } throw new RuntimeException("Invalid jar file"); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java index e1091f97b..e2d19d6df 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PineconeRetrieval.java @@ -1,47 +1,79 @@ package com.edgechain.lib.chains; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.request.ArkRequest; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PineconeRetrieval extends Retrieval { +import java.util.List; - private final Logger logger = LoggerFactory.getLogger(getClass()); +public class PineconeRetrieval { private final PineconeEndpoint pineconeEndpoint; private final ArkRequest arkRequest; - private final Endpoint endpoint; + private final String[] arr; + private String namespace; + private int batchSize = 30; public PineconeRetrieval( - PineconeEndpoint pineconeEndpoint, Endpoint endpoint, ArkRequest arkRequest) { + String[] arr, PineconeEndpoint pineconeEndpoint, String namespace, ArkRequest arkRequest) { this.pineconeEndpoint = pineconeEndpoint; - this.endpoint = endpoint; this.arkRequest = arkRequest; + this.arr = arr; + this.namespace = namespace; - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) + Logger logger = LoggerFactory.getLogger(getClass()); + if (pineconeEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (pineconeEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); + else if (pineconeEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) + logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); } - @Override - public void upsert(String input) { - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) { - WordEmbeddings embeddings = - openAiEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); - this.pineconeEndpoint.upsert(embeddings); - } else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) { - WordEmbeddings embeddings = - miniLMEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); - this.pineconeEndpoint.upsert(embeddings); - } else - throw new RuntimeException( - "Invalid Endpoint; Only OpenAIEndpoint & MiniLMEndpoint are supported"); + public void upsert() { + Observable.fromArray(arr) + .buffer(batchSize) + .concatMapCompletable( + batch -> + Observable.fromIterable(batch) + .flatMap( + input -> + Observable.fromCallable(() -> generateEmbeddings(input)) + .subscribeOn(Schedulers.io())) + .toList() + .flatMapCompletable( + wordEmbeddingsList -> + Completable.fromAction(() -> executeBatchUpsert(wordEmbeddingsList)) + .subscribeOn(Schedulers.io()))) + .blockingAwait(); + } + + private WordEmbeddings generateEmbeddings(String input) { + return pineconeEndpoint + .getEmbeddingEndpoint() + .embeddings(input, arkRequest) + .firstOrError() + .blockingGet(); + } + + private void executeBatchUpsert(List wordEmbeddingsList) { + pineconeEndpoint.batchUpsert(wordEmbeddingsList, this.namespace); + } + + public int getBatchSize() { + return batchSize; + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java index c888c03c9..48f82c6fe 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/PostgresRetrieval.java @@ -1,93 +1,186 @@ package com.edgechain.lib.chains; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.StringResponse; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PostgresRetrieval extends Retrieval { +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.*; +import java.util.stream.Collectors; - private final Logger logger = LoggerFactory.getLogger(getClass()); +public class PostgresRetrieval { - private final PostgresEndpoint postgresEndpoint; - private final int dimensions; + private final Logger logger = LoggerFactory.getLogger(this.getClass()); - private final PostgresDistanceMetric metric; - private final int lists; + private int batchSize = 30; + + private final String[] arr; private final String filename; + private final PostgresLanguage postgresLanguage; + private final ArkRequest arkRequest; - private final Endpoint endpoint; + private final PostgresEndpoint postgresEndpoint; + private final int dimensions; + private final PostgresDistanceMetric metric; + private final int lists; public PostgresRetrieval( + String[] arr, PostgresEndpoint postgresEndpoint, - String filename, int dimensions, - Endpoint endpoint, + PostgresDistanceMetric metric, + int lists, + String filename, + PostgresLanguage postgresLanguage, ArkRequest arkRequest) { - this.postgresEndpoint = postgresEndpoint; - this.dimensions = dimensions; + this.arr = arr; this.filename = filename; - this.endpoint = endpoint; + this.postgresEndpoint = postgresEndpoint; + this.postgresLanguage = postgresLanguage; this.arkRequest = arkRequest; - this.metric = PostgresDistanceMetric.COSINE; - this.lists = 2000; - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) + this.dimensions = dimensions; + this.metric = metric; + this.lists = lists; + + if (postgresEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) + logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); } public PostgresRetrieval( + String[] arr, PostgresEndpoint postgresEndpoint, - String filename, int dimensions, - Endpoint endpoint, - ArkRequest arkRequest, - PostgresDistanceMetric metric, - int lists) { + String filename, + PostgresLanguage postgresLanguage, + ArkRequest arkRequest) { + this.arr = arr; + this.filename = filename; + this.postgresLanguage = postgresLanguage; this.postgresEndpoint = postgresEndpoint; this.dimensions = dimensions; - this.filename = filename; - this.endpoint = endpoint; + this.metric = PostgresDistanceMetric.COSINE; + this.lists = 1000; this.arkRequest = arkRequest; - this.metric = metric; - this.lists = lists; - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) + if (postgresEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); + else if (postgresEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) + logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); + } + + public List upsert() { + + // Create Table... + this.postgresEndpoint.createTable(dimensions, metric, lists); + + ConcurrentLinkedQueue uuidQueue = new ConcurrentLinkedQueue<>(); + + Observable.fromArray(arr) + .buffer(batchSize) + .concatMapCompletable( + batch -> + Observable.fromIterable(batch) + .flatMap( + input -> + Observable.fromCallable(() -> generateEmbeddings(input)) + .subscribeOn(Schedulers.io())) + .toList() + .flatMapCompletable( + wordEmbeddingsList -> + Completable.fromAction( + () -> upsertAndCollectIds(wordEmbeddingsList, uuidQueue)) + .subscribeOn(Schedulers.io()))) + .blockingAwait(); + + return new ArrayList<>(uuidQueue); + } + + private WordEmbeddings generateEmbeddings(String input) { + return postgresEndpoint + .getEmbeddingEndpoint() + .embeddings(input, arkRequest) + .firstOrError() + .blockingGet(); + } + + private void upsertAndCollectIds( + List wordEmbeddingsList, ConcurrentLinkedQueue uuidQueue) { + List batchUuidList = executeBatchUpsert(wordEmbeddingsList); + uuidQueue.addAll(batchUuidList); + } + + private List executeBatchUpsert(List wordEmbeddingsList) { + return this.postgresEndpoint.upsert(wordEmbeddingsList, filename, postgresLanguage).stream() + .map(StringResponse::getResponse) + .collect(Collectors.toList()); + } + + public List insertMetadata(String metadataTableName) { + + // Create Table... + this.postgresEndpoint.createMetadataTable(metadataTableName); + + ConcurrentLinkedQueue uuidQueue = new ConcurrentLinkedQueue<>(); + + CountDownLatch latch = new CountDownLatch(1); + + Observable.fromArray(arr) + .map(str -> str.replaceAll("'", "")) + .buffer(batchSize) + .flatMapCompletable( + metadataList -> + Completable.fromAction(() -> insertMetadataAndCollectIds(metadataList, uuidQueue))) + .blockingSubscribe(latch::countDown, error -> latch.countDown()); + + return new ArrayList<>(uuidQueue); + } + + public StringResponse insertOneMetadata( + String metadataTableName, String metadata, String documentDate) { + // Create Table... + this.postgresEndpoint.createMetadataTable(metadataTableName); + return this.postgresEndpoint.insertMetadata(metadataTableName, metadata, documentDate); + } + + private void insertMetadataAndCollectIds( + List metadataList, ConcurrentLinkedQueue uuidQueue) { + List batchUuidList = executeBatchInsertMetadata(metadataList); + uuidQueue.addAll(batchUuidList); + } + + private List executeBatchInsertMetadata(List metadataList) { + return this.postgresEndpoint.batchInsertMetadata(metadataList).stream() + .map(StringResponse::getResponse) + .collect(Collectors.toList()); + } + + public int getBatchSize() { + return batchSize; } - @Override - public void upsert(String input) { - - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) { - WordEmbeddings embeddings = - openAiEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); - this.postgresEndpoint.upsert( - embeddings, this.filename, this.dimensions, this.metric, this.lists); - } else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) { - WordEmbeddings embeddings = - miniLMEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); - this.postgresEndpoint.upsert( - embeddings, this.filename, this.dimensions, this.metric, this.lists); - } else if (endpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) { - WordEmbeddings embeddings = bgeSmallEndpoint.embeddings(input, arkRequest); - this.postgresEndpoint.upsert( - embeddings, this.filename, this.dimensions, this.metric, this.lists); - } else - throw new RuntimeException( - "Invalid Endpoint; Only OpenAIEndpoint & MiniLMEndpoint are supported"); + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java index 34d4c3d50..35f5d1732 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/RedisRetrieval.java @@ -1,55 +1,87 @@ package com.edgechain.lib.chains; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; import com.edgechain.lib.request.ArkRequest; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class RedisRetrieval extends Retrieval { - - private final Logger logger = LoggerFactory.getLogger(getClass()); +import java.util.List; +public class RedisRetrieval { private final RedisEndpoint redisEndpoint; private final ArkRequest arkRequest; - private final Endpoint endpoint; + private final String[] arr; private final int dimension; private final RedisDistanceMetric metric; + private int batchSize = 30; public RedisRetrieval( + String[] arr, RedisEndpoint redisEndpoint, - Endpoint endpoint, int dimension, RedisDistanceMetric metric, ArkRequest arkRequest) { this.redisEndpoint = redisEndpoint; - this.endpoint = endpoint; this.dimension = dimension; this.metric = metric; this.arkRequest = arkRequest; - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) + this.arr = arr; + + Logger logger = LoggerFactory.getLogger(getClass()); + if (redisEndpoint.getEmbeddingEndpoint() instanceof OpenAiEmbeddingEndpoint openAiEndpoint) logger.info("Using OpenAi Embedding Service: " + openAiEndpoint.getModel()); - else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) + else if (redisEndpoint.getEmbeddingEndpoint() instanceof MiniLMEndpoint miniLMEndpoint) logger.info(String.format("Using %s", miniLMEndpoint.getMiniLMModel().getName())); + else if (redisEndpoint.getEmbeddingEndpoint() instanceof BgeSmallEndpoint bgeSmallEndpoint) + logger.info(String.format("Using BgeSmall: " + bgeSmallEndpoint.getModelUrl())); + } + + public void upsert() { + + this.redisEndpoint.createIndex(redisEndpoint.getNamespace(), dimension, metric); + + Observable.fromArray(arr) + .buffer(batchSize) + .concatMapCompletable( + batch -> + Observable.fromIterable(batch) + .flatMap( + input -> + Observable.fromCallable(() -> generateEmbeddings(input)) + .subscribeOn(Schedulers.io())) + .toList() + .flatMapCompletable( + wordEmbeddingsList -> + Completable.fromAction(() -> executeBatchUpsert(wordEmbeddingsList)) + .subscribeOn(Schedulers.io()))) + .blockingAwait(); + } + + private WordEmbeddings generateEmbeddings(String input) { + return redisEndpoint + .getEmbeddingEndpoint() + .embeddings(input, arkRequest) + .firstOrError() + .blockingGet(); + } + + private void executeBatchUpsert(List wordEmbeddingsList) { + redisEndpoint.batchUpsert(wordEmbeddingsList); + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; } - @Override - public void upsert(String input) { - - if (endpoint instanceof OpenAiEndpoint openAiEndpoint) { - WordEmbeddings embeddings = - openAiEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); - this.redisEndpoint.upsert(embeddings, dimension, metric); - } else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) { - WordEmbeddings embeddings = - miniLMEndpoint.embeddings(input, arkRequest).firstOrError().blockingGet(); - this.redisEndpoint.upsert(embeddings, dimension, metric); - } else - throw new RuntimeException( - "Invalid Endpoint; Only OpenAIEndpoint & MiniLMEndpoint are supported"); + public int getBatchSize() { + return batchSize; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/Retrieval.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/Retrieval.java deleted file mode 100644 index 39916b97f..000000000 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/chains/Retrieval.java +++ /dev/null @@ -1,6 +0,0 @@ -package com.edgechain.lib.chains; - -public abstract class Retrieval { - - public abstract void upsert(String input); -} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/PostgreSQLConfiguration.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/PostgreSQLConfiguration.java index 60945770b..4a547916e 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/PostgreSQLConfiguration.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/PostgreSQLConfiguration.java @@ -1,5 +1,6 @@ package com.edgechain.lib.configuration; +import com.zaxxer.hikari.HikariDataSource; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.jdbc.DataSourceBuilder; import org.springframework.context.annotation.Bean; @@ -8,7 +9,6 @@ import org.springframework.jdbc.core.JdbcTemplate; import javax.sql.DataSource; -import java.util.Objects; @Configuration public class PostgreSQLConfiguration { @@ -16,34 +16,23 @@ public class PostgreSQLConfiguration { @Autowired private Environment env; @Bean - public DataSource dataSource() { + DataSource dataSource() { String dbHost = env.getProperty("postgres.db.host"); String dbUsername = env.getProperty("postgres.db.username"); String dbPassword = env.getProperty("postgres.db.password"); return DataSourceBuilder.create() + .type(HikariDataSource.class) .url(dbHost) .driverClassName("org.postgresql.Driver") .username(dbUsername) .password(dbPassword) .build(); - - // return DataSourceBuilder.create() - // .type(HikariDataSource.class) - // .url(dbHost) - // .driverClassName("org.postgresql.Driver") - // .username(dbUsername) - // .password(dbPassword) - // .build(); } @Bean - public JdbcTemplate jdbcTemplate() { + JdbcTemplate jdbcTemplate() { return new JdbcTemplate(dataSource()); } - - private boolean nonNullAndNotEmpty(String val) { - return Objects.nonNull(val) && val.trim().isEmpty(); - } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/RedisConfiguration.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/RedisConfiguration.java index 947c37cd9..d5c162cea 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/RedisConfiguration.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/RedisConfiguration.java @@ -21,7 +21,7 @@ public class RedisConfiguration { @Bean @Lazy - public JedisPooled jedisPooled() { + JedisPooled jedisPooled() { int port = 6379; String host = "127.0.0.1"; @@ -40,7 +40,7 @@ public JedisPooled jedisPooled() { @Bean @Lazy - public JedisConnectionFactory jedisConnectionFactory() { + JedisConnectionFactory jedisConnectionFactory() { int port = 6379; String host = "127.0.0.1"; @@ -63,7 +63,7 @@ public JedisConnectionFactory jedisConnectionFactory() { } @Bean - public RedisTemplate redisTemplate() { + RedisTemplate redisTemplate() { RedisTemplate redisTemplate = new RedisTemplate<>(); redisTemplate.setConnectionFactory(jedisConnectionFactory()); redisTemplate.setKeySerializer(new StringRedisSerializer()); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/WebConfiguration.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/WebConfiguration.java index 519730dc8..f1f580484 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/WebConfiguration.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/configuration/WebConfiguration.java @@ -3,6 +3,8 @@ import com.edgechain.lib.configuration.domain.AuthFilter; import com.edgechain.lib.configuration.domain.MethodAuthentication; import com.edgechain.lib.configuration.domain.SecurityUUID; +import java.util.List; +import java.util.UUID; import org.modelmapper.ModelMapper; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -10,9 +12,6 @@ import org.springframework.context.annotation.Primary; import org.springframework.web.client.RestTemplate; -import java.util.List; -import java.util.UUID; - @Configuration("WebConfiguration") @Import(EdgeChainAutoConfiguration.class) public class WebConfiguration { @@ -20,29 +19,29 @@ public class WebConfiguration { public static final String CONTEXT_PATH = "/edgechains"; @Bean - public ModelMapper modelMapper() { + ModelMapper modelMapper() { return new ModelMapper(); } @Bean - public RestTemplate restTemplate() { + RestTemplate restTemplate() { return new RestTemplate(); } @Bean @Primary - public SecurityUUID securityUUID() { + SecurityUUID securityUUID() { return new SecurityUUID(UUID.randomUUID().toString()); } @Bean - public AuthFilter authFilter() { + AuthFilter authFilter() { AuthFilter filter = new AuthFilter(); - filter.setRequestPost(new MethodAuthentication(List.of(""), "")); - filter.setRequestGet(new MethodAuthentication(List.of(""), "")); - filter.setRequestDelete(new MethodAuthentication(List.of(""), "")); - filter.setRequestPatch(new MethodAuthentication(List.of(""), "")); - filter.setRequestPut(new MethodAuthentication(List.of(""), "")); + filter.setRequestPost(new MethodAuthentication(List.of("**"), "")); + filter.setRequestGet(new MethodAuthentication(List.of("**"), "")); + filter.setRequestDelete(new MethodAuthentication(List.of("**"), "")); + filter.setRequestPatch(new MethodAuthentication(List.of("**"), "")); + filter.setRequestPut(new MethodAuthentication(List.of("**"), "")); return filter; } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java index 08b5e7538..e5ab71034 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClient.java @@ -3,9 +3,11 @@ import com.edgechain.lib.context.client.HistoryContextClient; import com.edgechain.lib.context.client.repositories.PostgreSQLHistoryContextRepository; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; +import java.time.LocalDateTime; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -13,14 +15,11 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.time.LocalDateTime; -import java.util.Objects; - @Service public class PostgreSQLHistoryContextClient implements HistoryContextClient { - private Logger logger = LoggerFactory.getLogger(this.getClass()); + private final Logger logger = LoggerFactory.getLogger(this.getClass()); @Autowired private PostgreSQLHistoryContextRepository historyContextRepository; @@ -28,6 +27,7 @@ public class PostgreSQLHistoryContextClient private static final String PREFIX = "historycontext:"; + @Transactional @Override public EdgeChain create(String id, PostgreSQLHistoryContextEndpoint endpoint) { return new EdgeChain<>( @@ -41,6 +41,11 @@ public EdgeChain create(String id, PostgreSQLHistoryContextEndpo this.createTable(); // Create Table IF NOT EXISTS; HistoryContext context = new HistoryContext(PREFIX + id, "", LocalDateTime.now()); + + if (logger.isInfoEnabled()) { + logger.info("{} is added", context.getId()); + } + emitter.onNext(historyContextRepository.save(context)); emitter.onComplete(); @@ -51,6 +56,7 @@ public EdgeChain create(String id, PostgreSQLHistoryContextEndpo endpoint); } + @Transactional @Override public EdgeChain put( String id, String response, PostgreSQLHistoryContextEndpoint endpoint) { @@ -60,11 +66,14 @@ public EdgeChain put( try { HistoryContext historyContext = this.get(id, null).get(); - String input = response.replaceAll("'", ""); + String input = response.replace("'", ""); historyContext.setResponse(input); HistoryContext returnValue = this.historyContextRepository.save(historyContext); - logger.info(String.format("%s is updated", id)); + + if (logger.isInfoEnabled()) { + logger.info("{} is updated", id); + } emitter.onNext(returnValue); emitter.onComplete(); @@ -76,20 +85,23 @@ public EdgeChain put( endpoint); } + @Transactional(readOnly = true) @Override public EdgeChain get(String id, PostgreSQLHistoryContextEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { - emitter.onNext( + final HistoryContext val = this.historyContextRepository .findById(id) .orElseThrow( () -> - new RuntimeException( - "PostgreSQL history_context id isn't found."))); + new RuntimeException("PostgreSQL history_context id isn't found.")); + + emitter.onNext(val); emitter.onComplete(); + } catch (final Exception e) { emitter.onError(e); } @@ -97,6 +109,7 @@ public EdgeChain get(String id, PostgreSQLHistoryContextEndpoint endpoint); } + @Transactional @Override public EdgeChain delete(String id, PostgreSQLHistoryContextEndpoint endpoint) { return new EdgeChain<>( @@ -107,8 +120,13 @@ public EdgeChain delete(String id, PostgreSQLHistoryContextEndpoint endp HistoryContext historyContext = this.get(id, null).get(); this.historyContextRepository.delete(historyContext); + if (logger.isInfoEnabled()) { + logger.info("{} is deleted", id); + } + emitter.onNext(""); emitter.onComplete(); + } catch (final Exception e) { emitter.onError(e); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java index 962839794..c12006cf2 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/context/client/impl/RedisHistoryContextClient.java @@ -1,10 +1,13 @@ package com.edgechain.lib.context.client.impl; -import com.edgechain.lib.context.domain.HistoryContext; import com.edgechain.lib.context.client.HistoryContextClient; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.context.domain.HistoryContext; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; +import java.time.LocalDateTime; +import java.util.Objects; +import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -13,25 +16,20 @@ import org.springframework.data.redis.core.RedisTemplate; import org.springframework.stereotype.Repository; -import java.time.LocalDateTime; -import java.util.Objects; -import java.util.concurrent.TimeUnit; - @Repository public class RedisHistoryContextClient implements HistoryContextClient { - private Logger logger = LoggerFactory.getLogger(this.getClass()); + private final Logger logger = LoggerFactory.getLogger(this.getClass()); private static final String PREFIX = "historycontext:"; - @Autowired private RedisTemplate redisTemplate; + @Autowired private RedisTemplate redisTemplate; @Autowired @Lazy private Environment env; @Override public EdgeChain create(String id, RedisHistoryContextEndpoint endpoint) { - return new EdgeChain<>( Observable.create( emitter -> { @@ -54,6 +52,10 @@ public EdgeChain create(String id, RedisHistoryContextEndpoint e this.redisTemplate.expire( key, Long.parseLong(env.getProperty("redis.ttl")), TimeUnit.SECONDS); + if (logger.isInfoEnabled()) { + logger.info("{} is added", key); + } + emitter.onNext(context); emitter.onComplete(); @@ -71,7 +73,6 @@ public EdgeChain put( Observable.create( emitter -> { try { - HistoryContext historyContext = this.get(key, null).get(); historyContext.setResponse(response); @@ -79,7 +80,9 @@ public EdgeChain put( this.redisTemplate.expire( key, Long.parseLong(env.getProperty("redis.ttl")), TimeUnit.SECONDS); - logger.info(String.format("%s is updated", key)); + if (logger.isInfoEnabled()) { + logger.info("{} is updated", key); + } emitter.onNext(historyContext); emitter.onComplete(); @@ -99,9 +102,11 @@ public EdgeChain get(String key, RedisHistoryContextEndpoint end try { Boolean b = this.redisTemplate.hasKey(key); if (Boolean.TRUE.equals(b)) { - emitter.onNext( - Objects.requireNonNull( - (HistoryContext) this.redisTemplate.opsForValue().get(key))); + + HistoryContext obj = (HistoryContext) this.redisTemplate.opsForValue().get(key); + Objects.requireNonNull(obj, "null value not allowed! key " + key); + + emitter.onNext(obj); emitter.onComplete(); } else { emitter.onError(new RuntimeException("Redis history_context id isn't found.")); @@ -116,13 +121,17 @@ public EdgeChain get(String key, RedisHistoryContextEndpoint end @Override public EdgeChain delete(String key, RedisHistoryContextEndpoint endpoint) { - return new EdgeChain<>( Observable.create( emitter -> { try { this.get(key, null).get(); this.redisTemplate.delete(key); + + if (logger.isInfoEnabled()) { + logger.info("{} is deleted", key); + } + emitter.onNext(""); emitter.onComplete(); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/EmbeddingLoggerController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/EmbeddingLoggerController.java index 2163bc918..821f27339 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/EmbeddingLoggerController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/EmbeddingLoggerController.java @@ -1,13 +1,15 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.logger.EmbeddingLogger; import com.edgechain.lib.logger.EmbeddingLogger; import com.edgechain.lib.logger.entities.EmbeddingLog; import java.util.HashMap; -import java.util.Objects; - import org.springframework.data.domain.Page; -import org.springframework.web.bind.annotation.*; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; @RestController @RequestMapping("/v1/logs/embeddings") @@ -16,8 +18,10 @@ public class EmbeddingLoggerController { private EmbeddingLogger embeddingLogger; private EmbeddingLogger getInstance() { - if (Objects.isNull(embeddingLogger)) return embeddingLogger = new EmbeddingLogger(); - else return embeddingLogger; + if (embeddingLogger == null) { + embeddingLogger = new EmbeddingLogger(); + } + return embeddingLogger; } @GetMapping("/findAll/{page}/{size}") diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/JsonnetLoggerController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/JsonnetLoggerController.java new file mode 100644 index 000000000..8aab8ec9a --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/JsonnetLoggerController.java @@ -0,0 +1,39 @@ +package com.edgechain.lib.controllers; + +import com.edgechain.lib.logger.JsonnetLogger; +import java.util.HashMap; +import java.util.Objects; + +import com.edgechain.lib.logger.entities.JsonnetLog; +import org.springframework.data.domain.Page; +import org.springframework.web.bind.annotation.*; + +@RestController +@RequestMapping("/v1/logs/jsonnet") +public class JsonnetLoggerController { + + private JsonnetLogger jsonnetLogger; + + private JsonnetLogger getInstance() { + if (Objects.isNull(jsonnetLogger)) return jsonnetLogger = new JsonnetLogger(); + else return jsonnetLogger; + } + + @GetMapping("/findAll/{page}/{size}") + public Page findAll(@PathVariable int page, @PathVariable int size) { + return getInstance().findAll(page, size); + } + + @GetMapping("/findAll/sorted/{page}/{size}") + public Page findAllOrderByCompletedAtDesc( + @PathVariable int page, @PathVariable int size) { + return getInstance().findAllOrderByCreatedAtDesc(page, size); + } + + @PostMapping("/findByName/{page}/{size}") + public Page findAllBySelectedFileOrderByCreatedAtDesc( + @RequestBody HashMap mapper, @PathVariable int page, @PathVariable int size) { + return getInstance() + .findAllBySelectedFileOrderByCreatedAtDesc(mapper.get("filename"), page, size); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java index 1a44017c9..470ec6de3 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PgHistoryContextController.java @@ -1,7 +1,7 @@ package com.edgechain.lib.controllers; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.impl.FixedDelay; import org.json.JSONObject; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java index 166bc7b6f..0352f04a0 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/PostgresController.java @@ -1,6 +1,6 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; @@ -27,13 +27,8 @@ private PostgresEndpoint getInstance() { @DeleteMapping("/deleteAll") public StringResponse deletePostgres(ArkRequest arkRequest) { - String table = arkRequest.getQueryParam("table"); String namespace = arkRequest.getQueryParam("namespace"); - - getInstance().setTableName(table); - getInstance().setNamespace(namespace); - - return getInstance().deleteAll(); + return getInstance().deleteAll(table, namespace); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java index 950855fc2..42658fddb 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisController.java @@ -1,6 +1,6 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; import org.springframework.web.bind.annotation.DeleteMapping; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java index 9f956768a..d9cc61f0a 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/RedisHistoryContextController.java @@ -1,7 +1,7 @@ package com.edgechain.lib.controllers; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.rxjava.retry.impl.FixedDelay; import org.json.JSONObject; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java index 557cdb394..f81cd3d1a 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/controllers/SupabaseController.java @@ -1,6 +1,6 @@ package com.edgechain.lib.controllers; -import com.edgechain.lib.endpoint.impl.SupabaseEndpoint; +import com.edgechain.lib.endpoint.impl.supabase.SupabaseEndpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.supabase.response.AuthenticatedResponse; import com.edgechain.lib.supabase.response.SupabaseUser; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java index e5cb0101b..46b5f32f9 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/WordEmbeddings.java @@ -12,7 +12,7 @@ public class WordEmbeddings implements ArkObject, Serializable { private static final long serialVersionUID = 2210956496609994219L; private String id; private List values; - private String score; + private Double score; public WordEmbeddings() {} @@ -26,13 +26,13 @@ public WordEmbeddings(String id, List values) { this.values = values; } - public WordEmbeddings(String id, List values, String score) { + public WordEmbeddings(String id, List values, Double score) { this.id = id; this.values = values; this.score = score; } - public WordEmbeddings(String id, String score) { + public WordEmbeddings(String id, Double score) { this.id = id; this.score = score; } @@ -49,7 +49,7 @@ public void setValues(List values) { this.values = values; } - public String getScore() { + public Double getScore() { return score; } @@ -57,7 +57,7 @@ public void setId(String id) { this.id = id; } - public void setScore(String score) { + public void setScore(Double score) { this.score = score; } @@ -69,9 +69,34 @@ public String toString() { @Override public JSONObject toJson() { JSONObject json = new JSONObject(); - json.put("id", id); - json.put("values", new JSONArray(values)); - json.put("score", score); + + if (id != null) { + json.put("id", id); + } + + if (values != null) { + json.put("values", new JSONArray(values)); + } + + if (score != null) { + json.put("score", score); + } + return json; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + WordEmbeddings that = (WordEmbeddings) o; + + return id.equals(that.id); + } + + @Override + public int hashCode() { + return id.hashCode(); + } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java index 905ff134f..6cc92c672 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/bgeSmall/BgeSmallClient.java @@ -16,33 +16,24 @@ import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; -import org.springframework.stereotype.Service; - import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.LinkedList; import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; @Service public class BgeSmallClient { - private BgeSmallEndpoint endpoint; - private static volatile ZooModel bgeSmallEn; - public BgeSmallEndpoint getEndpoint() { - return endpoint; - } - - public void setEndpoint(BgeSmallEndpoint endpoint) { - this.endpoint = endpoint; - } - - public EdgeChain createEmbeddings(String input) { + public EdgeChain createEmbeddings(String input, BgeSmallEndpoint endpoint) { return new EdgeChain<>( Observable.create( @@ -69,19 +60,23 @@ private ZooModel loadSmallBgeEn() throws IOException { ZooModel r = bgeSmallEn; if (r == null) { + final Logger logger = LoggerFactory.getLogger(BgeSmallEndpoint.class); synchronized (this) { r = bgeSmallEn; if (r == null) { - Path path = Paths.get("./model"); + logger.info("Creating tokenizer"); + Path path = Paths.get(BgeSmallEndpoint.MODEL_FOLDER); HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder() .optTokenizerPath(path) .optManager(NDManager.newBaseManager("PyTorch")) .build(); + logger.info("Creating translator"); MyTextEmbeddingTranslator translator = new MyTextEmbeddingTranslator(tokenizer, Batchifier.STACK, "cls", true, true); + logger.info("Loading criteria"); Criteria criteria = Criteria.builder() .setTypes(String.class, float[].class) @@ -94,7 +89,7 @@ private ZooModel loadSmallBgeEn() throws IOException { r = criteria.loadModel(); bgeSmallEn = r; } catch (IOException | ModelNotFoundException | MalformedModelException e) { - e.printStackTrace(); + logger.error("Failed to load model", e); throw new RuntimeException(e); } } @@ -162,8 +157,9 @@ static NDArray processEmbedding( embedding = list.head(); } long[] attentionMask = encoding.getAttentionMask(); - try (NDManager ptManager = NDManager.newBaseManager("PyTorch")) { - NDArray inputAttentionMask = ptManager.create(attentionMask).toType(DataType.FLOAT32, true); + try (NDManager ptManager = NDManager.newBaseManager("PyTorch"); + NDArray array = ptManager.create(attentionMask)) { + NDArray inputAttentionMask = array.toType(DataType.FLOAT32, true); switch (pooling) { case "mean": return meanPool(embedding, inputAttentionMask, false); @@ -206,7 +202,7 @@ private static NDArray maxPool(NDArray embeddings, NDArray inputAttentionMask) { private static NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMask) { long[] shape = embeddings.getShape().getShape(); - NDArray weight = embeddings.getManager().arange(1, shape[0] + 1); + NDArray weight = embeddings.getManager().arange(1f, shape[0] + 1f); weight = weight.expandDims(-1).broadcast(shape); attentionMask = attentionMask.expandDims(-1).broadcast(shape).mul(weight); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java index b3c958995..b12ec6af0 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/miniLLM/MiniLMClient.java @@ -9,7 +9,7 @@ import ai.djl.training.util.ProgressBar; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; import com.edgechain.lib.embeddings.miniLLM.response.MiniLMResponse; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Observable; import org.springframework.stereotype.Service; @@ -21,8 +21,6 @@ @Service public class MiniLMClient { - private MiniLMEndpoint endpoint; - private static volatile ZooModel allMiniL6V2; private static volatile ZooModel allMiniL12V2; @@ -30,25 +28,17 @@ public class MiniLMClient { private static volatile ZooModel multiQAMiniLML6CosV1; - public MiniLMEndpoint getEndpoint() { - return endpoint; - } - - public void setEndpoint(MiniLMEndpoint endpoint) { - this.endpoint = endpoint; - } - - public EdgeChain createEmbeddings(String input) { + public EdgeChain createEmbeddings(String input, MiniLMEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { - if (this.endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L6_V2)) { + if (endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L6_V2)) { Predictor predictor = - loadAllMiniL6V2(this.endpoint.getMiniLMModel()).newPredictor(); + loadAllMiniL6V2(endpoint.getMiniLMModel()).newPredictor(); float[] predict = predictor.predict(input); @@ -59,10 +49,10 @@ public EdgeChain createEmbeddings(String input) { emitter.onNext(new MiniLMResponse(floatList)); emitter.onComplete(); - } else if (this.endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L12_V2)) { + } else if (endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L12_V2)) { Predictor predictor = - loadAllMiniL12V2(this.endpoint.getMiniLMModel()).newPredictor(); + loadAllMiniL12V2(endpoint.getMiniLMModel()).newPredictor(); float[] predict = predictor.predict(input); @@ -73,11 +63,9 @@ public EdgeChain createEmbeddings(String input) { emitter.onNext(new MiniLMResponse(floatList)); emitter.onComplete(); - } else if (this.endpoint - .getMiniLMModel() - .equals(MiniLMModel.PARAPHRASE_MINILM_L3_V2)) { + } else if (endpoint.getMiniLMModel().equals(MiniLMModel.PARAPHRASE_MINILM_L3_V2)) { Predictor predictor = - loadParaphraseMiniLML3v2(this.endpoint.getMiniLMModel()).newPredictor(); + loadParaphraseMiniLML3v2(endpoint.getMiniLMModel()).newPredictor(); float[] predict = predictor.predict(input); @@ -92,7 +80,7 @@ public EdgeChain createEmbeddings(String input) { System.out.println("d"); ZooModel zooModel = - loadMultiQAMiniLML6CosV1(this.endpoint.getMiniLMModel()); + loadMultiQAMiniLML6CosV1(endpoint.getMiniLMModel()); Predictor predictor = zooModel.newPredictor(); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/request/OpenAiEmbeddingRequest.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/request/OpenAiEmbeddingRequest.java index b7143cc25..b460fbfa9 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/request/OpenAiEmbeddingRequest.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/request/OpenAiEmbeddingRequest.java @@ -4,6 +4,8 @@ public class OpenAiEmbeddingRequest { private String input; private String model; + public OpenAiEmbeddingRequest() {} + public OpenAiEmbeddingRequest(String model, String input) { this.model = model; this.input = input; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbedding.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbedding.java index dfcf38cc0..1a946e054 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbedding.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbedding.java @@ -1,6 +1,7 @@ package com.edgechain.lib.embeddings.response; import com.edgechain.lib.response.ArkObject; +import org.json.JSONArray; import org.json.JSONObject; import java.util.List; @@ -50,10 +51,20 @@ public String toString() { @Override public JSONObject toJson() { - JSONObject jsonObject = new JSONObject(); - jsonObject.put("object", object); - jsonObject.put("embedding", embedding); - jsonObject.put("index", index); - return jsonObject; + JSONObject json = new JSONObject(); + + if (object != null) { + json.put("object", object); + } + + if (embedding != null) { + json.put("embedding", new JSONArray(embedding)); + } + + if (index != null) { + json.put("index", index); + } + + return json; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbeddingResponse.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbeddingResponse.java index ebf52ad5e..0d65ce0f5 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbeddingResponse.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/OpenAiEmbeddingResponse.java @@ -63,11 +63,24 @@ public String toString() { @Override public JSONObject toJson() { - JSONObject jsonObject = new JSONObject(); - jsonObject.put("model", model); - jsonObject.put("object", object); - jsonObject.put("data", data.stream().map(OpenAiEmbedding::toJson).collect(Collectors.toList())); - jsonObject.put("usage", usage.toJson()); - return jsonObject; + JSONObject json = new JSONObject(); + + if (model != null) { + json.put("model", model); + } + + if (object != null) { + json.put("object", object); + } + + if (data != null) { + json.put("data", data.stream().map(OpenAiEmbedding::toJson).collect(Collectors.toList())); + } + + if (usage != null) { + json.put("usage", usage.toJson()); // Assuming Usage has a toJson method + } + + return json; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/Usage.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/Usage.java index 28b4914b8..c30c10bbc 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/Usage.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/embeddings/response/Usage.java @@ -36,8 +36,15 @@ public String toString() { @Override public JSONObject toJson() { JSONObject json = new JSONObject(); - json.put("prompt_tokens", prompt_tokens); - json.put("total_tokens", total_tokens); + + if (prompt_tokens != 0L) { + json.put("prompt_tokens", prompt_tokens); + } + + if (total_tokens != 0L) { + json.put("total_tokens", total_tokens); + } + return json; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java index 5128f9894..2bdc54901 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/Endpoint.java @@ -36,6 +36,14 @@ public Endpoint(String url, String apiKey, RetryPolicy retryPolicy) { this.retryPolicy = retryPolicy; } + public void setUrl(String url) { + this.url = url; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + public String getApiKey() { return this.apiKey; } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java deleted file mode 100644 index 0ab920b58..000000000 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpoint.java +++ /dev/null @@ -1,118 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.request.ArkRequest; -import com.edgechain.lib.retrofit.BgeSmallService; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import retrofit2.Retrofit; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.net.URL; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; -import java.util.Objects; - -public class BgeSmallEndpoint extends Endpoint { - - private Logger logger = LoggerFactory.getLogger(BgeSmallEndpoint.class); - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final BgeSmallService bgeSmallService = retrofit.create(BgeSmallService.class); - - private String input; - - private String modelUrl; - private String tokenizerUrl; - - private String callIdentifier; - private final String MODEL_PATH = "./model/model.onnx"; - private final String TOKENIZER_PATH = "./model/tokenizer.json"; - private final String MODEL_FOLDER = "./model"; - - public BgeSmallEndpoint() {} - - public BgeSmallEndpoint(String modelUrl, String tokenizerUrl) { - this.modelUrl = modelUrl; - this.tokenizerUrl = tokenizerUrl; - - logger.info("Downloading bge-small-en model. Please wait..."); - File modelFile = new File(MODEL_PATH); - File tokenizerFile = new File(TOKENIZER_PATH); - - // check if the file already exists - if (!modelFile.exists()) downloadFile(modelUrl, MODEL_PATH); - if (!tokenizerFile.exists()) downloadFile(tokenizerUrl, TOKENIZER_PATH); - logger.info("Model downloaded successfully!"); - } - - public String getModelUrl() { - return modelUrl; - } - - public String getTokenizerUrl() { - return tokenizerUrl; - } - - public String getInput() { - return input; - } - - public String getCallIdentifier() { - return callIdentifier; - } - - public BgeSmallEndpoint(RetryPolicy retryPolicy, String modelUrl, String tokenizerUrl) { - super(retryPolicy); - this.modelUrl = modelUrl; - this.tokenizerUrl = tokenizerUrl; - } - - public WordEmbeddings embeddings(String input, ArkRequest arkRequest) { - - this.input = input; // set Input - - if (Objects.nonNull(arkRequest)) { - this.callIdentifier = arkRequest.getRequestURI(); - } - - return bgeSmallService - .embeddings(this) - .map(m -> new WordEmbeddings(input, m.getEmbedding())) - .blockingGet(); - } - - private void downloadFile(String urlStr, String path) { - - File modelFolderFile = new File(MODEL_FOLDER); - - if (!modelFolderFile.exists()) { - modelFolderFile.mkdir(); - } - - ReadableByteChannel rbc = null; - FileOutputStream fos = null; - try { - URL url = new URL(urlStr); - rbc = Channels.newChannel(url.openStream()); - fos = new FileOutputStream(path); - fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); - } catch (IOException e) { - logger.info("Error downloading model"); - e.printStackTrace(); - } finally { - assert fos != null; - try { - fos.close(); - rbc.close(); - } catch (IOException e) { - e.printStackTrace(); - } - } - } -} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java deleted file mode 100644 index 5d2938906..000000000 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PineconeEndpoint.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.retrofit.PineconeService; -import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; -import retrofit2.Retrofit; - -import java.util.List; - -public class PineconeEndpoint extends Endpoint { - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final PineconeService pineconeService = retrofit.create(PineconeService.class); - - private String namespace; - - // Getters; - private WordEmbeddings wordEmbeddings; - - private int topK; - - public PineconeEndpoint() {} - - public PineconeEndpoint(String namespace) { - this.namespace = namespace; - } - - public PineconeEndpoint(String url, String apiKey) { - super(url, apiKey); - } - - public PineconeEndpoint(String url, String apiKey, RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); - } - - public PineconeEndpoint(String url, String apiKey, String namespace) { - super(url, apiKey); - this.namespace = namespace; - } - - public PineconeEndpoint(String url, String apiKey, String namespace, RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); - this.namespace = namespace; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - // Getters - - public WordEmbeddings getWordEmbeddings() { - return wordEmbeddings; - } - - public void setWordEmbeddings(WordEmbeddings wordEmbeddings) { - this.wordEmbeddings = wordEmbeddings; - } - - public int getTopK() { - return topK; - } - - public void setTopK(int topK) { - this.topK = topK; - } - - public StringResponse upsert(WordEmbeddings wordEmbeddings) { - this.wordEmbeddings = wordEmbeddings; - return this.pineconeService.upsert(this).blockingGet(); - } - - public Observable> query(WordEmbeddings wordEmbeddings, int topK) { - this.wordEmbeddings = wordEmbeddings; - this.topK = topK; - return Observable.fromSingle(this.pineconeService.query(this)); - } - - public StringResponse deleteAll() { - return this.pineconeService.deleteAll(this).blockingGet(); - } -} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java deleted file mode 100644 index fc8443d28..000000000 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgresEndpoint.java +++ /dev/null @@ -1,135 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.index.domain.PostgresWordEmbeddings; -import com.edgechain.lib.index.enums.PostgresDistanceMetric; -import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.retrofit.PostgresService; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; -import retrofit2.Retrofit; - -import java.util.List; - -public class PostgresEndpoint extends Endpoint { - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final PostgresService postgresService = retrofit.create(PostgresService.class); - - private String tableName; - - private int lists; - - private String namespace; - - private String filename; - - // Getters - private WordEmbeddings wordEmbeddings; - private PostgresDistanceMetric metric; - private int dimensions; - private int topK; - - private int probes; - - public PostgresEndpoint() {} - - public PostgresEndpoint(RetryPolicy retryPolicy) { - super(retryPolicy); - } - - public PostgresEndpoint(String tableName) { - this.tableName = tableName; - } - - public PostgresEndpoint(String tableName, RetryPolicy retryPolicy) { - super(retryPolicy); - this.tableName = tableName; - } - - public String getTableName() { - return tableName; - } - - public String getNamespace() { - return namespace; - } - - public void setTableName(String tableName) { - this.tableName = tableName; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - // Getters - - public WordEmbeddings getWordEmbeddings() { - return wordEmbeddings; - } - - public int getDimensions() { - return dimensions; - } - - public int getTopK() { - return topK; - } - - public String getFilename() { - return filename; - } - - public PostgresDistanceMetric getMetric() { - return metric; - } - - public int getLists() { - return lists; - } - - public int getProbes() { - return probes; - } - - // Convenience Methods - - public StringResponse upsert( - WordEmbeddings wordEmbeddings, - String filename, - int dimension, - PostgresDistanceMetric metric, - int lists) { - this.wordEmbeddings = wordEmbeddings; - this.dimensions = dimension; - this.filename = filename; - this.metric = metric; - this.lists = lists; - return this.postgresService.upsert(this).blockingGet(); - } - - public Observable> query( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) { - this.wordEmbeddings = wordEmbeddings; - this.topK = topK; - this.metric = metric; - this.probes = 1; - return Observable.fromSingle(this.postgresService.query(this)); - } - - public Observable> query( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK, int probes) { - this.wordEmbeddings = wordEmbeddings; - this.topK = topK; - this.metric = metric; - this.probes = probes; - return Observable.fromSingle(this.postgresService.query(this)); - } - - public StringResponse deleteAll() { - return this.postgresService.deleteAll(this).blockingGet(); - } -} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java deleted file mode 100644 index 1dad38b37..000000000 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisEndpoint.java +++ /dev/null @@ -1,126 +0,0 @@ -package com.edgechain.lib.endpoint.impl; - -import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.Endpoint; -import com.edgechain.lib.retrofit.RedisService; -import com.edgechain.lib.index.enums.RedisDistanceMetric; -import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.retrofit.client.RetrofitClientInstance; -import com.edgechain.lib.rxjava.retry.RetryPolicy; -import io.reactivex.rxjava3.core.Observable; -import retrofit2.Retrofit; -import java.util.List; - -public class RedisEndpoint extends Endpoint { - - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final RedisService redisService = retrofit.create(RedisService.class); - - private String indexName; - private String namespace; - - // Getters; - private WordEmbeddings wordEmbeddings; - - private int dimensions; - - private RedisDistanceMetric metric; - - private int topK; - - public RedisEndpoint() {} - - public RedisEndpoint(RetryPolicy retryPolicy) { - super(retryPolicy); - } - - public RedisEndpoint(String indexName) { - this.indexName = indexName; - } - - public RedisEndpoint(String indexName, RetryPolicy retryPolicy) { - super(retryPolicy); - this.indexName = indexName; - } - - public RedisEndpoint(String indexName, String namespace) { - this.indexName = indexName; - this.namespace = namespace; - } - - public RedisEndpoint(String indexName, String namespace, RetryPolicy retryPolicy) { - super(retryPolicy); - this.indexName = indexName; - this.namespace = namespace; - } - - public String getIndexName() { - return indexName; - } - - public void setIndexName(String indexName) { - this.indexName = indexName; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - // Getters - public WordEmbeddings getWordEmbeddings() { - return wordEmbeddings; - } - - public void setWordEmbeddings(WordEmbeddings wordEmbeddings) { - this.wordEmbeddings = wordEmbeddings; - } - - public int getDimensions() { - return dimensions; - } - - public void setDimensions(int dimensions) { - this.dimensions = dimensions; - } - - public RedisDistanceMetric getMetric() { - return metric; - } - - public void setMetric(RedisDistanceMetric metric) { - this.metric = metric; - } - - public int getTopK() { - return topK; - } - - public void setTopK(int topK) { - this.topK = topK; - } - - // Convenience Methods - public StringResponse upsert( - WordEmbeddings wordEmbeddings, int dimension, RedisDistanceMetric metric) { - - this.wordEmbeddings = wordEmbeddings; - this.dimensions = dimension; - this.metric = metric; - - return this.redisService.upsert(this).blockingGet(); - } - - public Observable> query(WordEmbeddings embeddings, int topK) { - this.topK = topK; - this.wordEmbeddings = embeddings; - return Observable.fromSingle(this.redisService.query(this)); - } - - public void delete(String patternName) { - this.redisService.deleteByPattern(patternName, this).blockingAwait(); - } -} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgreSQLHistoryContextEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/PostgreSQLHistoryContextEndpoint.java similarity index 96% rename from Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgreSQLHistoryContextEndpoint.java rename to Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/PostgreSQLHistoryContextEndpoint.java index 2fb095412..c99a331da 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/PostgreSQLHistoryContextEndpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/PostgreSQLHistoryContextEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.context; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisHistoryContextEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/RedisHistoryContextEndpoint.java similarity index 96% rename from Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisHistoryContextEndpoint.java rename to Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/RedisHistoryContextEndpoint.java index 1a420722c..9dcfcfb75 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/RedisHistoryContextEndpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/context/RedisHistoryContextEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.context; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java new file mode 100644 index 000000000..7d115898e --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/BgeSmallEndpoint.java @@ -0,0 +1,120 @@ +package com.edgechain.lib.endpoint.impl.embeddings; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.BgeSmallService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import io.reactivex.rxjava3.core.Observable; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.Objects; + +import org.modelmapper.ModelMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class BgeSmallEndpoint extends EmbeddingEndpoint { + + private static final Logger logger = LoggerFactory.getLogger(BgeSmallEndpoint.class); + + private final BgeSmallService bgeSmallService = + RetrofitClientInstance.getInstance().create(BgeSmallService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String modelUrl; + private String tokenizerUrl; + + public static final String MODEL_FOLDER = "./model"; + public static final String MODEL_PATH = MODEL_FOLDER + "/model.onnx"; + public static final String TOKENIZER_PATH = MODEL_FOLDER + "/tokenizer.json"; + + public BgeSmallEndpoint() {} + + public BgeSmallEndpoint(String modelUrl, String tokenizerUrl) { + this.modelUrl = modelUrl; + this.tokenizerUrl = tokenizerUrl; + + File modelFile = new File(MODEL_PATH); + if (!modelFile.exists()) { + logger.info( + "Downloading bge-small-en model from {} to {}. Please wait...", + modelUrl, + modelFile.getAbsolutePath()); + downloadFile(modelUrl, MODEL_PATH); + } + + File tokenizerFile = new File(TOKENIZER_PATH); + if (!tokenizerFile.exists()) { + logger.info( + "Downloading bge-small-en tokenizer from {} to {}. Please wait...", + tokenizerUrl, + tokenizerFile.getAbsolutePath()); + downloadFile(tokenizerUrl, TOKENIZER_PATH); + } + + logger.info("Model downloaded successfully!"); + } + + public String getModelUrl() { + return modelUrl; + } + + public String getTokenizerUrl() { + return tokenizerUrl; + } + + public void setModelUrl(String modelUrl) { + this.modelUrl = modelUrl; + } + + public void setTokenizerUrl(String tokenizerUrl) { + this.tokenizerUrl = tokenizerUrl; + } + + public BgeSmallEndpoint(RetryPolicy retryPolicy, String modelUrl, String tokenizerUrl) { + super(retryPolicy); + this.modelUrl = modelUrl; + this.tokenizerUrl = tokenizerUrl; + } + + @Override + public Observable embeddings(String input, ArkRequest arkRequest) { + BgeSmallEndpoint mapper = modelMapper.map(this, BgeSmallEndpoint.class); + mapper.setRawText(input); + + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); + + return Observable.fromSingle( + bgeSmallService.embeddings(mapper).map(m -> new WordEmbeddings(input, m.getEmbedding()))); + } + + private void downloadFile(String urlStr, String path) { + + File modelFolderFile = new File(MODEL_FOLDER); + + if (!modelFolderFile.exists()) { + logger.info("Creating directory {}", MODEL_FOLDER); + modelFolderFile.mkdir(); + } + + try { + URL url = new URL(urlStr); + try (InputStream is = url.openStream(); + ReadableByteChannel rbc = Channels.newChannel(is); + FileOutputStream fos = new FileOutputStream(path)) { + long transferred = fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); + logger.info("Downloaded {} bytes", transferred); + } + } catch (IOException e) { + logger.error("Error downloading model", e); + } + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java new file mode 100644 index 000000000..609332c4b --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/EmbeddingEndpoint.java @@ -0,0 +1,64 @@ +package com.edgechain.lib.endpoint.impl.embeddings; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.reactivex.rxjava3.core.Observable; + +import java.io.Serializable; + +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = OpenAiEmbeddingEndpoint.class, name = "type1"), + @JsonSubTypes.Type(value = MiniLMEndpoint.class, name = "type2"), + @JsonSubTypes.Type(value = BgeSmallEndpoint.class, name = "type3"), +}) +public abstract class EmbeddingEndpoint extends Endpoint implements Serializable { + + private static final long serialVersionUID = 4201794264326630184L; + private String callIdentifier; + private String rawText; + + public EmbeddingEndpoint() {} + + public EmbeddingEndpoint(RetryPolicy retryPolicy) { + super(retryPolicy); + } + + public EmbeddingEndpoint(String url) { + super(url); + } + + public EmbeddingEndpoint(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + } + + public EmbeddingEndpoint(String url, String apiKey) { + super(url, apiKey); + } + + public EmbeddingEndpoint(String url, String apiKey, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + } + + public abstract Observable embeddings(String input, ArkRequest arkRequest); + + public void setRawText(String rawText) { + this.rawText = rawText; + } + + public void setCallIdentifier(String callIdentifier) { + this.callIdentifier = callIdentifier; + } + + public String getRawText() { + return rawText; + } + + public String getCallIdentifier() { + return callIdentifier; + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java similarity index 62% rename from Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java rename to Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java index f50f247d4..840cfac2c 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/MiniLMEndpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/MiniLMEndpoint.java @@ -1,8 +1,7 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.embeddings; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.embeddings.miniLLM.enums.MiniLMModel; -import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.MiniLMService; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; @@ -10,55 +9,46 @@ import java.util.Objects; import io.reactivex.rxjava3.core.Observable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.modelmapper.ModelMapper; import retrofit2.Retrofit; -public class MiniLMEndpoint extends Endpoint { - - private Logger logger = LoggerFactory.getLogger(MiniLMEndpoint.class); +public class MiniLMEndpoint extends EmbeddingEndpoint { private final Retrofit retrofit = RetrofitClientInstance.getInstance(); private final MiniLMService miniLMService = retrofit.create(MiniLMService.class); - - private String input; + private ModelMapper modelMapper = new ModelMapper(); private MiniLMModel miniLMModel; - private String callIdentifier; - public MiniLMEndpoint() {} public MiniLMEndpoint(MiniLMModel miniLMModel) { this.miniLMModel = miniLMModel; } - public String getInput() { - return input; + public void setMiniLMModel(MiniLMModel miniLMModel) { + this.miniLMModel = miniLMModel; } public MiniLMModel getMiniLMModel() { return miniLMModel; } - public String getCallIdentifier() { - return callIdentifier; - } - public MiniLMEndpoint(RetryPolicy retryPolicy, MiniLMModel miniLMModel) { super(retryPolicy); this.miniLMModel = miniLMModel; } + @Override public Observable embeddings(String input, ArkRequest arkRequest) { - this.input = input; // set Input + MiniLMEndpoint mapper = modelMapper.map(this, MiniLMEndpoint.class); + mapper.setRawText(input); - if (Objects.nonNull(arkRequest)) { - this.callIdentifier = arkRequest.getRequestURI(); - } + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); return Observable.fromSingle( - miniLMService.embeddings(this).map(m -> new WordEmbeddings(input, m.getEmbedding()))); + miniLMService.embeddings(mapper).map(m -> new WordEmbeddings(input, m.getEmbedding()))); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java new file mode 100644 index 000000000..6dc9dea4f --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/embeddings/OpenAiEmbeddingEndpoint.java @@ -0,0 +1,71 @@ +package com.edgechain.lib.endpoint.impl.embeddings; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.OpenAiService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.Objects; + +public class OpenAiEmbeddingEndpoint extends EmbeddingEndpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final OpenAiService openAiService = retrofit.create(OpenAiService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String orgId; + private String model; + + public OpenAiEmbeddingEndpoint() {} + + public OpenAiEmbeddingEndpoint(String url, String apiKey, String orgId, String model) { + super(url, apiKey); + this.orgId = orgId; + this.model = model; + } + + public OpenAiEmbeddingEndpoint( + String url, String apiKey, String orgId, String model, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + this.orgId = orgId; + this.model = model; + } + + public String getModel() { + return model; + } + + public String getOrgId() { + return orgId; + } + + public void setOrgId(String orgId) { + this.orgId = orgId; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Observable embeddings(String input, ArkRequest arkRequest) { + + OpenAiEmbeddingEndpoint mapper = modelMapper.map(this, OpenAiEmbeddingEndpoint.class); + mapper.setRawText(input); + + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); + + return Observable.fromSingle( + openAiService + .embeddings(mapper) + .map( + embeddingResponse -> + new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java new file mode 100644 index 000000000..366a06992 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PineconeEndpoint.java @@ -0,0 +1,162 @@ +package com.edgechain.lib.endpoint.impl.index; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.PineconeService; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; + +public class PineconeEndpoint extends Endpoint { + + private static final String QUERY_API = "/query"; + private static final String UPSERT_API = "/vectors/upsert"; + private static final String DELETE_API = "/vectors/delete"; + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final PineconeService pineconeService = retrofit.create(PineconeService.class); + private ModelMapper modelMapper = new ModelMapper(); + + private String originalUrl; + private String namespace; + + // Getters; + private WordEmbeddings wordEmbedding; + + private List wordEmbeddingsList; + + private int topK; + + private EmbeddingEndpoint embeddingEndpoint; + + public PineconeEndpoint() {} + + public PineconeEndpoint(String url, String apiKey, EmbeddingEndpoint embeddingEndpoint) { + super(url, apiKey); + this.originalUrl = url; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PineconeEndpoint( + String url, String apiKey, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + this.embeddingEndpoint = embeddingEndpoint; + this.originalUrl = url; + } + + public PineconeEndpoint( + String url, String apiKey, String namespace, EmbeddingEndpoint embeddingEndpoint) { + super(url, apiKey); + this.originalUrl = url; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PineconeEndpoint( + String url, + String apiKey, + String namespace, + EmbeddingEndpoint embeddingEndpoint, + RetryPolicy retryPolicy) { + super(url, apiKey, retryPolicy); + this.originalUrl = url; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public String getNamespace() { + return namespace; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + // Getters + + public void setOriginalUrl(String originalUrl) { + this.originalUrl = originalUrl; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } + + private void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + private void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + public int getTopK() { + return topK; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public String getOriginalUrl() { + return originalUrl; + } + + private void setTopK(int topK) { + this.topK = topK; + } + + public StringResponse upsert(WordEmbeddings wordEmbedding, String namespace) { + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setWordEmbedding(wordEmbedding); + mapper.setUrl(mapper.getOriginalUrl().concat(UPSERT_API)); + mapper.setNamespace(namespace); + + return this.pineconeService.upsert(mapper).blockingGet(); + } + + public StringResponse batchUpsert(List wordEmbeddingsList, String namespace) { + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + mapper.setUrl(mapper.getOriginalUrl().concat(UPSERT_API)); + mapper.setNamespace(namespace); + + return this.pineconeService.batchUpsert(mapper).blockingGet(); + } + + public Observable> query( + String query, String namespace, int topK, ArkRequest arkRequest) { + WordEmbeddings wordEmbeddings = + new EdgeChain<>(getEmbeddingEndpoint().embeddings(query, arkRequest)).get(); + + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setUrl(mapper.getOriginalUrl().concat(QUERY_API)); + mapper.setNamespace(namespace); + mapper.setTopK(topK); + return Observable.fromSingle(this.pineconeService.query(mapper)); + } + + public StringResponse deleteAll(String namespace) { + PineconeEndpoint mapper = modelMapper.map(this, PineconeEndpoint.class); + mapper.setUrl(mapper.getOriginalUrl().concat(DELETE_API)); + mapper.setNamespace(namespace); + return this.pineconeService.deleteAll(mapper).blockingGet(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java new file mode 100644 index 000000000..8fcee3b52 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/PostgresEndpoint.java @@ -0,0 +1,564 @@ +package com.edgechain.lib.endpoint.impl.index; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.index.domain.PostgresWordEmbeddings; +import com.edgechain.lib.index.domain.RRFWeight; +import com.edgechain.lib.index.enums.OrderRRFBy; +import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.retrofit.PostgresService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.schedulers.Schedulers; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; + +public class PostgresEndpoint extends Endpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final PostgresService postgresService = retrofit.create(PostgresService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String tableName; + private int lists; + + private String id; + private String namespace; + + private String filename; + + // Getters + private WordEmbeddings wordEmbedding; + + private List wordEmbeddingsList; + + private PostgresDistanceMetric metric; + private int dimensions; + private int topK; + private int upperLimit; + + private int probes; + private String embeddingChunk; + + // Fields for metadata table + private List metadataTableNames; + private String metadata; + private String metadataId; + private List metadataList; + private String documentDate; + + /** RRF * */ + private RRFWeight textWeight; + + private RRFWeight similarityWeight; + private RRFWeight dateWeight; + + private OrderRRFBy orderRRFBy; + private String searchQuery; + + private PostgresLanguage postgresLanguage; + + // Join Table + private List idList; + + private EmbeddingEndpoint embeddingEndpoint; + + public PostgresEndpoint() {} + + public PostgresEndpoint(RetryPolicy retryPolicy) { + super(retryPolicy); + } + + public PostgresEndpoint(String tableName, EmbeddingEndpoint embeddingEndpoint) { + this.tableName = tableName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint(String tableName, String namespace, EmbeddingEndpoint embeddingEndpoint) { + this.tableName = tableName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint( + String tableName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public PostgresEndpoint( + String tableName, + String namespace, + EmbeddingEndpoint embeddingEndpoint, + RetryPolicy retryPolicy) { + super(retryPolicy); + this.tableName = tableName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public String getTableName() { + return tableName; + } + + public String getNamespace() { + return namespace; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + public void setMetadata(String metadata) { + this.metadata = metadata; + } + + public void setId(String id) { + this.id = id; + } + + public void setMetadataId(String metadataId) { + this.metadataId = metadataId; + } + + public void setMetadataList(List metadataList) { + this.metadataList = metadataList; + } + + public void setEmbeddingChunk(String embeddingChunk) { + this.embeddingChunk = embeddingChunk; + } + + public int getUpperLimit() { + return upperLimit; + } + + public void setUpperLimit(int upperLimit) { + this.upperLimit = upperLimit; + } + + private void setLists(int lists) { + this.lists = lists; + } + + private void setFilename(String filename) { + this.filename = filename; + } + + private void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + private void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + private void setMetric(PostgresDistanceMetric metric) { + this.metric = metric; + } + + private void setDimensions(int dimensions) { + this.dimensions = dimensions; + } + + private void setTopK(int topK) { + this.topK = topK; + } + + private void setProbes(int probes) { + this.probes = probes; + } + + private void setMetadataTableNames(List metadataTableNames) { + this.metadataTableNames = metadataTableNames; + } + + private void setDocumentDate(String documentDate) { + this.documentDate = documentDate; + } + + private void setTextWeight(RRFWeight textWeight) { + this.textWeight = textWeight; + } + + private void setSimilarityWeight(RRFWeight similarityWeight) { + this.similarityWeight = similarityWeight; + } + + private void setDateWeight(RRFWeight dateWeight) { + this.dateWeight = dateWeight; + } + + private void setOrderRRFBy(OrderRRFBy orderRRFBy) { + this.orderRRFBy = orderRRFBy; + } + + private void setSearchQuery(String searchQuery) { + this.searchQuery = searchQuery; + } + + private void setPostgresLanguage(PostgresLanguage postgresLanguage) { + this.postgresLanguage = postgresLanguage; + } + + private void setIdList(List idList) { + this.idList = idList; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + + // Getters + + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public int getDimensions() { + return dimensions; + } + + public int getTopK() { + return topK; + } + + public String getFilename() { + return filename; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } + + public PostgresDistanceMetric getMetric() { + return metric; + } + + public int getLists() { + return lists; + } + + public int getProbes() { + return probes; + } + + public List getMetadataTableNames() { + return metadataTableNames; + } + + public String getMetadata() { + return metadata; + } + + public String getMetadataId() { + return metadataId; + } + + public String getId() { + return id; + } + + public List getMetadataList() { + return metadataList; + } + + public String getEmbeddingChunk() { + return embeddingChunk; + } + + public String getDocumentDate() { + return documentDate; + } + + public RRFWeight getTextWeight() { + return textWeight; + } + + public RRFWeight getSimilarityWeight() { + return similarityWeight; + } + + public RRFWeight getDateWeight() { + return dateWeight; + } + + public OrderRRFBy getOrderRRFBy() { + return orderRRFBy; + } + + public String getSearchQuery() { + return searchQuery; + } + + public PostgresLanguage getPostgresLanguage() { + return postgresLanguage; + } + + public List getIdList() { + return idList; + } + + public StringResponse upsert( + WordEmbeddings wordEmbeddings, + String filename, + int dimension, + PostgresDistanceMetric metric) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setFilename(filename); + mapper.setDimensions(dimension); + mapper.setMetric(metric); + return this.postgresService.upsert(mapper).blockingGet(); + } + + public StringResponse createTable(int dimensions, PostgresDistanceMetric metric, int lists) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setDimensions(dimensions); + mapper.setMetric(metric); + mapper.setLists(lists); + + return this.postgresService.createTable(mapper).blockingGet(); + } + + public StringResponse createMetadataTable(String metadataTableName) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.createMetadataTable(mapper).blockingGet(); + } + + public List upsert( + List wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + mapper.setFilename(filename); + mapper.setPostgresLanguage(postgresLanguage); + + return this.postgresService.batchUpsert(mapper).blockingGet(); + } + + public StringResponse insertMetadata( + String metadataTableName, String metadata, String documentDate) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadata(metadata); + mapper.setDocumentDate(documentDate); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.insertMetadata(mapper).blockingGet(); + } + + public List batchInsertMetadata(List metadataList) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataList(metadataList); + return this.postgresService.batchInsertMetadata(mapper).blockingGet(); + } + + public StringResponse insertIntoJoinTable( + String metadataTableName, String id, String metadataId) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setId(id); + mapper.setMetadataId(metadataId); + mapper.setMetadataTableNames(List.of(metadataTableName)); + + return this.postgresService.insertIntoJoinTable(mapper).blockingGet(); + } + + public StringResponse batchInsertIntoJoinTable( + String metadataTableName, List idList, String metadataId) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setIdList(idList); + mapper.setMetadataId(metadataId); + mapper.setMetadataTableNames(List.of(metadataTableName)); + return this.postgresService.batchInsertIntoJoinTable(mapper).blockingGet(); + } + + public Observable> query( + List inputList, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + ArkRequest arkRequest) { + + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) + .flatMap( + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.query(mapper)); + } + + public Observable> query( + List inputList, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + int probes, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) + .flatMap( + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setMetric(metric); + mapper.setProbes(probes); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + return Observable.fromSingle(this.postgresService.query(mapper)); + } + + public Observable> queryRRF( + String metadataTable, + List inputList, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + OrderRRFBy orderRRFBy, + String searchQuery, + PostgresLanguage postgresLanguage, + int probes, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(List.of(metadataTable)); + + List endpointEmbeddingList = + Observable.fromIterable(inputList) + .buffer(inputList.size() > 1 ? inputList.size() / 2 : 1) + .flatMap( + bufferedList -> + Observable.fromIterable(bufferedList) + .flatMap( + res -> + Observable.fromCallable( + () -> + new EdgeChain<>( + embeddingEndpoint.embeddings(res, arkRequest)) + .get()) + .subscribeOn(Schedulers.io()))) + .toList() + .blockingGet(); + + mapper.setWordEmbeddingsList(endpointEmbeddingList); + mapper.setTextWeight(textWeight); + mapper.setSimilarityWeight(similarityWeight); + mapper.setDateWeight(dateWeight); + mapper.setOrderRRFBy(orderRRFBy); + mapper.setSearchQuery(searchQuery); + mapper.setPostgresLanguage(postgresLanguage); + mapper.setProbes(probes); + mapper.setMetric(metric); + mapper.setTopK(topK); + mapper.setUpperLimit(upperLimit); + return Observable.fromSingle(this.postgresService.queryRRF(mapper)); + } + + public Observable> queryWithMetadata( + List metadataTableNames, + WordEmbeddings wordEmbeddings, + PostgresDistanceMetric metric, + int topK) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setMetadataTableNames(metadataTableNames); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setTopK(topK); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); + } + + public Observable> queryWithMetadata( + List metadataTableNames, + String input, + PostgresDistanceMetric metric, + int topK, + ArkRequest arkRequest) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + + WordEmbeddings wordEmbeddings = + new EdgeChain<>(embeddingEndpoint.embeddings(input, arkRequest)).get(); + + mapper.setMetadataTableNames(metadataTableNames); + mapper.setWordEmbedding(wordEmbeddings); + mapper.setTopK(topK); + mapper.setMetric(metric); + mapper.setProbes(1); + return Observable.fromSingle(this.postgresService.queryWithMetadata(mapper)); + } + + public Observable> getSimilarMetadataChunk(String embeddingChunk) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setEmbeddingChunk(embeddingChunk); + return Observable.fromSingle(this.postgresService.getSimilarMetadataChunk(mapper)); + } + + public Observable> getAllChunks(String tableName, String filename) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setTableName(tableName); + mapper.setFilename(filename); + + return Observable.fromSingle(this.postgresService.getAllChunks(this)); + } + + public StringResponse deleteAll(String tableName, String namespace) { + PostgresEndpoint mapper = modelMapper.map(this, PostgresEndpoint.class); + mapper.setTableName(tableName); + mapper.setNamespace(namespace); + return this.postgresService.deleteAll(mapper).blockingGet(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java new file mode 100644 index 000000000..dbc8941f8 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/index/RedisEndpoint.java @@ -0,0 +1,190 @@ +package com.edgechain.lib.endpoint.impl.index; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.endpoint.impl.embeddings.EmbeddingEndpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.RedisService; +import com.edgechain.lib.index.enums.RedisDistanceMetric; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; +import java.util.List; + +public class RedisEndpoint extends Endpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final RedisService redisService = retrofit.create(RedisService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String indexName; + private String namespace; + + // Getters; + private WordEmbeddings wordEmbedding; + private List wordEmbeddingsList; + + private int dimensions; + + private RedisDistanceMetric metric; + + private int topK; + + private String pattern; + + private EmbeddingEndpoint embeddingEndpoint; + + public RedisEndpoint() {} + + public RedisEndpoint(RetryPolicy retryPolicy) { + super(retryPolicy); + } + + public RedisEndpoint(String indexName, EmbeddingEndpoint embeddingEndpoint) { + this.indexName = indexName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public RedisEndpoint( + String indexName, EmbeddingEndpoint embeddingEndpoint, RetryPolicy retryPolicy) { + super(retryPolicy); + this.indexName = indexName; + this.embeddingEndpoint = embeddingEndpoint; + } + + public RedisEndpoint(String indexName, String namespace, EmbeddingEndpoint embeddingEndpoint) { + this.indexName = indexName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public RedisEndpoint( + String indexName, + String namespace, + EmbeddingEndpoint embeddingEndpoint, + RetryPolicy retryPolicy) { + super(retryPolicy); + this.indexName = indexName; + this.namespace = namespace; + this.embeddingEndpoint = embeddingEndpoint; + } + + public EmbeddingEndpoint getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public void setEmbeddingEndpoint(EmbeddingEndpoint embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + + public void setPattern(String pattern) { + this.pattern = pattern; + } + + public String getIndexName() { + return indexName; + } + + public void setIndexName(String indexName) { + this.indexName = indexName; + } + + public String getNamespace() { + return namespace; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + // Getters + public WordEmbeddings getWordEmbedding() { + return wordEmbedding; + } + + public void setWordEmbedding(WordEmbeddings wordEmbedding) { + this.wordEmbedding = wordEmbedding; + } + + public int getDimensions() { + return dimensions; + } + + public void setDimensions(int dimensions) { + this.dimensions = dimensions; + } + + public void setWordEmbeddingsList(List wordEmbeddingsList) { + this.wordEmbeddingsList = wordEmbeddingsList; + } + + public RedisDistanceMetric getMetric() { + return metric; + } + + public List getWordEmbeddingsList() { + return wordEmbeddingsList; + } + + public void setMetric(RedisDistanceMetric metric) { + this.metric = metric; + } + + public int getTopK() { + return topK; + } + + public void setTopK(int topK) { + this.topK = topK; + } + + public String getPattern() { + return pattern; + } + + // Convenience Methods + public StringResponse createIndex(String namespace, int dimension, RedisDistanceMetric metric) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setDimensions(dimension); + mapper.setMetric(metric); + mapper.setNamespace(namespace); + + return this.redisService.createIndex(mapper).blockingGet(); + } + + public void batchUpsert(List wordEmbeddingsList) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setWordEmbeddingsList(wordEmbeddingsList); + + this.redisService.batchUpsert(mapper).ignoreElement().blockingAwait(); + } + + public StringResponse upsert(WordEmbeddings wordEmbedding) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setWordEmbedding(wordEmbedding); + + return this.redisService.upsert(mapper).blockingGet(); + } + + public Observable> query(String input, int topK, ArkRequest arkRequest) { + + WordEmbeddings wordEmbedding = + new EdgeChain<>(embeddingEndpoint.embeddings(input, arkRequest)).get(); + + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setTopK(topK); + mapper.setWordEmbedding(wordEmbedding); + return Observable.fromSingle(this.redisService.query(mapper)); + } + + public void delete(String patternName) { + RedisEndpoint mapper = modelMapper.map(this, RedisEndpoint.class); + mapper.setPattern(patternName); + this.redisService.deleteByPattern(mapper).blockingAwait(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java new file mode 100644 index 000000000..1ca7f4943 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/integration/AirtableEndpoint.java @@ -0,0 +1,182 @@ +package com.edgechain.lib.endpoint.impl.integration; + +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.integration.airtable.query.AirtableQueryBuilder; +import com.edgechain.lib.retrofit.AirtableService; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import dev.fuxing.airtable.AirtableRecord; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; +import java.util.Map; + +public class AirtableEndpoint extends Endpoint { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final AirtableService airtableService = retrofit.create(AirtableService.class); + + private ModelMapper modelMapper = new ModelMapper(); + + private String baseId; + private String apiKey; + + private List ids; + private String tableName; + private List airtableRecordList; + private boolean typecast = false; + + private AirtableQueryBuilder airtableQueryBuilder; + + public AirtableEndpoint() {} + + public AirtableEndpoint(String baseId, String apiKey) { + this.baseId = baseId; + this.apiKey = apiKey; + } + + public void setIds(List ids) { + this.ids = ids; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public void setAirtableRecordList(List airtableRecordList) { + this.airtableRecordList = airtableRecordList; + } + + public void setTypecast(boolean typecast) { + this.typecast = typecast; + } + + public void setAirtableQueryBuilder(AirtableQueryBuilder airtableQueryBuilder) { + this.airtableQueryBuilder = airtableQueryBuilder; + } + + public void setBaseId(String baseId) { + this.baseId = baseId; + } + + @Override + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseId() { + return baseId; + } + + public String getApiKey() { + return apiKey; + } + + public String getTableName() { + return tableName; + } + + public List getIds() { + return ids; + } + + public List getAirtableRecordList() { + return airtableRecordList; + } + + public boolean isTypecast() { + return typecast; + } + + public AirtableQueryBuilder getAirtableQueryBuilder() { + return airtableQueryBuilder; + } + + public Observable> findAll(String tableName, AirtableQueryBuilder builder) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + ; + mapper.setTableName(tableName); + mapper.setAirtableQueryBuilder(builder); + return Observable.fromSingle(this.airtableService.findAll(mapper)); + } + + public Observable> findAll(String tableName) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setTableName(tableName); + mapper.setAirtableQueryBuilder(new AirtableQueryBuilder()); + + return Observable.fromSingle(this.airtableService.findAll(mapper)); + } + + public Observable findById(String tableName, String id) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setTableName(tableName); + mapper.setIds(List.of(id)); + return Observable.fromSingle(this.airtableService.findById(mapper)); + } + + public Observable> create( + String tableName, List airtableRecordList) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> create( + String tableName, List airtableRecordList, boolean typecast) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + mapper.setTypecast(typecast); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> create(String tableName, AirtableRecord airtableRecord) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(List.of(airtableRecord)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.create(mapper)); + } + + public Observable> update( + String tableName, List airtableRecordList) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> update( + String tableName, List airtableRecordList, boolean typecast) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(airtableRecordList); + mapper.setTableName(tableName); + mapper.setTypecast(typecast); + + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> update(String tableName, AirtableRecord airtableRecord) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setAirtableRecordList(List.of(airtableRecord)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.update(mapper)); + } + + public Observable> delete(String tableName, List ids) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setIds(ids); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.delete(mapper)); + } + + public Observable> delete(String tableName, String id) { + AirtableEndpoint mapper = modelMapper.map(this, AirtableEndpoint.class); + mapper.setIds(List.of(id)); + mapper.setTableName(tableName); + return Observable.fromSingle(this.airtableService.delete(mapper)); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java new file mode 100644 index 000000000..03e52f6da --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java @@ -0,0 +1,47 @@ +package com.edgechain.lib.endpoint.impl.llm; + +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.Llama2Service; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +public class LLamaQuickstart extends Endpoint { + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); + private final ModelMapper modelMapper = new ModelMapper(); + private String query; + + public LLamaQuickstart() {} + + public LLamaQuickstart(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + } + + public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) { + super(url, retryPolicy); + this.query = query; + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public Observable chatCompletion(String query, ArkRequest arkRequest) { + LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class); + mapper.setQuery(query); + return chatCompletion(mapper, arkRequest); + } + + private Observable chatCompletion( + LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) { + return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart)); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java new file mode 100644 index 000000000..96ba9912d --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java @@ -0,0 +1,180 @@ +package com.edgechain.lib.endpoint.impl.llm; + +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.Llama2Service; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.reactivex.rxjava3.core.Observable; +import org.json.JSONObject; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; +import java.util.Objects; + +public class Llama2Endpoint extends Endpoint { + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); + + private final ModelMapper modelMapper = new ModelMapper(); + + private String inputs; + private JSONObject parameters; + private Double temperature; + + @JsonProperty("top_k") + private Integer topK; + + @JsonProperty("top_p") + private Double topP; + + @JsonProperty("do_sample") + private Boolean doSample; + + @JsonProperty("max_new_tokens") + private Integer maxNewTokens; + + @JsonProperty("repetition_penalty") + private Double repetitionPenalty; + + private List stop; + private String chainName; + private String callIdentifier; + + public Llama2Endpoint() {} + + public Llama2Endpoint( + String url, + RetryPolicy retryPolicy, + Double temperature, + Integer topK, + Double topP, + Boolean doSample, + Integer maxNewTokens, + Double repetitionPenalty, + List stop) { + super(url, retryPolicy); + this.temperature = temperature; + this.topK = topK; + this.topP = topP; + this.doSample = doSample; + this.maxNewTokens = maxNewTokens; + this.repetitionPenalty = repetitionPenalty; + this.stop = stop; + } + + public Llama2Endpoint(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + this.temperature = 0.7; + this.maxNewTokens = 512; + } + + public String getInputs() { + return inputs; + } + + public void setInputs(String inputs) { + this.inputs = inputs; + } + + public JSONObject getParameters() { + return parameters; + } + + public void setParameters(JSONObject parameters) { + this.parameters = parameters; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Boolean getDoSample() { + return doSample; + } + + public void setDoSample(Boolean doSample) { + this.doSample = doSample; + } + + public Integer getMaxNewTokens() { + return maxNewTokens; + } + + public void setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + } + + public Double getRepetitionPenalty() { + return repetitionPenalty; + } + + public void setRepetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + } + + public List getStop() { + return stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public String getChainName() { + return chainName; + } + + public void setChainName(String chainName) { + this.chainName = chainName; + } + + public String getCallIdentifier() { + return callIdentifier; + } + + public void setCallIdentifier(String callIdentifier) { + this.callIdentifier = callIdentifier; + } + + public Observable> chatCompletion( + String inputs, String chainName, ArkRequest arkRequest) { + + Llama2Endpoint mapper = modelMapper.map(this, Llama2Endpoint.class); + mapper.setInputs(inputs); + mapper.setChainName(chainName); + return chatCompletion(mapper, arkRequest); + } + + private Observable> chatCompletion( + Llama2Endpoint mapper, ArkRequest arkRequest) { + + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); + + return Observable.fromSingle(this.llama2Service.chatCompletion(mapper)); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java similarity index 50% rename from Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java rename to Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java index d21685899..19d571ef8 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/OpenAiEndpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/OpenAiChatEndpoint.java @@ -1,22 +1,25 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.llm; import com.edgechain.lib.configuration.context.ApplicationContextHolder; -import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.openai.request.ChatMessage; +import com.edgechain.lib.openai.response.CompletionResponse; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.client.OpenAiStreamService; import com.edgechain.lib.retrofit.OpenAiService; -import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.openai.response.ChatCompletionResponse; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; import com.edgechain.lib.rxjava.retry.RetryPolicy; import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; import retrofit2.Retrofit; import java.util.List; +import java.util.Map; import java.util.Objects; -public class OpenAiEndpoint extends Endpoint { +public class OpenAiChatEndpoint extends Endpoint { private final OpenAiStreamService openAiStreamService = ApplicationContextHolder.getContext().getBean(OpenAiStreamService.class); @@ -24,15 +27,23 @@ public class OpenAiEndpoint extends Endpoint { private final Retrofit retrofit = RetrofitClientInstance.getInstance(); private final OpenAiService openAiService = retrofit.create(OpenAiService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String orgId; private String model; - private String role; - private Double temperature; + private Double temperature; + private List chatMessages; private Boolean stream; + private Double topP; + private Integer n; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + private Map logitBias; + private String user; - /** Getter Fields ** */ - private List chatMessages; + private String role; private String input; @@ -41,26 +52,24 @@ public class OpenAiEndpoint extends Endpoint { private String callIdentifier; - public OpenAiEndpoint() {} + private JsonnetLoader jsonnetLoader; - public OpenAiEndpoint(String url, String apiKey, String model) { - super(url, apiKey, null); - this.model = model; - } + public OpenAiChatEndpoint() {} - public OpenAiEndpoint(String url, String apiKey, String model, RetryPolicy retryPolicy) { - super(url, apiKey, retryPolicy); + public OpenAiChatEndpoint(String url, String apiKey, String model) { + super(url, apiKey, null); this.model = model; } - public OpenAiEndpoint( - String url, String apiKey, String model, String role, RetryPolicy retryPolicy) { + // For Embeddings.... + public OpenAiChatEndpoint( + String url, String apiKey, String orgId, String model, RetryPolicy retryPolicy) { super(url, apiKey, retryPolicy); + this.orgId = orgId; this.model = model; - this.role = role; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String model, @@ -73,7 +82,7 @@ public OpenAiEndpoint( this.temperature = temperature; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String model, String role, Double temperature, Boolean stream) { super(url, apiKey, null); this.model = model; @@ -82,22 +91,22 @@ public OpenAiEndpoint( this.stream = stream; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, + String orgId, String model, String role, Double temperature, - Boolean stream, RetryPolicy retryPolicy) { super(url, apiKey, retryPolicy); this.model = model; this.role = role; this.temperature = temperature; - this.stream = stream; + this.orgId = orgId; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String orgId, @@ -113,7 +122,7 @@ public OpenAiEndpoint( this.stream = stream; } - public OpenAiEndpoint( + public OpenAiChatEndpoint( String url, String apiKey, String orgId, @@ -130,14 +139,22 @@ public OpenAiEndpoint( this.stream = stream; } - public String getModel() { - return model; - } - public String getRole() { return role; } + public void setRole(String role) { + this.role = role; + } + + public void setInput(String input) { + this.input = input; + } + + public String getModel() { + return model; + } + public Double getTemperature() { return temperature; } @@ -158,10 +175,6 @@ public void setModel(String model) { this.model = model; } - public void setRole(String role) { - this.role = role; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } @@ -170,14 +183,86 @@ public void setStream(Boolean stream) { this.stream = stream; } - public List getChatMessages() { - return chatMessages; + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + public List getStop() { + return stop; } public String getInput() { return input; } + public void setStop(List stop) { + this.stop = stop; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Map getLogitBias() { + return logitBias; + } + + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + public void setCallIdentifier(String callIdentifier) { + this.callIdentifier = callIdentifier; + } + + public List getChatMessages() { + return chatMessages; + } + + public void setChatMessages(List chatMessages) { + this.chatMessages = chatMessages; + } + + public void setJsonnetLoader(JsonnetLoader jsonnetLoader) { + this.jsonnetLoader = jsonnetLoader; + } + + public JsonnetLoader getJsonnetLoader() { + return jsonnetLoader; + } + public String getChainName() { return chainName; } @@ -193,41 +278,55 @@ public String getCallIdentifier() { public Observable chatCompletion( String input, String chainName, ArkRequest arkRequest) { - this.chatMessages = List.of(new ChatMessage(this.role, input)); - this.chainName = chainName; + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(List.of(new ChatMessage(this.role, input))); + mapper.setChainName(chainName); - if (Objects.nonNull(arkRequest)) { - this.callIdentifier = arkRequest.getRequestURI(); - } + return chatCompletion(mapper, arkRequest); + } - if (Objects.nonNull(this.getStream()) && this.getStream()) - return this.openAiStreamService - .chatCompletion(this) - .map( - chatResponse -> { - if (!Objects.isNull(chatResponse.getChoices().get(0).getFinishReason())) { - chatResponse.getChoices().get(0).getMessage().setContent(""); - return chatResponse; - } else return chatResponse; - }); - else return Observable.fromSingle(this.openAiService.chatCompletion(this)); + public Observable chatCompletion( + String input, String chainName, JsonnetLoader loader, ArkRequest arkRequest) { + + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(List.of(new ChatMessage(this.role, input))); + mapper.setChainName(chainName); + mapper.setJsonnetLoader(loader); + + return chatCompletion(mapper, arkRequest); } public Observable chatCompletion( List chatMessages, String chainName, ArkRequest arkRequest) { + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(chatMessages); + mapper.setChainName(chainName); + return chatCompletion(mapper, arkRequest); + } - this.chainName = chainName; - this.chatMessages = chatMessages; + public Observable chatCompletion( + List chatMessages, + String chainName, + JsonnetLoader loader, + ArkRequest arkRequest) { + + OpenAiChatEndpoint mapper = modelMapper.map(this, OpenAiChatEndpoint.class); + mapper.setChatMessages(chatMessages); + mapper.setChainName(chainName); + mapper.setJsonnetLoader(loader); + + return chatCompletion(mapper, arkRequest); + } + + private Observable chatCompletion( + OpenAiChatEndpoint mapper, ArkRequest arkRequest) { - if (Objects.nonNull(arkRequest)) { - this.callIdentifier = arkRequest.getRequestURI(); - } else { - this.callIdentifier = "URI wasn't provided"; - } + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); - if (Objects.nonNull(this.getStream()) && this.getStream()) + if (Objects.nonNull(getStream()) && getStream()) return this.openAiStreamService - .chatCompletion(this) + .chatCompletion(mapper) .map( chatResponse -> { if (!Objects.isNull(chatResponse.getChoices().get(0).getFinishReason())) { @@ -235,21 +334,14 @@ public Observable chatCompletion( return chatResponse; } else return chatResponse; }); - else return Observable.fromSingle(this.openAiService.chatCompletion(this)); + else return Observable.fromSingle(this.openAiService.chatCompletion(mapper)); } - public Observable embeddings(String input, ArkRequest arkRequest) { - this.input = input; // set Input - - if (Objects.nonNull(arkRequest)) { - this.callIdentifier = arkRequest.getRequestURI(); - } + public Observable completion(String input, ArkRequest arkRequest) { + if (Objects.nonNull(arkRequest)) this.callIdentifier = arkRequest.getRequestURI(); + else this.callIdentifier = "URI wasn't provided"; - return Observable.fromSingle( - openAiService - .embeddings(this) - .map( - embeddingResponse -> - new WordEmbeddings(input, embeddingResponse.getData().get(0).getEmbedding()))); + this.input = input; + return Observable.fromSingle(this.openAiService.completion(this)); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/SupabaseEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/supabase/SupabaseEndpoint.java similarity index 96% rename from Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/SupabaseEndpoint.java rename to Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/supabase/SupabaseEndpoint.java index 7ff111733..22acc7921 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/SupabaseEndpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/supabase/SupabaseEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.supabase; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.retrofit.SupabaseService; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/WikiEndpoint.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/wiki/WikiEndpoint.java similarity index 78% rename from Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/WikiEndpoint.java rename to Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/wiki/WikiEndpoint.java index cc9e95871..f03cf381d 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/WikiEndpoint.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/wiki/WikiEndpoint.java @@ -1,4 +1,4 @@ -package com.edgechain.lib.endpoint.impl; +package com.edgechain.lib.endpoint.impl.wiki; import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.retrofit.WikiService; @@ -6,6 +6,7 @@ import com.edgechain.lib.rxjava.retry.RetryPolicy; import com.edgechain.lib.wiki.response.WikiResponse; import io.reactivex.rxjava3.core.Observable; +import org.modelmapper.ModelMapper; import retrofit2.Retrofit; public class WikiEndpoint extends Endpoint { @@ -13,13 +14,14 @@ public class WikiEndpoint extends Endpoint { private final Retrofit retrofit = RetrofitClientInstance.getInstance(); private final WikiService wikiService = retrofit.create(WikiService.class); + private ModelMapper modelMapper = new ModelMapper(); + private String input; public WikiEndpoint() {} public WikiEndpoint(RetryPolicy retryPolicy) { super(retryPolicy); - this.input = input; } public String getInput() { @@ -31,7 +33,8 @@ public void setInput(String input) { } public Observable getPageContent(String input) { - this.input = input; - return Observable.fromSingle(this.wikiService.getPageContent(this)); + WikiEndpoint mapper = modelMapper.map(this, WikiEndpoint.class); + mapper.setInput(input); + return Observable.fromSingle(this.wikiService.getPageContent(mapper)); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/ProjectRunner.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/ProjectRunner.java index 488460574..f2f8669cd 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/ProjectRunner.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/ProjectRunner.java @@ -1,11 +1,19 @@ package com.edgechain.lib.flyfly.commands.run; -import static java.nio.file.StandardWatchEventKinds.*; - import com.edgechain.lib.flyfly.utils.ProjectSetup; import jakarta.annotation.PreDestroy; -import java.io.*; -import java.nio.file.*; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.FileSystems; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.WatchEvent; +import java.nio.file.WatchKey; +import java.nio.file.WatchService; import java.nio.file.attribute.BasicFileAttributes; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -15,11 +23,14 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.zeroturnaround.exec.ProcessExecutor; +import static java.nio.file.StandardWatchEventKinds.ENTRY_CREATE; +import static java.nio.file.StandardWatchEventKinds.ENTRY_DELETE; +import static java.nio.file.StandardWatchEventKinds.ENTRY_MODIFY; @Component public class ProjectRunner { - private final Logger log = LoggerFactory.getLogger(this.getClass()); + private final Logger logger = LoggerFactory.getLogger(this.getClass()); @Autowired TestContainersStarter testContainersStarter; @Autowired ProjectSetup projectSetup; @@ -31,30 +42,35 @@ public class ProjectRunner { public void run() { try { - log.info("Configuring the project"); - log.info("Checking if initscript exists"); + logger.info("Configuring the project"); + logger.info("Checking if initscript exists"); if (!projectSetup.initscriptExists()) { - log.info("Initscript doesn't exist"); - log.info("Adding flyfly.gradle to initscripts"); + logger.info("Initscript doesn't exist"); + logger.info("Adding flyfly.gradle to initscripts"); projectSetup.addInitscript(); } projectSetup.addAutorouteJar(); allowInfrastructureServices = isDockerInstalled(); if (allowInfrastructureServices) checkAndConfigureServices(); - log.debug("registering watcher for src files changes"); + logger.debug("registering watcher for src files changes"); registerFilesWatcher(); - log.debug("registering watcher for build file changes"); + logger.debug("registering watcher for build file changes"); registerBuildFileWatcher(); - log.info("Starting the project"); + logger.info("Starting the project"); runTheProject(); loop(); + + } catch (InterruptedException ie) { + logger.warn("interrupted", ie); + Thread.currentThread().interrupt(); + } catch (Exception e) { - e.printStackTrace(); + logger.error("failed", e); } } boolean isDockerInstalled() throws IOException, InterruptedException { - log.info("Checking if docker is installed to allow infrastructure services"); + logger.info("Checking if docker is installed to allow infrastructure services"); int exitCode; try { String[] command; @@ -62,11 +78,19 @@ boolean isDockerInstalled() throws IOException, InterruptedException { else command = new String[] {"docker", "info"}; exitCode = new ProcessExecutor().command(command).start().getProcess().waitFor(); + + } catch (InterruptedException ie) { + logger.warn("interrupted", ie); + Thread.currentThread().interrupt(); + exitCode = -1; + } catch (Exception e) { + logger.error("failed", e); exitCode = -1; } + if (exitCode != 0) { - log.warn("Couldn't find docker. Disabling infrastructure services."); + logger.warn("Couldn't find docker. Disabling infrastructure services."); return false; } return true; @@ -82,9 +106,9 @@ void runTheProject() throws IOException { } void checkAndConfigureServices() throws IOException { - log.info("Checking if services are needed"); - // Set supportedDBGroupIds = - // Set.of("mysql", "com.mysql", "org.postgresql", "org.mariadb.jdbc"); + logger.info("Checking if services are needed"); + // Set supportedDBGroupIds = + // Set.of("mysql", "com.mysql", "org.postgresql", "org.mariadb.jdbc"); Set supportedDBGroupIds = Set.of("org.postgresql"); BufferedReader reader = new BufferedReader(new FileReader("build.gradle")); String line; @@ -96,12 +120,12 @@ void checkAndConfigureServices() throws IOException { if (start < 0 || end < 0) continue; String groupID = line.substring(start + 1, end); if (supportedDBGroupIds.contains(groupID)) { - if (!testContainersStarter.isServiesNeeded()) break; - log.info("Found : " + groupID); + if (!testContainersStarter.isServiceNeeded()) break; + logger.info("Found : {}", groupID); switch (groupID) { - // case "mysql", "com.mysql" -> testContainersStarter.startMySQL(); + // case "mysql", "com.mysql" -> testContainersStarter.startMySQL(); case "org.postgresql" -> testContainersStarter.startPostgreSQL(); - // case "org.mariadb.jdbc" -> testContainersStarter.startMariaDB(); + // case "org.mariadb.jdbc" -> testContainersStarter.startMariaDB(); } break; } @@ -167,7 +191,7 @@ boolean didBuildFileChange() throws InterruptedException { if (p.endsWith("build.gradle")) found = true; } key.reset(); - if (found) log.info("Detected build file change ..."); + if (found) logger.info("Detected build file change ..."); return found; } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarter.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarter.java index edea76f91..2ba68603a 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarter.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarter.java @@ -1,9 +1,12 @@ package com.edgechain.lib.flyfly.commands.run; import jakarta.annotation.PreDestroy; -import java.io.*; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; import java.nio.file.FileSystems; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; @@ -14,121 +17,152 @@ public class TestContainersStarter { private final Logger log = LoggerFactory.getLogger(this.getClass()); - private static final String dbName = "test"; - private static final String userName = "test"; - private static final String password = "test"; - // private MySQLContainer mysql; + private static final String DBNAME = "test"; + private static final String USERNAME = "test"; + private static final String PASSWORD = "test"; + // private MySQLContainer mysql; private PostgreSQLContainer postgre; - // private MariaDBContainer mariaDB; - private String flyflyTempTag = "#flyfly_temp_property"; - - // public void startMySQL() throws IOException { - // if (mysql != null && mysql.isRunning()) return; - // log.info("Starting a temporary MySQL database."); - // mysql = - // new MySQLContainer<>(DockerImageName.parse("mysql:5.7")) - // .withDatabaseName(dbName) - // .withUsername(userName) - // .withPassword(password); - // mysql.addParameter("TC_MY_CNF", null); - // mysql.start(); - // log.info("Database started."); - // log.info("DB URL: " + mysql.getJdbcUrl()); - // log.info("DB Username: " + mysql.getUsername()); - // log.info("DB Password: " + mysql.getPassword()); - // addTempProperties(mysql.getJdbcUrl()); - // } + // private MariaDBContainer mariaDB; + static final String FLYFLYTEMPTAG = "#flyfly_temp_property"; + + private String propertiesPath; + + public TestContainersStarter() { + propertiesPath = buildPropertiesPath(); + } + + public String getPropertiesPath() { + return propertiesPath; + } + + public void setPropertiesPath(String propertiesPath) { + this.propertiesPath = propertiesPath; + } + + // public void startMySQL() throws IOException { + // if (mysql != null && mysql.isRunning()) return; + // log.info("Starting a temporary MySQL database."); + // mysql = + // new MySQLContainer<>(DockerImageName.parse("mysql:5.7")) + // .withDatabaseName(dbName) + // .withUsername(userName) + // .withPassword(password); + // mysql.addParameter("TC_MY_CNF", null); + // mysql.start(); + // log.info("Database started."); + // log.info("DB URL: " + mysql.getJdbcUrl()); + // log.info("DB Username: " + mysql.getUsername()); + // log.info("DB Password: " + mysql.getPassword()); + // addTempProperties(mysql.getJdbcUrl()); + // } public void startPostgreSQL() throws IOException { if (postgre != null && postgre.isRunning()) return; log.info("Starting a temporary PostgreSQL database."); postgre = new PostgreSQLContainer<>("postgres:14.5") - .withDatabaseName(dbName) - .withUsername(userName) - .withPassword(password); + .withDatabaseName(DBNAME) + .withUsername(USERNAME) + .withPassword(PASSWORD); postgre.addParameter("TC_MY_CNF", null); postgre.start(); + log.info("Database started."); - log.info("DB URL: " + postgre.getJdbcUrl()); - log.info("DB Username: " + postgre.getUsername()); - log.info("DB Password: " + postgre.getPassword()); + log.info("DB URL: {}", postgre.getJdbcUrl()); + log.info("DB Username: {}", postgre.getUsername()); + log.info("DB Password: {}", postgre.getPassword()); + addTempProperties(postgre.getJdbcUrl()); } - // public void startMariaDB() throws IOException { - // if (postgre != null && postgre.isRunning()) return; - // log.info("Starting a temporary MariaDB database."); - // mariaDB = - // new MariaDBContainer<>("mariadb:10.3.6") - // .withDatabaseName(dbName) - // .withUsername(userName) - // .withPassword(password); - // mariaDB.addParameter("TC_MY_CNF", null); - // mariaDB.start(); - // log.info("Database started."); - // log.info("DB URL: " + mariaDB.getJdbcUrl()); - // log.info("DB Username: " + mariaDB.getUsername()); - // log.info("DB Password: " + mariaDB.getPassword()); - // addTempProperties(mariaDB.getJdbcUrl()); - // } + public void stopPostgreSQL() throws IOException { + try { + removeTempProperties(); + } catch (IOException e) { + } + // if (mysql != null && mysql.isRunning()) mysql.close(); + log.info("Stopping temporary PostgreSQL database."); + if (postgre != null && postgre.isRunning()) postgre.close(); + // if (mariaDB != null && mariaDB.isRunning()) mariaDB.close(); + } + + // public void startMariaDB() throws IOException { + // if (postgre != null && postgre.isRunning()) return; + // log.info("Starting a temporary MariaDB database."); + // mariaDB = + // new MariaDBContainer<>("mariadb:10.3.6") + // .withDatabaseName(dbName) + // .withUsername(userName) + // .withPassword(password); + // mariaDB.addParameter("TC_MY_CNF", null); + // mariaDB.start(); + // log.info("Database started."); + // log.info("DB URL: " + mariaDB.getJdbcUrl()); + // log.info("DB Username: " + mariaDB.getUsername()); + // log.info("DB Password: " + mariaDB.getPassword()); + // addTempProperties(mariaDB.getJdbcUrl()); + // } public void addTempProperties(String url) throws IOException { log.info("Appending temporary DB configuration to application.properties"); - BufferedWriter writer = new BufferedWriter(new FileWriter(getPropertiesPath(), true)); - writer.newLine(); - writer.append(flyflyTempTag); - writer.newLine(); - writer.append("spring.datasource.url=" + url); - writer.newLine(); - writer.append(flyflyTempTag); - writer.newLine(); - writer.append("spring.datasource.username=" + userName); - writer.newLine(); - writer.append(flyflyTempTag); - writer.newLine(); - writer.append("spring.datasource.password=" + password); - writer.flush(); - writer.close(); + try (FileWriter fw = new FileWriter(propertiesPath, true); + BufferedWriter writer = new BufferedWriter(fw)) { + writer.newLine(); + writer.append(FLYFLYTEMPTAG); + writer.newLine(); + writer.append("spring.datasource.url=").append(url); + writer.newLine(); + writer.append(FLYFLYTEMPTAG); + writer.newLine(); + writer.append("spring.datasource.username=").append(USERNAME); + writer.newLine(); + writer.append(FLYFLYTEMPTAG); + writer.newLine(); + writer.append("spring.datasource.password=").append(PASSWORD); + writer.flush(); + } } public void removeTempProperties() throws IOException { - BufferedReader reader = new BufferedReader(new FileReader(getPropertiesPath())); - StringBuilder sb = new StringBuilder(); + log.info("Removing temporary DB configuration from application.properties"); boolean tempNotFound = true; - String line; - while ((line = reader.readLine()) != null) { - if (line.contains(flyflyTempTag)) { - tempNotFound = false; - reader.readLine(); - continue; + StringBuilder sb = new StringBuilder(); + try (FileReader fr = new FileReader(propertiesPath); + BufferedReader reader = new BufferedReader(fr)) { + String line; + while ((line = reader.readLine()) != null) { + if (line.contains(FLYFLYTEMPTAG)) { + tempNotFound = false; + reader.readLine(); // skip next line + continue; + } + sb.append(line).append("\n"); } - sb.append(line + "\n"); } - reader.close(); if (tempNotFound) return; - BufferedWriter writer = new BufferedWriter(new FileWriter(getPropertiesPath())); - writer.write(sb.toString()); - writer.flush(); - writer.close(); + try (FileWriter fw = new FileWriter(propertiesPath); + BufferedWriter writer = new BufferedWriter(fw)) { + writer.write(sb.toString()); + writer.flush(); + } } - public boolean isServiesNeeded() throws IOException { - BufferedReader reader = new BufferedReader(new FileReader(getPropertiesPath())); - String line; - String datafield = "spring.datasource.url"; - while ((line = reader.readLine()) != null) { - if (line.contains(datafield)) { - reader.close(); - return false; + public boolean isServiceNeeded() throws IOException { + final String datafield = "spring.datasource.url"; + try (FileReader fr = new FileReader(propertiesPath); + BufferedReader reader = new BufferedReader(fr)) { + String line; + while ((line = reader.readLine()) != null) { + if (line.contains(datafield)) { + return false; + } } } - reader.close(); return true; } - public String getPropertiesPath() { + private static String buildPropertiesPath() { String separator = FileSystems.getDefault().getSeparator(); return System.getProperty("user.dir") + separator @@ -144,11 +178,8 @@ public String getPropertiesPath() { @PreDestroy public void destroy() { try { - removeTempProperties(); + stopPostgreSQL(); } catch (IOException e) { } - // if (mysql != null && mysql.isRunning()) mysql.close(); - if (postgre != null && postgre.isRunning()) postgre.close(); - // if (mariaDB != null && mariaDB.isRunning()) mariaDB.close(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/utils/ProjectSetup.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/utils/ProjectSetup.java index d6b86789d..cf26a9e7d 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/utils/ProjectSetup.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/flyfly/utils/ProjectSetup.java @@ -25,7 +25,7 @@ public class ProjectSetup { private String glowrootDir = flyflyDir + separator + "glowroot"; public boolean initscriptExists() { - log.debug("Checking if flyfly.gradle exists in " + initscriptDir); + log.debug("Checking if flyfly.gradle exists in {}", initscriptDir); return new File(initscriptDir).exists(); } @@ -44,7 +44,7 @@ public void addAutorouteJar() throws IOException { } public boolean formatScriptExists() { - log.debug("Checking if format.gradle exists in " + formatScriptDir); + log.debug("Checking if format.gradle exists in {}", formatScriptDir); return new File(formatScriptDir).exists(); } @@ -55,7 +55,7 @@ public void addFormatScript() throws IOException { } public boolean glowrootAgentExists() { - log.debug("Checking if glowroot folder exists in " + glowrootDir); + log.debug("Checking if glowroot folder exists in {}", glowrootDir); return new File(glowrootDir).exists(); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java index adcfd1c65..6b5d84f75 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PineconeClient.java @@ -1,6 +1,6 @@ package com.edgechain.lib.index.client.impl; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.index.request.pinecone.PineconeUpsert; import com.edgechain.lib.embeddings.WordEmbeddings; import com.edgechain.lib.response.StringResponse; @@ -18,36 +18,51 @@ @Service public class PineconeClient { - private PineconeEndpoint endpoint; - private String namespace; + public EdgeChain upsert(PineconeEndpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { - public PineconeEndpoint getEndpoint() { - return endpoint; - } + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("Api-Key", endpoint.getApiKey()); + + PineconeUpsert pinecone = new PineconeUpsert(); + pinecone.setVectors(List.of(endpoint.getWordEmbedding())); + pinecone.setNamespace(getNamespace(endpoint)); + + HttpEntity entity = new HttpEntity<>(pinecone, headers); + + ResponseEntity response = + new RestTemplate() + .exchange(endpoint.getUrl(), HttpMethod.POST, entity, String.class); + + emitter.onNext(new StringResponse(response.getBody())); + emitter.onComplete(); - public void setEndpoint(PineconeEndpoint endpoint) { - this.endpoint = endpoint; + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); } - public EdgeChain upsert(WordEmbeddings wordEmbeddings) { + public EdgeChain batchUpsert(PineconeEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { - this.namespace = - (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) - ? "" - : endpoint.getNamespace(); - HttpHeaders headers = new HttpHeaders(); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); headers.setContentType(MediaType.APPLICATION_JSON); headers.set("Api-Key", endpoint.getApiKey()); PineconeUpsert pinecone = new PineconeUpsert(); - pinecone.setVectors(List.of(wordEmbeddings)); - pinecone.setNamespace(namespace); + pinecone.setVectors(endpoint.getWordEmbeddingsList()); + pinecone.setNamespace(getNamespace(endpoint)); HttpEntity entity = new HttpEntity<>(pinecone, headers); @@ -65,18 +80,13 @@ public EdgeChain upsert(WordEmbeddings wordEmbeddings) { endpoint); } - public EdgeChain> query(WordEmbeddings wordEmbeddings, int topK) { + public EdgeChain> query(PineconeEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { - this.namespace = - (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) - ? "" - : endpoint.getNamespace(); - HttpHeaders headers = new HttpHeaders(); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); headers.setContentType(MediaType.APPLICATION_JSON); @@ -86,9 +96,9 @@ public EdgeChain> query(WordEmbeddings wordEmbeddings, int Map payload = new LinkedHashMap<>(); payload.put("includeValues", true); payload.put("includeMetadata", false); - payload.put("vector", wordEmbeddings.getValues()); - payload.put("top_k", topK); - payload.put("namespace", this.namespace); + payload.put("vector", endpoint.getWordEmbedding().getValues()); + payload.put("top_k", endpoint.getTopK()); + payload.put("namespace", getNamespace(endpoint)); HttpEntity> entity = new HttpEntity<>(payload, headers); @@ -105,18 +115,13 @@ public EdgeChain> query(WordEmbeddings wordEmbeddings, int endpoint); } - public EdgeChain deleteAll() { + public EdgeChain deleteAll(PineconeEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { - this.namespace = - (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) - ? "" - : endpoint.getNamespace(); - HttpHeaders headers = new HttpHeaders(); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); headers.setContentType(MediaType.APPLICATION_JSON); @@ -124,7 +129,7 @@ public EdgeChain deleteAll() { Map body = new HashMap<>(); body.put("deleteAll", true); - body.put("namespace", namespace); + body.put("namespace", getNamespace(endpoint)); HttpEntity> entity = new HttpEntity<>(body, headers); @@ -134,7 +139,7 @@ public EdgeChain deleteAll() { emitter.onNext( new StringResponse( "Word embeddings are successfully deleted for namespace:" - + this.namespace)); + + getNamespace(endpoint))); emitter.onComplete(); } catch (final Exception e) { @@ -158,4 +163,10 @@ private List parsePredict(String body) throws IOException { return words2VecList; } + + public String getNamespace(PineconeEndpoint endpoint) { + return (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) + ? "" + : endpoint.getNamespace(); + } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java index e3652b284..44b3425b4 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/PostgresClient.java @@ -1,71 +1,214 @@ package com.edgechain.lib.index.client.impl; -import com.edgechain.lib.configuration.context.ApplicationContextHolder; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; -import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.repositories.PostgresClientMetadataRepository; import com.edgechain.lib.index.repositories.PostgresClientRepository; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import io.reactivex.rxjava3.core.Observable; -import org.springframework.stereotype.Service; +import java.math.BigDecimal; +import java.sql.Date; import java.sql.Timestamp; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import org.postgresql.util.PGobject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; @Service public class PostgresClient { - private PostgresEndpoint postgresEndpoint; - private String namespace; + private static final Logger logger = LoggerFactory.getLogger(PostgresClient.class); + + private static final TypeReference> FLOAT_TYPE_REF = new TypeReference<>() {}; - private final PostgresClientRepository repository = - ApplicationContextHolder.getContext().getBean(PostgresClientRepository.class); + @Autowired private PostgresClientRepository repository; - public PostgresEndpoint getPostgresEndpoint() { - return postgresEndpoint; + @Autowired private PostgresClientMetadataRepository metadataRepository; + + private final ObjectMapper objectMapper = new ObjectMapper(); + + public EdgeChain createTable(PostgresEndpoint postgresEndpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + this.repository.createTable(postgresEndpoint); + emitter.onNext(new StringResponse("Table: " + postgresEndpoint.getTableName())); + emitter.onComplete(); + } catch (final Exception e) { + emitter.onError(e); + } + })); } - public void setPostgresEndpoint(PostgresEndpoint postgresEndpoint) { - this.postgresEndpoint = postgresEndpoint; + public EdgeChain createMetadataTable(PostgresEndpoint postgresEndpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + this.metadataRepository.createTable(postgresEndpoint); + emitter.onNext( + new StringResponse( + "Table: " + postgresEndpoint.getMetadataTableNames().get(0))); + emitter.onComplete(); + } catch (final Exception e) { + emitter.onError(e); + } + })); } - public String getNamespace() { - return namespace; + public EdgeChain> batchUpsert(PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + // Upsert Embeddings + List strings = + this.repository.batchUpsertEmbeddings( + postgresEndpoint.getTableName(), + postgresEndpoint.getWordEmbeddingsList(), + postgresEndpoint.getFilename(), + getNamespace(postgresEndpoint), + postgresEndpoint.getPostgresLanguage()); + + List stringResponseList = + strings.stream().map(StringResponse::new).toList(); + + emitter.onNext(stringResponseList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); } - public void setNamespace(String namespace) { - this.namespace = namespace; + public EdgeChain upsert(PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + // Upsert Embeddings + String embeddingId = + this.repository.upsertEmbeddings( + postgresEndpoint.getTableName(), + postgresEndpoint.getWordEmbedding(), + postgresEndpoint.getFilename(), + getNamespace(postgresEndpoint), + postgresEndpoint.getPostgresLanguage()); + + emitter.onNext(new StringResponse(embeddingId)); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); } - public EdgeChain upsert(WordEmbeddings wordEmbeddings) { + public EdgeChain insertMetadata(PostgresEndpoint postgresEndpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { + String metadata = postgresEndpoint.getMetadata(); + String input = metadata.replace("'", ""); - this.namespace = - (Objects.isNull(postgresEndpoint.getNamespace()) - || postgresEndpoint.getNamespace().isEmpty()) - ? "knowledge" - : postgresEndpoint.getNamespace(); + String metadataId = + this.metadataRepository.insertMetadata( + postgresEndpoint.getTableName(), + postgresEndpoint.getMetadataTableNames().get(0), + input, + postgresEndpoint.getDocumentDate()); - // Create Table - this.repository.createTable(postgresEndpoint); + emitter.onNext(new StringResponse(metadataId)); + emitter.onComplete(); - String input = wordEmbeddings.getId().replaceAll("'", ""); + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } - // Upsert Embeddings - this.repository.upsertEmbeddings( + public EdgeChain> batchInsertMetadata(PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + // Insert metadata + List strings = + this.metadataRepository.batchInsertMetadata( + postgresEndpoint.getTableName(), + postgresEndpoint.getMetadataTableNames().get(0), + postgresEndpoint.getMetadataList()); + + List stringResponseList = + strings.stream().map(StringResponse::new).toList(); + + emitter.onNext(stringResponseList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain insertIntoJoinTable(PostgresEndpoint postgresEndpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + this.metadataRepository.insertIntoJoinTable(postgresEndpoint); + + emitter.onNext(new StringResponse("Inserted")); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain batchInsertIntoJoinTable(PostgresEndpoint postgresEndpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + this.metadataRepository.batchInsertIntoJoinTable( postgresEndpoint.getTableName(), - input, - postgresEndpoint.getFilename(), - wordEmbeddings, - this.namespace); + postgresEndpoint.getMetadataTableNames().get(0), + postgresEndpoint.getIdList(), + postgresEndpoint.getMetadataId()); - emitter.onNext(new StringResponse("Upserted")); + emitter.onNext(new StringResponse("Inserted")); emitter.onComplete(); } catch (final Exception e) { @@ -75,39 +218,290 @@ public EdgeChain upsert(WordEmbeddings wordEmbeddings) { postgresEndpoint); } - public EdgeChain> query( - WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK, int probes) { + public EdgeChain> query(PostgresEndpoint postgresEndpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { - this.namespace = - (Objects.isNull(postgresEndpoint.getNamespace()) - || postgresEndpoint.getNamespace().isEmpty()) - ? "knowledge" - : postgresEndpoint.getNamespace(); + List wordEmbeddingsList = new ArrayList<>(); + + List> embeddings = + postgresEndpoint.getWordEmbeddingsList().stream() + .map(WordEmbeddings::getValues) + .toList(); List> rows = this.repository.query( postgresEndpoint.getTableName(), - this.namespace, - probes, - metric, - wordEmbeddings, - topK); + getNamespace(postgresEndpoint), + postgresEndpoint.getProbes(), + postgresEndpoint.getMetric(), + embeddings, + postgresEndpoint.getTopK(), + postgresEndpoint.getUpperLimit()); + + for (Map row : rows) { + + PostgresWordEmbeddings val = new PostgresWordEmbeddings(); + val.setId(Objects.nonNull(row.get("id")) ? row.get("id").toString() : null); + val.setRawText( + Objects.nonNull(row.get("raw_text")) ? (String) row.get("raw_text") : null); + val.setFilename( + Objects.nonNull(row.get("filename")) ? (String) row.get("filename") : null); + val.setTimestamp( + Objects.nonNull(row.get("timestamp")) + ? ((Timestamp) row.get("timestamp")).toLocalDateTime() + : null); + val.setNamespace( + Objects.nonNull(row.get("namespace")) ? (String) row.get("namespace") : null); + + val.setScore( + Objects.nonNull(row.get("score")) ? (Double) row.get("score") : null); + + PGobject pgObject = (PGobject) row.get("embedding"); + String jsonString = pgObject.getValue(); + List values = objectMapper.readerFor(FLOAT_TYPE_REF).readValue(jsonString); + val.setValues(values); + + wordEmbeddingsList.add(val); + } + emitter.onNext(wordEmbeddingsList); + emitter.onComplete(); + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain> queryRRF(PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { List wordEmbeddingsList = new ArrayList<>(); + List> embeddings = + postgresEndpoint.getWordEmbeddingsList().stream() + .map(WordEmbeddings::getValues) + .toList(); - for (Map row : rows) { + List> rows = + this.repository.queryRRF( + postgresEndpoint.getTableName(), + getNamespace(postgresEndpoint), + postgresEndpoint.getMetadataTableNames().get(0), + embeddings, + postgresEndpoint.getTextWeight(), + postgresEndpoint.getSimilarityWeight(), + postgresEndpoint.getDateWeight(), + postgresEndpoint.getSearchQuery(), + postgresEndpoint.getPostgresLanguage(), + postgresEndpoint.getProbes(), + postgresEndpoint.getMetric(), + postgresEndpoint.getTopK(), + postgresEndpoint.getUpperLimit(), + postgresEndpoint.getOrderRRFBy()); + + for (Map row : rows) { PostgresWordEmbeddings val = new PostgresWordEmbeddings(); - val.setId((String) row.get("id")); + val.setId(Objects.nonNull(row.get("id")) ? row.get("id").toString() : null); + val.setRawText( + Objects.nonNull(row.get("raw_text")) ? (String) row.get("raw_text") : null); + + val.setFilename( + Objects.nonNull(row.get("filename")) ? (String) row.get("filename") : null); + val.setTimestamp( + Objects.nonNull(row.get("timestamp")) + ? ((Timestamp) row.get("timestamp")).toLocalDateTime() + : null); + val.setNamespace( + Objects.nonNull(row.get("namespace")) ? (String) row.get("namespace") : null); + + BigDecimal bigDecimal = + Objects.nonNull(row.get("rrf_score")) + ? (BigDecimal) row.get("rrf_score") + : null; + val.setScore(bigDecimal.doubleValue()); + + if (postgresEndpoint.getMetadataTableNames().get(0).contains("title")) { + val.setTitleMetadata( + Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null); + } else { + val.setMetadata( + Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null); + } + Date documentDate = + Objects.nonNull(row.get("document_date")) + ? (Date) row.get("document_date") + : null; + val.setDocumentDate(documentDate.toString()); + + wordEmbeddingsList.add(val); + } + emitter.onNext(wordEmbeddingsList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain> queryWithMetadata( + PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + List wordEmbeddingsList = new ArrayList<>(); + if (postgresEndpoint.getMetadataTableNames() != null) { + try { + List metadataTableNames = postgresEndpoint.getMetadataTableNames(); + int numberOfMetadataTables = metadataTableNames.size(); + + /* + * This map will store the pairs + * We need to extract the title info from another metadata table. + * So instead of having extra PostgresWordEmbeddings objects for the title info + * We can store the title info in a map corresponding to the id key of the embeddings table + * Then after the loop is over we can inject the title field in the correct PostgresWordEmbeddings object by using the id key. + */ + Map titleMetadataMap = new HashMap<>(); + Map dateMetadataMap = new HashMap<>(); + for (String metadataTableName : metadataTableNames) { + List> rows = + this.metadataRepository.queryWithMetadata( + postgresEndpoint.getTableName(), + metadataTableName, + getNamespace(postgresEndpoint), + postgresEndpoint.getProbes(), + postgresEndpoint.getMetric(), + postgresEndpoint.getWordEmbedding().getValues(), + postgresEndpoint.getTopK()); + + // To filter out duplicate context chunks + Set contextChunkIds = new HashSet<>(); + for (Map row : rows) { + String metadataId = row.get("metadata_id").toString(); + if (!metadataTableName.contains("title_metadata") + && contextChunkIds.contains(metadataId)) continue; + + PostgresWordEmbeddings val = new PostgresWordEmbeddings(); + final String idStr = + Objects.nonNull(row.get("id")) ? row.get("id").toString() : null; + val.setId(idStr); + val.setRawText( + Objects.nonNull(row.get("raw_text")) + ? (String) row.get("raw_text") + : null); + val.setFilename( + Objects.nonNull(row.get("filename")) + ? (String) row.get("filename") + : null); + val.setTimestamp( + Objects.nonNull(row.get("timestamp")) + ? ((Timestamp) row.get("timestamp")).toLocalDateTime() + : null); + val.setNamespace( + Objects.nonNull(row.get("namespace")) + ? (String) row.get("namespace") + : null); + val.setScore( + Objects.nonNull(row.get("score")) ? (Double) row.get("score") : null); + + // Add metadata fields in response + if (metadataTableName.contains("title_metadata")) { + titleMetadataMap.put(idStr, (String) row.get("metadata")); + dateMetadataMap.put(idStr, (String) row.get("document_date")); + + // For checking if only one metadata table is present which is the title + // table + if (numberOfMetadataTables > 1) continue; + } else { + val.setMetadata((String) row.get("metadata")); + } + contextChunkIds.add(metadataId); + wordEmbeddingsList.add(val); + } + + // Insert the title and date fields into their respective + // PostgresWordEmbeddings + for (PostgresWordEmbeddings wordEmbedding : wordEmbeddingsList) { + String id = wordEmbedding.getId(); + if (titleMetadataMap.containsKey(id)) { + wordEmbedding.setTitleMetadata(titleMetadataMap.get(id)); + } + if (dateMetadataMap.containsKey(id)) { + wordEmbedding.setDocumentDate(dateMetadataMap.get(id)); + } + } + } + } catch (Exception e) { + logger.warn("ignored query error", e); + } + } + + emitter.onNext(wordEmbeddingsList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain> getAllChunks(PostgresEndpoint postgresEndpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + List wordEmbeddingsList = new ArrayList<>(); + List> rows = this.repository.getAllChunks(postgresEndpoint); + for (Map row : rows) { + PostgresWordEmbeddings val = new PostgresWordEmbeddings(); + val.setId(row.get("id").toString()); val.setRawText((String) row.get("raw_text")); val.setFilename((String) row.get("filename")); - val.setTimestamp(((Timestamp) row.get("timestamp")).toLocalDateTime()); - val.setNamespace((String) row.get("namespace")); - val.setScore((Double) row.get("score")); + PGobject pgObject = (PGobject) row.get("embedding"); + String jsonString = pgObject.getValue(); + List values = objectMapper.readerFor(FLOAT_TYPE_REF).readValue(jsonString); + val.setValues(values); + wordEmbeddingsList.add(val); + } + emitter.onNext(wordEmbeddingsList); + emitter.onComplete(); + } catch (final Exception e) { + emitter.onError(e); + } + }), + postgresEndpoint); + } + + public EdgeChain> getSimilarMetadataChunk( + PostgresEndpoint postgresEndpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + List wordEmbeddingsList = new ArrayList<>(); + List> rows = + this.metadataRepository.getSimilarMetadataChunk( + postgresEndpoint.getTableName(), + postgresEndpoint.getMetadataTableNames().get(0), + postgresEndpoint.getEmbeddingChunk()); + for (Map row : rows) { + + PostgresWordEmbeddings val = new PostgresWordEmbeddings(); + val.setMetadataId(row.get("metadata_id").toString()); + val.setMetadata((String) row.get("metadata")); wordEmbeddingsList.add(val); } @@ -122,23 +516,17 @@ public EdgeChain> query( postgresEndpoint); } - public EdgeChain deleteAll() { + public EdgeChain deleteAll(PostgresEndpoint postgresEndpoint) { return new EdgeChain<>( Observable.create( emitter -> { - this.namespace = - (Objects.isNull(postgresEndpoint.getNamespace()) - || postgresEndpoint.getNamespace().isEmpty()) - ? "knowledge" - : postgresEndpoint.getNamespace(); - + String namespace = getNamespace(postgresEndpoint); try { - this.repository.deleteAll(postgresEndpoint.getTableName(), this.namespace); + this.repository.deleteAll(postgresEndpoint.getTableName(), namespace); emitter.onNext( new StringResponse( - "Word embeddings are successfully deleted for namespace:" - + this.namespace)); + "Word embeddings are successfully deleted for namespace:" + namespace)); emitter.onComplete(); } catch (final Exception e) { emitter.onError(e); @@ -146,4 +534,11 @@ public EdgeChain deleteAll() { }), postgresEndpoint); } + + private String getNamespace(PostgresEndpoint postgresEndpoint) { + return (Objects.isNull(postgresEndpoint.getNamespace()) + || postgresEndpoint.getNamespace().isEmpty()) + ? "knowledge" + : postgresEndpoint.getNamespace(); + } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java index 7f1c3790b..c436e1033 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/client/impl/RedisClient.java @@ -1,7 +1,7 @@ package com.edgechain.lib.index.client.impl; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.enums.RedisDistanceMetric; import com.edgechain.lib.index.responses.RedisDocument; import com.edgechain.lib.index.responses.RedisProperty; @@ -16,7 +16,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.*; import redis.clients.jedis.search.*; import java.util.*; @@ -30,84 +30,105 @@ public class RedisClient { + " local res = redis.call('del', k)" + " end"; - private RedisEndpoint endpoint; - - private String indexName; - private String namespace; - - public RedisEndpoint getEndpoint() { - return endpoint; - } - - public void setEndpoint(RedisEndpoint endpoint) { - this.endpoint = endpoint; - } - private final Logger logger = LoggerFactory.getLogger(this.getClass()); @Autowired private JedisPooled jedisPooled; - public EdgeChain upsert( - WordEmbeddings words2Vec, int dimension, RedisDistanceMetric metric) { - + public EdgeChain createIndex(RedisEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { try { + this.createSearchIndex( + getNamespace(endpoint), + endpoint.getIndexName(), + endpoint.getDimensions(), + endpoint.getMetric()); + emitter.onNext(new StringResponse("Created Index ~ ")); + emitter.onComplete(); + } catch (final Exception e) { + emitter.onError(e); + } + })); + } - this.indexName = endpoint.getIndexName(); - this.namespace = - (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) - ? "knowledge" - : endpoint.getNamespace(); + public EdgeChain upsert(RedisEndpoint endpoint) { - this.createSearchIndex(dimension, RedisDistanceMetric.getDistanceMetric(metric)); + return new EdgeChain<>( + Observable.create( + emitter -> { + try (Jedis jedis = new Jedis(jedisPooled.getPool().getResource())) { Map map = new HashMap<>(); - map.put("id".getBytes(), words2Vec.getId().getBytes()); + map.put("id".getBytes(), endpoint.getWordEmbedding().getId().getBytes()); map.put( "values".getBytes(), - FloatUtils.toByteArray(FloatUtils.toFloatArray(words2Vec.getValues()))); + FloatUtils.toByteArray( + FloatUtils.toFloatArray(endpoint.getWordEmbedding().getValues()))); long v = - jedisPooled.hset((this.namespace + ":" + words2Vec.getId()).getBytes(), map); - - jedisPooled.getPool().returnResource(jedisPooled.getPool().getResource()); + jedis.hset( + (getNamespace(endpoint) + ":" + endpoint.getWordEmbedding().getId()) + .getBytes(), + map); emitter.onNext(new StringResponse("Created ~ " + v)); emitter.onComplete(); } catch (Exception ex) { - jedisPooled.getPool().returnBrokenResource(jedisPooled.getPool().getResource()); emitter.onError(ex); } }), endpoint); } - public EdgeChain> query(WordEmbeddings words2Vec, int topK) { - + public EdgeChain batchUpsert(RedisEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { - try { + try (Jedis jedis = new Jedis(jedisPooled.getPool().getResource())) { + + Pipeline pipeline = jedis.pipelined(); + + for (WordEmbeddings w : endpoint.getWordEmbeddingsList()) { + Map map = new HashMap<>(); + map.put("id".getBytes(), w.getId().getBytes()); + map.put( + "values".getBytes(), + FloatUtils.toByteArray(FloatUtils.toFloatArray(w.getValues()))); + + pipeline.hmset((getNamespace(endpoint) + ":" + w.getId()).getBytes(), map); + } - this.indexName = endpoint.getIndexName(); - this.namespace = - (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) - ? "knowledge" - : endpoint.getNamespace(); + pipeline.sync(); + emitter.onNext(new StringResponse("Batch Processing Completed")); + emitter.onComplete(); + } catch (Exception ex) { + + emitter.onError(ex); + } + }), + endpoint); + } + + public EdgeChain> query(RedisEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { Query query = new Query("*=>[KNN $k @values $values]") .addParam( "values", - FloatUtils.toByteArray(FloatUtils.toFloatArray(words2Vec.getValues()))) - .addParam("k", topK) + FloatUtils.toByteArray( + FloatUtils.toFloatArray(endpoint.getWordEmbedding().getValues()))) + .addParam("k", endpoint.getTopK()) .returnFields("id", "__values_score") .setSortBy("__values_score", false) .dialect(2); - SearchResult searchResult = jedisPooled.ftSearch(this.indexName, query); + SearchResult searchResult = jedisPooled.ftSearch(endpoint.getIndexName(), query); String body = JsonUtils.convertToString(searchResult); @@ -121,53 +142,46 @@ public EdgeChain> query(WordEmbeddings words2Vec, int topK) ArrayList properties = iterator.next().getProperties(); words2VecList.add( new WordEmbeddings( - properties.get(1).getId(), - String.valueOf(properties.get(0).get__values_score()))); + properties.get(1).getId(), properties.get(0).get__values_score())); } emitter.onNext(words2VecList); emitter.onComplete(); } catch (Exception ex) { + jedisPooled.getPool().returnBrokenResource(jedisPooled.getPool().getResource()); emitter.onError(ex); } }), endpoint); } - public EdgeChain deleteByPattern(String pattern) { + public EdgeChain deleteByPattern(RedisEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { - try { + try (Jedis jedis = new Jedis(jedisPooled.getPool().getResource())) { - this.indexName = endpoint.getIndexName(); - this.namespace = - (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) - ? "knowledge" - : endpoint.getNamespace(); - - jedisPooled.eval(String.format(REDIS_DELETE_SCRIPT_IN_LUA, pattern)); - - jedisPooled.getPool().returnResource(jedisPooled.getPool().getResource()); + jedis.eval(String.format(REDIS_DELETE_SCRIPT_IN_LUA, endpoint.getPattern())); emitter.onNext( new StringResponse( - "Word embeddings are successfully deleted for pattern:" + pattern)); + "Word embeddings are successfully deleted for pattern:" + + endpoint.getPattern())); emitter.onComplete(); } catch (Exception ex) { - jedisPooled.getPool().returnBrokenResource(jedisPooled.getPool().getResource()); emitter.onError(ex); } }), endpoint); } - private void createSearchIndex(int dimension, String metric) { + private void createSearchIndex( + String namespace, String indexName, int dimension, RedisDistanceMetric metric) { try { - Map map = jedisPooled.ftInfo(this.indexName); + Map map = jedisPooled.ftInfo(indexName); if (Objects.nonNull(map)) { return; } @@ -184,12 +198,18 @@ private void createSearchIndex(int dimension, String metric) { .addTextField("id", 1) .addVectorField("values", Schema.VectorField.VectorAlgo.HNSW, attributes); - IndexDefinition indexDefinition = new IndexDefinition().setPrefixes(this.namespace); + IndexDefinition indexDefinition = new IndexDefinition().setPrefixes(namespace); String ftCreate = jedisPooled.ftCreate( - this.indexName, IndexOptions.defaultOptions().setDefinition(indexDefinition), schema); + indexName, IndexOptions.defaultOptions().setDefinition(indexDefinition), schema); logger.info("Redis search vector_index created ~ " + ftCreate); } + + private String getNamespace(RedisEndpoint endpoint) { + return (Objects.isNull(endpoint.getNamespace()) || endpoint.getNamespace().isEmpty()) + ? "knowledge" + : endpoint.getNamespace(); + } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java index 862c44025..cd2263fe8 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/PostgresWordEmbeddings.java @@ -9,8 +9,9 @@ public class PostgresWordEmbeddings implements ArkObject { - private Long embedding_id; - + // private Long embedding_id; + // + // private Integer embedding_id; private String id; private String rawText; @@ -23,15 +24,20 @@ public class PostgresWordEmbeddings implements ArkObject { private LocalDateTime timestamp; - private Double score; // will be added + private Double score; - public Long getEmbedding_id() { - return embedding_id; - } + private String metadata; + private String metadataId; + private String titleMetadata; + private String documentDate; - public void setEmbedding_id(Long embedding_id) { - this.embedding_id = embedding_id; - } + // public Integer getEmbedding_id() { + // return embedding_id; + // } + // + // public void setEmbedding_id(Integer embedding_id) { + // this.embedding_id = embedding_id; + // } public String getId() { return id; @@ -89,16 +95,104 @@ public void setFilename(String filename) { this.filename = filename; } + public String getMetadata() { + return metadata; + } + + public void setMetadata(String metadata) { + this.metadata = metadata; + } + + public String getMetadataId() { + return metadataId; + } + + public void setMetadataId(String metadataId) { + this.metadataId = metadataId; + } + + public String getTitleMetadata() { + return titleMetadata; + } + + public void setTitleMetadata(String titleMetadata) { + this.titleMetadata = titleMetadata; + } + + public String getDocumentDate() { + return documentDate; + } + + public void setDocumentDate(String documentDate) { + this.documentDate = documentDate; + } + @Override public JSONObject toJson() { JSONObject json = new JSONObject(); - json.put("id", id); - json.put("rawText", rawText); - json.put("namespace", namespace); - json.put("filename", filename); - json.put("values", new JSONArray(values)); - json.put("timestamp", timestamp.toString()); - json.put("score", score); + + // if (embedding_id != null) { + // json.put("embedding_id", embedding_id); + // } + + if (id != null) { + json.put("id", id); + } + + if (rawText != null) { + json.put("rawText", rawText); + } + + if (namespace != null) { + json.put("namespace", namespace); + } + + if (filename != null) { + json.put("filename", filename); + } + + if (values != null) { + json.put("values", new JSONArray(values)); + } + + if (timestamp != null) { + json.put("timestamp", timestamp.toString()); + } + + if (score != null && !Double.isNaN(score)) { + json.put("score", score); + } + + if (titleMetadata != null) { + json.put("titleMetadata", titleMetadata); + } + + if (documentDate != null) { + json.put("documentDate", documentDate); + } + + if (metadata != null) { + json.put("metadata", metadata); + } + if (metadataId != null) { + json.put("metadataId", metadataId); + } + return json; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PostgresWordEmbeddings that = (PostgresWordEmbeddings) o; + + return id.equals(that.id); + } + + @Override + public int hashCode() { + return id.hashCode(); + } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/RRFWeight.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/RRFWeight.java new file mode 100644 index 000000000..19b6f6974 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/domain/RRFWeight.java @@ -0,0 +1,45 @@ +package com.edgechain.lib.index.domain; + +import com.edgechain.lib.index.enums.BaseWeight; + +import java.util.StringJoiner; + +public class RRFWeight { + + private BaseWeight baseWeight = BaseWeight.W1_0; + private double fineTuneWeight = 0.5; + + public RRFWeight() {} + + public RRFWeight(BaseWeight baseWeight, double fineTuneWeight) { + this.baseWeight = baseWeight; + this.fineTuneWeight = fineTuneWeight; + + if (fineTuneWeight < 0 || fineTuneWeight > 1.0) + throw new IllegalArgumentException("Weights must be between 0 and 1."); + } + + public void setBaseWeight(BaseWeight baseWeight) { + this.baseWeight = baseWeight; + } + + public void setFineTuneWeight(double fineTuneWeight) { + this.fineTuneWeight = fineTuneWeight; + } + + public BaseWeight getBaseWeight() { + return baseWeight; + } + + public double getFineTuneWeight() { + return fineTuneWeight; + } + + @Override + public String toString() { + return new StringJoiner(", ", RRFWeight.class.getSimpleName() + "[", "]") + .add("baseWeight=" + baseWeight) + .add("fineTuneWeight=" + fineTuneWeight) + .toString(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/BaseWeight.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/BaseWeight.java new file mode 100644 index 000000000..e998f8a09 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/BaseWeight.java @@ -0,0 +1,32 @@ +package com.edgechain.lib.index.enums; + +public enum BaseWeight { + W1_0(1.0), + W1_25(1.25), + W1_5(1.5), + W1_75(1.75), + W2_0(2.0), + W2_25(2.25), + W2_5(2.5), + W2_75(2.75), + W3_0(3.0); + + private final double value; + + BaseWeight(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + public static BaseWeight fromDouble(double value) { + for (BaseWeight baseWeight : BaseWeight.values()) { + if (baseWeight.getValue() == value) { + return baseWeight; + } + } + throw new IllegalArgumentException("Invalid BaseWeight value: " + value); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/OrderRRFBy.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/OrderRRFBy.java new file mode 100644 index 000000000..db857d0e2 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/OrderRRFBy.java @@ -0,0 +1,23 @@ +package com.edgechain.lib.index.enums; + +public enum OrderRRFBy { + DEFAULT, // Preferred Way; ordered by rrf_score; (relevance over freshness) + TEXT_RANK, // First Ordered By Text_Rank; then ordered by rrf_score (text_rank preferred, then + // relevance) + SIMILARITY, // First Ordered by Similarity; then ordered by rrf_score; (similarity preferred, then + // relevance) + DATE_RANK; // First Ordered by date_rank; then ordered by rrf_score; (freshness preferred, then + + // relevance) + + public static OrderRRFBy fromString(String value) { + if (value != null) { + for (OrderRRFBy orderRRFBy : OrderRRFBy.values()) { + if (orderRRFBy.name().equalsIgnoreCase(value)) { + return orderRRFBy; + } + } + } + throw new IllegalArgumentException("Invalid OrderRRFBy value: " + value); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/PostgresLanguage.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/PostgresLanguage.java new file mode 100644 index 000000000..01f7e82df --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/enums/PostgresLanguage.java @@ -0,0 +1,43 @@ +package com.edgechain.lib.index.enums; + +public enum PostgresLanguage { + SIMPLE("simple"), + ARABIC("arabic"), + ARMENIAN("armenian"), + BASQUE("basque"), + CATALAN("catalan"), + DANISH("danish"), + DUTCH("dutch"), + ENGLISH("english"), + FINNISH("finnish"), + FRENCH("french"), + GERMAN("german"), + GREEK("greek"), + HINDI("hindi"), + HUNGARIAN("hungarian"), + INDONESIAN("indonesian"), + IRISH("irish"), + ITALIAN("italian"), + LITHUANIAN("lithuanian"), + NEPALI("nepali"), + NORWEGIAN("norwegian"), + PORTUGUESE("portuguese"), + ROMANIAN("romanian"), + RUSSIAN("russian"), + SERBIAN("serbian"), + SPANISH("spanish"), + SWEDISH("swedish"), + TAMIL("tamil"), + TURKISH("turkish"), + YIDDISH("yiddish"); + + private final String value; + + PostgresLanguage(String value) { + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java new file mode 100644 index 000000000..5f340aed3 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientMetadataRepository.java @@ -0,0 +1,208 @@ +package com.edgechain.lib.index.repositories; + +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.utils.FloatUtils; +import com.github.f4b6a3.uuid.UuidCreator; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Repository; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import java.util.*; + +@Repository +public class PostgresClientMetadataRepository { + + @Autowired private JdbcTemplate jdbcTemplate; + + @Transactional + public void createTable(PostgresEndpoint postgresEndpoint) { + String metadataTable = postgresEndpoint.getMetadataTableNames().get(0); + jdbcTemplate.execute( + String.format( + "CREATE TABLE IF NOT EXISTS %s (metadata_id UUID PRIMARY KEY, metadata TEXT NOT NULL," + + " document_date DATE);", + postgresEndpoint.getTableName() + "_" + metadataTable)); + + // Create a JOIN table + jdbcTemplate.execute( + String.format( + "CREATE TABLE IF NOT EXISTS %s (id UUID UNIQUE NOT NULL, metadata_id UUID NOT NULL, " + + "FOREIGN KEY (id) REFERENCES %s(id) ON DELETE CASCADE, " + + "FOREIGN KEY (metadata_id) REFERENCES %s(metadata_id) ON DELETE CASCADE, " + + "PRIMARY KEY (id, metadata_id));", + postgresEndpoint.getTableName() + "_join_" + metadataTable, + postgresEndpoint.getTableName(), + postgresEndpoint.getTableName() + "_" + metadataTable)); + + jdbcTemplate.execute( + String.format( + "CREATE INDEX IF NOT EXISTS idx_%s ON %s (metadata_id);", + postgresEndpoint.getTableName() + "_join_" + metadataTable, + postgresEndpoint.getTableName() + "_join_" + metadataTable)); + } + + @Transactional + public List batchInsertMetadata( + String table, String metadataTableName, List metadataList) { + + Set uuidSet = new HashSet<>(); + + for (int i = 0; i < metadataList.size(); i++) { + + String metadata = metadataList.get(i).replace("'", ""); + + UUID metadataId = + jdbcTemplate.queryForObject( + String.format( + "INSERT INTO %s (metadata_id, metadata) VALUES ('%s', ?) RETURNING metadata_id;", + table.concat("_").concat(metadataTableName), UuidCreator.getTimeOrderedEpoch()), + UUID.class, + metadata); + + if (metadataId != null) { + uuidSet.add(metadataId.toString()); + } + } + + return new ArrayList<>(uuidSet); + } + + @Transactional + public String insertMetadata( + String table, String metadataTableName, String metadata, String documentDate) { + + metadata = metadata.replace("'", ""); + + UUID metadataId = + jdbcTemplate.queryForObject( + String.format( + "INSERT INTO %s (metadata_id, metadata, document_date) VALUES ('%s', ?," + + " TO_DATE(NULLIF(?, ''), 'Month DD, YYYY')) RETURNING metadata_id;", + table.concat("_").concat(metadataTableName), UuidCreator.getTimeOrderedEpoch()), + UUID.class, + metadata, + documentDate); + + return Objects.requireNonNull(metadataId).toString(); + } + + @Transactional + public void insertIntoJoinTable(PostgresEndpoint postgresEndpoint) { + String joinTableName = + postgresEndpoint.getTableName() + + "_join_" + + postgresEndpoint.getMetadataTableNames().get(0); + jdbcTemplate.execute( + String.format( + "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET" + + " metadata_id = EXCLUDED.metadata_id;", + joinTableName, + UUID.fromString(postgresEndpoint.getId()), + UUID.fromString(postgresEndpoint.getMetadataId()))); + } + + @Transactional + public void batchInsertIntoJoinTable( + String tableName, String metadataTableName, List idList, String metadataId) { + String joinTableName = tableName + "_join_" + metadataTableName; + List sqlStatements = new ArrayList<>(); + for (String id : idList) { + sqlStatements.add( + String.format( + "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET" + + " metadata_id = EXCLUDED.metadata_id;", + joinTableName, UUID.fromString(id), UUID.fromString(metadataId))); + } + jdbcTemplate.batchUpdate(sqlStatements.toArray(new String[0])); + } + + @Transactional(readOnly = true, propagation = Propagation.REQUIRED) + public List> queryWithMetadata( + String tableName, + String metadataTableName, + String namespace, + int probes, + PostgresDistanceMetric metric, + List values, + int topK) { + + String embeddings = Arrays.toString(FloatUtils.toFloatArray(values)); + + jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); + String joinTable = tableName + "_join_" + metadataTableName; + + if (metric.equals(PostgresDistanceMetric.IP)) { + return jdbcTemplate.queryForList( + String.format( + "SELECT e.id, metadata, TO_CHAR(document_date, 'Month DD, YYYY') as document_date," + + " j.metadata_id, raw_text, namespace, filename, timestamp, ( embedding <#>" + + " '%s') * -1 AS score FROM %s e INNER JOIN %s j ON e.id = j.id INNER JOIN %s m" + + " ON j.metadata_id = m.metadata_id WHERE namespace='%s' ORDER BY embedding %s" + + " '%s' LIMIT %s;", + embeddings, + tableName, + joinTable, + tableName.concat("_").concat(metadataTableName), + namespace, + PostgresDistanceMetric.getDistanceMetric(metric), + embeddings, + topK)); + + } else if (metric.equals(PostgresDistanceMetric.COSINE)) { + return jdbcTemplate.queryForList( + String.format( + "SELECT e.id, metadata, TO_CHAR(document_date, 'Month DD, YYYY') as document_date," + + " j.metadata_id, raw_text, namespace, filename, timestamp, 1 - ( embedding <=>" + + " '%s') AS score FROM %s e INNER JOIN %s j ON e.id = j.id INNER JOIN %s m ON" + + " j.metadata_id = m.metadata_id WHERE namespace='%s' ORDER BY embedding %s '%s'" + + " LIMIT %s;", + embeddings, + tableName, + joinTable, + tableName.concat("_").concat(metadataTableName), + namespace, + PostgresDistanceMetric.getDistanceMetric(metric), + embeddings, + topK)); + } else { + return jdbcTemplate.queryForList( + String.format( + "SELECT e.id, metadata, TO_CHAR(document_date, 'Month DD, YYYY') as document_date," + + " j.metadata_id, raw_text, namespace, filename, timestamp, (embedding <-> '%s')" + + " AS score FROM %s e INNER JOIN %s j ON e.id = j.id INNER JOIN %s m ON" + + " j.metadata_id = m.metadata_id WHERE namespace='%s' ORDER BY embedding %s '%s'" + + " ASC LIMIT %s;", + embeddings, + tableName, + joinTable, + tableName.concat("_").concat(metadataTableName), + namespace, + PostgresDistanceMetric.getDistanceMetric(metric), + embeddings, + topK)); + } + } + + // Full-text search + @Transactional(readOnly = true, propagation = Propagation.REQUIRED) + public List> getSimilarMetadataChunk( + String table, String metadataTableName, String embeddingChunk) { + // Remove special characters and replace with a space + String cleanEmbeddingChunk = + embeddingChunk.replaceAll("[^a-zA-Z0-9\\s]", " ").replaceAll("\\s+", " ").trim(); + + String tableName = table.concat("_").concat(metadataTableName); + + // Split the embeddingChunk into words and join them with the '|' (OR) operator + String tsquery = String.join(" | ", cleanEmbeddingChunk.split("\\s+")); + return jdbcTemplate.queryForList( + String.format( + "SELECT *, ts_rank(to_tsvector(%s.metadata), query) as rank_metadata " + + "FROM %s, to_tsvector(%s.metadata) document, to_tsquery('%s') query " + + "WHERE query @@ document ORDER BY rank_metadata DESC", + tableName, tableName, tableName, tsquery)); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java index ea3e641c7..a626c5741 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/index/repositories/PostgresClientRepository.java @@ -1,8 +1,11 @@ package com.edgechain.lib.index.repositories; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.index.domain.RRFWeight; +import com.edgechain.lib.index.enums.OrderRRFBy; import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; import com.edgechain.lib.utils.FloatUtils; import com.github.f4b6a3.uuid.UuidCreator; import org.springframework.beans.factory.annotation.Autowired; @@ -23,60 +26,142 @@ public class PostgresClientRepository { public void createTable(PostgresEndpoint postgresEndpoint) { jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector;"); - jdbcTemplate.execute( + jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;"); + + String checkTableQuery = String.format( - "CREATE TABLE IF NOT EXISTS %s (embedding_id SERIAL PRIMARY KEY, id VARCHAR(255) NOT" - + " NULL UNIQUE, raw_text TEXT NOT NULL UNIQUE, embedding vector(%s), timestamp" - + " TIMESTAMP NOT NULL, namespace TEXT, filename VARCHAR(255));", - postgresEndpoint.getTableName(), postgresEndpoint.getDimensions())); + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '%s'", + postgresEndpoint.getTableName()); + + int tableExists = jdbcTemplate.queryForObject(checkTableQuery, Integer.class); + + String indexName; + String vectorOps; if (PostgresDistanceMetric.L2.equals(postgresEndpoint.getMetric())) { - jdbcTemplate.execute( - String.format( - "CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding vector_l2_ops) WITH" - + " (lists = %s);", - postgresEndpoint.getTableName().concat("_").concat("l2_idx"), - postgresEndpoint.getTableName(), - postgresEndpoint.getLists())); + indexName = postgresEndpoint.getTableName().concat("_").concat("l2_idx"); + vectorOps = "vector_l2_ops"; } else if (PostgresDistanceMetric.COSINE.equals(postgresEndpoint.getMetric())) { + indexName = postgresEndpoint.getTableName().concat("_").concat("cosine_idx"); + vectorOps = "vector_cosine_ops"; + } else { + indexName = postgresEndpoint.getTableName().concat("_").concat("ip_idx"); + vectorOps = "vector_ip_ops"; + } + + String indexQuery = + String.format( + "CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding %s) WITH" + + " (lists = %s);", + indexName, postgresEndpoint.getTableName(), vectorOps, postgresEndpoint.getLists()); + + String tsvIndexQuery = + String.format( + "CREATE INDEX IF NOT EXISTS %s ON %s USING GIN(tsv);", + postgresEndpoint.getTableName().concat("_tsv_idx"), postgresEndpoint.getTableName()); + + if (tableExists == 0) { + jdbcTemplate.execute( String.format( - "CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding vector_cosine_ops) WITH" - + " (lists = %s);", - postgresEndpoint.getTableName().concat("_").concat("cosine_idx"), - postgresEndpoint.getTableName(), - postgresEndpoint.getLists())); + "CREATE TABLE IF NOT EXISTS %s (id UUID PRIMARY KEY, " + + " raw_text TEXT NOT NULL UNIQUE, embedding vector(%s), timestamp" + + " TIMESTAMP NOT NULL, namespace TEXT, filename VARCHAR(255), tsv TSVECTOR);", + postgresEndpoint.getTableName(), postgresEndpoint.getDimensions())); + + jdbcTemplate.execute(indexQuery); + jdbcTemplate.execute(tsvIndexQuery); + } else { - jdbcTemplate.execute( + + String checkIndexQuery = String.format( - "CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding vector_ip_ops) WITH" - + " (lists = %s);", - postgresEndpoint.getTableName().concat("_").concat("ip_idx"), - postgresEndpoint.getTableName(), - postgresEndpoint.getLists())); + "SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s';", + postgresEndpoint.getTableName(), indexName); + + Integer indexExists = jdbcTemplate.queryForObject(checkIndexQuery, Integer.class); + + if (indexExists != null && indexExists != 1) + throw new RuntimeException( + "No index is specified therefore use the following SQL:\n" + indexQuery); } } @Transactional - public void upsertEmbeddings( + public List batchUpsertEmbeddings( String tableName, - String input, + List wordEmbeddingsList, String filename, + String namespace, + PostgresLanguage language) { + + Set uuidSet = new HashSet<>(); + + for (int i = 0; i < wordEmbeddingsList.size(); i++) { + WordEmbeddings wordEmbeddings = wordEmbeddingsList.get(i); + + if (wordEmbeddings != null && wordEmbeddings.getValues() != null) { + + float[] floatArray = FloatUtils.toFloatArray(wordEmbeddings.getValues()); + String rawText = wordEmbeddings.getId().replace("'", ""); + + UUID id = + jdbcTemplate.queryForObject( + String.format( + "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename, tsv)" + + " VALUES ('%s', ?, '%s', '%s', '%s', '%s', TO_TSVECTOR('%s', '%s')) ON" + + " CONFLICT (raw_text) DO UPDATE SET embedding = EXCLUDED.embedding" + + " RETURNING id;", + tableName, + UuidCreator.getTimeOrderedEpoch(), + Arrays.toString(floatArray), + LocalDateTime.now(), + namespace, + filename, + language.getValue(), + rawText), + UUID.class, + rawText); + + if (id != null) { + uuidSet.add(id.toString()); + } + } + } + + return new ArrayList<>(uuidSet); + } + + @Transactional + public String upsertEmbeddings( + String tableName, WordEmbeddings wordEmbeddings, - String namespace) { + String filename, + String namespace, + PostgresLanguage language) { - jdbcTemplate.execute( - String.format( - "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename) VALUES ('%s'," - + " '%s', '%s', '%s', '%s', '%s') ON CONFLICT (raw_text) DO UPDATE SET embedding =" - + " EXCLUDED.embedding;", - tableName, - UuidCreator.getTimeOrderedEpoch().toString(), - input, - Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())), - LocalDateTime.now(), - namespace, - filename)); + float[] floatArray = FloatUtils.toFloatArray(wordEmbeddings.getValues()); + String rawText = wordEmbeddings.getId().replace("'", ""); + + UUID uuid = + jdbcTemplate.queryForObject( + String.format( + "INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename, tsv)" + + " VALUES ('%s', ?, '%s', '%s', '%s', '%s', TO_TSVECTOR('%s', '%s')) ON" + + " CONFLICT (raw_text) DO UPDATE SET embedding = EXCLUDED.embedding RETURNING" + + " id;", + tableName, + UuidCreator.getTimeOrderedEpoch(), + Arrays.toString(floatArray), + LocalDateTime.now(), + namespace, + filename, + language.getValue(), + rawText), + UUID.class, + rawText); + + return Objects.requireNonNull(uuid).toString(); } @Transactional(readOnly = true, propagation = Propagation.REQUIRED) @@ -85,51 +170,220 @@ public List> query( String namespace, int probes, PostgresDistanceMetric metric, - WordEmbeddings wordEmbeddings, - int topK) { - - String embeddings = Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())); + List> values, + int topK, + int upperLimit) { jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); - if (metric.equals(PostgresDistanceMetric.IP)) { - return jdbcTemplate.queryForList( - String.format( - "SELECT id, raw_text, namespace, filename, timestamp, ( embedding <#> '%s') * -1 AS" - + " score FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' LIMIT %s;", - embeddings, - tableName, - namespace, - PostgresDistanceMetric.getDistanceMetric(metric), - Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())), - topK)); + StringBuilder query = new StringBuilder(); + + for (int i = 0; i < values.size(); i++) { + + String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i))); - } else if (metric.equals(PostgresDistanceMetric.COSINE)) { + query.append("(").append("SELECT id, raw_text, embedding, namespace, filename, timestamp,"); + switch (metric) { + case COSINE -> query + .append(String.format("1 - (embedding <=> '%s') AS score ", embeddings)) + .append(" FROM ") + .append(tableName) + .append(" WHERE namespace = ") + .append("'") + .append(namespace) + .append("'") + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT "); + case IP -> query + .append(String.format("(embedding <#> '%s') * -1 AS score ", embeddings)) + .append(" FROM ") + .append(tableName) + .append(" WHERE namespace = ") + .append("'") + .append(namespace) + .append("'") + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT "); + case L2 -> query + .append(String.format("embedding <-> '%s' AS score ", embeddings)) + .append(" FROM ") + .append(tableName) + .append(" WHERE namespace = ") + .append("'") + .append(namespace) + .append("'") + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT "); + default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); + } + query.append(topK).append(")"); + + if (i < values.size() - 1) { + query.append(" UNION ALL ").append("\n"); + } + } + + if (values.size() > 1) { return jdbcTemplate.queryForList( String.format( - "SELECT id, raw_text, namespace, filename, timestamp, 1 - ( embedding <=> '%s') AS" - + " score FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' LIMIT %s;", - embeddings, - tableName, - namespace, - PostgresDistanceMetric.getDistanceMetric(metric), - Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())), - topK)); + "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER" + + " BY score DESC LIMIT %s;", + query, upperLimit)); } else { + return jdbcTemplate.queryForList(query.toString()); + } + } + + public List> queryRRF( + String tableName, + String namespace, + String metadataTableName, + List> values, + RRFWeight textWeight, + RRFWeight similarityWeight, + RRFWeight dateWeight, + String searchQuery, + PostgresLanguage language, + int probes, + PostgresDistanceMetric metric, + int topK, + int upperLimit, + OrderRRFBy orderRRFBy) { + + jdbcTemplate.execute(String.format("SET LOCAL ivfflat.probes = %s;", probes)); + + StringBuilder query = new StringBuilder(); + + for (int i = 0; i < values.size(); i++) { + String embeddings = Arrays.toString(FloatUtils.toFloatArray(values.get(i))); + + query + .append("(") + .append( + "SELECT id, raw_text, document_date, metadata, namespace, filename, timestamp, \n") + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY text_rank DESC) + %s) + \n", + textWeight.getBaseWeight().getValue(), textWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY similarity DESC) + %s) + \n", + similarityWeight.getBaseWeight().getValue(), + similarityWeight.getFineTuneWeight())) + .append( + String.format( + "%s / (ROW_NUMBER() OVER (ORDER BY date_rank DESC) + %s) AS rrf_score\n", + dateWeight.getBaseWeight().getValue(), dateWeight.getFineTuneWeight())) + .append("FROM ( ") + .append( + "SELECT sv.id, sv.raw_text, sv.namespace, sv.filename, sv.timestamp," + + " svtm.document_date, svtm.metadata, ") + .append( + String.format( + "ts_rank_cd(sv.tsv, plainto_tsquery('%s', '%s')) AS text_rank, ", + language.getValue(), searchQuery)); + + switch (metric) { + case COSINE -> query.append( + String.format("1 - (sv.embedding <=> '%s') AS similarity, ", embeddings)); + case IP -> query.append( + String.format("(sv.embedding <#> '%s') * -1 AS similarity, ", embeddings)); + case L2 -> query.append(String.format("sv.embedding <-> '%s' AS similarity, ", embeddings)); + default -> throw new IllegalArgumentException("Invalid similarity measure: " + metric); + } + + query + .append("CASE ") + .append("WHEN svtm.document_date IS NULL THEN 0 ") // Null date handling + .append( + "ELSE EXTRACT(YEAR FROM svtm.document_date) * 365 + EXTRACT(DOY FROM" + + " svtm.document_date) ") + .append("END AS date_rank ") + .append("FROM ") + .append( + String.format( + "(SELECT id, raw_text, embedding, tsv, namespace, filename, timestamp from %s" + + " WHERE namespace = '%s'", + tableName, namespace)); + + switch (metric) { + case COSINE -> query + .append(" ORDER BY embedding <=> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); + case IP -> query + .append(" ORDER BY embedding <#> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); + case L2 -> query + .append(" ORDER BY embedding <-> ") + .append("'") + .append(embeddings) + .append("'") + .append(" LIMIT ") + .append(topK); + default -> throw new IllegalArgumentException("Invalid metric: " + metric); + } + query + .append(")") + .append(" sv ") + .append("JOIN ") + .append(tableName.concat("_join_").concat(metadataTableName)) + .append(" jtm ON sv.id = jtm.id ") + .append("JOIN ") + .append(tableName.concat("_").concat(metadataTableName)) + .append(" svtm ON jtm.metadata_id = svtm.metadata_id ") + .append(") subquery "); + + switch (orderRRFBy) { + case TEXT_RANK -> query.append("ORDER BY text_rank DESC, rrf_score DESC"); + case SIMILARITY -> query.append("ORDER BY similarity DESC, rrf_score DESC"); + case DATE_RANK -> query.append("ORDER BY date_rank DESC, rrf_score DESC"); + case DEFAULT -> query.append("ORDER BY rrf_score DESC"); + default -> throw new IllegalArgumentException("Invalid orderRRFBy value"); + } + + query.append(" LIMIT ").append(topK).append(")"); + if (i < values.size() - 1) { + query.append(" UNION ALL ").append("\n"); + } + } + + if (values.size() > 1) { return jdbcTemplate.queryForList( String.format( - "SELECT id, raw_text, namespace, filename, timestamp, (embedding <-> '%s') AS score" - + " FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' ASC LIMIT %s;", - embeddings, - tableName, - namespace, - PostgresDistanceMetric.getDistanceMetric(metric), - Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())), - topK)); + "SELECT * FROM (SELECT DISTINCT ON (result.id) * FROM ( %s ) result) subquery ORDER" + + " BY rrf_score DESC LIMIT %s;", + query, upperLimit)); + } else { + return jdbcTemplate.queryForList(query.toString()); } } + @Transactional(readOnly = true) + public List> getAllChunks(PostgresEndpoint endpoint) { + return jdbcTemplate.queryForList( + String.format( + "SELECT id, raw_text, embedding, filename from %s WHERE filename = '%s';", + endpoint.getTableName(), endpoint.getFilename())); + } + @Transactional public void deleteAll(String tableName, String namespace) { jdbcTemplate.execute( diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java new file mode 100644 index 000000000..166d7c320 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/client/AirtableClient.java @@ -0,0 +1,188 @@ +package com.edgechain.lib.integration.airtable.client; + +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import com.edgechain.lib.integration.airtable.query.AirtableQueryBuilder; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import dev.fuxing.airtable.AirtableApi; +import dev.fuxing.airtable.AirtableRecord; +import dev.fuxing.airtable.AirtableTable; +import io.reactivex.rxjava3.core.Observable; +import org.springframework.stereotype.Service; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +@Service +public class AirtableClient { + + public EdgeChain> findAll(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + AirtableQueryBuilder airtableQueryBuilder = endpoint.getAirtableQueryBuilder(); + + AirtableTable.PaginationList list = + table.list( + querySpec -> { + int maxRecords = airtableQueryBuilder.getMaxRecords(); + int pageSize = airtableQueryBuilder.getPageSize(); + String sortField = airtableQueryBuilder.getSortField(); + String sortDirection = airtableQueryBuilder.getSortDirection(); + String offset = airtableQueryBuilder.getOffset(); + List fields = airtableQueryBuilder.getFields(); + String filterByFormula = airtableQueryBuilder.getFilterByFormula(); + String view = airtableQueryBuilder.getView(); + String cellFormat = airtableQueryBuilder.getCellFormat(); + String timeZone = airtableQueryBuilder.getTimeZone(); + String userLocale = airtableQueryBuilder.getUserLocale(); + + querySpec.maxRecords(maxRecords); + querySpec.pageSize(pageSize); + + if (sortField != null && sortDirection != null) { + querySpec.sort(sortField, sortDirection); + } + + if (offset != null) { + querySpec.offset(offset); + } + + if (fields != null) { + querySpec.fields(fields); + } + + if (filterByFormula != null) { + querySpec.filterByFormula(filterByFormula); + } + + if (view != null) { + querySpec.view(view); + } + + if (cellFormat != null) { + querySpec.cellFormat(cellFormat); + } + + if (timeZone != null) { + querySpec.timeZone(timeZone); + } + + if (userLocale != null) { + querySpec.userLocale(userLocale); + } + }); + + Map mapper = new HashMap<>(); + mapper.put("data", list); + mapper.put("offset", list.getOffset()); + + emitter.onNext(mapper); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain findById(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + String id = endpoint.getIds().get(0); + + if (Objects.isNull(id) || id.isEmpty()) + throw new RuntimeException("Id cannot be null"); + + AirtableRecord record = table.get(id); + + if (Objects.isNull(record)) + throw new RuntimeException("Id: " + id + " does not exist"); + + emitter.onNext(record); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain> create(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + List airtableRecordList = + table.post(endpoint.getAirtableRecordList(), endpoint.isTypecast()); + + emitter.onNext(airtableRecordList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain> update(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + List airtableRecordList = + table.put(endpoint.getAirtableRecordList(), endpoint.isTypecast()); + + emitter.onNext(airtableRecordList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } + + public EdgeChain> delete(AirtableEndpoint endpoint) { + + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + AirtableApi api = new AirtableApi(endpoint.getApiKey()); + AirtableTable table = api.base(endpoint.getBaseId()).table(endpoint.getTableName()); + + List deleteIdsList = table.delete(endpoint.getIds()); + + emitter.onNext(deleteIdsList); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + })); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java new file mode 100644 index 000000000..7e6945568 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/AirtableQueryBuilder.java @@ -0,0 +1,130 @@ +package com.edgechain.lib.integration.airtable.query; + +import dev.fuxing.airtable.formula.AirtableFormula; +import dev.fuxing.airtable.formula.AirtableFunction; +import dev.fuxing.airtable.formula.AirtableOperator; + +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +public class AirtableQueryBuilder { + private String offset; + private List fields; + private String filterByFormula; + private int maxRecords = 100; + private int pageSize = 100; + private String sortField; + private String sortDirection; + private String view; + private String cellFormat; + private String timeZone; + private String userLocale; + + public void offset(String offset) { + this.offset = offset; + } + + public void fields(String... fields) { + this.fields = Arrays.asList(fields); + } + + public void filterByFormula(String formula) { + this.filterByFormula = formula; + } + + public void filterByFormula(AirtableFunction function, AirtableFormula.Object... objects) { + this.filterByFormula = function.apply(objects); + } + + public void filterByFormula( + AirtableOperator operator, + AirtableFormula.Object left, + AirtableFormula.Object right, + AirtableFormula.Object... others) { + this.filterByFormula = operator.apply(left, right, others); + } + + public void maxRecords(int maxRecords) { + this.maxRecords = maxRecords; + } + + public void pageSize(int pageSize) { + this.pageSize = pageSize; + } + + public void sort(String field, String direction) { + this.sortField = field; + this.sortDirection = direction; + } + + public void view(String view) { + this.view = view; + } + + public void cellFormat(String cellFormat) { + this.cellFormat = cellFormat; + } + + public void timeZone(String timeZone) { + this.timeZone = timeZone; + } + + public void timeZone(ZoneId zoneId) { + this.timeZone = zoneId.getId(); + } + + public void userLocale(String userLocale) { + this.userLocale = userLocale; + } + + public void userLocale(Locale locale) { + this.userLocale = locale.toLanguageTag().toLowerCase(); + } + + // Getters for QuerySpec fields + public String getOffset() { + return offset; + } + + public List getFields() { + return fields; + } + + public String getFilterByFormula() { + return filterByFormula; + } + + public int getMaxRecords() { + return maxRecords; + } + + public int getPageSize() { + return pageSize; + } + + public String getSortField() { + return sortField; + } + + public String getSortDirection() { + return sortDirection; + } + + public String getView() { + return view; + } + + public String getCellFormat() { + return cellFormat; + } + + public String getTimeZone() { + return timeZone; + } + + public String getUserLocale() { + return userLocale; + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java new file mode 100644 index 000000000..9b203f024 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/integration/airtable/query/SortOrder.java @@ -0,0 +1,25 @@ +package com.edgechain.lib.integration.airtable.query; + +public enum SortOrder { + ASC("asc"), + DESC("desc"); + + private final String value; + + SortOrder(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static SortOrder fromValue(String value) { + for (SortOrder sortOrder : SortOrder.values()) { + if (sortOrder.value.equalsIgnoreCase(value)) { + return sortOrder; + } + } + throw new IllegalArgumentException("Invalid SortOrder value: " + value); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetArgs.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetArgs.java index 8bdbd4a33..945fd1ca0 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetArgs.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetArgs.java @@ -6,7 +6,9 @@ public class JsonnetArgs { private DataType dataType; - private final String val; + private String val; + + public JsonnetArgs() {} public JsonnetArgs(DataType dataType, String val) { diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java index c3a8a932b..c9fcf20eb 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/JsonnetLoader.java @@ -2,6 +2,7 @@ import com.edgechain.lib.jsonnet.enums.DataType; import com.edgechain.lib.jsonnet.exceptions.JsonnetLoaderException; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.ObjectMapper; import io.github.jam01.xtrasonnet.Transformer; import org.apache.commons.io.FileUtils; @@ -14,10 +15,21 @@ import java.io.*; import java.util.*; -public abstract class JsonnetLoader { +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) +public abstract class JsonnetLoader implements Serializable { private final Logger logger = LoggerFactory.getLogger(this.getClass()); + private String f1; + private String f2; + + private String metadata; + private String selectedFile; + + private int threshold = 0; + + private String splitSize; + private Map args = new HashMap<>(); private Map xtraArgsMap = new HashMap<>(); private static final ObjectMapper objectMapper = new ObjectMapper(); @@ -25,6 +37,33 @@ public abstract class JsonnetLoader { public JsonnetLoader() {} + public JsonnetLoader(String f1) { + this.f1 = f1; + } + + public JsonnetLoader(int threshold, String f1, String f2) { + this.f1 = f1; + this.f2 = f2; + if (threshold >= 1 && threshold < 100) { + this.threshold = threshold; + this.splitSize = + String.valueOf(threshold).concat("-").concat(String.valueOf((100 - threshold))); + } else throw new RuntimeException("Threshold has to be b/w 1 and 100"); + } + + public void load(InputStream in1, InputStream in2) { + int randValue = (int) (Math.random() * 101); + if (randValue <= threshold) { + this.selectedFile = getF1(); + logger.info("Using File: " + getF1()); + load(in1); + } else { + this.selectedFile = getF2(); + logger.info("Using File: " + getF2()); + load(in2); + } + } + public void load(InputStream inputStream) { try { preconfigured(); @@ -70,6 +109,8 @@ public void load(InputStream inputStream) { .build() .transform(serializeXtraArgs(xtraArgsMap)); // Get the String Output & Transform it into JsonnetSchema + + this.metadata = res; this.jsonObject = new JSONObject(res); // Delete File @@ -126,4 +167,28 @@ private void preconfigured() { if (Objects.isNull(args.get("keepMaxTokens"))) args.put("keepMaxTokens", new JsonnetArgs(DataType.BOOLEAN, "false")); } + + public String getMetadata() { + return metadata; + } + + public String getSelectedFile() { + return selectedFile; + } + + public String getF1() { + return f1; + } + + public String getF2() { + return f2; + } + + public int getThreshold() { + return threshold; + } + + public String getSplitSize() { + return splitSize; + } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java index ca3d0665b..33120fc99 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/XtraSonnetCustomFunc.java @@ -1,6 +1,6 @@ package com.edgechain.lib.jsonnet; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.github.jam01.xtrasonnet.DataFormatService; import io.github.jam01.xtrasonnet.header.Header; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/impl/FileJsonnetLoader.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/impl/FileJsonnetLoader.java index 0743e402a..529b98902 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/impl/FileJsonnetLoader.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/jsonnet/impl/FileJsonnetLoader.java @@ -2,32 +2,75 @@ import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.exceptions.JsonnetLoaderException; +import org.apache.commons.io.FilenameUtils; import java.io.*; public class FileJsonnetLoader extends JsonnetLoader { - private String filePath; + private String filePath1; + private String filePath2; + + public String getFilePath1() { + return filePath1; + } + + public void setFilePath1(String filePath1) { + this.filePath1 = filePath1; + } + + public String getFilePath2() { + return filePath2; + } + + public void setFilePath2(String filePath2) { + this.filePath2 = filePath2; + } + + public FileJsonnetLoader() {} public FileJsonnetLoader(String filePath) { - this.filePath = filePath; + super(filePath); + this.filePath1 = filePath; + + if (!new File(filePath).exists()) { + throw new JsonnetLoaderException("File not found - " + filePath); + } } - public FileJsonnetLoader() { - super(); + public FileJsonnetLoader(int threshold, String filePath1, String filePath2) { + super(threshold, FilenameUtils.getName(filePath1), FilenameUtils.getName(filePath2)); + this.filePath1 = filePath1; + this.filePath2 = filePath2; + + if (!new File(filePath1).exists()) { + throw new JsonnetLoaderException("File not found - " + filePath1); + } + + if (!new File(filePath2).exists()) { + throw new JsonnetLoaderException("File not found. " + filePath2); + } } @Override public JsonnetLoader loadOrReload() { - try (InputStream in = new FileInputStream(filePath)) { - this.load(in); - return this; - } catch (final Exception e) { - throw new JsonnetLoaderException(e.getMessage()); - } - } - public String getFilePath() { - return filePath; + if (getThreshold() >= 1 && getThreshold() < 100) { + try (InputStream in1 = new FileInputStream(filePath1); + InputStream in2 = new FileInputStream(filePath2)) { + load(in1, in2); + return this; + } catch (final Exception e) { + throw new JsonnetLoaderException(e.getMessage()); + } + + } else { + try (InputStream in = new FileInputStream(filePath1)) { + load(in); + return this; + } catch (final Exception e) { + throw new JsonnetLoaderException(e.getMessage()); + } + } } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java new file mode 100644 index 000000000..ee47fe298 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java @@ -0,0 +1,57 @@ +package com.edgechain.lib.llama2; + +import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; +import com.edgechain.lib.llama2.request.LLamaCompletionRequest; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.reactivex.rxjava3.core.Observable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import java.util.List; + +@Service +public class LLamaClient { + @Autowired private ObjectMapper objectMapper; + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); + + public EdgeChain> createChatCompletion( + LLamaCompletionRequest request, LLamaQuickstart endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + logger.info("Logging ChatCompletion...."); + + logger.info("==============REQUEST DATA================"); + logger.info(request.toString()); + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(request, headers); + // + String response = + restTemplate.postForObject(endpoint.getUrl(), entity, String.class); + + List chatCompletionResponse = + objectMapper.readValue(response, new TypeReference<>() {}); + emitter.onNext(chatCompletionResponse); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java new file mode 100644 index 000000000..481a890f8 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java @@ -0,0 +1,89 @@ +package com.edgechain.lib.llama2; + +import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; +import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; +import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.reactivex.rxjava3.core.Observable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.*; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@Service +public class Llama2Client { + @Autowired private ObjectMapper objectMapper; + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); + + public EdgeChain> createChatCompletion( + Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + logger.info("Logging ChatCompletion...."); + + logger.info("==============REQUEST DATA================"); + logger.info(request.toString()); + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(request, headers); + // + String response = + restTemplate.postForObject(endpoint.getUrl(), entity, String.class); + + List chatCompletionResponse = + objectMapper.readValue(response, new TypeReference<>() {}); + emitter.onNext(chatCompletionResponse); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + + public EdgeChain createGetChatCompletion(LLamaQuickstart endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.set("User-Agent", "insomnia/8.2.0"); + HttpEntity entity = new HttpEntity<>(headers); + + Map param = Collections.singletonMap("query", endpoint.getQuery()); + + String endpointUrl = endpoint.getUrl() + "?query={query}"; + + ResponseEntity response = + restTemplate.exchange(endpointUrl, HttpMethod.GET, entity, String.class, param); + + logger.info("\nRESPONSE DATA {}\n", response.getBody()); + + emitter.onNext(response.getBody()); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java new file mode 100644 index 000000000..eea4b0d99 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java @@ -0,0 +1,88 @@ +package com.edgechain.lib.llama2.request; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.StringJoiner; + +public class LLamaCompletionRequest { + @JsonProperty("text_inputs") + private String textInputs; + + @JsonProperty("return_full_text") + private Boolean returnFullText; + + @JsonProperty("top_k") + private Integer topK; + + public LLamaCompletionRequest() {} + + public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) { + this.textInputs = textInputs; + this.returnFullText = returnFullText; + this.topK = topK; + } + + @Override + public String toString() { + return new StringJoiner(", ", LLamaCompletionRequest.class.getSimpleName() + "{", "}") + .add("\"text_inputs:\"" + textInputs) + .add("\"return_full_text:\"" + returnFullText) + .add("\"top_k:\"" + topK) + .toString(); + } + + public static LlamaSupportChatCompletionRequestBuilder builder() { + return new LlamaSupportChatCompletionRequestBuilder(); + } + + public String getTextInputs() { + return textInputs; + } + + public void setTextInputs(String textInputs) { + this.textInputs = textInputs; + } + + public Boolean getReturnFullText() { + return returnFullText; + } + + public void setReturnFullText(Boolean returnFullText) { + this.returnFullText = returnFullText; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public static class LlamaSupportChatCompletionRequestBuilder { + private String textInputs; + private Boolean returnFullText; + private Integer topK; + + private LlamaSupportChatCompletionRequestBuilder() {} + + public LlamaSupportChatCompletionRequestBuilder textInputs(String textInputs) { + this.textInputs = textInputs; + return this; + } + + public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFullText) { + this.returnFullText = returnFullText; + return this; + } + + public LlamaSupportChatCompletionRequestBuilder topK(Integer topK) { + this.topK = topK; + return this; + } + + public LLamaCompletionRequest build() { + return new LLamaCompletionRequest(textInputs, returnFullText, topK); + } + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java new file mode 100644 index 000000000..17f190e93 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java @@ -0,0 +1,67 @@ +package com.edgechain.lib.llama2.request; + +import org.json.JSONObject; + +import java.util.StringJoiner; + +public class Llama2ChatCompletionRequest { + + private String inputs; + private JSONObject parameters; + + public Llama2ChatCompletionRequest() {} + + public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) { + this.inputs = inputs; + this.parameters = parameters; + } + + @Override + public String toString() { + return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]") + .add("\"inputs:\"" + inputs) + .add("\"parameters:\"" + parameters) + .toString(); + } + + public static Llama2ChatCompletionRequestBuilder builder() { + return new Llama2ChatCompletionRequestBuilder(); + } + + public String getInputs() { + return inputs; + } + + public void setInputs(String inputs) { + this.inputs = inputs; + } + + public JSONObject getParameters() { + return parameters; + } + + public void setParameters(JSONObject parameters) { + this.parameters = parameters; + } + + public static class Llama2ChatCompletionRequestBuilder { + private String inputs; + private JSONObject parameters; + + private Llama2ChatCompletionRequestBuilder() {} + + public Llama2ChatCompletionRequestBuilder inputs(String inputs) { + this.inputs = inputs; + return this; + } + + public Llama2ChatCompletionRequestBuilder parameters(JSONObject parameters) { + this.parameters = parameters; + return this; + } + + public Llama2ChatCompletionRequest build() { + return new Llama2ChatCompletionRequest(inputs, parameters); + } + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java new file mode 100644 index 000000000..033d78286 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java @@ -0,0 +1,18 @@ +package com.edgechain.lib.llama2.response; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class Llama2ChatCompletionResponse { + @JsonProperty("generated_text") + private String generatedText; + + public Llama2ChatCompletionResponse() {} + + public String getGeneratedText() { + return generatedText; + } + + public void setGeneratedText(String generatedText) { + this.generatedText = generatedText; + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/EmbeddingLogger.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/EmbeddingLogger.java index c1e30b9b5..5fd9129aa 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/EmbeddingLogger.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/EmbeddingLogger.java @@ -1,16 +1,13 @@ package com.edgechain.lib.logger; -import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; import com.edgechain.lib.retrofit.logger.EmbeddingLoggerService; -import com.edgechain.lib.retrofit.logger.EmbeddingLoggerService; +import java.util.HashMap; import org.springframework.data.domain.Page; import org.springframework.web.bind.annotation.PathVariable; import retrofit2.Retrofit; -import java.util.HashMap; - public class EmbeddingLogger { private final Retrofit retrofit = RetrofitClientInstance.getInstance(); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/JsonnetLogger.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/JsonnetLogger.java new file mode 100644 index 000000000..ab8dbe77c --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/JsonnetLogger.java @@ -0,0 +1,37 @@ +package com.edgechain.lib.logger; + +import com.edgechain.lib.logger.entities.JsonnetLog; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import java.util.HashMap; + +import com.edgechain.lib.retrofit.logger.JsonnetLoggerService; +import org.springframework.data.domain.Page; +import retrofit2.Retrofit; + +public class JsonnetLogger { + + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final JsonnetLoggerService jsonnetLoggerService = + retrofit.create(JsonnetLoggerService.class); + + public JsonnetLogger() {} + + public Page findAll(int page, int size) { + return this.jsonnetLoggerService.findAll(page, size).blockingGet(); + } + + public Page findAllOrderByCreatedAtDesc(int page, int size) { + return this.jsonnetLoggerService.findAllOrderByCreatedAtDesc(page, size).blockingGet(); + } + + public Page findAllBySelectedFileOrderByCreatedAtDesc( + String filename, int page, int size) { + + HashMap mapper = new HashMap<>(); + mapper.put("filename", filename); + + return this.jsonnetLoggerService + .findAllBySelectedFileOrderByCreatedAtDesc(mapper, page, size) + .blockingGet(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/ChatCompletionLog.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/ChatCompletionLog.java index 5960c5352..74b2e0a11 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/ChatCompletionLog.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/ChatCompletionLog.java @@ -5,6 +5,7 @@ import javax.validation.constraints.NotBlank; import java.time.LocalDateTime; +import java.util.StringJoiner; import java.util.UUID; @Table(name = "chat_completion_logs") @@ -45,6 +46,16 @@ public class ChatCompletionLog { @Column(columnDefinition = "TEXT") private String content; + private Double presencePenalty; + private Double frequencyPenalty; + + @Column(name = "top_p") + private Double topP; + + private Integer n; + + private Double temperature; + private Long latency; private Long promptTokens; @@ -159,23 +170,66 @@ public void setCallIdentifier(String callIdentifier) { this.callIdentifier = callIdentifier; } + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + @Override public String toString() { - final StringBuilder sb = new StringBuilder("ChatCompletionLog{"); - sb.append("chatCompletionId=").append(chatCompletionId); - sb.append(", id='").append(id).append('\''); - sb.append(", name='").append(name).append('\''); - sb.append(", callIdentifier='").append(callIdentifier).append('\''); - sb.append(", type='").append(type).append('\''); - sb.append(", createdAt=").append(createdAt); - sb.append(", completedAt=").append(completedAt); - sb.append(", model='").append(model).append('\''); - sb.append(", input='").append(input).append('\''); - sb.append(", content='").append(content).append('\''); - sb.append(", latency=").append(latency); - sb.append(", promptTokens=").append(promptTokens); - sb.append(", totalTokens=").append(totalTokens); - sb.append('}'); - return sb.toString(); + return new StringJoiner(", ", ChatCompletionLog.class.getSimpleName() + "[", "]") + .add("id='" + id + "'") + .add("name='" + name + "'") + .add("callIdentifier='" + callIdentifier + "'") + .add("type='" + type + "'") + .add("temperature=" + temperature) + .add("createdAt=" + createdAt) + .add("completedAt=" + completedAt) + .add("model='" + model + "'") + .add("input='" + input + "'") + .add("content='" + content + "'") + .add("presencePenalty=" + presencePenalty) + .add("frequencyPenalty=" + frequencyPenalty) + .add("topP=" + topP) + .add("n=" + n) + .add("latency=" + latency) + .add("promptTokens=" + promptTokens) + .add("totalTokens=" + totalTokens) + .toString(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/JsonnetLog.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/JsonnetLog.java new file mode 100644 index 000000000..d9dce0784 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/entities/JsonnetLog.java @@ -0,0 +1,122 @@ +package com.edgechain.lib.logger.entities; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import jakarta.persistence.*; + +import java.time.LocalDateTime; +import java.util.StringJoiner; +import java.util.UUID; + +@Table(name = "jsonnet_logs") +@Entity(name = "JsonnetLog") +public class JsonnetLog { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @JsonIgnore + private Long jsonnetLogId; + + @Column(nullable = false, unique = true) + private String id; + + @Column(nullable = false) + private String splitSize; + + @Column(nullable = false, columnDefinition = "TEXT") + private String metadata; + + @Column(columnDefinition = "TEXT") + private String content; + + private String selectedFile; + + @Column(nullable = false) + private String f1; + + @Column(nullable = false) + private String f2; + + private LocalDateTime createdAt; + + @PrePersist + protected void onCreate() { + setId(UUID.randomUUID().toString()); + } + + public void setId(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public String getSplitSize() { + return splitSize; + } + + public void setSplitSize(String splitSize) { + this.splitSize = splitSize; + } + + public String getMetadata() { + return metadata; + } + + public void setMetadata(String jsonString) { + this.metadata = jsonString; + } + + public String getF1() { + return f1; + } + + public void setF1(String f1) { + this.f1 = f1; + } + + public String getF2() { + return f2; + } + + public void setF2(String f2) { + this.f2 = f2; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public String getSelectedFile() { + return selectedFile; + } + + public void setSelectedFile(String selectedFile) { + this.selectedFile = selectedFile; + } + + @Override + public String toString() { + return new StringJoiner(", ", JsonnetLog.class.getSimpleName() + "[", "]") + .add("id='" + id + "'") + .add("splitSize=" + splitSize) + .add("metadata='" + metadata + "'") + .add("content='" + content + "'") + .add("f1='" + f1 + "'") + .add("f2='" + f2 + "'") + .add("createdAt=" + createdAt) + .toString(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/repositories/JsonnetLogRepository.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/repositories/JsonnetLogRepository.java new file mode 100644 index 000000000..a90955ea8 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/repositories/JsonnetLogRepository.java @@ -0,0 +1,18 @@ +package com.edgechain.lib.logger.repositories; + +import com.edgechain.lib.logger.entities.JsonnetLog; +import org.jetbrains.annotations.NotNull; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; +import org.springframework.data.jpa.repository.JpaRepository; + +public interface JsonnetLogRepository extends JpaRepository { + + @Override + @NotNull + Page findAll(@NotNull Pageable pageable); + + Page findAllByOrderByCreatedAtDesc(Pageable pageable); + + Page findAllBySelectedFileOrderByCreatedAtDesc(String filename, Pageable pageable); +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/ChatCompletionLogService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/ChatCompletionLogService.java index 9871831f6..2e8142ead 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/ChatCompletionLogService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/ChatCompletionLogService.java @@ -88,9 +88,14 @@ public void createTable() { + " model VARCHAR(255) NOT NULL,\n" + " input TEXT NOT NULL,\n" + " content TEXT,\n" + + " presence_penalty DOUBLE PRECISION,\n" + + " frequency_penalty DOUBLE PRECISION,\n" + + " top_p DOUBLE PRECISION,\n" + + " n INTEGER,\n" + + " temperature DOUBLE PRECISION,\n" + " latency BIGINT,\n" + " prompt_tokens BIGINT,\n" - + " total_tokens BIGINT\n" + + " total_tokens BIGINT" + ");"); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/EmbeddingLogService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/EmbeddingLogService.java index 8f4ad06e6..986c65a40 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/EmbeddingLogService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/EmbeddingLogService.java @@ -1,8 +1,6 @@ package com.edgechain.lib.logger.services; import com.edgechain.lib.logger.entities.EmbeddingLog; -import com.edgechain.lib.logger.entities.EmbeddingLog; -import com.edgechain.lib.logger.repositories.EmbeddingLogRepository; import com.edgechain.lib.logger.repositories.EmbeddingLogRepository; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.domain.Page; @@ -17,6 +15,7 @@ public class EmbeddingLogService { @Autowired private EmbeddingLogRepository embeddingLogRepository; @Autowired private JdbcTemplate jdbcTemplate; + @Transactional public EmbeddingLog saveOrUpdate(EmbeddingLog embeddingLog) { this.createTable(); return this.embeddingLogRepository.save(embeddingLog); @@ -64,19 +63,23 @@ public Page findAllByLatencyGreaterThanEqual(long latency, Pageabl return this.embeddingLogRepository.findAllByLatencyGreaterThanEqual(latency, pageable); } + private static final String SQL_CREATE_TABLE = + """ + CREATE TABLE IF NOT EXISTS embedding_logs ( + embedding_id SERIAL PRIMARY KEY, + id VARCHAR(255) NOT NULL UNIQUE, + call_identifier VARCHAR(255) NOT NULL, + created_at TIMESTAMP, + completed_at TIMESTAMP, + model VARCHAR(255) NOT NULL, + latency BIGINT, + prompt_tokens BIGINT, + total_tokens BIGINT + ); + """; + @Transactional public void createTable() { - jdbcTemplate.execute( - "CREATE TABLE IF NOT EXISTS embedding_logs (\n" - + " embedding_id SERIAL PRIMARY KEY,\n" - + " id VARCHAR(255) NOT NULL UNIQUE,\n" - + " call_identifier VARCHAR(255) NOT NULL,\n" - + " created_at TIMESTAMP,\n" - + " completed_at TIMESTAMP,\n" - + " model VARCHAR(255) NOT NULL,\n" - + " latency BIGINT,\n" - + " prompt_tokens BIGINT,\n" - + " total_tokens BIGINT\n" - + ");;"); + jdbcTemplate.execute(SQL_CREATE_TABLE); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/JsonnetLogService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/JsonnetLogService.java new file mode 100644 index 000000000..df4f91d58 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/logger/services/JsonnetLogService.java @@ -0,0 +1,54 @@ +package com.edgechain.lib.logger.services; + +import com.edgechain.lib.logger.entities.JsonnetLog; +import com.edgechain.lib.logger.repositories.JsonnetLogRepository; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +public class JsonnetLogService { + + @Autowired private JsonnetLogRepository jsonnetLogRepository; + @Autowired private JdbcTemplate jdbcTemplate; + + public JsonnetLog saveOrUpdate(JsonnetLog jsonnetLog) { + this.createTable(); + return jsonnetLogRepository.save(jsonnetLog); + } + + @Transactional(readOnly = true) + public Page findAll(Pageable pageable) { + return this.jsonnetLogRepository.findAll(pageable); + } + + @Transactional(readOnly = true) + public Page findAllOrderByCreatedAtDesc(Pageable pageable) { + return this.jsonnetLogRepository.findAllByOrderByCreatedAtDesc(pageable); + } + + @Transactional(readOnly = true) + public Page findAllBySelectedFileOrderByCreatedAtDesc( + String filename, Pageable pageable) { + return this.jsonnetLogRepository.findAllBySelectedFileOrderByCreatedAtDesc(filename, pageable); + } + + @Transactional + public void createTable() { + jdbcTemplate.execute( + "CREATE TABLE IF NOT EXISTS jsonnet_logs (\n" + + " jsonnet_log_id SERIAL PRIMARY KEY,\n" + + " id VARCHAR(255) NOT NULL UNIQUE,\n" + + " split_size VARCHAR(255) NOT NULL,\n" + + " metadata TEXT NOT NULL,\n" + + " content TEXT,\n" + + " selected_file VARCHAR(255),\n" + + " f1 VARCHAR(255) NOT NULL,\n" + + " f2 VARCHAR(255) NOT NULL,\n" + + " created_at TIMESTAMP\n" + + ");\n"); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java index e406cab3f..fffd19dab 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/client/OpenAiClient.java @@ -3,7 +3,8 @@ import com.edgechain.lib.constants.EndpointConstants; import com.edgechain.lib.embeddings.request.OpenAiEmbeddingRequest; import com.edgechain.lib.embeddings.response.OpenAiEmbeddingResponse; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.request.ChatCompletionRequest; import com.edgechain.lib.openai.request.CompletionRequest; import com.edgechain.lib.openai.response.ChatCompletionResponse; @@ -27,17 +28,8 @@ public class OpenAiClient { private final Logger logger = LoggerFactory.getLogger(getClass()); private final RestTemplate restTemplate = new RestTemplate(); - private OpenAiEndpoint endpoint; - - public OpenAiEndpoint getEndpoint() { - return endpoint; - } - - public void setEndpoint(OpenAiEndpoint endpoint) { - this.endpoint = endpoint; - } - - public EdgeChain createChatCompletion(ChatCompletionRequest request) { + public EdgeChain createChatCompletion( + ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { return new EdgeChain<>( Observable.create( @@ -74,9 +66,12 @@ public EdgeChain createChatCompletion(ChatCompletionRequ } public EdgeChain createChatCompletionStream( - ChatCompletionRequest request) { + ChatCompletionRequest request, OpenAiChatEndpoint endpoint) { try { + logger.info("Logging ChatCompletion Stream...."); + logger.info(request.toString()); + return new EdgeChain<>( RxJava3Adapter.fluxToObservable( WebClient.builder() @@ -102,7 +97,8 @@ public EdgeChain createChatCompletionStream( } } - public EdgeChain createCompletion(CompletionRequest request) { + public EdgeChain createCompletion( + CompletionRequest request, OpenAiChatEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { @@ -128,7 +124,8 @@ public EdgeChain createCompletion(CompletionRequest request) endpoint); } - public EdgeChain createEmbeddings(OpenAiEmbeddingRequest request) { + public EdgeChain createEmbeddings( + OpenAiEmbeddingRequest request, OpenAiEmbeddingEndpoint endpoint) { return new EdgeChain<>( Observable.create( emitter -> { diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java index e304a4232..52353f076 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/providers/OpenAiCompletionProvider.java @@ -1,15 +1,15 @@ package com.edgechain.lib.openai.providers; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.request.CompletionRequest; import com.edgechain.lib.response.StringResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; /** Going to be removed * */ public class OpenAiCompletionProvider { - private final OpenAiEndpoint endpoint; + private final OpenAiChatEndpoint endpoint; - public OpenAiCompletionProvider(OpenAiEndpoint endpoint) { + public OpenAiCompletionProvider(OpenAiChatEndpoint endpoint) { this.endpoint = endpoint; } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatCompletionRequest.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatCompletionRequest.java index 952d2e06a..4fe7ce4dd 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatCompletionRequest.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatCompletionRequest.java @@ -1,22 +1,70 @@ package com.edgechain.lib.openai.request; -import java.util.List; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.*; public class ChatCompletionRequest { + private static final Double DEFAULT_TEMPERATURE = 0.7; + private static final Boolean DEFAULT_STREAM = false; + private static final Double DEFAULT_TOP_P = 1.0; + private static final Integer DEFAULT_N = 1; + private static final List DEFAULT_STOP = Collections.emptyList(); + private static final Double DEFAULT_PRESENCE_PENALTY = 0.0; + private static final Double DEFAULT_FREQUENCY_PENALTY = 0.0; + private static final Map DEFAULT_LOGIT_BIAS = Collections.emptyMap(); + private static final String DEFAULT_USER = ""; + private String model; private Double temperature; private List messages; private Boolean stream; + @JsonProperty("top_p") + private Double topP; + + private Integer n; + + private List stop; + + @JsonProperty("presence_penalty") + private Double presencePenalty; + + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + @JsonProperty("logit_bias") + private Map logitBias; + + private String user; + public ChatCompletionRequest() {} public ChatCompletionRequest( - String model, List messages, Double temperature, Boolean stream) { + String model, + Double temperature, + List messages, + Boolean stream, + Double topP, + Integer n, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Map logitBias, + String user) { this.model = model; - this.temperature = temperature; + this.temperature = (temperature != null) ? temperature : DEFAULT_TEMPERATURE; this.messages = messages; - this.stream = stream; + this.stream = (stream != null) ? stream : DEFAULT_STREAM; + this.topP = (topP != null) ? topP : DEFAULT_TOP_P; + this.n = (n != null) ? n : DEFAULT_N; + this.stop = (stop != null) ? stop : DEFAULT_STOP; + this.presencePenalty = (presencePenalty != null) ? presencePenalty : DEFAULT_PRESENCE_PENALTY; + this.frequencyPenalty = + (frequencyPenalty != null) ? frequencyPenalty : DEFAULT_FREQUENCY_PENALTY; + this.logitBias = (logitBias != null) ? logitBias : DEFAULT_LOGIT_BIAS; + this.user = (user != null) ? user : DEFAULT_USER; } public String getModel() { @@ -51,15 +99,49 @@ public void setStream(Boolean stream) { this.stream = stream; } + public Double getTopP() { + return topP; + } + + public Integer getN() { + return n; + } + + public List getStop() { + return stop; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public Map getLogitBias() { + return logitBias; + } + + public String getUser() { + return user; + } + @Override public String toString() { - final StringBuilder sb = new StringBuilder("ChatCompletionRequest{"); - sb.append("model='").append(model).append('\''); - sb.append(", temperature=").append(temperature); - sb.append(", messages=").append(messages); - sb.append(", stream=").append(stream); - sb.append('}'); - return sb.toString(); + return new StringJoiner(", ", ChatCompletionRequest.class.getSimpleName() + "[", "]") + .add("model='" + model + "'") + .add("temperature=" + temperature) + .add("messages=" + messages) + .add("stream=" + stream) + .add("topP=" + topP) + .add("n=" + n) + .add("stop=" + stop) + .add("presencePenalty=" + presencePenalty) + .add("frequencyPenalty=" + frequencyPenalty) + .add("logitBias=" + logitBias) + .add("user='" + user + "'") + .toString(); } public static ChatCompletionRequestBuilder builder() { @@ -70,8 +152,26 @@ public static class ChatCompletionRequestBuilder { private String model; private Double temperature; private List messages; + private Boolean stream; + + @JsonProperty("top_p") + private Double topP; - private Boolean stream = Boolean.FALSE; + private Integer n; + private List stop; + + @JsonProperty("presence_penalty") + private Double presencePenalty; + + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + @JsonProperty("logit_bias") + private Map logitBias; + + private String user; + + private ChatCompletionRequestBuilder() {} public ChatCompletionRequestBuilder model(String model) { this.model = model; @@ -88,13 +188,59 @@ public ChatCompletionRequestBuilder messages(List messages) { return this; } - public ChatCompletionRequestBuilder stream(Boolean value) { - this.stream = value; + public ChatCompletionRequestBuilder stream(Boolean stream) { + this.stream = stream; + return this; + } + + public ChatCompletionRequestBuilder topP(Double topP) { + this.topP = topP; + return this; + } + + public ChatCompletionRequestBuilder n(Integer n) { + this.n = n; + return this; + } + + public ChatCompletionRequestBuilder stop(List stop) { + this.stop = stop; + return this; + } + + public ChatCompletionRequestBuilder presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public ChatCompletionRequestBuilder frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public ChatCompletionRequestBuilder logitBias(Map logitBias) { + this.logitBias = logitBias; + return this; + } + + public ChatCompletionRequestBuilder user(String user) { + this.user = user; return this; } public ChatCompletionRequest build() { - return new ChatCompletionRequest(model, messages, temperature, stream); + return new ChatCompletionRequest( + model, + temperature, + messages, + stream, + topP, + n, + stop, + presencePenalty, + frequencyPenalty, + logitBias, + user); } } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatMessage.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatMessage.java index 6ea6007b6..44642313b 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatMessage.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/request/ChatMessage.java @@ -22,6 +22,10 @@ public String getContent() { return content; } + public void setRole(String role) { + this.role = role; + } + public void setContent(String content) { this.content = content; } @@ -33,8 +37,15 @@ public String toString() { public JSONObject toJson() { JSONObject json = new JSONObject(); - json.put("role", role); - json.put("content", content); + + if (role != null) { + json.put("role", role); + } + + if (content != null) { + json.put("content", content); + } + return json; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/response/ChatCompletionResponse.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/response/ChatCompletionResponse.java index d44acfc67..fd4ced252 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/response/ChatCompletionResponse.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/openai/response/ChatCompletionResponse.java @@ -86,13 +86,31 @@ public String toString() { @Override public JSONObject toJson() { JSONObject json = new JSONObject(); - json.put("id", id); - json.put("object", object); + + if (id != null) { + json.put("id", id); + } + + if (object != null) { + json.put("object", object); + } + json.put("created", created); - json.put("model", model); - json.put( - "choices", choices.stream().map(ChatCompletionChoice::toJson).collect(Collectors.toList())); - json.put("usage", usage.toJson()); + + if (model != null) { + json.put("model", model); + } + + if (choices != null) { + json.put( + "choices", + choices.stream().map(ChatCompletionChoice::toJson).collect(Collectors.toList())); + } + + if (usage != null) { + json.put("usage", usage.toJson()); + } + return json; } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java index dd047758f..875bba38c 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/request/ArkRequest.java @@ -9,12 +9,12 @@ import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; +import java.io.BufferedReader; import java.io.IOException; import java.security.Principal; import java.util.Collection; import java.util.Enumeration; import java.util.Objects; -import java.util.stream.Collectors; public class ArkRequest { @@ -99,8 +99,15 @@ public int getIntQueryParam(String key) { } public JSONObject getBody() { - try { - return new JSONObject(this.request.getReader().lines().collect(Collectors.joining())); + + StringBuilder jsonContent = new StringBuilder(); + + try (BufferedReader reader = request.getReader()) { + String line; + while ((line = reader.readLine()) != null) { + jsonContent.append(line); + } + return new JSONObject(jsonContent.toString()); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitter.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitter.java index 97a26403e..b2dda1187 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitter.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitter.java @@ -12,6 +12,6 @@ public ArkEmitter(EdgeChain edgeChain) { } public ArkEmitter(Observable observable) { - this.observer = new ArkEmitterObserver(observable, this); + this.observer = new ArkEmitterObserver<>(observable, this); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitterObserver.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitterObserver.java index fc0f2fbd8..16140aab7 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitterObserver.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkEmitterObserver.java @@ -39,6 +39,7 @@ public void onError(@NonNull Throwable e) { @Override public void onComplete() { + if (!completed) { completed = true; responseBodyEmitter.complete(); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkObservable.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkObservable.java index 88c816f56..8f1f56fdd 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkObservable.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/response/ArkObservable.java @@ -9,7 +9,7 @@ public class ArkObservable extends DeferredResult implements ArkResponse { private final ArkObserver observer; public ArkObservable(EdgeChain edgeChain) { - observer = new ArkObserver<>(edgeChain.getScheduledObservable(), this); + observer = new ArkObserver<>(edgeChain.getObservable(), this); } public ArkObservable(Observable observable) { diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java new file mode 100644 index 000000000..dc5744a9d --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/AirtableService.java @@ -0,0 +1,29 @@ +package com.edgechain.lib.retrofit; + +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import dev.fuxing.airtable.AirtableRecord; +import io.reactivex.rxjava3.core.Single; +import retrofit2.http.Body; +import retrofit2.http.HTTP; +import retrofit2.http.POST; + +import java.util.List; +import java.util.Map; + +public interface AirtableService { + + @POST("airtable/findAll") + Single> findAll(@Body AirtableEndpoint endpoint); + + @POST("airtable/findById") + Single findById(@Body AirtableEndpoint endpoint); + + @POST("airtable/create") + Single> create(@Body AirtableEndpoint endpoint); + + @POST("airtable/update") + Single> update(@Body AirtableEndpoint endpoint); + + @HTTP(method = "DELETE", path = "airtable/delete", hasBody = true) + Single> delete(@Body AirtableEndpoint endpoint); +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java index 3f8302378..ad7526fe5 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/BgeSmallService.java @@ -1,7 +1,7 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.POST; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java new file mode 100644 index 000000000..5f18ace6c --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java @@ -0,0 +1,18 @@ +package com.edgechain.lib.retrofit; + +import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; +import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import io.reactivex.rxjava3.core.Single; +import retrofit2.http.Body; +import retrofit2.http.POST; + +import java.util.List; + +public interface Llama2Service { + @POST(value = "llama/chat-completion") + Single> chatCompletion(@Body Llama2Endpoint llama2Endpoint); + + @POST(value = "llama/chat-completion") + Single llamaCompletion(@Body LLamaQuickstart lLamaQuickstart); +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java index a0cfaff77..96e09ba1b 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/MiniLMService.java @@ -2,7 +2,7 @@ import com.edgechain.lib.embeddings.miniLLM.response.MiniLMResponse; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.POST; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java index c1bd2e116..814b815a1 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/OpenAiService.java @@ -1,9 +1,10 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.response.OpenAiEmbeddingResponse; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.response.ChatCompletionResponse; -import io.reactivex.rxjava3.core.Completable; +import com.edgechain.lib.openai.response.CompletionResponse; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.POST; @@ -11,11 +12,11 @@ public interface OpenAiService { @POST(value = "openai/chat-completion") - Single chatCompletion(@Body OpenAiEndpoint openAiEndpoint); + Single chatCompletion(@Body OpenAiChatEndpoint OpenAiChatEndpoint); @POST(value = "openai/completion") - Single completion(@Body OpenAiEndpoint openAiEndpoint); + Single completion(@Body OpenAiChatEndpoint openAiChatEndpoint); @POST(value = "openai/embeddings") - Single embeddings(@Body OpenAiEndpoint openAiEndpoint); + Single embeddings(@Body OpenAiEmbeddingEndpoint openAiEmbeddingEndpoint); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java index 9d5282377..0a74c27dc 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PineconeService.java @@ -1,7 +1,7 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; @@ -15,6 +15,9 @@ public interface PineconeService { @POST(value = "index/pinecone/upsert") Single upsert(@Body PineconeEndpoint pineconeEndpoint); + @POST(value = "index/pinecone/batch-upsert") + Single batchUpsert(@Body PineconeEndpoint pineconeEndpoint); + @POST(value = "index/pinecone/query") Single> query(@Body PineconeEndpoint pineconeEndpoint); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java index b1b8325d9..d4b28f3dc 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgreSQLContextService.java @@ -2,7 +2,7 @@ import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import retrofit2.http.*; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java index bdbd229b7..a8701d952 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/PostgresService.java @@ -1,6 +1,6 @@ package com.edgechain.lib.retrofit; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Single; @@ -12,13 +12,46 @@ public interface PostgresService { + @POST(value = "index/postgres/create-table") + Single createTable(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/metadata/create-table") + Single createMetadataTable(@Body PostgresEndpoint postgresEndpoint); + @POST(value = "index/postgres/upsert") Single upsert(@Body PostgresEndpoint postgresEndpoint); - // + @POST(value = "index/postgres/batch-upsert") + Single> batchUpsert(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/metadata/insert") + Single insertMetadata(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/metadata/batch-insert") + Single> batchInsertMetadata(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/join/insert") + Single insertIntoJoinTable(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/join/batch-insert") + Single batchInsertIntoJoinTable(@Body PostgresEndpoint postgresEndpoint); + @POST(value = "index/postgres/query") Single> query(@Body PostgresEndpoint postgresEndpoint); + @POST(value = "index/postgres/query-rrf") + Single> queryRRF(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/metadata/query") + Single> queryWithMetadata(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/chunks") + Single> getAllChunks(@Body PostgresEndpoint postgresEndpoint); + + @POST(value = "index/postgres/similarity-metadata") + Single> getSimilarMetadataChunk( + @Body PostgresEndpoint postgresEndpoint); + @POST("index/postgres/probes") Single probes(@Body PostgresEndpoint postgresEndpoint); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java index bc327e916..faa4843e0 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisContextService.java @@ -3,7 +3,7 @@ import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import retrofit2.http.*; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java index 513fb5dc5..b29ba59f7 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/RedisService.java @@ -1,25 +1,30 @@ package com.edgechain.lib.retrofit; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.response.StringResponse; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; import retrofit2.http.HTTP; import retrofit2.http.POST; -import retrofit2.http.Query; import java.util.List; public interface RedisService { + @POST(value = "index/redis/create-index") + Single createIndex(@Body RedisEndpoint redisEndpoint); + @POST(value = "index/redis/upsert") Single upsert(@Body RedisEndpoint redisEndpoint); + @POST(value = "index/redis/batch-upsert") + Single batchUpsert(@Body RedisEndpoint redisEndpoint); + @POST(value = "index/redis/query") Single> query(@Body RedisEndpoint redisEndpoint); @HTTP(method = "DELETE", path = "index/redis/delete", hasBody = true) - Completable deleteByPattern(@Query("pattern") String pattern, @Body RedisEndpoint redisEndpoint); + Completable deleteByPattern(@Body RedisEndpoint redisEndpoint); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java index d5b2d33bf..6b24c965f 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/WikiService.java @@ -1,6 +1,6 @@ package com.edgechain.lib.retrofit; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.wiki.response.WikiResponse; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java index 4bdab038e..4392bb348 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/OpenAiStreamService.java @@ -2,11 +2,10 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.configuration.domain.SecurityUUID; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.response.ChatCompletionResponse; import com.edgechain.lib.utils.JsonUtils; import io.reactivex.rxjava3.core.Observable; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -22,10 +21,7 @@ public class OpenAiStreamService { @Autowired private SecurityUUID securityUUID; - public Observable chatCompletion(OpenAiEndpoint openAiEndpoint) { - - logger.info("Logging Chat Completion Stream...."); - logger.info("Prompt: " + StringUtils.join(openAiEndpoint.getChatMessages())); + public Observable chatCompletion(OpenAiChatEndpoint endpoint) { return RxJava3Adapter.fluxToObservable( WebClient.builder() @@ -43,7 +39,7 @@ public Observable chatCompletion(OpenAiEndpoint openAiEn httpHeaders.set("stream", "true"); httpHeaders.set("Authorization", securityUUID.getAuthKey()); }) - .bodyValue(JsonUtils.convertToString(openAiEndpoint)) + .bodyValue(JsonUtils.convertToString(endpoint)) .retrieve() .bodyToFlux(ChatCompletionResponse.class)); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java index f2b7c6b4b..b0fa2921d 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/client/RetrofitClientInstance.java @@ -13,6 +13,12 @@ import com.fasterxml.jackson.module.paramnames.ParameterNamesModule; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import okhttp3.ConnectionPool; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; @@ -21,23 +27,25 @@ import retrofit2.adapter.rxjava3.RxJava3CallAdapterFactory; import retrofit2.converter.jackson.JacksonConverterFactory; -import java.lang.reflect.Type; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.TimeUnit; - public class RetrofitClientInstance { - private static final String BASE_URL = "http://0.0.0.0"; + private RetrofitClientInstance() { + // no + } - private static SecurityUUID securityUUID = - ApplicationContextHolder.getContext().getBean(SecurityUUID.class); + private static final String BASE_URL = "http://0.0.0.0"; + private static SecurityUUID securityUUID; private static Retrofit retrofit; public static Retrofit getInstance() { if (retrofit == null) { + // tests may set this to a mock - do not overwrite it if present. + if (securityUUID == null) { + securityUUID = ApplicationContextHolder.getContext().getBean(SecurityUUID.class); + } + return retrofit = new Retrofit.Builder() .baseUrl( @@ -46,43 +54,49 @@ public static Retrofit getInstance() { + System.getProperty("server.port") + WebConfiguration.CONTEXT_PATH + "/") - .addConverterFactory(JacksonBuilder()) + .addConverterFactory(createJacksonFactory()) .addCallAdapterFactory(RxJava3CallAdapterFactory.create()) .client( new OkHttpClient.Builder() + .connectionPool(new ConnectionPool(10, 5, TimeUnit.MINUTES)) .addInterceptor( chain -> { - Request original = chain.request(); - Request request = - original - .newBuilder() - .header("Authorization", securityUUID.getAuthKey()) - .build(); - Response response = chain.proceed(request); - String body = response.body().string(); + try { + Request original = chain.request(); + final String authKey = securityUUID.getAuthKey(); + Request request = + original.newBuilder().header("Authorization", authKey).build(); + Response response = chain.proceed(request); + String body = response.body().string(); - String errorMessage = ""; + String errorMessage = ""; - if (!response.isSuccessful()) { - // Create a new Gson object - Gson gson = new Gson(); + if (!response.isSuccessful()) { + // Create a new Gson object + Gson gson = new Gson(); - // Define the type for the map - Type type = new TypeToken>() {}.getType(); + // Define the type for the map + Type type = new TypeToken>() {}.getType(); - // Convert JSON string into a map - Map map = gson.fromJson(body, type); + // Convert JSON string into a map + Map map = gson.fromJson(body, type); - if (Objects.nonNull(map)) { - errorMessage = map.toString(); + if (Objects.nonNull(map)) { + errorMessage = map.toString(); + } } - } - return response - .newBuilder() - .body(ResponseBody.create(body, response.body().contentType())) - .message(errorMessage) - .build(); + return response + .newBuilder() + .body(ResponseBody.create(body, response.body().contentType())) + .message(errorMessage) + .build(); + } catch (Exception e) { + // Interceptor can handle only IOException. Anything else = stall. + // Solution: wrap any exception in an IOException. + // Read more here: https://github.com/square/retrofit/issues/3453 + throw new IOException(e); + } }) .connectTimeout(15, TimeUnit.MINUTES) .readTimeout(20, TimeUnit.MINUTES) @@ -92,7 +106,7 @@ public static Retrofit getInstance() { return retrofit; } - private static JacksonConverterFactory JacksonBuilder() { + private static JacksonConverterFactory createJacksonFactory() { ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); objectMapper.registerModule(new JavaTimeModule()); @@ -100,7 +114,6 @@ private static JacksonConverterFactory JacksonBuilder() { objectMapper.registerModule(new Jdk8Module()); objectMapper.registerModule(new PageJacksonModule()); objectMapper.registerModule(new SortJacksonModule()); - objectMapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); return JacksonConverterFactory.create(objectMapper); } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/logger/JsonnetLoggerService.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/logger/JsonnetLoggerService.java new file mode 100644 index 000000000..85bf99e7f --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/logger/JsonnetLoggerService.java @@ -0,0 +1,24 @@ +package com.edgechain.lib.retrofit.logger; + +import com.edgechain.lib.logger.entities.JsonnetLog; +import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; +import org.springframework.data.domain.Page; +import retrofit2.http.Body; +import retrofit2.http.GET; +import retrofit2.http.POST; +import retrofit2.http.Path; + +public interface JsonnetLoggerService { + + @GET(value = "logs/jsonnet/findAll/{page}/{size}") + Single> findAll(@Path("page") int page, @Path("size") int size); + + @GET(value = "logs/jsonnet/findAll/sorted/{page}/{size}") + Single> findAllOrderByCreatedAtDesc( + @Path("page") int page, @Path("size") int size); + + @POST(value = "logs/jsonnet/findByName/sorted/{page}/{size}") + Single> findAllBySelectedFileOrderByCreatedAtDesc( + @Body HashMap mapper, @Path("page") int page, @Path("size") int size); +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/AbstractEdgeChain.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/AbstractEdgeChain.java index 2546488aa..f4e3cdb99 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/AbstractEdgeChain.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/AbstractEdgeChain.java @@ -2,10 +2,19 @@ import com.edgechain.lib.rxjava.retry.RetryPolicy; import io.reactivex.rxjava3.annotations.NonNull; -import io.reactivex.rxjava3.core.*; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Notification; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.core.ObservableSource; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; -import io.reactivex.rxjava3.functions.*; - +import io.reactivex.rxjava3.functions.Action; +import io.reactivex.rxjava3.functions.BiFunction; +import io.reactivex.rxjava3.functions.BooleanSupplier; +import io.reactivex.rxjava3.functions.Consumer; +import io.reactivex.rxjava3.functions.Function; +import io.reactivex.rxjava3.functions.Predicate; import java.io.Serializable; public abstract class AbstractEdgeChain implements Serializable { @@ -14,7 +23,7 @@ public abstract class AbstractEdgeChain implements Serializable { protected Observable observable; - public AbstractEdgeChain(Observable observable) { + protected AbstractEdgeChain(Observable observable) { this.observable = observable; } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java index a030780d5..7e3f42cb6 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/rxjava/transformer/observable/EdgeChain.java @@ -178,6 +178,13 @@ public Single toSingle() { else return this.observable.subscribeOn(Schedulers.io()).firstOrError(); } + public Single toSingleWithoutScheduler() { + + if (RetryUtils.available(endpoint)) + return this.observable.retryWhen(endpoint.getRetryPolicy()).firstOrError(); + else return this.observable.firstOrError(); + } + @Override public T get() { if (RetryUtils.available(endpoint)) @@ -220,10 +227,10 @@ public void completed(Action onComplete, Consumer onError) { } public ArkResponse getArkResponse() { - return new ArkObservable<>(this.observable.subscribeOn(Schedulers.io())); + return new ArkObservable<>(this.observable); } public ArkResponse getArkStreamResponse() { - return new ArkEmitter<>(this.observable.subscribeOn(Schedulers.io())); + return new ArkEmitter<>(this.observable); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtFilter.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtFilter.java index 9fe5614ca..480b62fc9 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtFilter.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtFilter.java @@ -13,7 +13,6 @@ import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.Objects; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; @@ -60,6 +59,9 @@ protected void doFilterInternal( try { claimsJws = jwtHelper.parseToken(token); } catch (final Exception e) { + // use Spring Security logger here instead of SLF4J + logger.info("JWT not accepted: %s".formatted(e.getMessage())); + ErrorResponse errorResponse = new ErrorResponse(e.getMessage()); response.setContentType(MediaType.APPLICATION_JSON_VALUE); response.getWriter().print(JsonUtils.convertToString(errorResponse)); @@ -70,6 +72,9 @@ protected void doFilterInternal( String email = (String) claimsJws.getBody().get("email"); String role = (String) claimsJws.getBody().get("role"); + // use Spring Security logger here instead of SLF4J + logger.info("JWT email=%s role=%s".formatted(email, role)); + User user = new User(); user.setEmail(email); user.setAccessToken(token); diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtHelper.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtHelper.java index a97d6112b..2ff8fcf58 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtHelper.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/JwtHelper.java @@ -2,6 +2,7 @@ import io.jsonwebtoken.*; import java.security.Key; +import java.util.Objects; import javax.crypto.spec.SecretKeySpec; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.env.Environment; @@ -14,10 +15,10 @@ public class JwtHelper { public Jws parseToken(String accessToken) { try { - Key hmacKey = - new SecretKeySpec( - env.getProperty("jwt.secret").getBytes(), SignatureAlgorithm.HS256.getJcaName()); - + final String secret = env.getProperty("jwt.secret"); + Objects.requireNonNull(secret, "JWT secret not set"); + final byte[] bytes = secret.getBytes(); + final Key hmacKey = new SecretKeySpec(bytes, SignatureAlgorithm.HS256.getJcaName()); return Jwts.parser().setSigningKey(hmacKey).parseClaimsJws(accessToken); } catch (MalformedJwtException e) { @@ -33,27 +34,12 @@ public Jws parseToken(String accessToken) { } } - // validate public boolean validate(String accessToken) { try { - Key hmacKey = - new SecretKeySpec( - env.getProperty("jwt.secret").getBytes(), SignatureAlgorithm.HS256.getJcaName()); - - // String encoded = - // Base64.getEncoder().encodeToString(this.supabaseEnv.getJwtSecret().getBytes()); - Jwts.parser().setSigningKey(hmacKey).parseClaimsJws(accessToken); + parseToken(accessToken); return true; - } catch (MalformedJwtException e) { - throw new JwtException("Token Malformed"); - } catch (UnsupportedJwtException e) { - throw new JwtException("Token Unsupported"); - } catch (ExpiredJwtException e) { - throw new JwtException("Token Expired"); - } catch (IllegalArgumentException e) { - throw new JwtException("Token Empty"); - } catch (SignatureException e) { - throw new JwtException("Token Signature Failed"); + } catch (JwtException e) { + return false; } } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/WebSecurity.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/WebSecurity.java index cd35fabc8..bc6494253 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/WebSecurity.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/supabase/security/WebSecurity.java @@ -2,6 +2,9 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.configuration.domain.AuthFilter; +import java.util.Arrays; +import java.util.Objects; +import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.context.annotation.Bean; @@ -15,6 +18,7 @@ import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configurers.AuthorizeHttpRequestsConfigurer; import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; @@ -25,9 +29,6 @@ import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import org.springframework.web.filter.CorsFilter; -import java.util.Arrays; -import java.util.Objects; - @EnableWebSecurity @EnableMethodSecurity @Configuration @@ -38,62 +39,94 @@ public class WebSecurity { @Autowired private JwtFilter jwtFilter; @Bean - public AuthenticationManager authenticationManager(AuthenticationConfiguration config) - throws Exception { + AuthenticationManager authenticationManager(AuthenticationConfiguration config) throws Exception { return config.getAuthenticationManager(); } @Bean - public PasswordEncoder passwordEncoder() { + PasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); } @Bean - public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { - - return http.cors() - .configurationSource(corsConfiguration()) - .and() - .csrf() - .disable() - .authorizeHttpRequests( - (auth) -> { - try { - auth.requestMatchers("" + WebConfiguration.CONTEXT_PATH + "/**") - .permitAll() - .requestMatchers(HttpMethod.POST, authFilter.getRequestPost().getRequests()) - .hasAnyAuthority(authFilter.getRequestPost().getAuthorities()) - // - .requestMatchers(HttpMethod.GET, authFilter.getRequestGet().getRequests()) - .hasAnyAuthority(authFilter.getRequestGet().getAuthorities()) - .requestMatchers(HttpMethod.DELETE, authFilter.getRequestDelete().getRequests()) - .hasAnyAuthority(authFilter.getRequestDelete().getAuthorities()) - .requestMatchers(HttpMethod.PUT, authFilter.getRequestPut().getRequests()) - .hasAnyAuthority(authFilter.getRequestPut().getAuthorities()) - .requestMatchers(HttpMethod.PATCH, authFilter.getRequestPatch().getRequests()) - .hasAnyAuthority(authFilter.getRequestPatch().getAuthorities()) - .anyRequest() - .permitAll(); - - } catch (Exception e) { - throw new RuntimeException(e); - } - }) + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + + http.cors(cors -> cors.configurationSource(corsConfiguration())) + .csrf(csrf -> csrf.disable()) + .authorizeHttpRequests(auth -> buildAuth(auth)) .httpBasic(Customizer.withDefaults()) .addFilterBefore(jwtFilter, UsernamePasswordAuthenticationFilter.class) - .sessionManagement() - .sessionCreationPolicy(SessionCreationPolicy.STATELESS) - .and() - .exceptionHandling() - .and() + .sessionManagement( + management -> management.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) + .exceptionHandling(Customizer.withDefaults()) .securityContext(c -> c.requireExplicitSave(false)) - .formLogin() - .disable() - .build(); + .formLogin(login -> login.disable()); + return http.build(); + } + + private AuthorizeHttpRequestsConfigurer.AuthorizationManagerRequestMatcherRegistry + buildAuth( + AuthorizeHttpRequestsConfigurer.AuthorizationManagerRequestMatcherRegistry + auth) { + AuthorizeHttpRequestsConfigurer.AuthorizationManagerRequestMatcherRegistry reg = + auth.requestMatchers("" + WebConfiguration.CONTEXT_PATH + "/**").permitAll(); + + reg = + applyAuth( + reg.requestMatchers( + HttpMethod.POST, safeRequests(authFilter.getRequestPost().getRequests(), "POST")), + authFilter.getRequestPost().getAuthorities()); + reg = + applyAuth( + reg.requestMatchers( + HttpMethod.GET, safeRequests(authFilter.getRequestGet().getRequests(), "GET")), + authFilter.getRequestGet().getAuthorities()); + reg = + applyAuth( + reg.requestMatchers( + HttpMethod.DELETE, + safeRequests(authFilter.getRequestDelete().getRequests(), "DELETE")), + authFilter.getRequestDelete().getAuthorities()); + reg = + applyAuth( + reg.requestMatchers( + HttpMethod.PUT, safeRequests(authFilter.getRequestPut().getRequests(), "PUT")), + authFilter.getRequestPut().getAuthorities()); + reg = + applyAuth( + reg.requestMatchers( + HttpMethod.PATCH, + safeRequests(authFilter.getRequestPatch().getRequests(), "PATCH")), + authFilter.getRequestPatch().getAuthorities()); + + reg = reg.anyRequest().permitAll(); + return reg; + } + + private String[] safeRequests(String[] src, String method) { + if (src == null || src.length == 0 || (src.length == 1 && src[0].isEmpty())) { + LoggerFactory.getLogger(getClass()) + .warn( + "Http {} security request patterns outdated. Fixed to a list with one String \"**\" -" + + " please update your configuration", + method); + return new String[] {"**"}; + } else { + return src; + } + } + + private AuthorizeHttpRequestsConfigurer.AuthorizationManagerRequestMatcherRegistry + applyAuth(AuthorizeHttpRequestsConfigurer.AuthorizedUrl url, String[] auths) { + if (auths == null || auths.length == 0 || (auths.length == 1 && auths[0].isEmpty())) { + return url.permitAll(); + } else { + return url.hasAnyAuthority(auths); + } } @Bean - public CorsConfigurationSource corsConfiguration() { + CorsConfigurationSource corsConfiguration() { CorsConfiguration configuration = new CorsConfiguration(); @@ -126,7 +159,7 @@ public CorsConfigurationSource corsConfiguration() { } @Bean - public FilterRegistrationBean corsFilter() { + FilterRegistrationBean corsFilter() { FilterRegistrationBean bean = new FilterRegistrationBean<>(new CorsFilter(corsConfiguration())); bean.setOrder(Ordered.HIGHEST_PRECEDENCE); @@ -134,7 +167,7 @@ public FilterRegistrationBean corsFilter() { } @Bean - public FilterRegistrationBean jwtFilterFilterRegistrationBean(JwtFilter jwtFilter) { + FilterRegistrationBean jwtFilterFilterRegistrationBean(JwtFilter jwtFilter) { FilterRegistrationBean registrationBean = new FilterRegistrationBean<>(jwtFilter); registrationBean.setEnabled(false); return registrationBean; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java new file mode 100644 index 000000000..a3fb83282 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/utils/ContextReorder.java @@ -0,0 +1,50 @@ +package com.edgechain.lib.utils; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.index.domain.PostgresWordEmbeddings; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +@Component +public class ContextReorder { + + public List reorderWordEmbeddings(List wordEmbeddingsList) { + + wordEmbeddingsList.sort(Comparator.comparingDouble(WordEmbeddings::getScore).reversed()); + + int mid = wordEmbeddingsList.size() / 2; + + List modifiedList = new ArrayList<>(wordEmbeddingsList.subList(0, mid)); + + List secondHalfList = + wordEmbeddingsList.subList(mid, wordEmbeddingsList.size()); + secondHalfList.sort(Comparator.comparingDouble(WordEmbeddings::getScore)); + + modifiedList.addAll(secondHalfList); + + return modifiedList; + } + + public List reorderPostgresWordEmbeddings( + List postgresWordEmbeddings) { + + postgresWordEmbeddings.sort( + Comparator.comparingDouble(PostgresWordEmbeddings::getScore).reversed()); + + int mid = postgresWordEmbeddings.size() / 2; + + List modifiedList = + new ArrayList<>(postgresWordEmbeddings.subList(0, mid)); + + List secondHalfList = + postgresWordEmbeddings.subList(mid, postgresWordEmbeddings.size()); + secondHalfList.sort(Comparator.comparingDouble(PostgresWordEmbeddings::getScore)); + + modifiedList.addAll(secondHalfList); + + return modifiedList; + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java index bd647cd72..677b543c4 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/wiki/client/WikiClient.java @@ -1,6 +1,6 @@ package com.edgechain.lib.wiki.client; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import com.edgechain.lib.wiki.response.WikiResponse; import com.fasterxml.jackson.databind.JsonNode; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java index 41210f1cc..57e23b21a 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/bgeSmall/BgeSmallController.java @@ -3,7 +3,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.bgeSmall.BgeSmallClient; import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse; -import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.services.EmbeddingLogService; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; @@ -32,10 +32,8 @@ public class BgeSmallController { @PostMapping public Single embeddings(@RequestBody BgeSmallEndpoint bgeSmallEndpoint) { - this.bgeSmallClient.setEndpoint(bgeSmallEndpoint); - EdgeChain edgeChain = - this.bgeSmallClient.createEmbeddings(bgeSmallEndpoint.getInput()); + this.bgeSmallClient.createEmbeddings(bgeSmallEndpoint.getRawText(), bgeSmallEndpoint); if (Objects.nonNull(env.getProperty("postgres.db.host"))) { @@ -53,9 +51,9 @@ public Single embeddings(@RequestBody BgeSmallEndpoint bgeSmal embeddingLog.setLatency(duration.toMillis()); embeddingLogService.saveOrUpdate(embeddingLog); }) - .toSingle(); + .toSingleWithoutScheduler(); } - return edgeChain.toSingle(); + return edgeChain.toSingleWithoutScheduler(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java index cf5d8d16b..fce8831c0 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/PostgreSQLHistoryContextController.java @@ -4,7 +4,7 @@ import com.edgechain.lib.context.client.impl.PostgreSQLHistoryContextClient; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.PostgreSQLHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.PostgreSQLHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import org.springframework.beans.factory.annotation.Autowired; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java index ce97d689a..1a4c4a402 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/context/RedisHistoryContextController.java @@ -4,7 +4,7 @@ import com.edgechain.lib.context.client.impl.RedisHistoryContextClient; import com.edgechain.lib.context.domain.ContextPutRequest; import com.edgechain.lib.context.domain.HistoryContext; -import com.edgechain.lib.endpoint.impl.RedisHistoryContextEndpoint; +import com.edgechain.lib.endpoint.impl.context.RedisHistoryContextEndpoint; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import org.springframework.beans.factory.annotation.Autowired; diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java index 8a7a8a527..98d3cd9f3 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PineconeController.java @@ -2,10 +2,9 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.PineconeEndpoint; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; import com.edgechain.lib.index.client.impl.PineconeClient; import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Single; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -20,32 +19,21 @@ public class PineconeController { @PostMapping("/upsert") public Single upsert(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.upsert(pineconeEndpoint).toSingle(); + } - pineconeClient.setEndpoint(pineconeEndpoint); - - EdgeChain edgeChain = - pineconeClient.upsert(pineconeEndpoint.getWordEmbeddings()); - - return edgeChain.toSingle(); + @PostMapping("/batch-upsert") + public Single batchUpsert(@RequestBody PineconeEndpoint pineconeEndpoint) { + return pineconeClient.batchUpsert(pineconeEndpoint).toSingleWithoutScheduler(); } @PostMapping("/query") public Single> query(@RequestBody PineconeEndpoint pineconeEndpoint) { - - pineconeClient.setEndpoint(pineconeEndpoint); - - EdgeChain> edgeChain = - pineconeClient.query(pineconeEndpoint.getWordEmbeddings(), pineconeEndpoint.getTopK()); - - return edgeChain.toSingle(); + return pineconeClient.query(pineconeEndpoint).toSingle(); } @DeleteMapping("/deleteAll") public Single deleteAll(@RequestBody PineconeEndpoint pineconeEndpoint) { - - pineconeClient.setEndpoint(pineconeEndpoint); - - EdgeChain edgeChain = pineconeClient.deleteAll(); - return edgeChain.toSingle(); + return pineconeClient.deleteAll(pineconeEndpoint).toSingle(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java index 58491b687..a83f48db3 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/PostgresController.java @@ -1,7 +1,7 @@ package com.edgechain.service.controllers.index; import com.edgechain.lib.configuration.WebConfiguration; -import com.edgechain.lib.endpoint.impl.PostgresEndpoint; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; import com.edgechain.lib.index.client.impl.PostgresClient; import com.edgechain.lib.index.domain.PostgresWordEmbeddings; import com.edgechain.lib.response.StringResponse; @@ -20,38 +20,85 @@ public class PostgresController { @Autowired @Lazy private PostgresClient postgresClient; + @PostMapping("/create-table") + public Single createTable(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.createTable(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/create-table") + public Single createMetadataTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.createMetadataTable(postgresEndpoint).toSingle(); + } + @PostMapping("/upsert") public Single upsert(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.upsert(postgresEndpoint).toSingle(); + } - this.postgresClient.setPostgresEndpoint(postgresEndpoint); - EdgeChain edgeChain = - this.postgresClient.upsert(postgresEndpoint.getWordEmbeddings()); + @PostMapping("/batch-upsert") + public Single> batchUpsert(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.batchUpsert(postgresEndpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/metadata/insert") + public Single insertMetadata(@RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.insertMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/metadata/batch-insert") + public Single> batchInsertMetadata( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.batchInsertMetadata(postgresEndpoint).toSingle(); + } + + @PostMapping("/join/insert") + public Single insertIntoJoinTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + EdgeChain edgeChain = this.postgresClient.insertIntoJoinTable(postgresEndpoint); + return edgeChain.toSingle(); + } + @PostMapping("/join/batch-insert") + public Single batchInsertIntoJoinTable( + @RequestBody PostgresEndpoint postgresEndpoint) { + EdgeChain edgeChain = + this.postgresClient.batchInsertIntoJoinTable(postgresEndpoint); return edgeChain.toSingle(); } @PostMapping("/query") public Single> query( @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.query(postgresEndpoint).toSingle(); + } + + @PostMapping("/query-rrf") + public Single> queryRRF( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryRRF(postgresEndpoint).toSingle(); + } - this.postgresClient.setPostgresEndpoint(postgresEndpoint); + @PostMapping("/metadata/query") + public Single> queryWithMetadata( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.queryWithMetadata(postgresEndpoint).toSingle(); + } - EdgeChain> edgeChain = - this.postgresClient.query( - postgresEndpoint.getWordEmbeddings(), - postgresEndpoint.getMetric(), - postgresEndpoint.getTopK(), - postgresEndpoint.getProbes()); + @PostMapping("/chunks") + public Single> getAllChunks( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.getAllChunks(postgresEndpoint).toSingle(); + } - return edgeChain.toSingle(); + @PostMapping("/similarity-metadata") + public Single> getSimilarMetadataChunk( + @RequestBody PostgresEndpoint postgresEndpoint) { + return this.postgresClient.getSimilarMetadataChunk(postgresEndpoint).toSingle(); } @DeleteMapping("/deleteAll") public Single deleteAll(@RequestBody PostgresEndpoint postgresEndpoint) { - - this.postgresClient.setPostgresEndpoint(postgresEndpoint); - - EdgeChain edgeChain = this.postgresClient.deleteAll(); - return edgeChain.toSingle(); + return this.postgresClient.deleteAll(postgresEndpoint).toSingle(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java index 6fcd2e566..1d4916f38 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/index/RedisController.java @@ -2,10 +2,9 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.WordEmbeddings; -import com.edgechain.lib.endpoint.impl.RedisEndpoint; +import com.edgechain.lib.endpoint.impl.index.RedisEndpoint; import com.edgechain.lib.index.client.impl.RedisClient; import com.edgechain.lib.response.StringResponse; -import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import org.springframework.beans.factory.annotation.Autowired; @@ -20,38 +19,28 @@ public class RedisController { @Autowired @Lazy private RedisClient redisClient; + @PostMapping("/create-index") + public Single createIndex(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.createIndex(redisEndpoint).toSingle(); + } + @PostMapping("/upsert") public Single upsert(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.upsert(redisEndpoint).toSingle(); + } - this.redisClient.setEndpoint(redisEndpoint); - - EdgeChain edgeChain = - this.redisClient.upsert( - redisEndpoint.getWordEmbeddings(), - redisEndpoint.getDimensions(), - redisEndpoint.getMetric()); - - return edgeChain.toSingle(); + @PostMapping("/batch-upsert") + public Single batchUpsert(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.batchUpsert(redisEndpoint).toSingleWithoutScheduler(); } @PostMapping("/query") public Single> query(@RequestBody RedisEndpoint redisEndpoint) { - - this.redisClient.setEndpoint(redisEndpoint); - - EdgeChain> edgeChain = - this.redisClient.query(redisEndpoint.getWordEmbeddings(), redisEndpoint.getTopK()); - - return edgeChain.toSingle(); + return this.redisClient.query(redisEndpoint).toSingle(); } @DeleteMapping("/delete") - public Completable deleteByPattern( - @RequestParam("pattern") String pattern, @RequestBody RedisEndpoint redisEndpoint) { - - this.redisClient.setEndpoint(redisEndpoint); - - EdgeChain edgeChain = this.redisClient.deleteByPattern(pattern); - return edgeChain.await(); + public Completable deleteByPattern(@RequestBody RedisEndpoint redisEndpoint) { + return this.redisClient.deleteByPattern(redisEndpoint).await(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java new file mode 100644 index 000000000..741fe2362 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/integration/AirtableController.java @@ -0,0 +1,44 @@ +package com.edgechain.service.controllers.integration; + +import com.edgechain.lib.configuration.WebConfiguration; +import com.edgechain.lib.endpoint.impl.integration.AirtableEndpoint; +import com.edgechain.lib.integration.airtable.client.AirtableClient; +import dev.fuxing.airtable.AirtableRecord; +import io.reactivex.rxjava3.core.Single; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; + +import java.util.List; +import java.util.Map; + +@RestController("Service AirtableController") +@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/airtable") +public class AirtableController { + + @Autowired private AirtableClient airtableClient; + + @PostMapping("/findAll") + public Single> findAll(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.findAll(endpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/findById") + public Single findById(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.findById(endpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/create") + public Single> create(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.create(endpoint).toSingleWithoutScheduler(); + } + + @PostMapping("/update") + public Single> update(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.update(endpoint).toSingleWithoutScheduler(); + } + + @DeleteMapping("/delete") + public Single> delete(@RequestBody AirtableEndpoint endpoint) { + return airtableClient.delete(endpoint).toSingleWithoutScheduler(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java new file mode 100644 index 000000000..d998c28bf --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java @@ -0,0 +1,30 @@ +package com.edgechain.service.controllers.llama2; + +import com.edgechain.lib.configuration.WebConfiguration; +import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; +import com.edgechain.lib.llama2.Llama2Client; +import com.edgechain.lib.logger.services.ChatCompletionLogService; +import com.edgechain.lib.logger.services.JsonnetLogService; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Single; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.env.Environment; +import org.springframework.web.bind.annotation.*; + +@RestController("Service Llama2Controller") +@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama") +public class Llama2Controller { + @Autowired private ChatCompletionLogService chatCompletionLogService; + + @Autowired private JsonnetLogService jsonnetLogService; + + @Autowired private Environment env; + @Autowired private Llama2Client llama2Client; + + @PostMapping(value = "/chat-completion") + public Single getChatCompletion(@RequestBody LLamaQuickstart endpoint) { + + EdgeChain edgeChain = llama2Client.createGetChatCompletion(endpoint); + return edgeChain.toSingle(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/EmbeddingLogController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/EmbeddingLogController.java index d01d214b5..3cad9c5b0 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/EmbeddingLogController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/EmbeddingLogController.java @@ -3,13 +3,16 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.services.EmbeddingLogService; -import com.edgechain.lib.logger.services.EmbeddingLogService; +import java.util.HashMap; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; -import org.springframework.web.bind.annotation.*; - -import java.util.HashMap; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; @RestController("Service EmbeddingLogController") @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/logs/embeddings") diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/JsonnetLogController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/JsonnetLogController.java new file mode 100644 index 000000000..80da69078 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/logging/JsonnetLogController.java @@ -0,0 +1,36 @@ +package com.edgechain.service.controllers.logging; + +import com.edgechain.lib.configuration.WebConfiguration; +import com.edgechain.lib.logger.entities.JsonnetLog; +import java.util.HashMap; + +import com.edgechain.lib.logger.services.JsonnetLogService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageRequest; +import org.springframework.web.bind.annotation.*; + +@RestController("Service JsonnetLogControllers") +@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/logs/jsonnet") +public class JsonnetLogController { + + @Autowired private JsonnetLogService jsonnetLogService; + + @GetMapping("/findAll/{page}/{size}") + public Page findAll(@PathVariable int page, @PathVariable int size) { + return this.jsonnetLogService.findAll(PageRequest.of(page, size)); + } + + @GetMapping("/findAll/sorted/{page}/{size}") + public Page findAllByOrderByCreatedAtDesc( + @PathVariable int page, @PathVariable int size) { + return this.jsonnetLogService.findAllOrderByCreatedAtDesc(PageRequest.of(page, size)); + } + + @PostMapping("/findByName/sorted/{page}/{size}") + public Page findAllBySelectedFileOrderByCreatedAtDesc( + @RequestBody HashMap mapper, @PathVariable int page, @PathVariable int size) { + return this.jsonnetLogService.findAllBySelectedFileOrderByCreatedAtDesc( + mapper.get("filename"), PageRequest.of(page, size)); + } +} diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java index 56824fc07..9ab3ea75e 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/miniLM/MiniLMController.java @@ -3,7 +3,7 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.miniLLM.MiniLMClient; import com.edgechain.lib.embeddings.miniLLM.response.MiniLMResponse; -import com.edgechain.lib.endpoint.impl.MiniLMEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.MiniLMEndpoint; import com.edgechain.lib.logger.entities.EmbeddingLog; import com.edgechain.lib.logger.services.EmbeddingLogService; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; @@ -32,10 +32,8 @@ public class MiniLMController { @PostMapping public Single embeddings(@RequestBody MiniLMEndpoint miniLMEndpoint) { - this.miniLMClient.setEndpoint(miniLMEndpoint); - EdgeChain edgeChain = - this.miniLMClient.createEmbeddings(miniLMEndpoint.getInput()); + this.miniLMClient.createEmbeddings(miniLMEndpoint.getRawText(), miniLMEndpoint); if (Objects.nonNull(env.getProperty("postgres.db.host"))) { @@ -53,9 +51,9 @@ public Single embeddings(@RequestBody MiniLMEndpoint miniLMEndpo embeddingLog.setLatency(duration.toMillis()); embeddingLogService.saveOrUpdate(embeddingLog); }) - .toSingle(); + .toSingleWithoutScheduler(); } - return edgeChain.toSingle(); + return edgeChain.toSingleWithoutScheduler(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java index addb60116..6d3c6f8a1 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/openai/OpenAiController.java @@ -3,11 +3,14 @@ import com.edgechain.lib.configuration.WebConfiguration; import com.edgechain.lib.embeddings.request.OpenAiEmbeddingRequest; import com.edgechain.lib.embeddings.response.OpenAiEmbeddingResponse; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.embeddings.OpenAiEmbeddingEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.logger.entities.ChatCompletionLog; import com.edgechain.lib.logger.entities.EmbeddingLog; +import com.edgechain.lib.logger.entities.JsonnetLog; import com.edgechain.lib.logger.services.ChatCompletionLogService; import com.edgechain.lib.logger.services.EmbeddingLogService; +import com.edgechain.lib.logger.services.JsonnetLogService; import com.edgechain.lib.openai.client.OpenAiClient; import com.edgechain.lib.openai.request.ChatCompletionRequest; import com.edgechain.lib.openai.request.ChatMessage; @@ -32,8 +35,6 @@ import java.sql.SQLException; import java.time.Duration; import java.time.LocalDateTime; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -44,12 +45,14 @@ public class OpenAiController { @Autowired private ChatCompletionLogService chatCompletionLogService; @Autowired private EmbeddingLogService embeddingLogService; + @Autowired private JsonnetLogService jsonnetLogService; @Autowired private Environment env; @Autowired private OpenAiClient openAiClient; @PostMapping(value = "/chat-completion") - public Single chatCompletion(@RequestBody OpenAiEndpoint openAiEndpoint) { + public Single chatCompletion( + @RequestBody OpenAiChatEndpoint openAiEndpoint) { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() @@ -57,38 +60,61 @@ public Single chatCompletion(@RequestBody OpenAiEndpoint .temperature(openAiEndpoint.getTemperature()) .messages(openAiEndpoint.getChatMessages()) .stream(false) + .topP(openAiEndpoint.getTopP()) + .n(openAiEndpoint.getN()) + .stop(openAiEndpoint.getStop()) + .presencePenalty(openAiEndpoint.getPresencePenalty()) + .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) + .logitBias(openAiEndpoint.getLogitBias()) + .user(openAiEndpoint.getUser()) .build(); - this.openAiClient.setEndpoint(openAiEndpoint); - EdgeChain edgeChain = - openAiClient.createChatCompletion(chatCompletionRequest); + openAiClient.createChatCompletion(chatCompletionRequest, openAiEndpoint); if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - ChatCompletionLog chatCompletionLog = new ChatCompletionLog(); - chatCompletionLog.setName(openAiEndpoint.getChainName()); - chatCompletionLog.setCreatedAt(LocalDateTime.now()); - chatCompletionLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - chatCompletionLog.setInput(StringUtils.join(openAiEndpoint.getChatMessages())); - chatCompletionLog.setModel(openAiEndpoint.getModel()); + ChatCompletionLog chatLog = new ChatCompletionLog(); + chatLog.setName(openAiEndpoint.getChainName()); + chatLog.setCreatedAt(LocalDateTime.now()); + chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); + chatLog.setModel(chatCompletionRequest.getModel()); + + chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); + chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); + chatLog.setTopP(chatCompletionRequest.getTopP()); + chatLog.setN(chatCompletionRequest.getN()); + chatLog.setTemperature(chatCompletionRequest.getTemperature()); return edgeChain .doOnNext( c -> { - chatCompletionLog.setPromptTokens(c.getUsage().getPrompt_tokens()); - chatCompletionLog.setTotalTokens(c.getUsage().getTotal_tokens()); - chatCompletionLog.setContent(c.getChoices().get(0).getMessage().getContent()); - chatCompletionLog.setType(c.getObject()); + chatLog.setPromptTokens(c.getUsage().getPrompt_tokens()); + chatLog.setTotalTokens(c.getUsage().getTotal_tokens()); + chatLog.setContent(c.getChoices().get(0).getMessage().getContent()); + chatLog.setType(c.getObject()); - chatCompletionLog.setCompletedAt(LocalDateTime.now()); + chatLog.setCompletedAt(LocalDateTime.now()); Duration duration = - Duration.between( - chatCompletionLog.getCreatedAt(), chatCompletionLog.getCompletedAt()); - chatCompletionLog.setLatency(duration.toMillis()); - - chatCompletionLogService.saveOrUpdate(chatCompletionLog); + Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); + chatLog.setLatency(duration.toMillis()); + + chatCompletionLogService.saveOrUpdate(chatLog); + + if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) + && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { + JsonnetLog jsonnetLog = new JsonnetLog(); + jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); + jsonnetLog.setContent(c.getChoices().get(0).getMessage().getContent()); + jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); + jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); + jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); + jsonnetLog.setCreatedAt(LocalDateTime.now()); + jsonnetLog.setSelectedFile(openAiEndpoint.getJsonnetLoader().getSelectedFile()); + jsonnetLogService.saveOrUpdate(jsonnetLog); + } }) .toSingle(); @@ -98,7 +124,7 @@ public Single chatCompletion(@RequestBody OpenAiEndpoint @PostMapping( value = "/chat-completion-stream", consumes = {MediaType.APPLICATION_JSON_VALUE}) - public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoint) { + public SseEmitter chatCompletionStream(@RequestBody OpenAiChatEndpoint openAiEndpoint) { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() @@ -106,10 +132,14 @@ public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoin .temperature(openAiEndpoint.getTemperature()) .messages(openAiEndpoint.getChatMessages()) .stream(true) + .topP(openAiEndpoint.getTopP()) + .n(openAiEndpoint.getN()) + .stop(openAiEndpoint.getStop()) + .presencePenalty(openAiEndpoint.getPresencePenalty()) + .frequencyPenalty(openAiEndpoint.getFrequencyPenalty()) + .logitBias(openAiEndpoint.getLogitBias()) + .user(openAiEndpoint.getUser()) .build(); - - this.openAiClient.setEndpoint(openAiEndpoint); - SseEmitter emitter = new SseEmitter(); ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -117,18 +147,24 @@ public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoin () -> { try { EdgeChain edgeChain = - openAiClient.createChatCompletionStream(chatCompletionRequest); + openAiClient.createChatCompletionStream(chatCompletionRequest, openAiEndpoint); AtomInteger chunks = AtomInteger.of(0); if (Objects.nonNull(env.getProperty("postgres.db.host"))) { - ChatCompletionLog chatCompletionLog = new ChatCompletionLog(); - chatCompletionLog.setName(openAiEndpoint.getChainName()); - chatCompletionLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); - chatCompletionLog.setInput(StringUtils.join(openAiEndpoint.getChatMessages())); - chatCompletionLog.setModel(openAiEndpoint.getModel()); - chatCompletionLog.setCreatedAt(LocalDateTime.now()); + ChatCompletionLog chatLog = new ChatCompletionLog(); + chatLog.setName(openAiEndpoint.getChainName()); + chatLog.setCreatedAt(LocalDateTime.now()); + chatLog.setCallIdentifier(openAiEndpoint.getCallIdentifier()); + chatLog.setInput(StringUtils.join(chatCompletionRequest.getMessages())); + chatLog.setModel(chatCompletionRequest.getModel()); + + chatLog.setPresencePenalty(chatCompletionRequest.getPresencePenalty()); + chatLog.setFrequencyPenalty(chatCompletionRequest.getFrequencyPenalty()); + chatLog.setTopP(chatCompletionRequest.getTopP()); + chatLog.setN(chatCompletionRequest.getN()); + chatLog.setTemperature(chatCompletionRequest.getTemperature()); StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append("<|im_start|>"); @@ -139,12 +175,7 @@ public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoin EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); - List strings = new ArrayList<>(); - for (ChatMessage chatMessage : openAiEndpoint.getChatMessages()) { - strings.add(chatMessage.getContent()); - } - - chatCompletionLog.setPromptTokens((long) enc.countTokens(stringBuilder.toString())); + chatLog.setPromptTokens((long) enc.countTokens(stringBuilder.toString())); StringBuilder content = new StringBuilder(); @@ -162,20 +193,30 @@ public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoin if (Objects.nonNull(res.getChoices().get(0).getFinishReason())) { emitter.complete(); - - chatCompletionLog.setType(res.getObject()); - chatCompletionLog.setContent(content.toString()); - chatCompletionLog.setCompletedAt(LocalDateTime.now()); - chatCompletionLog.setTotalTokens( - chunks.get() + chatCompletionLog.getPromptTokens()); + chatLog.setType(res.getObject()); + chatLog.setContent(content.toString()); + chatLog.setCompletedAt(LocalDateTime.now()); + chatLog.setTotalTokens(chunks.get() + chatLog.getPromptTokens()); Duration duration = - Duration.between( - chatCompletionLog.getCreatedAt(), - chatCompletionLog.getCompletedAt()); - chatCompletionLog.setLatency(duration.toMillis()); - - chatCompletionLogService.saveOrUpdate(chatCompletionLog); + Duration.between(chatLog.getCreatedAt(), chatLog.getCompletedAt()); + chatLog.setLatency(duration.toMillis()); + + chatCompletionLogService.saveOrUpdate(chatLog); + + if (Objects.nonNull(openAiEndpoint.getJsonnetLoader()) + && openAiEndpoint.getJsonnetLoader().getThreshold() >= 1) { + JsonnetLog jsonnetLog = new JsonnetLog(); + jsonnetLog.setMetadata(openAiEndpoint.getJsonnetLoader().getMetadata()); + jsonnetLog.setContent(content.toString()); + jsonnetLog.setF1(openAiEndpoint.getJsonnetLoader().getF1()); + jsonnetLog.setF2(openAiEndpoint.getJsonnetLoader().getF2()); + jsonnetLog.setSplitSize(openAiEndpoint.getJsonnetLoader().getSplitSize()); + jsonnetLog.setCreatedAt(LocalDateTime.now()); + jsonnetLog.setSelectedFile( + openAiEndpoint.getJsonnetLoader().getSelectedFile()); + jsonnetLogService.saveOrUpdate(jsonnetLog); + } } } catch (final Exception e) { @@ -209,7 +250,7 @@ public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoin } @PostMapping("/completion") - public Single completion(@RequestBody OpenAiEndpoint openAiEndpoint) { + public Single completion(@RequestBody OpenAiChatEndpoint openAiEndpoint) { CompletionRequest completionRequest = CompletionRequest.builder() @@ -218,22 +259,20 @@ public Single completion(@RequestBody OpenAiEndpoint openAiE .temperature(openAiEndpoint.getTemperature()) .build(); - this.openAiClient.setEndpoint(openAiEndpoint); - - EdgeChain edgeChain = openAiClient.createCompletion(completionRequest); + EdgeChain edgeChain = + openAiClient.createCompletion(completionRequest, openAiEndpoint); return edgeChain.toSingle(); } @PostMapping("/embeddings") - public Single embeddings(@RequestBody OpenAiEndpoint openAiEndpoint) - throws SQLException { - - this.openAiClient.setEndpoint(openAiEndpoint); + public Single embeddings( + @RequestBody OpenAiEmbeddingEndpoint openAiEndpoint) throws SQLException { EdgeChain edgeChain = openAiClient.createEmbeddings( - new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getInput())); + new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()), + openAiEndpoint); if (Objects.nonNull(env.getProperty("postgres.db.host"))) { @@ -255,9 +294,9 @@ public Single embeddings(@RequestBody OpenAiEndpoint op embeddingLogService.saveOrUpdate(embeddingLog); }) - .toSingle(); + .toSingleWithoutScheduler(); } - return edgeChain.toSingle(); + return edgeChain.toSingleWithoutScheduler(); } } diff --git a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java index f7ef9660b..f1d394e8b 100644 --- a/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java +++ b/Java/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/wiki/WikiController.java @@ -1,7 +1,7 @@ package com.edgechain.service.controllers.wiki; import com.edgechain.lib.configuration.WebConfiguration; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import com.edgechain.lib.wiki.client.WikiClient; import com.edgechain.lib.wiki.response.WikiResponse; diff --git a/Java/FlySpring/edgechain-app/src/main/resources/application-test.properties b/Java/FlySpring/edgechain-app/src/main/resources/application-test.properties new file mode 100644 index 000000000..311c86783 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/resources/application-test.properties @@ -0,0 +1,5 @@ +# this profile will be used only when @ActiveProfiles("test") + +# uncomment these log level lines to follow what Spring is doing for ROLES +# logging.level.org.springframework.security=TRACE +# logging.level.org.springframework.security.web.FilterChainProxy=INFO diff --git a/Java/FlySpring/edgechain-app/src/main/resources/schema.sql b/Java/FlySpring/edgechain-app/src/main/resources/schema.sql new file mode 100644 index 000000000..a4fed6928 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/main/resources/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE history_context ( + id VARCHAR(255) NOT NULL PRIMARY KEY, + response VARCHAR(1024), + created_at TIMESTAMP +); \ No newline at end of file diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java new file mode 100644 index 000000000..978b58089 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/EdgeChainApplicationTest.java @@ -0,0 +1,13 @@ +package com.edgechain; + +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.context.SpringBootTest; + +@SpringBootTest +class EdgeChainApplicationTest { + + @Test + void test() { + // simply checks server can start + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/chunker/ChunkerTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/chunker/ChunkerTest.java index e6edef4d7..08f6df3e0 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/chunker/ChunkerTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/chunker/ChunkerTest.java @@ -1,10 +1,9 @@ package com.edgechain.chunker; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; - import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.Timeout; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.test.context.SpringBootTest; @@ -12,6 +11,11 @@ import com.edgechain.lib.chunk.Chunker; import com.edgechain.lib.chunk.enums.LangType; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; + @SpringBootTest public class ChunkerTest { @@ -82,4 +86,61 @@ public void chunker_BySentenceWithEmptyInput_ReturnEmptyArray(TestInfo testInfo) String[] expected = {"This is a test."}; assertArrayEquals(expected, result); } + + @Test + @DisplayName("Test By Very Small ChunkSize ") + void chunker_ByVerySmallChunkSize_ReturnedExpectedValue() { + String input = "This is Testing"; + Chunker chunker = new Chunker(input); + + String[] result = chunker.byChunkSize(1); + + String[] expected = {"T", "h", "i", "s", "i", "s", "T", "e", "s", "t", "i", "n", "g"}; + assertNotEquals(expected, result); + } + + @Test + @DisplayName("Test By ChunkSize - Input Contains Whitespace") + void chunker_ByChunkSize_InputWhiteSpaceCharacter_ReturnedExpectedValue() { + String input = "\n\t\t"; + Chunker chunker = new Chunker(input); + int chunkSize = 5; + + String[] result = chunker.byChunkSize(chunkSize); + + String[] expected = {""}; + assertArrayEquals(expected, result); + } + + @Test + @DisplayName("Test By Sentence - Contains Only Spaces") + void chunker_BySentence_InputContainsOnlySpaces_ReturnedExpectedValue() { + String input = " "; + Chunker chunker = new Chunker(input); + + String[] result = chunker.bySentence(LangType.EN); + logger.info(Arrays.toString(result)); + String[] expected = {}; + assertArrayEquals(expected, result); + assertEquals(expected.length, result.length); + } + + @Test + @DisplayName("Performance Test With Large String") + @Timeout(value = 5, unit = TimeUnit.SECONDS) + void chunker_Performance_LargeInputString_ReturnedExpectedValue() { + String input = "E".repeat(10000); + Chunker chunker = new Chunker(input); + int chunkSize = 5; + + long startTime = System.currentTimeMillis(); + String[] result = chunker.byChunkSize(chunkSize); + long endTime = System.currentTimeMillis(); + long totalExecutionTime = endTime - startTime; + logger.info(String.valueOf(totalExecutionTime)); + + long maxExecutionTime = 5000; // Execution time in mills + assertEquals(2000, result.length); + assertTrue(totalExecutionTime <= maxExecutionTime); + } } diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/codeInterpreter/CodeInterpreterTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/codeInterpreter/CodeInterpreterTest.java index e4d0527aa..ff6d22dde 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/codeInterpreter/CodeInterpreterTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/codeInterpreter/CodeInterpreterTest.java @@ -1,9 +1,7 @@ package com.edgechain.codeInterpreter; import static org.junit.Assert.assertFalse; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -65,4 +63,14 @@ public void test_empty_example_extraction() throws Exception { assertNotNull(extractedValue); assertFalse(extractedValue.contains(prompt)); } + + @Test + @DisplayName("Test for empty input") + void test_emptyInput_ReturnedExpectedValue() { + String inputJsonnet = ""; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + + assertThrows(Exception.class, () -> jsonnetLoader.load(inputStream)); + } } diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/jsonnet/JsonnetLoaderTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/jsonnet/JsonnetLoaderTest.java index 5dcdcbfc8..12e6b77fa 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/jsonnet/JsonnetLoaderTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/jsonnet/JsonnetLoaderTest.java @@ -1,9 +1,5 @@ package com.edgechain.jsonnet; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -20,6 +16,9 @@ import com.edgechain.lib.jsonnet.JsonnetLoader; import com.edgechain.lib.jsonnet.impl.FileJsonnetLoader; +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; + @SpringBootTest public class JsonnetLoaderTest { @@ -106,4 +105,36 @@ public void test_external_variable_xtrasonnet() throws Exception { assertNotNull(externalVar); assertEquals(externalVar, "5"); } + + @Test + void jsonLoader_LoadJsonnet_WithInvalidJsonnet_ThrowsException() { + String inputJsonnet = "This is a test sentence."; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + assertThrows(Exception.class, () -> jsonnetLoader.load(inputStream)); + } + + @Test + void jsonLoader_LoadJsonnet_WithEmptyJsonnet_ThrowsExpcetion() { + String inputJsonnet = "{}"; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + jsonnetLoader.load(inputStream); + assertThrows(Exception.class, () -> jsonnetLoader.get("jsonnet")); + } + + @Test + void jsonLoader_LoadJsonnetWithArrayOfObjects_ReturnExpectedValue(TestInfo testInfo) { + String inputJsonnet = "{ \"objects\": [{ \"key\": \"value1\" }, { \"key\": \"value2\" }] }"; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + + jsonnetLoader.load(inputStream); + JSONArray objects = jsonnetLoader.getArray("objects"); + + assertNotNull(objects); + assertEquals(2, objects.length()); + assertEquals("value1", objects.getJSONObject(0).getString("key")); + assertEquals("value2", objects.getJSONObject(1).getString("key")); + } } diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClientTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClientTest.java new file mode 100644 index 000000000..5e4228e33 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/context/client/impl/PostgreSQLHistoryContextClientTest.java @@ -0,0 +1,91 @@ +package com.edgechain.lib.context.client.impl; + +import com.edgechain.lib.context.domain.HistoryContext; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import com.edgechain.testutil.PostgresTestContainer; +import com.edgechain.testutil.PostgresTestContainer.PostgresImage; +import com.zaxxer.hikari.HikariConfig; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.test.annotation.DirtiesContext; +import org.testcontainers.junit.jupiter.Testcontainers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +@Testcontainers(disabledWithoutDocker = true) +@SpringBootTest(webEnvironment = WebEnvironment.NONE) +@DirtiesContext +class PostgreSQLHistoryContextClientTest { + + private static final Logger LOGGER = + LoggerFactory.getLogger(PostgreSQLHistoryContextClientTest.class); + + private static PostgresTestContainer instance = new PostgresTestContainer(PostgresImage.PLAIN); + + @BeforeAll + static void baseSetupAll() { + instance.start(); + } + + @AfterAll + static void baseTeardownAll() { + instance.stop(); + } + + @Autowired private HikariConfig hikariConfig; + @Autowired private PostgreSQLHistoryContextClient service; + + @Test + void allMethods() { + // hikari has own copy of properties so set these here + hikariConfig.setJdbcUrl(instance.getJdbcUrl()); + hikariConfig.setUsername(instance.getUsername()); + hikariConfig.setPassword(instance.getPassword()); + + final Data data = new Data(); + + final EdgeChain create = service.create("DAVE", null); + create.toSingle().blockingSubscribe(s -> data.id = s.getId(), e -> data.failed = true); + assertFalse(data.failed); + assertNotNull(data.id); + LOGGER.info("create OK id={}", data.id); + + final EdgeChain put = service.put(data.id, "COW", null); + put.toSingle().blockingSubscribe(s -> {}, e -> data.failed = true); + assertFalse(data.failed); + LOGGER.info("put OK"); + + final EdgeChain get = service.get(data.id, null); + get.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.failed = true); + assertFalse(data.failed); + assertEquals("COW", data.val); + LOGGER.info("get OK val={}", data.val); + + EdgeChain delete = service.delete(data.id, null); + delete.toSingle().blockingSubscribe(s -> data.val = s, e -> data.failed = true); + assertFalse(data.failed); + assertEquals("", data.val); + LOGGER.info("delete OK val={}", data.val); + + final EdgeChain getMissing = service.get("not_there", null); + getMissing + .toSingle() + .blockingSubscribe(s -> data.failed = true, e -> data.val = e.getMessage()); + assertFalse(data.failed); + assertEquals("PostgreSQL history_context id isn't found.", data.val); + LOGGER.info("get-NotFound OK val={}", data.val); + } + + private static class Data { + public boolean failed; + public String id; + public String val; + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/context/client/repositories/PostgreSQLHistoryContextRepositoryTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/context/client/repositories/PostgreSQLHistoryContextRepositoryTest.java new file mode 100644 index 000000000..88801627b --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/context/client/repositories/PostgreSQLHistoryContextRepositoryTest.java @@ -0,0 +1,85 @@ +package com.edgechain.lib.context.client.repositories; + +import com.edgechain.lib.context.domain.HistoryContext; +import com.edgechain.testutil.PostgresTestContainer; +import org.junit.jupiter.api.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase; +import org.springframework.boot.test.autoconfigure.orm.jpa.DataJpaTest; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; +import org.springframework.test.context.jdbc.Sql; + +import java.time.LocalDateTime; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.*; + +@DataJpaTest +@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE) +@Sql(scripts = {"classpath:schema.sql"}) +class PostgreSQLHistoryContextRepositoryTest { + + Logger logger = LoggerFactory.getLogger(getClass()); + + @Autowired private PostgreSQLHistoryContextRepository repository; + + private static final PostgresTestContainer instance = + new PostgresTestContainer(PostgresTestContainer.PostgresImage.VECTOR); + + @BeforeAll + static void setupAll() { + instance.start(); + } + + @AfterAll + static void tearAll() { + instance.stop(); + } + + @BeforeEach + void setUp() { + repository.deleteAll(); + } + + @DynamicPropertySource + static void setProperties(DynamicPropertyRegistry registry) { + registry.add("spring.datasource.url", instance::getJdbcUrl); + registry.add("spring.datasource.username", instance::getUsername); + registry.add("spring.datasource.password", instance::getPassword); + } + + @Test + void test_Save_And_Retrieve_History_Context() { + HistoryContext historyContext = getHistoryContext(); + repository.save(historyContext); + + Optional result = repository.findById("1"); + logger.info("history context {}", result); + + assertTrue(result.isPresent()); + } + + @Test + void test_Delete_History_Context() { + HistoryContext historyContext = getHistoryContext(); + repository.save(historyContext); + + repository.deleteById("1"); + Optional result = repository.findById("1"); + + assertTrue(result.isEmpty()); + } + + @Test + void test_Find_By_Non_Exist_Context() { + Optional result = repository.findById("10"); + assertTrue(result.isEmpty()); + } + + private HistoryContext getHistoryContext() { + return new HistoryContext("1", "testing history context", LocalDateTime.now()); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java new file mode 100644 index 000000000..b7cae1606 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/endpoint/impl/BgeSmallEndpointTest.java @@ -0,0 +1,60 @@ +package com.edgechain.lib.endpoint.impl; + +import com.edgechain.lib.configuration.domain.SecurityUUID; +import com.edgechain.lib.endpoint.impl.embeddings.BgeSmallEndpoint; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import java.io.File; + +import org.junit.jupiter.api.Test; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.util.ReflectionTestUtils; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +class BgeSmallEndpointTest { + + @Test + @DirtiesContext + void downloadFiles() { + // Retrofit needs a port + System.setProperty("server.port", "8888"); + + // give Retrofit a mock securityUUI instance so it goes not call context + SecurityUUID mockSecurityUUID = mock(SecurityUUID.class); + ReflectionTestUtils.setField(RetrofitClientInstance.class, "securityUUID", mockSecurityUUID); + + try { + // GIVEN we have no local files + deleteFiles(); + + // WHEN we create the endpoint instance + // (get tiny JSON files as example download data) + new BgeSmallEndpoint( + "https://jsonplaceholder.typicode.com/posts/1", + "https://jsonplaceholder.typicode.com/posts/2"); + + // THEN the files now exist + File modelFile = new File(BgeSmallEndpoint.MODEL_PATH); + assertTrue(modelFile.exists()); + + File tokenizerFile = new File(BgeSmallEndpoint.TOKENIZER_PATH); + assertTrue(tokenizerFile.exists()); + } finally { + // reset the Retrofit instance + ReflectionTestUtils.setField(RetrofitClientInstance.class, "securityUUID", null); + ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); + + deleteFiles(); // make sure we clean up files afterwards + } + } + + // === HELPER METHODS === + + private static void deleteFiles() { + File modelFile = new File(BgeSmallEndpoint.MODEL_PATH); + modelFile.delete(); + + File tokenizerFile = new File(BgeSmallEndpoint.TOKENIZER_PATH); + tokenizerFile.delete(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarterTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarterTest.java new file mode 100644 index 000000000..b74c11936 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/flyfly/commands/run/TestContainersStarterTest.java @@ -0,0 +1,105 @@ +package com.edgechain.lib.flyfly.commands.run; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.DockerClientFactory; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +class TestContainersStarterTest { + + private static final Logger LOGGER = LoggerFactory.getLogger(TestContainersStarterTest.class); + + private TestContainersStarter starter; + + private String tempFilename; + + @BeforeEach + void setup() { + tempFilename = org.assertj.core.util.Files.newTemporaryFile().getAbsolutePath(); + + starter = new TestContainersStarter(); + starter.setPropertiesPath(tempFilename); + } + + @AfterEach + void teardown() { + new File(tempFilename).delete(); + } + + @Test + void addTempProperties() { + try { + starter.addTempProperties("DAVE"); + + List lines = + org.assertj.core.util.Files.linesOf(new File(tempFilename), StandardCharsets.UTF_8); + assertEquals(TestContainersStarter.FLYFLYTEMPTAG, lines.get(1)); + assertTrue(lines.get(2).contains("=DAVE")); + assertEquals(TestContainersStarter.FLYFLYTEMPTAG, lines.get(3)); + assertEquals(TestContainersStarter.FLYFLYTEMPTAG, lines.get(5)); + } catch (IOException e) { + fail("could not finish test", e); + } + } + + @Test + void addTempAndThenRemove() { + try { + Files.writeString(Paths.get(tempFilename), "FIRST LINE\n", StandardCharsets.UTF_8); + + assertTrue(starter.isServiceNeeded()); + + starter.addTempProperties("DAVE"); + assertFalse(starter.isServiceNeeded()); + + starter.removeTempProperties(); + assertTrue(starter.isServiceNeeded()); + + String result = Files.readString(Paths.get(tempFilename), StandardCharsets.UTF_8); + assertEquals("FIRST LINE\n\n", result); // addTemp.. adds a blank line at the start + } catch (IOException e) { + fail("could not finish test", e); + } + } + + @Test + void startAndStopContainer() { + if (!isDockerAvailable()) { + LOGGER.warn("Docker is not running - test skipped"); + return; + } + + try { + starter.startPostgreSQL(); + } catch (IOException e) { + fail("not able to start PostgreSQL", e); + } finally { + try { + starter.stopPostgreSQL(); + } catch (IOException e2) { + fail("failed to stop PostgreSQL", e2); + } + } + } + + boolean isDockerAvailable() { + try { + DockerClientFactory.instance().client(); + return true; + } catch (Throwable ex) { + return false; + } + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java new file mode 100644 index 000000000..5178bfd63 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/index/client/impl/PostgresClientTest.java @@ -0,0 +1,387 @@ +package com.edgechain.lib.index.client.impl; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.index.domain.PostgresWordEmbeddings; +import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.enums.PostgresLanguage; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import com.edgechain.testutil.PostgresTestContainer; +import com.edgechain.testutil.PostgresTestContainer.PostgresImage; +import com.zaxxer.hikari.HikariConfig; +import java.util.List; +import java.util.stream.Collectors; + +import io.reactivex.rxjava3.observers.TestObserver; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.test.annotation.DirtiesContext; +import org.testcontainers.junit.jupiter.Testcontainers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@Testcontainers(disabledWithoutDocker = true) +@SpringBootTest(webEnvironment = WebEnvironment.NONE) +@DirtiesContext +class PostgresClientTest { + + private static final Logger LOGGER = LoggerFactory.getLogger(PostgresClientTest.class); + + private static final float FLOAT_ERROR_MARGIN = 0.0001f; + + private static PostgresTestContainer instance = new PostgresTestContainer(PostgresImage.VECTOR); + + @BeforeAll + static void setupAll() { + instance.start(); + } + + @AfterAll + static void teardownAll() { + instance.stop(); + } + + @Autowired private HikariConfig hikariConfig; + + @Autowired private PostgresClient service; + + @Test + void allMethods() { + // hikari has own copy of properties so set these here + hikariConfig.setJdbcUrl(instance.getJdbcUrl()); + hikariConfig.setUsername(instance.getUsername()); + hikariConfig.setPassword(instance.getPassword()); + + createTable(); + createMetadataTable(); + + deleteAll(); // check delete before we get foreign keys + + String uuid1 = upsert(); + batchUpsert(); + + query_noMeta(); + + String uuid2 = insertMetadata(); + + batchInsertMetadata(); + insertIntoJoinTable(uuid1, uuid2); + + query_meta(); + getChunks(); + getSimilarChunks(); + } + + private void createTable() { + createTable_metric(PostgresDistanceMetric.COSINE, "t_embedding"); + } + + private void createTable_metric(PostgresDistanceMetric metric, String tableName) { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn(tableName); + when(mockPe.getLists()).thenReturn(1); + when(mockPe.getDimensions()).thenReturn(2); + when(mockPe.getMetric()).thenReturn(metric); + + TestObserver test = service.createTable(mockPe).getObservable().test(); + + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + test.assertNoErrors(); + LOGGER.info("createTable (metric={}) response: '{}'", metric, tableName); + } + + private void createMetadataTable() { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + + TestObserver test = service.createMetadataTable(mockPe).getObservable().test(); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + LOGGER.info("createMetadataTable response: '{}'", test.values().get(0).getResponse()); + } + + private String upsert() { + WordEmbeddings we = new WordEmbeddings(); + we.setId("WE1"); + we.setScore(0.86914713); + we.setValues(List.of(0.25f, 0.5f)); + + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getWordEmbedding()).thenReturn(we); + when(mockPe.getFilename()).thenReturn("readme.pdf"); + when(mockPe.getNamespace()).thenReturn("testns"); + when(mockPe.getPostgresLanguage()).thenReturn(PostgresLanguage.ENGLISH); + + TestObserver test = service.upsert(mockPe).getObservable().test(); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + test.assertNoErrors(); + + return test.values().get(0).getResponse(); + } + + private String insertMetadata() { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + when(mockPe.getMetadata()).thenReturn("This is a sample text"); + when(mockPe.getDocumentDate()).thenReturn("November 11, 2015"); + + TestObserver test = service.insertMetadata(mockPe).getObservable().test(); + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + test.assertNoErrors(); + return test.values().get(0).getResponse(); + } + + private void batchUpsert() { + WordEmbeddings we1 = new WordEmbeddings(); + we1.setId("WE1"); + we1.setScore(1.05689); + we1.setValues(List.of(0.25f, 0.5f)); + + WordEmbeddings we2 = new WordEmbeddings(); + we2.setId("WE2"); + we2.setScore(2.02689); + we2.setValues(List.of(0.75f, 0.9f)); + + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getWordEmbeddingsList()).thenReturn(List.of(we1, we2)); + when(mockPe.getFilename()).thenReturn("readme.pdf"); + when(mockPe.getNamespace()).thenReturn("testns"); + when(mockPe.getPostgresLanguage()).thenReturn(PostgresLanguage.ENGLISH); + + final Data data = new Data(); + EdgeChain> result = service.batchUpsert(mockPe); + result + .toSingle() + .blockingSubscribe( + s -> data.val = s.stream().map(r -> r.getResponse()).collect(Collectors.joining(",")), + e -> data.error = e); + if (data.error != null) { + fail("batchUpsert failed", data.error); + } + LOGGER.info("batchUpsert response: '{}'", data.val); + } + + private void batchInsertMetadata() { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + when(mockPe.getMetadataList()).thenReturn(List.of("text1", "text2")); + + final Data data = new Data(); + EdgeChain> result = service.batchInsertMetadata(mockPe); + result + .toSingle() + .blockingSubscribe( + s -> data.val = s.stream().map(r -> r.getResponse()).collect(Collectors.joining(",")), + e -> data.error = e); + if (data.error != null) { + fail("batchInsertMetadata failed", data.error); + } + LOGGER.info("batchInsertMetadata response: '{}'", data.val); + } + + private void insertIntoJoinTable(String uuid1, String uuid2) { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + when(mockPe.getId()).thenReturn(uuid1); + when(mockPe.getMetadataId()).thenReturn(uuid2); + + TestObserver test = service.insertIntoJoinTable(mockPe).getObservable().test(); + + try { + test.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + test.assertNoErrors(); + } + + private void deleteAll() { + deleteAll_namespace(null, "knowledge"); + deleteAll_namespace("", "knowledge"); + deleteAll_namespace("testns", "testns"); + } + + private void deleteAll_namespace(String namespace, String expected) { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getNamespace()).thenReturn(namespace); + + final Data data = new Data(); + EdgeChain result = service.deleteAll(mockPe); + result.toSingle().blockingSubscribe(s -> data.val = s.getResponse(), e -> data.error = e); + if (data.error != null) { + fail("deleteAll failed", data.error); + } + LOGGER.info("deleteAll (namespace={}) response: '{}'", namespace, data.val); + assertTrue(data.val.endsWith(expected)); + } + + private void query_noMeta() { + query_noMeta_metric(PostgresDistanceMetric.COSINE); + query_noMeta_metric(PostgresDistanceMetric.IP); + query_noMeta_metric(PostgresDistanceMetric.L2); + } + + private void query_noMeta_metric(PostgresDistanceMetric metric) { + WordEmbeddings we1 = new WordEmbeddings(); + we1.setId("WEQUERY"); + we1.setScore(1.05589); + we1.setValues(List.of(0.25f, 0.5f)); + + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getNamespace()).thenReturn("testns"); + when(mockPe.getProbes()).thenReturn(5); + when(mockPe.getMetric()).thenReturn(metric); + when(mockPe.getWordEmbeddingsList()).thenReturn(List.of(we1)); + when(mockPe.getTopK()).thenReturn(5); + when(mockPe.getUpperLimit()).thenReturn(5); + when(mockPe.getMetadataTableNames()).thenReturn(null); + + final Data data = new Data(); + EdgeChain> result = service.query(mockPe); + result + .toSingle() + .blockingSubscribe( + s -> data.val = s.stream().map(r -> r.getRawText()).collect(Collectors.joining(",")), + e -> data.error = e); + if (data.error != null) { + fail("query (no meta) failed", data.error); + } + LOGGER.info("query no meta (metric={}) response: '{}'", metric, data.val); + + // WE1 from single upsert, and WE2 from batch upsert + assertTrue(data.val.contains("WE1") && data.val.contains("WE2")); + } + + private void query_meta() { + query_meta_metric(PostgresDistanceMetric.COSINE); + query_meta_metric(PostgresDistanceMetric.IP); + query_meta_metric(PostgresDistanceMetric.L2); + } + + private void query_meta_metric(PostgresDistanceMetric metric) { + WordEmbeddings we1 = new WordEmbeddings(); + we1.setId("WEQUERY"); + we1.setScore(1.258); + we1.setValues(List.of(0.25f, 0.5f)); + + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getNamespace()).thenReturn("testns"); + when(mockPe.getProbes()).thenReturn(20); + when(mockPe.getMetric()).thenReturn(metric); + when(mockPe.getWordEmbedding()).thenReturn(we1); + when(mockPe.getTopK()).thenReturn(5); + when(mockPe.getUpperLimit()).thenReturn(5); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + + final Data data = new Data(); + EdgeChain> result = service.queryWithMetadata(mockPe); + result + .toSingle() + .blockingSubscribe( + s -> data.val = s.stream().map(r -> r.getRawText()).collect(Collectors.joining(",")), + e -> data.error = e); + if (data.error != null) { + fail("query (meta) failed", data.error); + } + LOGGER.info("query with meta (metric={}) response: '{}'", metric, data.val); + + // WE1 from single joined upsert + assertTrue(data.val.contains("WE1")); + } + + private void getChunks() { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getFilename()).thenReturn("readme.pdf"); + + final Data data = new Data(); + EdgeChain> result = service.getAllChunks(mockPe); + result + .toSingle() + .blockingSubscribe( + s -> { + data.list = s; + data.val = s.stream().map(r -> r.getRawText()).collect(Collectors.joining(",")); + }, + e -> data.error = e); + if (data.error != null) { + fail("getChunks failed", data.error); + } + LOGGER.info("getChunks response: '{}'", data.val); + + // WE1 from single upsert, and WE2 from batch upsert + assertTrue(data.val.contains("WE1") && data.val.contains("WE2")); + + PostgresWordEmbeddings first = data.list.get(0); + assertEquals(0.25f, first.getValues().get(0), FLOAT_ERROR_MARGIN); + assertEquals(0.5f, first.getValues().get(1), FLOAT_ERROR_MARGIN); + + PostgresWordEmbeddings second = data.list.get(1); + assertEquals(0.75f, second.getValues().get(0), FLOAT_ERROR_MARGIN); + assertEquals(0.9f, second.getValues().get(1), FLOAT_ERROR_MARGIN); + } + + private void getSimilarChunks() { + PostgresEndpoint mockPe = mock(PostgresEndpoint.class); + when(mockPe.getTableName()).thenReturn("t_embedding"); + when(mockPe.getMetadataTableNames()).thenReturn(List.of("title_metadata")); + when(mockPe.getEmbeddingChunk()).thenReturn("how to test this"); + + final Data data = new Data(); + EdgeChain> result = service.getSimilarMetadataChunk(mockPe); + result + .toSingle() + .blockingSubscribe( + s -> { + data.list = s; + data.val = s.stream().map(r -> r.getRawText()).collect(Collectors.joining(",")); + }, + e -> data.error = e); + if (data.error != null) { + fail("getSimilarMetadataChunk failed", data.error); + } + LOGGER.info("getSimilarMetadataChunk response: '{}'", data.val); + } + + private static class Data { + public Throwable error; + public String val; + public List list; + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/logger/services/EmbeddingLogServiceTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/logger/services/EmbeddingLogServiceTest.java new file mode 100644 index 000000000..37054dd1e --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/logger/services/EmbeddingLogServiceTest.java @@ -0,0 +1,67 @@ +package com.edgechain.lib.logger.services; + +import org.junit.jupiter.api.Test; +import org.postgresql.ds.PGSimpleDataSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.util.ReflectionTestUtils; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import static org.junit.jupiter.api.Assertions.fail; + +@Testcontainers(disabledWithoutDocker = true) +class EmbeddingLogServiceTest { + + private static PostgresTestContainer instance = new PostgresTestContainer(); + + @Test + void test() { + instance.start(); + try { + // create datasource and template using Docker properties + final PGSimpleDataSource datasource = new PGSimpleDataSource(); + datasource.setUrl(instance.getJdbcUrl()); + datasource.setUser(instance.getUsername()); + datasource.setPassword(instance.getPassword()); + + final JdbcTemplate template = new JdbcTemplate(datasource); + + // create service using template + final EmbeddingLogService service = new EmbeddingLogService(); + ReflectionTestUtils.setField(service, "jdbcTemplate", template); + + service.createTable(); + + } catch (Exception e) { + fail("could not create table", e); + + } finally { + instance.stop(); + } + } + + public static class PostgresTestContainer extends PostgreSQLContainer { + + private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestContainer.class); + + private static final String DOCKER_IMAGE = + PostgreSQLContainer.IMAGE + ":" + PostgreSQLContainer.DEFAULT_TAG; + + public PostgresTestContainer() { + super(DOCKER_IMAGE); + } + + @Override + public void start() { + LOGGER.info("starting container"); + super.start(); + } + + @Override + public void stop() { + LOGGER.info("stopping container"); + super.stop(); + } + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityConfigFixTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityConfigFixTest.java new file mode 100644 index 000000000..01d0ba63d --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityConfigFixTest.java @@ -0,0 +1,119 @@ +package com.edgechain.lib.supabase.security; + +import com.edgechain.lib.configuration.domain.AuthFilter; +import com.edgechain.lib.configuration.domain.MethodAuthentication; +import com.edgechain.testutil.TestJwtCreator; +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) +@AutoConfigureMockMvc +@ContextConfiguration(name = "contextWithEmptyPatternsAuthFilter") +class WebSecurityConfigFixTest { + + // ====== TEST JWT-BASED SECURITY WITH NO ROLES (bearer token must be in header) ====== + + private static final String FULL_NONCONTEXT_PATH = "/v0/endpoint/"; + + @BeforeAll + static void setupAll() { + System.setProperty("jwt.secret", "edge-chain-unit-test-jwt-secret"); + } + + @TestConfiguration + public static class AuthFilterTestConfig { + @Bean + AuthFilter authFilter() { + // provide an AuthFilter with an empty string in the patterns list. + // the security class should fix this to ** + AuthFilter auth = new AuthFilter(); + auth.setRequestGet(new MethodAuthentication(List.of(""), "")); + auth.setRequestDelete(new MethodAuthentication(List.of(""), "")); + auth.setRequestPatch(new MethodAuthentication(List.of(""), "")); + auth.setRequestPost(new MethodAuthentication(List.of(""), "")); + auth.setRequestPut(new MethodAuthentication(List.of(""), "")); + return auth; + } + } + + @Autowired private MockMvc mvc; + + @Autowired private JwtHelper jwtHelper; + + @Test + void validateJwt() { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + assertTrue(jwtHelper.validate(jwt)); + } + + @Test + void getEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + get(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void postEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + post(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void deleteEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + delete(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void patchEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + patch(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void putEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + put(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityContextPathTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityContextPathTest.java new file mode 100644 index 000000000..9b6482670 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityContextPathTest.java @@ -0,0 +1,79 @@ +package com.edgechain.lib.supabase.security; + +import com.edgechain.lib.configuration.WebConfiguration; +import com.edgechain.lib.configuration.domain.SecurityUUID; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) +@AutoConfigureMockMvc +class WebSecurityContextPathTest { + + // ====== TEST CONTEXT-BASED SECURITY (uuid must be in header) ====== + + private static final String FULL_CONTEXT_PATH = WebConfiguration.CONTEXT_PATH + "/"; + + @Autowired private MockMvc mvc; + + @Autowired private SecurityUUID securityUUID; + + @Test + void getContextEndpoint() throws Exception { + mvc.perform( + get(FULL_CONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", securityUUID.getAuthKey())) + .andExpect(status().isNotFound()); + } + + @Test + void postContextEndpoint() throws Exception { + mvc.perform( + post(FULL_CONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", securityUUID.getAuthKey())) + .andExpect(status().isNotFound()); + } + + @Test + void deleteContextEndpoint() throws Exception { + mvc.perform( + delete(FULL_CONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", securityUUID.getAuthKey())) + .andExpect(status().isNotFound()); + } + + @Test + void patchContextEndpoint() throws Exception { + mvc.perform( + patch(FULL_CONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", securityUUID.getAuthKey())) + .andExpect(status().isNotFound()); + } + + @Test + void putContextEndpoint() throws Exception { + mvc.perform( + put(FULL_CONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", securityUUID.getAuthKey())) + .andExpect(status().isNotFound()); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityJwtNoRoleTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityJwtNoRoleTest.java new file mode 100644 index 000000000..6d401409e --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityJwtNoRoleTest.java @@ -0,0 +1,96 @@ +package com.edgechain.lib.supabase.security; + +import com.edgechain.testutil.TestJwtCreator; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) +@AutoConfigureMockMvc +class WebSecurityJwtNoRoleTest { + + // ====== TEST JWT-BASED SECURITY WITH NO ROLES (bearer token must be in header) ====== + + private static final String FULL_NONCONTEXT_PATH = "/v0/endpoint/"; + + @BeforeAll + static void setupAll() { + System.setProperty("jwt.secret", "edge-chain-unit-test-jwt-secret"); + } + + @Autowired private MockMvc mvc; + + @Autowired private JwtHelper jwtHelper; + + @Test + void validateJwt() { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + assertTrue(jwtHelper.validate(jwt)); + } + + @Test + void getEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + get(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void postEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + post(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void deleteEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + delete(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void patchEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + patch(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void putEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_IGNORED"); + mvc.perform( + put(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityJwtWithRoleTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityJwtWithRoleTest.java new file mode 100644 index 000000000..81da385eb --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/lib/supabase/security/WebSecurityJwtWithRoleTest.java @@ -0,0 +1,173 @@ +package com.edgechain.lib.supabase.security; + +import com.edgechain.lib.configuration.domain.AuthFilter; +import com.edgechain.lib.configuration.domain.MethodAuthentication; +import com.edgechain.testutil.TestJwtCreator; +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@ActiveProfiles("test") +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) +@AutoConfigureMockMvc +@ContextConfiguration(name = "contextWithTestRolesAuthFilter") +class WebSecurityJwtWithRoleTest { + + // ====== TEST JWT-BASED SECURITY WITH ROLES (bearer token must be in header) ====== + + private static final String FULL_NONCONTEXT_PATH = "/v0/endpoint/"; + + @TestConfiguration + public static class AuthFilterTestConfig { + @Bean + AuthFilter authFilter() { + // provide an AuthFilter with roles to check we test the correct role for each method + AuthFilter auth = new AuthFilter(); + auth.setRequestGet(new MethodAuthentication(List.of("**"), "ROLE_ADMIN1", "ROLE_AI1")); + auth.setRequestDelete(new MethodAuthentication(List.of("**"), "ROLE_ADMIN2", "ROLE_AI2")); + auth.setRequestPatch(new MethodAuthentication(List.of("**"), "ROLE_ADMIN3", "ROLE_AI3")); + auth.setRequestPost(new MethodAuthentication(List.of("**"), "ROLE_ADMIN4", "ROLE_AI4")); + auth.setRequestPut(new MethodAuthentication(List.of("**"), "ROLE_ADMIN5", "ROLE_AI5")); + return auth; + } + } + + @BeforeAll + static void setupAll() { + System.setProperty("jwt.secret", "edge-chain-unit-test-jwt-secret"); + } + + @Autowired private MockMvc mvc; + + @Autowired private JwtHelper jwtHelper; + + @Test + void validateJwt() { + String jwt = TestJwtCreator.generate("ROLE_TEST"); + assertTrue(jwtHelper.validate(jwt)); + } + + @Test + void getEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_ADMIN1"); + mvc.perform( + get(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void getEndpoint_notAuth() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_NO_ACCESS"); + mvc.perform( + get(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isForbidden()); + } + + @Test + void postEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_ADMIN4"); + mvc.perform( + post(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void postEndpoint_notAuth() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_NO_ACCESS"); + mvc.perform( + post(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isForbidden()); + } + + @Test + void deleteEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_ADMIN2"); + mvc.perform( + delete(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void deleteEndpoint_notAuth() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_NO_ACCESS"); + mvc.perform( + delete(FULL_NONCONTEXT_PATH) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isForbidden()); + } + + @Test + void patchEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_ADMIN3"); + mvc.perform( + patch(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void patchEndpoint_notAuth() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_NO_ACCESS"); + mvc.perform( + patch(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isForbidden()); + } + + @Test + void putEndpoint() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_ADMIN5"); + mvc.perform( + put(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isNotFound()); + } + + @Test + void putEndpoint_notAuth() throws Exception { + String jwt = TestJwtCreator.generate("ROLE_NO_ACCESS"); + mvc.perform( + put(FULL_NONCONTEXT_PATH) + .content("{}") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .header("Authorization", "Bearer " + jwt)) + .andExpect(status().isForbidden()); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java index cf8c8be36..9b8ac08aa 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/openai/OpenAiClientTest.java @@ -1,6 +1,6 @@ package com.edgechain.openai; -import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; +import com.edgechain.lib.endpoint.impl.llm.OpenAiChatEndpoint; import com.edgechain.lib.openai.request.ChatCompletionRequest; import com.edgechain.lib.openai.request.ChatMessage; import com.edgechain.lib.openai.response.ChatCompletionResponse; @@ -88,15 +88,15 @@ public void testOpenAiClient_ChatCompletionResponse_ShouldMappedToPOJO(Class } @Test - @DisplayName("Test OpenAiEndpoint With Retry Mechanism") + @DisplayName("Test OpenAiChatEndpoint With Retry Mechanism") @Order(3) public void testOpenAiClient_WithRetryMechanism_ShouldThrowExceptionWithRetry(TestInfo testInfo) throws InterruptedException { System.out.println("======== " + testInfo.getDisplayName() + " ========"); - OpenAiEndpoint endpoint = - new OpenAiEndpoint( + OpenAiChatEndpoint endpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, "", // apiKey "", // orgId @@ -120,7 +120,7 @@ public void testOpenAiClient_WithRetryMechanism_ShouldThrowExceptionWithRetry(Te } @Test - @DisplayName("Test OpenAiEndpoint With No Retry Mechanism") + @DisplayName("Test OpenAiChatEndpoint With No Retry Mechanism") @Order(4) public void testOpenAiClient_WithNoRetryMechanism_ShouldThrowExceptionWithNoRetry( TestInfo testInfo) throws InterruptedException { @@ -128,8 +128,8 @@ public void testOpenAiClient_WithNoRetryMechanism_ShouldThrowExceptionWithNoRetr System.out.println("======== " + testInfo.getDisplayName() + " ========"); // Step 1 : Create OpenAi Endpoint - OpenAiEndpoint endpoint = - new OpenAiEndpoint( + OpenAiChatEndpoint endpoint = + new OpenAiChatEndpoint( OPENAI_CHAT_COMPLETION_API, "", // apiKey "", // orgId diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java index 0f49fb8d6..f300ac6b8 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/pinecone/PineconeClientTest.java @@ -1,6 +1,107 @@ package com.edgechain.pinecone; +import com.edgechain.lib.endpoint.impl.index.PineconeEndpoint; +import com.edgechain.lib.index.client.impl.PineconeClient; +import com.edgechain.lib.response.StringResponse; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; -@SpringBootTest -public class PineconeClientTest {} +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +public class PineconeClientTest { + + @LocalServerPort private int port; + + @Autowired private PineconeClient pineconeClient; + + private PineconeEndpoint pineconeEndpoint; + + @BeforeEach + void setUp() { + System.setProperty("server.port", String.valueOf(port)); + pineconeEndpoint = new PineconeEndpoint("https://arakoo.ai", "apiKey", "Pinecone", null); + } + + @Test + @DisplayName("Test Upsert") + public void test_Upsert() { + RestTemplate restTemplate = mock(RestTemplate.class); + ResponseEntity responseEntity = ResponseEntity.ok("some dummy data"); + when(restTemplate.exchange(anyString(), any(), any(), eq(String.class))) + .thenReturn(responseEntity); + + EdgeChain result = pineconeClient.upsert(pineconeEndpoint); + + assertNotNull(result); + } + + @Test + @DisplayName("Test Batch Upsert") + public void test_Batch_Upsert() { + RestTemplate restTemplate = mock(RestTemplate.class); + ResponseEntity responseEntity = ResponseEntity.ok("some dummy data"); + when(restTemplate.exchange(anyString(), any(), any(), eq(String.class))) + .thenReturn(responseEntity); + + EdgeChain result = pineconeClient.batchUpsert(pineconeEndpoint); + + assertNotNull(result); + } + + @Test + @DisplayName("Test Query") + public void test_Query() { + RestTemplate restTemplate = mock(RestTemplate.class); + ResponseEntity responseEntity = ResponseEntity.ok("some dummy data"); + when(restTemplate.exchange(anyString(), any(), any(), eq(String.class))) + .thenReturn(responseEntity); + + EdgeChain queryResult = pineconeClient.batchUpsert(pineconeEndpoint); + + assertNotNull(queryResult); + assertNull(queryResult.get().getResponse()); + } + + @Test + @DisplayName("Test Delete All") + public void test_Delete_All() { + RestTemplate restTemplate = mock(RestTemplate.class); + ResponseEntity responseEntity = ResponseEntity.ok("some dummy data"); + when(restTemplate.exchange(anyString(), any(), any(), eq(String.class))) + .thenReturn(responseEntity); + + EdgeChain deleteResult = pineconeClient.deleteAll(pineconeEndpoint); + + assertNotNull(deleteResult); + assertTrue(deleteResult.get().getResponse().endsWith("Pinecone")); + } + + @Test + @DisplayName("Test Get Namespace") + public void test_GetNamespace() { + // Test for non-empty namespace + String nonEmptyNamespace = pineconeClient.getNamespace(pineconeEndpoint); + assertEquals("Pinecone", nonEmptyNamespace); + + // Test for empty namespace + pineconeEndpoint.setNamespace(""); + String emptyNamespace = pineconeClient.getNamespace(pineconeEndpoint); + assertEquals("", emptyNamespace); + + // Test for null namespace + pineconeEndpoint.setNamespace(null); + String nullNamespace = pineconeClient.getNamespace(pineconeEndpoint); + assertEquals("", nullNamespace); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java new file mode 100644 index 000000000..012c7dc16 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/postgres/PostgresClientMetadataRepositoryTest.java @@ -0,0 +1,185 @@ +package com.edgechain.postgres; + +import com.edgechain.lib.embeddings.WordEmbeddings; +import com.edgechain.lib.endpoint.impl.index.PostgresEndpoint; +import com.edgechain.lib.index.enums.PostgresDistanceMetric; +import com.edgechain.lib.index.repositories.PostgresClientMetadataRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.jdbc.core.JdbcTemplate; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +@SpringBootTest +@RunWith(MockitoJUnitRunner.class) +public class PostgresClientMetadataRepositoryTest { + + private final Logger logger = LoggerFactory.getLogger(this.getClass()); + @Mock private JdbcTemplate jdbcTemplate; + + PostgresEndpoint postgresEndpoint; + @InjectMocks private PostgresClientMetadataRepository repository; + + @Captor private ArgumentCaptor sqlQueryCaptor; + + @BeforeEach + public void setUp() { + postgresEndpoint = mock(PostgresEndpoint.class); + } + + @Test + @DisplayName( + "Test if jdbcTemplate execute method is called twice for createTable() and verifying sql" + + " queries") + public void testCreateTable_NonEmptyMetadataTableNames() { + // Arrange + when(postgresEndpoint.getMetadataTableNames()) + .thenReturn(Collections.singletonList("metadataTestTable")); + + // Act + repository.createTable(postgresEndpoint); + + // Assert + verify(jdbcTemplate, times(3)).execute(sqlQueryCaptor.capture()); + } + + @Test + @DisplayName("createTable() should throw error when the metadata table names list is empty") + public void testCreateTable_EmptyMetadataTableNames() { + // Arrange + when(postgresEndpoint.getMetadataTableNames()).thenReturn(Collections.emptyList()); + + // Act and Assert + assertThrows(IndexOutOfBoundsException.class, () -> repository.createTable(postgresEndpoint)); + verify(jdbcTemplate, times(0)).execute(sqlQueryCaptor.capture()); + } + + @Test + @DisplayName("Insert metadata must throw NullPointerException when metadata ID is null") + public void testInsertMetadata_ThrowsNullPointerException() { + + // Arrange + String tablename = "table"; + String metadataTableName = "metadata_table"; + String metadata = "example_metadata"; + String documentDate = "Aug 01, 2023"; + + // Mock jdbcTemplate.queryForObject to return null + when(jdbcTemplate.queryForObject(anyString(), eq(UUID.class), any(Object[].class))) + .thenReturn(null); + + // Act and Assert + assertThrows( + NullPointerException.class, + () -> { + repository.insertMetadata(tablename, metadataTableName, metadata, documentDate); + }); + + // Verify that jdbcTemplate.queryForObject was called with the correct SQL query and arguments + verify(jdbcTemplate, times(1)) + .queryForObject(sqlQueryCaptor.capture(), eq(UUID.class), any(Object[].class)); + } + + @Test + @DisplayName("Insert entry into the join table") + public void testInsertIntoJoinTable() { + // Arrange + String id = UUID.randomUUID().toString(); + String metadataId = UUID.randomUUID().toString(); + when(postgresEndpoint.getTableName()).thenReturn("embedding_table"); + when(postgresEndpoint.getMetadataTableNames()) + .thenReturn(Collections.singletonList("metadata_table")); + when(postgresEndpoint.getId()).thenReturn(id); + when(postgresEndpoint.getMetadataId()).thenReturn(metadataId); + String joinTable = + postgresEndpoint.getTableName() + + "_join_" + + postgresEndpoint.getMetadataTableNames().get(0); + + // Act + repository.insertIntoJoinTable(postgresEndpoint); + + // Assert + verify(jdbcTemplate, times(1)).execute(sqlQueryCaptor.capture()); + + // Verify the captured query and actual query is same or not + String capturedQuery = sqlQueryCaptor.getValue(); + String expectedQuery = + String.format( + "INSERT INTO %s (id, metadata_id) VALUES ('%s', '%s') ON CONFLICT (id) DO UPDATE SET" + + " metadata_id = EXCLUDED.metadata_id;", + joinTable, postgresEndpoint.getId(), postgresEndpoint.getMetadataId()); + assertEquals(expectedQuery, capturedQuery); + } + + @Test + @DisplayName("Query with metadata") + public void testQueryWithMetadata() { + // Arrange + String tableName = "embedding_table"; + String metadataTableName = "metadata_table"; + String namespace = "example_namespace"; + int probes = 1; + PostgresDistanceMetric metric = PostgresDistanceMetric.L2; + List wordEmbeddingValues = List.of(0.1f, 0.2f, 0.3f); + WordEmbeddings wordEmbeddings = new WordEmbeddings("", wordEmbeddingValues); + int topK = 5; + String metadataId = UUID.randomUUID().toString(); + String id = UUID.randomUUID().toString(); + + // Mock queryForList method to return a dummy result + List> dummyResult = + List.of( + Map.of( + "id", + id, + "metadata", + "example_metadata", + "document_date", + "Aug 01, 2023", + "metadata_id", + metadataId, + "raw_text", + "example_raw_text", + "namespace", + "example_namespace", + "filename", + "example_filename", + "timestamp", + "example_timestamp", + "score", + 0.5)); + when(jdbcTemplate.queryForList(anyString())).thenReturn(dummyResult); + + // Act + List> result = + repository.queryWithMetadata( + tableName, + metadataTableName, + namespace, + probes, + metric, + wordEmbeddings.getValues(), + topK); + + // Assert + verify(jdbcTemplate).queryForList(sqlQueryCaptor.capture()); + assertEquals(dummyResult, result); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/reactChain/ReactChainTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/reactChain/ReactChainTest.java index fc8de64ab..9d6ec81be 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/reactChain/ReactChainTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/reactChain/ReactChainTest.java @@ -1,6 +1,7 @@ package com.edgechain.reactChain; import com.edgechain.lib.jsonnet.JsonnetLoader; +import com.edgechain.lib.jsonnet.exceptions.JsonnetLoaderException; import com.edgechain.lib.jsonnet.impl.FileJsonnetLoader; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -9,8 +10,8 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; @SpringBootTest public class ReactChainTest { @@ -76,4 +77,48 @@ local callFunction(funcName) = assertNotNull(searchFunction); assertEquals(searchFunction, "udf.fn"); } + + @Test + @DisplayName("Test extractAction with invalid input") + void test_extractAction_WithInvalidJsonnet() throws Exception { + String inputJsonnet = "This is invalid jsonnet."; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + assertThrows(Exception.class, () -> jsonnetLoader.load(inputStream)); + } + + @Test + @DisplayName("Test extractAction with empty input") + void test_extractAction_withEmptyJsonnet() throws Exception { + String inputJsonnet = ""; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + assertThrows(Exception.class, () -> jsonnetLoader.get("action")); + assertThrows(Exception.class, () -> jsonnetLoader.load(inputStream)); + } + + @Test + @DisplayName("Test extractThought - invalid input") + void test_extractThought_WithInvalidInput() { + String inputJsonnet = "This is not a valid jsonnet pattern"; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + assertThrows(Exception.class, () -> jsonnetLoader.load(inputStream)); + } + + @Test + @DisplayName("Test Mapper - Missing function") + public void test_mapper_MissingFunction_ReturnedExpectedResult() { + String inputJsonnet = + """ + local config = { + "edgechains.config": { + "mapper": {}, + }, + }; + """; + InputStream inputStream = new ByteArrayInputStream(inputJsonnet.getBytes()); + JsonnetLoader jsonnetLoader = new FileJsonnetLoader(); + assertThrows(JsonnetLoaderException.class, () -> jsonnetLoader.load(inputStream)); + } } diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/PostgresTestContainer.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/PostgresTestContainer.java new file mode 100644 index 000000000..6feab846b --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/PostgresTestContainer.java @@ -0,0 +1,41 @@ +package com.edgechain.testutil; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.utility.DockerImageName; + +public class PostgresTestContainer extends PostgreSQLContainer { + + public enum PostgresImage { + PLAIN, + VECTOR + }; + + private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestContainer.class); + + // private static final String DOCKER_IMAGE = PostgreSQLContainer.IMAGE + ":" + + // PostgreSQLContainer.DEFAULT_TAG; + + private static final DockerImageName IMAGE = DockerImageName.parse("postgres").withTag("14.5"); + + private static final DockerImageName VECTOR_IMAGE = + DockerImageName.parse("ankane/pgvector").asCompatibleSubstituteFor("postgres"); + + public PostgresTestContainer(PostgresImage img) { + super(img == PostgresImage.VECTOR ? VECTOR_IMAGE : IMAGE); + } + + @Override + public void start() { + LOGGER.info("starting container"); + super.start(); + LOGGER.info("TEST with Docker PostgreSQL url={}", getJdbcUrl()); + } + + @Override + public void stop() { + LOGGER.info("stopping container"); + super.stop(); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java new file mode 100644 index 000000000..78b605896 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestConfigSupport.java @@ -0,0 +1,77 @@ +package com.edgechain.testutil; + +import com.edgechain.lib.configuration.context.ApplicationContextHolder; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import org.modelmapper.ModelMapper; +import org.springframework.context.ApplicationContext; +import org.springframework.test.util.ReflectionTestUtils; +import retrofit2.Retrofit; +import static org.mockito.Mockito.mock; + +/** Two useful pairs of functions to set private static fields. */ +public final class TestConfigSupport { + + private ApplicationContext prevAppContext; + private String prevServerPort; + + /** + * Creates and forcefully uses a mock application context. Previous value is remembered. Call this + * once in a @BeforeEach setup method. + * + * @return a mock application context + */ + public ApplicationContext setupAppContext() { + prevAppContext = ApplicationContextHolder.getContext(); + + ApplicationContext mockAppContext = mock(ApplicationContext.class); + ReflectionTestUtils.setField(ApplicationContextHolder.class, "context", mockAppContext); + + return mockAppContext; + } + + /** + * Restore a previously saved application context. Call this once in an @AfterEach teardown + * method. + */ + public void tearDownAppContext() { + ReflectionTestUtils.setField(ApplicationContextHolder.class, "context", prevAppContext); + } + + /** + * Creates and forcefully uses a mock Retrofit instance. Call this once in a @BeforeEach setup + * method. + * + * @return a mock Retrofit instance + */ + public Retrofit setupRetrofit() { + Retrofit mockRetrofit = mock(Retrofit.class); + ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", mockRetrofit); + + // Retrofit needs a valid port + prevServerPort = System.getProperty("server.port"); + System.setProperty("server.port", "8888"); + + return mockRetrofit; + } + + private ModelMapper setupModelMapper() { + ModelMapper mockModelMapper = mock(ModelMapper.class); + ReflectionTestUtils.setField(ModelMapper.class, "modelMapper", mockModelMapper); + // Retrofit needs a valid port + prevServerPort = System.getProperty("server.port"); + System.setProperty("server.port", "8888"); + return mockModelMapper; + } + + /** + * Erases the current Retrofit instance so it can be recreated. Call this once in an @AfterEach + * teardown method. + */ + public void tearDownRetrofit() { + ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); + + if (prevServerPort != null) { + System.setProperty("server.port", prevServerPort); + } + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestJwtCreator.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestJwtCreator.java new file mode 100644 index 000000000..07a633fcf --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/testutil/TestJwtCreator.java @@ -0,0 +1,25 @@ +package com.edgechain.testutil; + +import com.auth0.jwt.JWT; +import com.auth0.jwt.algorithms.Algorithm; +import java.util.Date; + +public final class TestJwtCreator { + + private TestJwtCreator() { + // no + } + + public static String generate(String role) { + Algorithm algo = Algorithm.HMAC256(System.getProperty("jwt.secret").getBytes()); + Date d = new Date(); + return JWT.create() + .withSubject("example JWT for testing") + .withClaim("email", "admin@machine.local") + .withClaim("role", role) + .withIssuer("edgechain-tester") + .withIssuedAt(d) + .withExpiresAt(new Date(d.getTime() + 25000)) + .sign(algo); + } +} diff --git a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java index 8d8b20a98..ec556317b 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java +++ b/Java/FlySpring/edgechain-app/src/test/java/com/edgechain/wiki/WikiClientTest.java @@ -1,19 +1,30 @@ package com.edgechain.wiki; -import com.edgechain.lib.endpoint.impl.WikiEndpoint; +import com.edgechain.lib.configuration.domain.SecurityUUID; +import com.edgechain.lib.endpoint.impl.wiki.WikiEndpoint; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.wiki.response.WikiResponse; import io.reactivex.rxjava3.observers.TestObserver; -import org.junit.jupiter.api.*; +import java.io.IOException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.TestMethodOrder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.test.context.SpringBootTest; -import static org.junit.jupiter.api.Assertions.*; - -import com.edgechain.lib.wiki.response.WikiResponse; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.util.ReflectionTestUtils; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; -@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) @TestMethodOrder(MethodOrderer.OrderAnnotation.class) -public class WikiClientTest { +class WikiClientTest { @LocalServerPort private int port; @@ -21,25 +32,36 @@ public class WikiClientTest { @BeforeEach public void setup() { - System.setProperty("server.port", "" + port); + System.setProperty("server.port", String.valueOf(port)); } @Test - @DisplayName("Test WikiContent Method Returns WikiResponse") - @Order(1) - public void wikiControllerTest_TestWikiContentMethod_ReturnsWikiResponse(TestInfo testInfo) + @DisplayName("Test WikiContent Method Handles Exception") + @DirtiesContext + void wikiControllerTest_TestWikiContentMethod_HandlesException(TestInfo testInfo) throws InterruptedException { + try { + // create a mock instance that will generate a non-IOException in the interceptor + SecurityUUID mockSecurityUUID = mock(SecurityUUID.class); + when(mockSecurityUUID.getAuthKey()).thenThrow(new RuntimeException("FORCED TEST EXCEPTION")); + ReflectionTestUtils.setField(RetrofitClientInstance.class, "securityUUID", mockSecurityUUID); + ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); - logger.info("======== " + testInfo.getDisplayName() + " ========"); + logger.info("======== {} ========", testInfo.getDisplayName()); - // Prepare test data - WikiEndpoint wikiEndpoint = new WikiEndpoint(); - TestObserver test = wikiEndpoint.getPageContent("Barack Obama").test(); + // Prepare test data + WikiEndpoint wikiEndpoint = new WikiEndpoint(); + TestObserver test = wikiEndpoint.getPageContent("Barack Obama").test(); - test.await(); + test.await(); - logger.info(test.values().toString()); + logger.info("{}", test.values().toString()); - test.assertNoErrors(); + test.assertError(IOException.class); + } finally { + // reset instance + ReflectionTestUtils.setField(RetrofitClientInstance.class, "securityUUID", null); + ReflectionTestUtils.setField(RetrofitClientInstance.class, "retrofit", null); + } } } diff --git a/Java/FlySpring/edgechain-app/src/test/java/resources/ChatCompletionRequest.json b/Java/FlySpring/edgechain-app/src/test/java/resources/ChatCompletionRequest.json index fe735f847..542a34b86 100644 --- a/Java/FlySpring/edgechain-app/src/test/java/resources/ChatCompletionRequest.json +++ b/Java/FlySpring/edgechain-app/src/test/java/resources/ChatCompletionRequest.json @@ -1,9 +1,18 @@ { - "model" : "gpt-3.5-turbo", - "temperature" : 0.7, - "messages" : [ { - "role" : "user", - "content" : "Can you write two unique sentences on Java Language?" - } ], - "stream" : false + "model": "gpt-3.5-turbo", + "temperature": 0.7, + "messages": [ + { + "role": "user", + "content": "Can you write two unique sentences on Java Language?" + } + ], + "stream": false, + "n": 1, + "stop": [], + "user": "", + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "logit_bias": {} } \ No newline at end of file diff --git a/Java/FlySpring/edgechain-app/src/test/resources/logback-test.xml b/Java/FlySpring/edgechain-app/src/test/resources/logback-test.xml new file mode 100644 index 000000000..c050220c4 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/resources/logback-test.xml @@ -0,0 +1,18 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n + + + + + + + + + + + + + + diff --git a/Java/FlySpring/edgechain-app/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/Java/FlySpring/edgechain-app/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 000000000..1f0955d45 --- /dev/null +++ b/Java/FlySpring/edgechain-app/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline diff --git a/Java/FlySpring/flyfly/.DS_Store b/Java/FlySpring/flyfly/.DS_Store new file mode 100644 index 000000000..2ff3bf83d Binary files /dev/null and b/Java/FlySpring/flyfly/.DS_Store differ diff --git a/Java/FlySpring/flyfly/pom.xml b/Java/FlySpring/flyfly/pom.xml index aa917240b..957692182 100644 --- a/Java/FlySpring/flyfly/pom.xml +++ b/Java/FlySpring/flyfly/pom.xml @@ -1,87 +1,85 @@ - + 4.0.0 + - org.springframework.boot - spring-boot-starter-parent - 3.0.0 - + com.flyspring + edgechain-parent + 0.0.1-SNAPSHOT - com.flyspring + flyfly 0.0.1-SNAPSHOT flyfly flyfly CLI + - 17 + 3.3.9 - - - - org.testcontainers - testcontainers-bom - 1.17.6 - pom - import - - - info.picocli picocli-spring-boot-starter - 4.7.0 + ${picocli.version} + org.apache.commons commons-lang3 + commons-io commons-io - 2.11.0 + net.lingala.zip4j zip4j - 2.11.3 + ${zip4j.version} + org.zeroturnaround zt-exec - 1.12 + ${zeroturnaround.version} + org.apache.maven maven-model - 3.3.9 + ${maven-model.version} + org.projectlombok lombok - 1.18.20 - provided + true + org.testcontainers testcontainers - 1.17.6 + org.testcontainers mysql + - mysql - mysql-connector-java + com.mysql + mysql-connector-j org.testcontainers postgresql + org.postgresql postgresql @@ -91,10 +89,12 @@ org.testcontainers mariadb + org.mariadb.jdbc mariadb-java-client + org.springframework.boot spring-boot-starter-test @@ -103,19 +103,21 @@ + clean install org.springframework.boot spring-boot-maven-plugin + ${spring-boot.version} - com.flyspring.flyfly.FlyflyApplication - + com.flyspring.flyfly.FlyflyApplication + org.apache.maven.plugins maven-antrun-plugin - 1.8 + ${maven-antrun.version} download-and-unpack-jbang @@ -126,18 +128,20 @@ - + + dest="${project.build.directory}/jbang" /> - + + file="${project.build.directory}/jbang/jbang/bin/jbang.jar" + tofile="${project.basedir}/src/main/resources/jbang.jar" /> @@ -156,17 +160,21 @@ gofly + + true + maven-antrun-plugin - 3.1.0 + ${maven-antrun.version} package - + @@ -180,13 +188,4 @@ - - - - - - - - - diff --git a/Java/FlySpring/pom.xml b/Java/FlySpring/pom.xml new file mode 100644 index 000000000..46befc706 --- /dev/null +++ b/Java/FlySpring/pom.xml @@ -0,0 +1,102 @@ + + + 4.0.0 + + com.flyspring + edgechain-parent + pom + 0.0.1-SNAPSHOT + EdgeChain parent + Parent POM for EdgeChain + + + 17 + 17 + 17 + + UTF-8 + UTF-8 + + 3.4.1 + 3.1.0 + + 3.1.3 + + 2.7.0 + 2.11.0 + 4.7.0 + 1.2.1 + 1.19.0 + 0.0.7 + 1.12 + 2.11.3 + + + + + + org.springframework.boot + spring-boot-dependencies + ${spring-boot.version} + pom + import + + + + commons-io + commons-io + ${commons-io.version} + + + + org.testcontainers + testcontainers-bom + ${testcontainers.version} + pom + import + + + + org.testcontainers + testcontainers + ${testcontainers.version} + + + + org.testcontainers + junit-jupiter + ${testcontainers.version} + + + + org.testcontainers + mysql + ${testcontainers.version} + + + + org.testcontainers + postgresql + ${testcontainers.version} + + + + org.testcontainers + mariadb + ${testcontainers.version} + + + + + + autoroute + flyfly + edgechain-app + + + + clean install + + diff --git a/Java/FlySpring/readme.md b/Java/FlySpring/readme.md index 891fc9f00..02d5c69bb 100644 --- a/Java/FlySpring/readme.md +++ b/Java/FlySpring/readme.md @@ -1,14 +1,21 @@ # flyfly CLI + ## Installation & Usage + cd to autoroute directory -```console -mvn clean package -P gofly + +```bash +mvn clean package ``` + cd to flyfly directory -```console -mvn clean package -P gofly + +```bash +mvn clean package ``` + Now cd into Examples/starter :- flyfly is ready to roll! + ```bash java -jar flyfly.jar ``` @@ -16,10 +23,13 @@ java -jar flyfly.jar ## Commands ### run + Runs the Spring Boot application. if the project has jpa and a database driver(connector) in the build file and there is not 'spring.datasource.url' in the application.properties. Then the CLI will start a TestContainers database and add temporary values to application.properties to allow the application to run successfully. That's if the driver is supported by the CLI and Docker is installed. -Currently supported DBs are: MySQL, Postgres, and MariaDB. +Currently supported DBs are: MySQL, Postgres, and MariaDB. + ### format + Format the code using Spotless. P.S. - New examples will be added soon! \ No newline at end of file