From 3c2d8b32b90979d2003b26ea3bc36d0a2be65a09 Mon Sep 17 00:00:00 2001 From: Ruby K V Date: Tue, 3 Dec 2024 15:19:38 -0600 Subject: [PATCH 1/4] deleted provider specific client factory, created generic client factory, moved hardcoded url to property file --- .DS_Store | Bin 0 -> 6148 bytes pom.xml | 16 ++-- src/main/java/com/r7b7/App.java | 27 ++++++ .../r7b7/client/DefaultAnthropicClient.java | 21 +++-- .../factory/AnthropicClientFactory.java | 24 ----- .../client/factory/GroqClientFactory.java | 24 ----- .../r7b7/client/factory/LLMClientFactory.java | 85 ++++++++++++++++++ .../client/factory/OllamaClientFactory.java | 24 ----- .../client/factory/OpenAIClientFactory.java | 24 ----- .../java/com/r7b7/config/PropertyConfig.java | 18 ++++ .../com/r7b7/service/AnthropicService.java | 4 +- .../java/com/r7b7/service/GroqService.java | 4 +- .../java/com/r7b7/service/OllamaService.java | 4 +- .../java/com/r7b7/service/OpenAIService.java | 7 +- src/main/resources/application.properties | 6 ++ 15 files changed, 173 insertions(+), 115 deletions(-) create mode 100644 .DS_Store create mode 100644 src/main/java/com/r7b7/App.java delete mode 100644 src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java delete mode 100644 src/main/java/com/r7b7/client/factory/GroqClientFactory.java create mode 100644 src/main/java/com/r7b7/client/factory/LLMClientFactory.java delete mode 100644 src/main/java/com/r7b7/client/factory/OllamaClientFactory.java delete mode 100644 src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java create mode 100644 src/main/java/com/r7b7/config/PropertyConfig.java create mode 100644 src/main/resources/application.properties diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5bedb1ce95563ffc067f4aaa154b744c40a6c86c GIT binary patch literal 6148 zcmeHKO>fgc5S>jzW2>r^10XI)mbg}F`T;`f#Z8k#CE!q__E1o;>!?_`-Y9kmQWVKY z5I=?+KZXCo3Eu2(D&ho=2+d5wowr;SBQs|iaO;qrppTJbA`1*1%=>U(U6J?_xqo?(KH`L zse0b8(I`#wqSN^%YW2pQji%!`o6hUtt(*o$SWNP6IKJTCODWT65r@%PGMo-tcaLRO zgh@6WtAZpRg7WG-$znO}%1IXIsy5IAj_bOE*8SP6v+X^7Jm22.18.1 - org.junit.jupiter - junit-jupiter-api - 5.11.3 - test - + org.junit.jupiter + junit-jupiter-api + 5.11.3 + test + org.mockito mockito-core @@ -55,5 +55,11 @@ 3.0.0 + + + src/main/resources + true + + \ No newline at end of file diff --git a/src/main/java/com/r7b7/App.java b/src/main/java/com/r7b7/App.java new file mode 100644 index 0000000..88ac4bb --- /dev/null +++ b/src/main/java/com/r7b7/App.java @@ -0,0 +1,27 @@ +package com.r7b7; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +public class App { + public static void main(String[] args) throws FileNotFoundException, IOException { + Properties properties = new Properties(); + try (InputStream input = App.class.getClassLoader().getResourceAsStream("application.properties")) { + if (input == null) { + System.out.println("Sorry, unable to find config.properties"); + return; + } + // Load properties file + properties.load(input); + + // Access properties + String propertyValue = properties.getProperty("hospai.openai.url"); + System.out.println("Property value: " + propertyValue); + } catch (Exception ex) { + ex.printStackTrace(); + } + } +} diff --git a/src/main/java/com/r7b7/client/DefaultAnthropicClient.java b/src/main/java/com/r7b7/client/DefaultAnthropicClient.java index 52735db..c5a1915 100644 --- a/src/main/java/com/r7b7/client/DefaultAnthropicClient.java +++ b/src/main/java/com/r7b7/client/DefaultAnthropicClient.java @@ -6,25 +6,36 @@ import java.net.http.HttpResponse; import java.util.List; import java.util.Map; +import java.util.Properties; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.r7b7.client.model.AnthropicResponse; import com.r7b7.client.model.Message; +import com.r7b7.config.PropertyConfig; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.ErrorResponse; public class DefaultAnthropicClient implements AnthropicClient { - private String ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"; - private String ANTHROPIC_VERSION = "2023-06-01"; - private Integer MAX_TOKENS = 1024; + private String ANTHROPIC_API_URL; + private String ANTHROPIC_VERSION; + private String MAX_TOKENS; - public DefaultAnthropicClient() {} + public DefaultAnthropicClient() { + try { + Properties properties = PropertyConfig.loadConfig(); + ANTHROPIC_API_URL = properties.getProperty("hospai.anthropic.url"); + ANTHROPIC_VERSION = properties.getProperty("hospai.anthropic.version"); + MAX_TOKENS = properties.getProperty("hospai.anthropic.maxTokens"); + } catch (Exception ex) { + throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY"); + } + } @Override - public CompletionResponse generateCompletion(CompletionRequest request){ + public CompletionResponse generateCompletion(CompletionRequest request) { try { ObjectMapper objectMapper = new ObjectMapper(); ArrayNode arrayNode = objectMapper.valueToTree(request.messages()); diff --git a/src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java b/src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java deleted file mode 100644 index 70ba673..0000000 --- a/src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.r7b7.client.factory; - -import com.r7b7.client.AnthropicClient; -import com.r7b7.client.DefaultAnthropicClient; - -public class AnthropicClientFactory { - private static AnthropicClient currentClient; - - public static AnthropicClient createDefaultClient() { - DefaultAnthropicClient client = new DefaultAnthropicClient(); - return client; - } - - public static void setClient(AnthropicClient client) { - currentClient = client; - } - - public static AnthropicClient getClient() { - if (null == currentClient) { - currentClient = createDefaultClient(); - } - return currentClient; - } -} diff --git a/src/main/java/com/r7b7/client/factory/GroqClientFactory.java b/src/main/java/com/r7b7/client/factory/GroqClientFactory.java deleted file mode 100644 index 3d44887..0000000 --- a/src/main/java/com/r7b7/client/factory/GroqClientFactory.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.r7b7.client.factory; - -import com.r7b7.client.DefaultGroqClient; -import com.r7b7.client.GroqClient; - -public class GroqClientFactory { - private static GroqClient currentClient; - - public static GroqClient createDefaultClient() { - DefaultGroqClient client = new DefaultGroqClient(); - return client; - } - - public static void setClient(GroqClient client) { - currentClient = client; - } - - public static GroqClient getClient() { - if (null == currentClient) { - currentClient = createDefaultClient(); - } - return currentClient; - } -} diff --git a/src/main/java/com/r7b7/client/factory/LLMClientFactory.java b/src/main/java/com/r7b7/client/factory/LLMClientFactory.java new file mode 100644 index 0000000..6b00123 --- /dev/null +++ b/src/main/java/com/r7b7/client/factory/LLMClientFactory.java @@ -0,0 +1,85 @@ +package com.r7b7.client.factory; + +import com.r7b7.client.AnthropicClient; +import com.r7b7.client.DefaultAnthropicClient; +import com.r7b7.client.DefaultGroqClient; +import com.r7b7.client.DefaultOllamaClient; +import com.r7b7.client.DefaultOpenAIClient; +import com.r7b7.client.GroqClient; +import com.r7b7.client.OllamaClient; +import com.r7b7.client.OpenAIClient; + +public class LLMClientFactory { + private static AnthropicClient currentAnthropicClient; + private static GroqClient currentGroqClient; + private static OllamaClient currentOllamaClient; + private static OpenAIClient currentOpenAIClient; + + // Open AI Client + public static OpenAIClient createDefaultOpenAIClient() { + DefaultOpenAIClient client = new DefaultOpenAIClient(); + return client; + } + + public static void setOpenAIClient(OpenAIClient client) { + currentOpenAIClient = client; + } + + public static OpenAIClient getOpenAIClient() { + if (null == currentOpenAIClient) { + currentOpenAIClient = createDefaultOpenAIClient(); + } + return currentOpenAIClient; + } + + // Anthropic Client + public static AnthropicClient createDefaultAnthropicClient() { + DefaultAnthropicClient client = new DefaultAnthropicClient(); + return client; + } + + public static void setAnthropicClient(AnthropicClient client) { + currentAnthropicClient = client; + } + + public static AnthropicClient getAnthropicClient() { + if (null == currentAnthropicClient) { + currentAnthropicClient = createDefaultAnthropicClient(); + } + return currentAnthropicClient; + } + + // Groq Client + public static GroqClient createDefaultGroqClient() { + DefaultGroqClient client = new DefaultGroqClient(); + return client; + } + + public static void setGroqClient(GroqClient client) { + currentGroqClient = client; + } + + public static GroqClient getGroqClient() { + if (null == currentGroqClient) { + currentGroqClient = createDefaultGroqClient(); + } + return currentGroqClient; + } + + // Ollama Client + public static OllamaClient createDefaultOllamaClient() { + DefaultOllamaClient client = new DefaultOllamaClient(); + return client; + } + + public static void setOllamaClient(OllamaClient client) { + currentOllamaClient = client; + } + + public static OllamaClient getOllamaClient() { + if (null == currentOllamaClient) { + currentOllamaClient = createDefaultOllamaClient(); + } + return currentOllamaClient; + } +} diff --git a/src/main/java/com/r7b7/client/factory/OllamaClientFactory.java b/src/main/java/com/r7b7/client/factory/OllamaClientFactory.java deleted file mode 100644 index 20165a0..0000000 --- a/src/main/java/com/r7b7/client/factory/OllamaClientFactory.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.r7b7.client.factory; - -import com.r7b7.client.DefaultOllamaClient; -import com.r7b7.client.OllamaClient; - -public class OllamaClientFactory { - private static OllamaClient currentClient; - - public static OllamaClient createDefaultClient() { - DefaultOllamaClient client = new DefaultOllamaClient(); - return client; - } - - public static void setClient(OllamaClient client) { - currentClient = client; - } - - public static OllamaClient getClient() { - if (null == currentClient) { - currentClient = createDefaultClient(); - } - return currentClient; - } -} diff --git a/src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java b/src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java deleted file mode 100644 index 3fa4cb7..0000000 --- a/src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.r7b7.client.factory; - -import com.r7b7.client.DefaultOpenAIClient; -import com.r7b7.client.OpenAIClient; - -public class OpenAIClientFactory { - private static OpenAIClient currentClient; - - public static OpenAIClient createDefaultClient() { - DefaultOpenAIClient client = new DefaultOpenAIClient(); - return client; - } - - public static void setClient(OpenAIClient client) { - currentClient = client; - } - - public static OpenAIClient getClient() { - if (null == currentClient) { - currentClient = createDefaultClient(); - } - return currentClient; - } -} diff --git a/src/main/java/com/r7b7/config/PropertyConfig.java b/src/main/java/com/r7b7/config/PropertyConfig.java new file mode 100644 index 0000000..94c3157 --- /dev/null +++ b/src/main/java/com/r7b7/config/PropertyConfig.java @@ -0,0 +1,18 @@ +package com.r7b7.config; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +import com.r7b7.App; + +public class PropertyConfig { + + public static Properties loadConfig() throws IOException { + Properties properties = new Properties(); + try (InputStream input = App.class.getClassLoader().getResourceAsStream("application.properties")) { + properties.load(input); + } + return properties; + } +} diff --git a/src/main/java/com/r7b7/service/AnthropicService.java b/src/main/java/com/r7b7/service/AnthropicService.java index 53ff208..0daf3ac 100644 --- a/src/main/java/com/r7b7/service/AnthropicService.java +++ b/src/main/java/com/r7b7/service/AnthropicService.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import com.r7b7.client.AnthropicClient; -import com.r7b7.client.factory.AnthropicClientFactory; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Param; @@ -24,7 +24,7 @@ public AnthropicService(String apiKey, String model) { @Override public LLMResponse generateResponse(LLMRequest request) { - AnthropicClient client = AnthropicClientFactory.getClient(); + AnthropicClient client = LLMClientFactory.getAnthropicClient(); Map platformAllignedParams = null; platformAllignedParams = getPlatformAllignedParams(request); diff --git a/src/main/java/com/r7b7/service/GroqService.java b/src/main/java/com/r7b7/service/GroqService.java index 3e49771..3eeac65 100644 --- a/src/main/java/com/r7b7/service/GroqService.java +++ b/src/main/java/com/r7b7/service/GroqService.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import com.r7b7.client.GroqClient; -import com.r7b7.client.factory.GroqClientFactory; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Param; @@ -25,7 +25,7 @@ public GroqService(String apiKey, String model) { @Override public LLMResponse generateResponse(LLMRequest request) { CompletionResponse response = null; - GroqClient client = GroqClientFactory.getClient(); + GroqClient client = LLMClientFactory.getGroqClient(); Map platformAllignedParams = getPlatformAllignedParams(request); response = client diff --git a/src/main/java/com/r7b7/service/OllamaService.java b/src/main/java/com/r7b7/service/OllamaService.java index 0c2e3a5..cc3fea3 100644 --- a/src/main/java/com/r7b7/service/OllamaService.java +++ b/src/main/java/com/r7b7/service/OllamaService.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import com.r7b7.client.OllamaClient; -import com.r7b7.client.factory.OllamaClientFactory; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Param; @@ -22,7 +22,7 @@ public OllamaService(String model) { @Override public LLMResponse generateResponse(LLMRequest request) { - OllamaClient client = OllamaClientFactory.getClient(); + OllamaClient client = LLMClientFactory.getOllamaClient(); Map platformAllignedParams = getPlatformAllignedParams(request); CompletionResponse response = client diff --git a/src/main/java/com/r7b7/service/OpenAIService.java b/src/main/java/com/r7b7/service/OpenAIService.java index 4eeb1d8..2326335 100644 --- a/src/main/java/com/r7b7/service/OpenAIService.java +++ b/src/main/java/com/r7b7/service/OpenAIService.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import com.r7b7.client.OpenAIClient; -import com.r7b7.client.factory.OpenAIClientFactory; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Param; @@ -25,10 +25,11 @@ public OpenAIService(String apiKey, String model) { @Override public LLMResponse generateResponse(LLMRequest request) { CompletionResponse response = null; - OpenAIClient client = OpenAIClientFactory.getClient(); + OpenAIClient client = LLMClientFactory.getOpenAIClient(); Map platformAllignedParams = getPlatformAllignedParams(request); - response = client.generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey)); + response = client + .generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey)); Map metadata = Map.of( "model", model, "provider", "openai"); diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties new file mode 100644 index 0000000..1f36d62 --- /dev/null +++ b/src/main/resources/application.properties @@ -0,0 +1,6 @@ +hospai.openai.url=https://google.com +hospai.anthropic.url=https://api.anthropic.com/v1/messages +hospai.anthropic.version=2023-06-01 +hospai.anthropic.maxTokens=1024 +hospai.ollama.url= +hospai.groq.url= \ No newline at end of file From f3ffd5cb99442673ba8e6445c5200c7b0d29391d Mon Sep 17 00:00:00 2001 From: Ruby K V Date: Thu, 5 Dec 2024 13:51:03 -0600 Subject: [PATCH 2/4] Code refactored,PromptBuilder added,test updated --- src/main/java/com/r7b7/App.java | 27 ------ .../r7b7/client/DefaultAnthropicClient.java | 40 +++----- .../com/r7b7/client/DefaultGroqClient.java | 50 +++++----- .../com/r7b7/client/DefaultOllamaClient.java | 43 ++++---- .../com/r7b7/client/DefaultOpenAIClient.java | 46 ++++----- ...ropicClient.java => IAnthropicClient.java} | 2 +- .../{GroqClient.java => IGroqClient.java} | 2 +- .../{OpenAIClient.java => IOllamaClient.java} | 2 +- .../{OllamaClient.java => IOpenAIClient.java} | 2 +- .../r7b7/client/factory/LLMClientFactory.java | 40 ++++---- .../java/com/r7b7/config/PropertyConfig.java | 12 ++- .../com/r7b7/entity/CompletionRequest.java | 3 +- .../com/r7b7/entity/CompletionResponse.java | 3 +- src/main/java/com/r7b7/entity/Param.java | 5 - src/main/java/com/r7b7/entity/Role.java | 2 +- .../java/com/r7b7/model/BaseLLMRequest.java | 9 +- .../java/com/r7b7/model/BaseLLMResponse.java | 25 ----- .../{LLMRequest.java => ILLMRequest.java} | 5 +- src/main/java/com/r7b7/model/LLMResponse.java | 10 -- .../com/r7b7/service/AnthropicService.java | 97 ++++++++++++------- .../java/com/r7b7/service/GroqService.java | 77 ++++++++------- .../java/com/r7b7/service/ILLMService.java | 16 +++ .../java/com/r7b7/service/LLMService.java | 11 --- .../com/r7b7/service/LLMServiceFactory.java | 4 +- .../java/com/r7b7/service/OllamaService.java | 74 ++++++++------ .../java/com/r7b7/service/OpenAIService.java | 77 ++++++++------- .../java/com/r7b7/service/PromptBuilder.java | 35 +++++++ .../java/com/r7b7/service/PromptEngine.java | 45 +++++---- .../java/com/r7b7/util/StringUtility.java | 4 +- src/main/resources/application.properties | 7 +- .../client/DefaultAnthropicClientTest.java | 41 ++++++-- .../r7b7/client/DefaultGroqClientTest.java | 44 ++++++--- .../r7b7/client/DefaultOllamaClientTest.java | 34 ++++--- .../r7b7/client/DefaultOpenAIClientTest.java | 43 +++++--- .../r7b7/service/AnthropicServiceTest.java | 88 +++++++++++------ .../com/r7b7/service/GroqServiceTest.java | 85 ++++++++++------ .../com/r7b7/service/OllamaServiceTest.java | 86 ++++++++++------ .../com/r7b7/service/OpenAIServiceTest.java | 86 ++++++++++------ .../com/r7b7/service/PromptEngineTest.java | 55 ++++++++--- 39 files changed, 775 insertions(+), 562 deletions(-) delete mode 100644 src/main/java/com/r7b7/App.java rename src/main/java/com/r7b7/client/{AnthropicClient.java => IAnthropicClient.java} (83%) rename src/main/java/com/r7b7/client/{GroqClient.java => IGroqClient.java} (85%) rename src/main/java/com/r7b7/client/{OpenAIClient.java => IOllamaClient.java} (84%) rename src/main/java/com/r7b7/client/{OllamaClient.java => IOpenAIClient.java} (84%) delete mode 100644 src/main/java/com/r7b7/entity/Param.java delete mode 100644 src/main/java/com/r7b7/model/BaseLLMResponse.java rename src/main/java/com/r7b7/model/{LLMRequest.java => ILLMRequest.java} (57%) delete mode 100644 src/main/java/com/r7b7/model/LLMResponse.java create mode 100644 src/main/java/com/r7b7/service/ILLMService.java delete mode 100644 src/main/java/com/r7b7/service/LLMService.java create mode 100644 src/main/java/com/r7b7/service/PromptBuilder.java diff --git a/src/main/java/com/r7b7/App.java b/src/main/java/com/r7b7/App.java deleted file mode 100644 index 88ac4bb..0000000 --- a/src/main/java/com/r7b7/App.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.r7b7; - -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.InputStream; -import java.util.Properties; - -public class App { - public static void main(String[] args) throws FileNotFoundException, IOException { - Properties properties = new Properties(); - try (InputStream input = App.class.getClassLoader().getResourceAsStream("application.properties")) { - if (input == null) { - System.out.println("Sorry, unable to find config.properties"); - return; - } - // Load properties file - properties.load(input); - - // Access properties - String propertyValue = properties.getProperty("hospai.openai.url"); - System.out.println("Property value: " + propertyValue); - } catch (Exception ex) { - ex.printStackTrace(); - } - } -} diff --git a/src/main/java/com/r7b7/client/DefaultAnthropicClient.java b/src/main/java/com/r7b7/client/DefaultAnthropicClient.java index c5a1915..3b4555f 100644 --- a/src/main/java/com/r7b7/client/DefaultAnthropicClient.java +++ b/src/main/java/com/r7b7/client/DefaultAnthropicClient.java @@ -9,8 +9,6 @@ import java.util.Properties; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.r7b7.client.model.AnthropicResponse; import com.r7b7.client.model.Message; import com.r7b7.config.PropertyConfig; @@ -18,17 +16,15 @@ import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.ErrorResponse; -public class DefaultAnthropicClient implements AnthropicClient { +public class DefaultAnthropicClient implements IAnthropicClient { private String ANTHROPIC_API_URL; private String ANTHROPIC_VERSION; - private String MAX_TOKENS; public DefaultAnthropicClient() { try { Properties properties = PropertyConfig.loadConfig(); ANTHROPIC_API_URL = properties.getProperty("hospai.anthropic.url"); ANTHROPIC_VERSION = properties.getProperty("hospai.anthropic.version"); - MAX_TOKENS = properties.getProperty("hospai.anthropic.maxTokens"); } catch (Exception ex) { throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY"); } @@ -38,30 +34,14 @@ public DefaultAnthropicClient() { public CompletionResponse generateCompletion(CompletionRequest request) { try { ObjectMapper objectMapper = new ObjectMapper(); - ArrayNode arrayNode = objectMapper.valueToTree(request.messages()); - ObjectNode requestBody = objectMapper.createObjectNode(); - requestBody.put("model", request.model()); - requestBody.set("messages", arrayNode); - - if (null != request.params()) { - for (Map.Entry entry : request.params().entrySet()) { - Object value = entry.getValue(); - if (value instanceof String) { - requestBody.put(entry.getKey(), (String) value); - } else if (value instanceof Integer) { - requestBody.put(entry.getKey(), (Integer) value); - } - } - } else { - requestBody.put("max_tokens", MAX_TOKENS); - } + String jsonRequest = objectMapper.writeValueAsString(request.requestBody()); HttpRequest httpRequest = HttpRequest.newBuilder() .uri(URI.create(this.ANTHROPIC_API_URL)) .header("Content-Type", "application/json") .header("x-api-key", request.apiKey()) .header("anthropic-version", ANTHROPIC_VERSION) - .POST(HttpRequest.BodyPublishers.ofString(requestBody.toString())) + .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)) .build(); HttpResponse response = HttpClient.newHttpClient().send(httpRequest, @@ -69,8 +49,8 @@ public CompletionResponse generateCompletion(CompletionRequest request) { if (response.statusCode() == 200) { return extractResponseText(response.body()); } else { - return new CompletionResponse(null, response, new ErrorResponse( - "Request sent to LLM failed with status code " + response, null)); + return new CompletionResponse(null, null, new ErrorResponse( + "Request sent to LLM failed: " + response.statusCode() + response.body(), null)); } } catch (Exception ex) { return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex)); @@ -82,13 +62,21 @@ private CompletionResponse extractResponseText(String responseBody) { List msgs = null; AnthropicResponse response = null; ErrorResponse error = null; + Map metadata = null; + try { ObjectMapper mapper = new ObjectMapper(); response = mapper.readValue(responseBody, AnthropicResponse.class); msgs = response.content().stream().map(content -> new Message(content.type(), content.text())).toList(); + metadata = Map.of( + "id", response.id(), + "model", response.model(), + "provider", "Anthropic", + "input_tokens", response.usage().inputTokens(), + "output_tokens", response.usage().outputTokens()); } catch (Exception ex) { error = new ErrorResponse("Exception occurred in extracting response", ex); } - return new CompletionResponse(msgs, response, error); + return new CompletionResponse(msgs, metadata, error); } } diff --git a/src/main/java/com/r7b7/client/DefaultGroqClient.java b/src/main/java/com/r7b7/client/DefaultGroqClient.java index feaddb1..963ff1e 100644 --- a/src/main/java/com/r7b7/client/DefaultGroqClient.java +++ b/src/main/java/com/r7b7/client/DefaultGroqClient.java @@ -6,54 +6,47 @@ import java.net.http.HttpResponse; import java.util.List; import java.util.Map; +import java.util.Properties; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.r7b7.client.model.Message; import com.r7b7.client.model.OpenAIResponse; +import com.r7b7.config.PropertyConfig; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.ErrorResponse; -public class DefaultGroqClient implements GroqClient { - private String GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"; +public class DefaultGroqClient implements IGroqClient { + private String GROQ_API_URL; - public DefaultGroqClient(){} + public DefaultGroqClient() { + try { + Properties properties = PropertyConfig.loadConfig(); + GROQ_API_URL = properties.getProperty("hospai.groq.url"); + } catch (Exception ex) { + throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY"); + } + } @Override public CompletionResponse generateCompletion(CompletionRequest request) { try { ObjectMapper objectMapper = new ObjectMapper(); - ArrayNode arrayNode = objectMapper.valueToTree(request.messages()); - ObjectNode requestBody = objectMapper.createObjectNode(); - requestBody.put("model", request.model()); - requestBody.set("messages", arrayNode); + String jsonRequest = objectMapper.writeValueAsString(request.requestBody()); - if (null != request.params()) { - for (Map.Entry entry : request.params().entrySet()) { - Object value = entry.getValue(); - if (value instanceof String) { - requestBody.put(entry.getKey(), (String) value); - } else if (value instanceof Integer) { - requestBody.put(entry.getKey(), (Integer) value); - } - } - } - HttpRequest httpRequest = HttpRequest.newBuilder() .uri(URI.create(this.GROQ_API_URL)) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + request.apiKey()) - .POST(HttpRequest.BodyPublishers.ofString(requestBody.toString())) + .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)) .build(); HttpResponse response = HttpClient.newHttpClient().send(httpRequest, HttpResponse.BodyHandlers.ofString()); if (response.statusCode() == 200) { return extractResponseText(response.body()); } else { - return new CompletionResponse(null, response, new ErrorResponse( - "Request sent to LLM failed with status code " + response.statusCode(), null)); + return new CompletionResponse(null, null, new ErrorResponse( + "Request sent to LLM failed: " + response.statusCode() + response.body(), null)); } } catch (Exception ex) { return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex)); @@ -64,16 +57,23 @@ private CompletionResponse extractResponseText(String responseBody) { List msgs = null; OpenAIResponse response = null; ErrorResponse error = null; + Map metadata = null; try { ObjectMapper mapper = new ObjectMapper(); response = mapper.readValue(responseBody, OpenAIResponse.class); msgs = response.choices().stream() .map(choice -> new Message(choice.message().role(), choice.message().content())).toList(); + metadata = Map.of( + "id", response.id(), + "model", response.model(), + "provider", "Groq", + "prompt_tokens", response.usage().promptTokens(), + "completion_tokens", response.usage().completionTokens(), + "total_tokens", response.usage().totalTokens()); } catch (Exception ex) { error = new ErrorResponse("Exception occurred in extracting response", ex); } - return new CompletionResponse(msgs, response, error); + return new CompletionResponse(msgs, metadata, error); } - } diff --git a/src/main/java/com/r7b7/client/DefaultOllamaClient.java b/src/main/java/com/r7b7/client/DefaultOllamaClient.java index ed918f5..cacad98 100644 --- a/src/main/java/com/r7b7/client/DefaultOllamaClient.java +++ b/src/main/java/com/r7b7/client/DefaultOllamaClient.java @@ -4,43 +4,35 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Properties; import com.fasterxml.jackson.databind.ObjectMapper; import com.r7b7.client.model.Message; import com.r7b7.client.model.OllamaResponse; +import com.r7b7.config.PropertyConfig; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.ErrorResponse; -public class DefaultOllamaClient implements OllamaClient { - private String OLLAMA_API_URL = "http://localhost:11434/api/chat"; +public class DefaultOllamaClient implements IOllamaClient { + private String OLLAMA_API_URL; public DefaultOllamaClient() { + try { + Properties properties = PropertyConfig.loadConfig(); + OLLAMA_API_URL = properties.getProperty("hospai.ollama.url"); + } catch (Exception ex) { + throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY"); + } } @Override public CompletionResponse generateCompletion(CompletionRequest request) { try { ObjectMapper objectMapper = new ObjectMapper(); - - Map requestMap = new HashMap<>(); - requestMap.put("model", request.model()); - requestMap.put("messages", request.messages()); - requestMap.put("stream", false); - - Map optionsMap = new HashMap<>(); - if (null != request.params() && request.params().get("temperature") != null) { - optionsMap.put("temperature", request.params().get("temperature")); - } - if (null != request.params() && request.params().get("seed") != null) { - optionsMap.put("seed", request.params().get("seed")); - } - - requestMap.put("options", optionsMap); - String jsonRequest = objectMapper.writeValueAsString(requestMap); + String jsonRequest = objectMapper.writeValueAsString(request.requestBody()); HttpRequest httpRequest = HttpRequest.newBuilder() .uri(URI.create(this.OLLAMA_API_URL)) @@ -53,8 +45,8 @@ public CompletionResponse generateCompletion(CompletionRequest request) { if (response.statusCode() == 200) { return extractResponseText(response.body()); } else { - return new CompletionResponse(null, response, new ErrorResponse( - "Request sent to LLM failed with status code " + response.statusCode(), null)); + return new CompletionResponse(null, null, new ErrorResponse( + "Request sent to LLM failed: " + response.statusCode() + response.body(), null)); } } catch (Exception ex) { return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex)); @@ -65,13 +57,20 @@ private CompletionResponse extractResponseText(String responseBody) { List msgs = null; OllamaResponse response = null; ErrorResponse error = null; + Map metadata = null; + try { ObjectMapper mapper = new ObjectMapper(); response = mapper.readValue(responseBody, OllamaResponse.class); msgs = List.of(response.message()); + metadata = Map.of( + "model", response.model(), + "provider", "Ollama", + "total_duration", response.total_duration(), + "eval_duration", response.eval_duration()); } catch (Exception ex) { error = new ErrorResponse("Exception occurred in extracting response", ex); } - return new CompletionResponse(msgs, response, error); + return new CompletionResponse(msgs, metadata, error); } } diff --git a/src/main/java/com/r7b7/client/DefaultOpenAIClient.java b/src/main/java/com/r7b7/client/DefaultOpenAIClient.java index beeeb0e..0b13b82 100644 --- a/src/main/java/com/r7b7/client/DefaultOpenAIClient.java +++ b/src/main/java/com/r7b7/client/DefaultOpenAIClient.java @@ -6,55 +6,47 @@ import java.net.http.HttpResponse; import java.util.List; import java.util.Map; +import java.util.Properties; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.r7b7.client.model.Message; import com.r7b7.client.model.OpenAIResponse; +import com.r7b7.config.PropertyConfig; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.ErrorResponse; -public class DefaultOpenAIClient implements OpenAIClient { - private String OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"; +public class DefaultOpenAIClient implements IOpenAIClient { + private String OPENAI_API_URL; public DefaultOpenAIClient() { + try { + Properties properties = PropertyConfig.loadConfig(); + OPENAI_API_URL = properties.getProperty("hospai.openai.url"); + } catch (Exception ex) { + throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY"); + } } @Override public CompletionResponse generateCompletion(CompletionRequest request) { try { ObjectMapper objectMapper = new ObjectMapper(); - ArrayNode arrayNode = objectMapper.valueToTree(request.messages()); - ObjectNode requestBody = objectMapper.createObjectNode(); - requestBody.put("model", request.model()); - requestBody.set("messages", arrayNode); - - if (null != request.params()) { - for (Map.Entry entry : request.params().entrySet()) { - Object value = entry.getValue(); - if (value instanceof String) { - requestBody.put(entry.getKey(), (String) value); - } else if (value instanceof Integer) { - requestBody.put(entry.getKey(), (Integer) value); - } - } - } + String jsonRequest = objectMapper.writeValueAsString(request.requestBody()); HttpRequest httpRequest = HttpRequest.newBuilder() .uri(URI.create(this.OPENAI_API_URL)) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + request.apiKey()) - .POST(HttpRequest.BodyPublishers.ofString(requestBody.toString())) + .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)) .build(); HttpResponse response = HttpClient.newHttpClient().send(httpRequest, HttpResponse.BodyHandlers.ofString()); if (response.statusCode() == 200) { return extractResponseText(response.body()); } else { - return new CompletionResponse(null, response, new ErrorResponse( - "Request sent to LLM failed with status code " + response.statusCode(), null)); + return new CompletionResponse(null, null, new ErrorResponse( + "Request sent to LLM failed: " + response.statusCode() + response.body(), null)); } } catch (Exception ex) { return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex)); @@ -65,15 +57,23 @@ private CompletionResponse extractResponseText(String responseBody) { List msgs = null; OpenAIResponse response = null; ErrorResponse error = null; + Map metadata = null; try { ObjectMapper mapper = new ObjectMapper(); response = mapper.readValue(responseBody, OpenAIResponse.class); msgs = response.choices().stream() .map(choice -> new Message(choice.message().role(), choice.message().content())).toList(); + metadata = Map.of( + "id", response.id(), + "model", response.model(), + "provider", "OpenAi", + "prompt_tokens", response.usage().promptTokens(), + "completion_tokens", response.usage().completionTokens(), + "total_tokens", response.usage().totalTokens()); } catch (Exception ex) { error = new ErrorResponse("Exception occurred in extracting response", ex); } - return new CompletionResponse(msgs, response, error); + return new CompletionResponse(msgs, metadata, error); } } diff --git a/src/main/java/com/r7b7/client/AnthropicClient.java b/src/main/java/com/r7b7/client/IAnthropicClient.java similarity index 83% rename from src/main/java/com/r7b7/client/AnthropicClient.java rename to src/main/java/com/r7b7/client/IAnthropicClient.java index 3aa4d6b..0ffab7a 100644 --- a/src/main/java/com/r7b7/client/AnthropicClient.java +++ b/src/main/java/com/r7b7/client/IAnthropicClient.java @@ -3,6 +3,6 @@ import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -public interface AnthropicClient { +public interface IAnthropicClient { CompletionResponse generateCompletion(CompletionRequest request); } diff --git a/src/main/java/com/r7b7/client/GroqClient.java b/src/main/java/com/r7b7/client/IGroqClient.java similarity index 85% rename from src/main/java/com/r7b7/client/GroqClient.java rename to src/main/java/com/r7b7/client/IGroqClient.java index b0125ab..88c3f66 100644 --- a/src/main/java/com/r7b7/client/GroqClient.java +++ b/src/main/java/com/r7b7/client/IGroqClient.java @@ -3,6 +3,6 @@ import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -public interface GroqClient { +public interface IGroqClient { CompletionResponse generateCompletion(CompletionRequest request); } diff --git a/src/main/java/com/r7b7/client/OpenAIClient.java b/src/main/java/com/r7b7/client/IOllamaClient.java similarity index 84% rename from src/main/java/com/r7b7/client/OpenAIClient.java rename to src/main/java/com/r7b7/client/IOllamaClient.java index 5dc1858..3a8924d 100644 --- a/src/main/java/com/r7b7/client/OpenAIClient.java +++ b/src/main/java/com/r7b7/client/IOllamaClient.java @@ -3,6 +3,6 @@ import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -public interface OpenAIClient { +public interface IOllamaClient { CompletionResponse generateCompletion(CompletionRequest request); } diff --git a/src/main/java/com/r7b7/client/OllamaClient.java b/src/main/java/com/r7b7/client/IOpenAIClient.java similarity index 84% rename from src/main/java/com/r7b7/client/OllamaClient.java rename to src/main/java/com/r7b7/client/IOpenAIClient.java index dac4f1b..77a449a 100644 --- a/src/main/java/com/r7b7/client/OllamaClient.java +++ b/src/main/java/com/r7b7/client/IOpenAIClient.java @@ -3,6 +3,6 @@ import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -public interface OllamaClient { +public interface IOpenAIClient { CompletionResponse generateCompletion(CompletionRequest request); } diff --git a/src/main/java/com/r7b7/client/factory/LLMClientFactory.java b/src/main/java/com/r7b7/client/factory/LLMClientFactory.java index 6b00123..9207f2a 100644 --- a/src/main/java/com/r7b7/client/factory/LLMClientFactory.java +++ b/src/main/java/com/r7b7/client/factory/LLMClientFactory.java @@ -1,31 +1,31 @@ package com.r7b7.client.factory; -import com.r7b7.client.AnthropicClient; +import com.r7b7.client.IAnthropicClient; import com.r7b7.client.DefaultAnthropicClient; import com.r7b7.client.DefaultGroqClient; import com.r7b7.client.DefaultOllamaClient; import com.r7b7.client.DefaultOpenAIClient; -import com.r7b7.client.GroqClient; -import com.r7b7.client.OllamaClient; -import com.r7b7.client.OpenAIClient; +import com.r7b7.client.IGroqClient; +import com.r7b7.client.IOllamaClient; +import com.r7b7.client.IOpenAIClient; public class LLMClientFactory { - private static AnthropicClient currentAnthropicClient; - private static GroqClient currentGroqClient; - private static OllamaClient currentOllamaClient; - private static OpenAIClient currentOpenAIClient; + private static IAnthropicClient currentAnthropicClient; + private static IGroqClient currentGroqClient; + private static IOllamaClient currentOllamaClient; + private static IOpenAIClient currentOpenAIClient; // Open AI Client - public static OpenAIClient createDefaultOpenAIClient() { + public static IOpenAIClient createDefaultOpenAIClient() { DefaultOpenAIClient client = new DefaultOpenAIClient(); return client; } - public static void setOpenAIClient(OpenAIClient client) { + public static void setOpenAIClient(IOpenAIClient client) { currentOpenAIClient = client; } - public static OpenAIClient getOpenAIClient() { + public static IOpenAIClient getOpenAIClient() { if (null == currentOpenAIClient) { currentOpenAIClient = createDefaultOpenAIClient(); } @@ -33,16 +33,16 @@ public static OpenAIClient getOpenAIClient() { } // Anthropic Client - public static AnthropicClient createDefaultAnthropicClient() { + public static IAnthropicClient createDefaultAnthropicClient() { DefaultAnthropicClient client = new DefaultAnthropicClient(); return client; } - public static void setAnthropicClient(AnthropicClient client) { + public static void setAnthropicClient(IAnthropicClient client) { currentAnthropicClient = client; } - public static AnthropicClient getAnthropicClient() { + public static IAnthropicClient getAnthropicClient() { if (null == currentAnthropicClient) { currentAnthropicClient = createDefaultAnthropicClient(); } @@ -50,16 +50,16 @@ public static AnthropicClient getAnthropicClient() { } // Groq Client - public static GroqClient createDefaultGroqClient() { + public static IGroqClient createDefaultGroqClient() { DefaultGroqClient client = new DefaultGroqClient(); return client; } - public static void setGroqClient(GroqClient client) { + public static void setGroqClient(IGroqClient client) { currentGroqClient = client; } - public static GroqClient getGroqClient() { + public static IGroqClient getGroqClient() { if (null == currentGroqClient) { currentGroqClient = createDefaultGroqClient(); } @@ -67,16 +67,16 @@ public static GroqClient getGroqClient() { } // Ollama Client - public static OllamaClient createDefaultOllamaClient() { + public static IOllamaClient createDefaultOllamaClient() { DefaultOllamaClient client = new DefaultOllamaClient(); return client; } - public static void setOllamaClient(OllamaClient client) { + public static void setOllamaClient(IOllamaClient client) { currentOllamaClient = client; } - public static OllamaClient getOllamaClient() { + public static IOllamaClient getOllamaClient() { if (null == currentOllamaClient) { currentOllamaClient = createDefaultOllamaClient(); } diff --git a/src/main/java/com/r7b7/config/PropertyConfig.java b/src/main/java/com/r7b7/config/PropertyConfig.java index 94c3157..4ac905f 100644 --- a/src/main/java/com/r7b7/config/PropertyConfig.java +++ b/src/main/java/com/r7b7/config/PropertyConfig.java @@ -4,14 +4,16 @@ import java.io.InputStream; import java.util.Properties; -import com.r7b7.App; - public class PropertyConfig { + private static Properties properties; public static Properties loadConfig() throws IOException { - Properties properties = new Properties(); - try (InputStream input = App.class.getClassLoader().getResourceAsStream("application.properties")) { - properties.load(input); + if (null == properties) { + properties = new Properties(); + try (InputStream input = PropertyConfig.class.getClassLoader() + .getResourceAsStream("application.properties")) { + properties.load(input); + } } return properties; } diff --git a/src/main/java/com/r7b7/entity/CompletionRequest.java b/src/main/java/com/r7b7/entity/CompletionRequest.java index a58716b..1b76f62 100644 --- a/src/main/java/com/r7b7/entity/CompletionRequest.java +++ b/src/main/java/com/r7b7/entity/CompletionRequest.java @@ -1,7 +1,6 @@ package com.r7b7.entity; -import java.util.List; import java.util.Map; -public record CompletionRequest(List messages, Map params, String model, String apiKey) { +public record CompletionRequest(Map requestBody, String apiKey) { } diff --git a/src/main/java/com/r7b7/entity/CompletionResponse.java b/src/main/java/com/r7b7/entity/CompletionResponse.java index 6825360..65c5ded 100644 --- a/src/main/java/com/r7b7/entity/CompletionResponse.java +++ b/src/main/java/com/r7b7/entity/CompletionResponse.java @@ -1,7 +1,8 @@ package com.r7b7.entity; import java.util.List; +import java.util.Map; -public record CompletionResponse (List messages, Object completeResponse, ErrorResponse error){ +public record CompletionResponse (List messages, Map metaData, ErrorResponse error){ } diff --git a/src/main/java/com/r7b7/entity/Param.java b/src/main/java/com/r7b7/entity/Param.java deleted file mode 100644 index 4691f60..0000000 --- a/src/main/java/com/r7b7/entity/Param.java +++ /dev/null @@ -1,5 +0,0 @@ -package com.r7b7.entity; - -public enum Param { - max_token, n, temperature, seed -} diff --git a/src/main/java/com/r7b7/entity/Role.java b/src/main/java/com/r7b7/entity/Role.java index 197a46c..7e1f5f3 100644 --- a/src/main/java/com/r7b7/entity/Role.java +++ b/src/main/java/com/r7b7/entity/Role.java @@ -1,5 +1,5 @@ package com.r7b7.entity; public enum Role { - user, assistant + user, assistant, system } diff --git a/src/main/java/com/r7b7/model/BaseLLMRequest.java b/src/main/java/com/r7b7/model/BaseLLMRequest.java index a8d85d8..d0612b2 100644 --- a/src/main/java/com/r7b7/model/BaseLLMRequest.java +++ b/src/main/java/com/r7b7/model/BaseLLMRequest.java @@ -4,13 +4,12 @@ import java.util.Map; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; -public class BaseLLMRequest implements LLMRequest { +public class BaseLLMRequest implements ILLMRequest { private final List messages; - private final Map parameters; + private final Map parameters; - public BaseLLMRequest(List messages, Map parameters) { + public BaseLLMRequest(List messages, Map parameters) { this.messages = messages; this.parameters = parameters; } @@ -21,7 +20,7 @@ public List getPrompt() { } @Override - public Map getParameters() { + public Map getParameters() { return parameters; } } diff --git a/src/main/java/com/r7b7/model/BaseLLMResponse.java b/src/main/java/com/r7b7/model/BaseLLMResponse.java deleted file mode 100644 index 033dc62..0000000 --- a/src/main/java/com/r7b7/model/BaseLLMResponse.java +++ /dev/null @@ -1,25 +0,0 @@ -package com.r7b7.model; - -import java.util.Map; - -import com.r7b7.entity.CompletionResponse; - -public class BaseLLMResponse implements LLMResponse { - private final CompletionResponse content; - private final Map metadata; - - public BaseLLMResponse(CompletionResponse content, Map metadata) { - this.content = content; - this.metadata = metadata; - } - - @Override - public CompletionResponse getContent() { - return content; - } - - @Override - public Map getMetadata() { - return metadata; - } -} diff --git a/src/main/java/com/r7b7/model/LLMRequest.java b/src/main/java/com/r7b7/model/ILLMRequest.java similarity index 57% rename from src/main/java/com/r7b7/model/LLMRequest.java rename to src/main/java/com/r7b7/model/ILLMRequest.java index 1f0b1eb..ccf1bd7 100644 --- a/src/main/java/com/r7b7/model/LLMRequest.java +++ b/src/main/java/com/r7b7/model/ILLMRequest.java @@ -4,9 +4,8 @@ import java.util.Map; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; -public interface LLMRequest { +public interface ILLMRequest { List getPrompt(); - Map getParameters(); + Map getParameters(); } diff --git a/src/main/java/com/r7b7/model/LLMResponse.java b/src/main/java/com/r7b7/model/LLMResponse.java deleted file mode 100644 index cea4217..0000000 --- a/src/main/java/com/r7b7/model/LLMResponse.java +++ /dev/null @@ -1,10 +0,0 @@ -package com.r7b7.model; - -import java.util.Map; - -import com.r7b7.entity.CompletionResponse; - -public interface LLMResponse { - CompletionResponse getContent(); - Map getMetadata(); -} diff --git a/src/main/java/com/r7b7/service/AnthropicService.java b/src/main/java/com/r7b7/service/AnthropicService.java index 0daf3ac..39ad7b0 100644 --- a/src/main/java/com/r7b7/service/AnthropicService.java +++ b/src/main/java/com/r7b7/service/AnthropicService.java @@ -1,19 +1,21 @@ package com.r7b7.service; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; -import com.r7b7.client.AnthropicClient; +import com.r7b7.client.IAnthropicClient; import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -import com.r7b7.entity.Param; -import com.r7b7.model.BaseLLMResponse; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.entity.Message; +import com.r7b7.entity.Role; +import com.r7b7.model.ILLMRequest; +import com.r7b7.util.StringUtility; -public class AnthropicService implements LLMService { +public class AnthropicService implements ILLMService { private final String apiKey; private final String model; @@ -23,38 +25,65 @@ public AnthropicService(String apiKey, String model) { } @Override - public LLMResponse generateResponse(LLMRequest request) { - AnthropicClient client = LLMClientFactory.getAnthropicClient(); - Map platformAllignedParams = null; - - platformAllignedParams = getPlatformAllignedParams(request); - CompletionResponse response = client.generateCompletion( - new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey)); - - Map metadata = Map.of( - "model", model, - "provider", "anthropic"); - return new BaseLLMResponse(response, metadata); + public CompletionResponse generateResponse(ILLMRequest request) { + IAnthropicClient client = LLMClientFactory.getAnthropicClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + String systemMessage = getSystemMessage(request); + if (!StringUtility.isNullOrEmpty(systemMessage)) { + requestMap.put("system", systemMessage); + } + requestMap.put("messages", request.getPrompt()); + // set mandatory param if not set explicitly + requestMap.put("max_tokens", 1024); + if (null != request.getParameters()) { + for (Map.Entry entry : request.getParameters().entrySet()) { + requestMap.put(entry.getKey(), entry.getValue()); + } + } + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey)); + return response; } @Override - public CompletableFuture generateResponseAsync(LLMRequest request) { + public CompletableFuture generateResponseAsync(ILLMRequest request) { return CompletableFuture.supplyAsync(() -> generateResponse(request)); } - private Map getPlatformAllignedParams(LLMRequest request) { - Map platformAllignedParams = null; - if (null != request.getParameters()) { - Map keyMapping = Map.of( - Param.max_token, "max_tokens", - Param.temperature, "temperature"); - - platformAllignedParams = request.getParameters().entrySet().stream() - .filter(entry -> keyMapping.containsKey(entry.getKey())) - .collect(Collectors.toMap( - entry -> keyMapping.get(entry.getKey()), - Map.Entry::getValue)); - } - return platformAllignedParams; + @Override + public CompletionResponse generateResponse(String inputQuery) { + IAnthropicClient client = LLMClientFactory.getAnthropicClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.user, inputQuery)); + requestMap.put("messages", prompt); + // set mandatory param + requestMap.put("max_tokens", 1024); + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey)); + return response; + } + + @Override + public CompletableFuture generateResponseAsync(String inputQuery) { + return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery)); + } + + private String getSystemMessage(ILLMRequest request) { + String systemMessage = request.getPrompt().stream() + .filter(msg -> msg.role() == Role.system) + .findFirst() + .map(msg -> { + request.getPrompt().remove(msg); + return msg.content(); + }) + .orElse(null); + return systemMessage; } } diff --git a/src/main/java/com/r7b7/service/GroqService.java b/src/main/java/com/r7b7/service/GroqService.java index 3eeac65..c376e88 100644 --- a/src/main/java/com/r7b7/service/GroqService.java +++ b/src/main/java/com/r7b7/service/GroqService.java @@ -1,19 +1,20 @@ package com.r7b7.service; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; -import com.r7b7.client.GroqClient; +import com.r7b7.client.IGroqClient; import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -import com.r7b7.entity.Param; -import com.r7b7.model.BaseLLMResponse; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.entity.Message; +import com.r7b7.entity.Role; +import com.r7b7.model.ILLMRequest; -public class GroqService implements LLMService { +public class GroqService implements ILLMService { private final String apiKey; private final String model; @@ -23,39 +24,45 @@ public GroqService(String apiKey, String model) { } @Override - public LLMResponse generateResponse(LLMRequest request) { - CompletionResponse response = null; - GroqClient client = LLMClientFactory.getGroqClient(); - Map platformAllignedParams = getPlatformAllignedParams(request); - - response = client - .generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey)); - Map metadata = Map.of( - "model", model, - "provider", "grok"); - return new BaseLLMResponse(response, metadata); + public CompletionResponse generateResponse(ILLMRequest request) { + IGroqClient client = LLMClientFactory.getGroqClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + requestMap.put("messages", request.getPrompt()); + if (null != request.getParameters()) { + for (Map.Entry entry : request.getParameters().entrySet()) { + requestMap.put(entry.getKey(), entry.getValue()); + } + } + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey)); + return response; } @Override - public CompletableFuture generateResponseAsync(LLMRequest request) { + public CompletableFuture generateResponseAsync(ILLMRequest request) { return CompletableFuture.supplyAsync(() -> generateResponse(request)); } - private Map getPlatformAllignedParams(LLMRequest request) { - Map platformAllignedParams = null; - if (null != request.getParameters()) { - Map keyMapping = Map.of( - Param.max_token, "max_tokens", - Param.n, "n", - Param.temperature, "temperature", - Param.seed, "seed"); - - platformAllignedParams = request.getParameters().entrySet().stream() - .filter(entry -> keyMapping.containsKey(entry.getKey())) - .collect(Collectors.toMap( - entry -> keyMapping.get(entry.getKey()), - Map.Entry::getValue)); - } - return platformAllignedParams; + @Override + public CompletionResponse generateResponse(String inputQuery) { + IGroqClient client = LLMClientFactory.getGroqClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.user, inputQuery)); + requestMap.put("messages", prompt); + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey)); + return response; + } + + @Override + public CompletableFuture generateResponseAsync(String inputQuery) { + return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery)); } } diff --git a/src/main/java/com/r7b7/service/ILLMService.java b/src/main/java/com/r7b7/service/ILLMService.java new file mode 100644 index 0000000..501dc50 --- /dev/null +++ b/src/main/java/com/r7b7/service/ILLMService.java @@ -0,0 +1,16 @@ +package com.r7b7.service; + +import java.util.concurrent.CompletableFuture; + +import com.r7b7.entity.CompletionResponse; +import com.r7b7.model.ILLMRequest; + +public interface ILLMService { + CompletionResponse generateResponse(ILLMRequest request); + + CompletionResponse generateResponse(String inputQuery); + + CompletableFuture generateResponseAsync(ILLMRequest request); + + CompletableFuture generateResponseAsync(String inputQuery); +} diff --git a/src/main/java/com/r7b7/service/LLMService.java b/src/main/java/com/r7b7/service/LLMService.java deleted file mode 100644 index 0922606..0000000 --- a/src/main/java/com/r7b7/service/LLMService.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.r7b7.service; - -import java.util.concurrent.CompletableFuture; - -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; - -public interface LLMService { - LLMResponse generateResponse(LLMRequest request); - CompletableFuture generateResponseAsync(LLMRequest request); -} diff --git a/src/main/java/com/r7b7/service/LLMServiceFactory.java b/src/main/java/com/r7b7/service/LLMServiceFactory.java index bca90c0..641d4ed 100644 --- a/src/main/java/com/r7b7/service/LLMServiceFactory.java +++ b/src/main/java/com/r7b7/service/LLMServiceFactory.java @@ -3,7 +3,7 @@ import com.r7b7.entity.Provider; public class LLMServiceFactory { - public static LLMService createService(Provider provider, String apiKey, String model) { + public static ILLMService createService(Provider provider, String apiKey, String model) { return switch (provider) { case Provider.OPENAI -> new OpenAIService(apiKey, model); case Provider.ANTHROPIC -> new AnthropicService(apiKey, model); @@ -13,7 +13,7 @@ public static LLMService createService(Provider provider, String apiKey, String }; } - public static LLMService createService(Provider provider, String model) { + public static ILLMService createService(Provider provider, String model) { return createService(provider, null, model); } } diff --git a/src/main/java/com/r7b7/service/OllamaService.java b/src/main/java/com/r7b7/service/OllamaService.java index cc3fea3..03d499b 100644 --- a/src/main/java/com/r7b7/service/OllamaService.java +++ b/src/main/java/com/r7b7/service/OllamaService.java @@ -1,19 +1,20 @@ package com.r7b7.service; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; -import com.r7b7.client.OllamaClient; +import com.r7b7.client.IOllamaClient; import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -import com.r7b7.entity.Param; -import com.r7b7.model.BaseLLMResponse; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.entity.Message; +import com.r7b7.entity.Role; +import com.r7b7.model.ILLMRequest; -public class OllamaService implements LLMService { +public class OllamaService implements ILLMService { private final String model; public OllamaService(String model) { @@ -21,37 +22,48 @@ public OllamaService(String model) { } @Override - public LLMResponse generateResponse(LLMRequest request) { - OllamaClient client = LLMClientFactory.getOllamaClient(); - Map platformAllignedParams = getPlatformAllignedParams(request); + public CompletionResponse generateResponse(ILLMRequest request) { + IOllamaClient client = LLMClientFactory.getOllamaClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + requestMap.put("messages", request.getPrompt()); - CompletionResponse response = client - .generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, null)); - Map metadata = Map.of( - "model", model, - "provider", "Ollama"); + Map optionsMap = new HashMap<>(); + if (null != request.getParameters()) { + for (Map.Entry entry : request.getParameters().entrySet()) { + optionsMap.put(entry.getKey(), entry.getValue()); + } + } + requestMap.put("options", optionsMap); + // override disabled features if set dynamically + requestMap.put("stream", false); - return new BaseLLMResponse(response, metadata); + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, null)); + return response; } @Override - public CompletableFuture generateResponseAsync(LLMRequest request) { + public CompletableFuture generateResponseAsync(ILLMRequest request) { return CompletableFuture.supplyAsync(() -> generateResponse(request)); } - private Map getPlatformAllignedParams(LLMRequest request) { - Map platformAllignedParams = null; - if (null != request.getParameters()) { - Map keyMapping = Map.of( - Param.seed, "seed", - Param.temperature, "temperature"); - - platformAllignedParams = request.getParameters().entrySet().stream() - .filter(entry -> keyMapping.containsKey(entry.getKey())) - .collect(Collectors.toMap( - entry -> keyMapping.get(entry.getKey()), - Map.Entry::getValue)); - } - return platformAllignedParams; + @Override + public CompletionResponse generateResponse(String inputQuery) { + IOllamaClient client = LLMClientFactory.getOllamaClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.user, inputQuery)); + requestMap.put("messages", prompt); + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, null)); + return response; + } + + @Override + public CompletableFuture generateResponseAsync(String inputQuery) { + return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery)); } } diff --git a/src/main/java/com/r7b7/service/OpenAIService.java b/src/main/java/com/r7b7/service/OpenAIService.java index 2326335..2a74aa3 100644 --- a/src/main/java/com/r7b7/service/OpenAIService.java +++ b/src/main/java/com/r7b7/service/OpenAIService.java @@ -1,19 +1,20 @@ package com.r7b7.service; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; -import com.r7b7.client.OpenAIClient; +import com.r7b7.client.IOpenAIClient; import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; -import com.r7b7.entity.Param; -import com.r7b7.model.BaseLLMResponse; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.entity.Message; +import com.r7b7.entity.Role; +import com.r7b7.model.ILLMRequest; -public class OpenAIService implements LLMService { +public class OpenAIService implements ILLMService { private final String apiKey; private final String model; @@ -23,38 +24,46 @@ public OpenAIService(String apiKey, String model) { } @Override - public LLMResponse generateResponse(LLMRequest request) { - CompletionResponse response = null; - OpenAIClient client = LLMClientFactory.getOpenAIClient(); - Map platformAllignedParams = getPlatformAllignedParams(request); - - response = client - .generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey)); - Map metadata = Map.of( - "model", model, - "provider", "openai"); - return new BaseLLMResponse(response, metadata); + public CompletionResponse generateResponse(ILLMRequest request) { + IOpenAIClient client = LLMClientFactory.getOpenAIClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + requestMap.put("messages", request.getPrompt()); + if (null != request.getParameters()) { + for (Map.Entry entry : request.getParameters().entrySet()) { + requestMap.put(entry.getKey(), entry.getValue()); + } + } + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey)); + return response; } @Override - public CompletableFuture generateResponseAsync(LLMRequest request) { + public CompletableFuture generateResponseAsync(ILLMRequest request) { return CompletableFuture.supplyAsync(() -> generateResponse(request)); } - private Map getPlatformAllignedParams(LLMRequest request) { - Map platformAllignedParams = null; - if (null != request.getParameters()) { - Map keyMapping = Map.of( - Param.max_token, "max_completion_tokens", - Param.n, "n", - Param.temperature, "temperature"); - - platformAllignedParams = request.getParameters().entrySet().stream() - .filter(entry -> keyMapping.containsKey(entry.getKey())) - .collect(Collectors.toMap( - entry -> keyMapping.get(entry.getKey()), - Map.Entry::getValue)); - } - return platformAllignedParams; + @Override + public CompletionResponse generateResponse(String inputQuery) { + IOpenAIClient client = LLMClientFactory.getOpenAIClient(); + Map requestMap = new HashMap<>(); + requestMap.put("model", this.model); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.user, inputQuery)); + requestMap.put("messages", prompt); + + // override disabled features if set dynamically + requestMap.put("stream", false); + + CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey)); + return response; + } + + @Override + public CompletableFuture generateResponseAsync(String inputQuery) { + return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery)); } } diff --git a/src/main/java/com/r7b7/service/PromptBuilder.java b/src/main/java/com/r7b7/service/PromptBuilder.java new file mode 100644 index 0000000..197ccc0 --- /dev/null +++ b/src/main/java/com/r7b7/service/PromptBuilder.java @@ -0,0 +1,35 @@ +package com.r7b7.service; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.r7b7.entity.Message; + +public class PromptBuilder { + private List messages = new ArrayList<>(); + private Map params = new HashMap<>(); + + public PromptBuilder addMessage(Message message) { + messages.add(message); + return this; + } + + public PromptBuilder addParam(String key, Object value) { + params.put(key, value); + return this; + } + + public List getMessages() { + return messages; + } + + public Map getParams() { + return params; + } + + public PromptEngine build(ILLMService service) { + return new PromptEngine(service, params, messages); + } +} diff --git a/src/main/java/com/r7b7/service/PromptEngine.java b/src/main/java/com/r7b7/service/PromptEngine.java index c275e34..7c3d5c3 100644 --- a/src/main/java/com/r7b7/service/PromptEngine.java +++ b/src/main/java/com/r7b7/service/PromptEngine.java @@ -6,42 +6,41 @@ import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; -import com.r7b7.entity.Role; import com.r7b7.model.BaseLLMRequest; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.model.ILLMRequest; public class PromptEngine { - private final LLMService llmService; - private final Message assistantMessage; - private final Map params; + private final ILLMService llmService; + private final Map params; + private final List messages; - public PromptEngine(LLMService llmService) { - this(llmService, new Message(Role.assistant, "You are a helpful assistant"), null); + public PromptEngine(ILLMService llmService) { + this(llmService, null, null); } - public PromptEngine(LLMService llmService, Message assistantMessage, Map params) { + public PromptEngine(ILLMService llmService, Map params, List messages) { this.llmService = llmService; - this.assistantMessage = assistantMessage; this.params = params; + this.messages = messages; } - public CompletionResponse getResponse(String inputQuery) { - Message userMsg = new Message(Role.user, inputQuery); - List messages = List.of(assistantMessage, userMsg); - LLMRequest request = new BaseLLMRequest(messages, params); - LLMResponse response = llmService.generateResponse(request); - return response.getContent(); + public CompletionResponse sendQuery() { + ILLMRequest request = new BaseLLMRequest(this.messages, this.params); + CompletionResponse response = this.llmService.generateResponse(request); + return response; } - public CompletableFuture getResponseAsync(String inputQuery) { - Message userMsg = new Message(Role.user, inputQuery); - List messages = List.of(assistantMessage, userMsg); - LLMRequest request = new BaseLLMRequest(messages, params); + public CompletableFuture sendQueryAsync() { + ILLMRequest request = new BaseLLMRequest(this.messages, this.params); + return this.llmService.generateResponseAsync(request); + } - return llmService.generateResponseAsync(request) - .thenApply(LLMResponse::getContent); + public CompletionResponse sendQuery(String inputQuery) { + CompletionResponse response = this.llmService.generateResponse(inputQuery); + return response; } + public CompletableFuture sendQueryAsync(String inputQuery) { + return this.llmService.generateResponseAsync(inputQuery); + } } diff --git a/src/main/java/com/r7b7/util/StringUtility.java b/src/main/java/com/r7b7/util/StringUtility.java index ddd5703..28ef137 100644 --- a/src/main/java/com/r7b7/util/StringUtility.java +++ b/src/main/java/com/r7b7/util/StringUtility.java @@ -3,7 +3,7 @@ import java.net.URI; public final class StringUtility { - + public static boolean isNullOrEmpty(String str) { return str == null || str.isEmpty(); } @@ -13,7 +13,7 @@ public static boolean isValidHttpOrHttpsUrl(String url) { return false; } try { - URI.create(url).toURL(); + URI.create(url).toURL(); return true; } catch (Exception e) { return false; diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 1f36d62..0ce3686 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,6 +1,5 @@ -hospai.openai.url=https://google.com +hospai.openai.url=https://api.openai.com/v1/chat/completions hospai.anthropic.url=https://api.anthropic.com/v1/messages hospai.anthropic.version=2023-06-01 -hospai.anthropic.maxTokens=1024 -hospai.ollama.url= -hospai.groq.url= \ No newline at end of file +hospai.ollama.url=http://localhost:11434/api/chat +hospai.groq.url=https://api.groq.com/openai/v1/chat/completions \ No newline at end of file diff --git a/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java b/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java index 2fc3291..9ff4f22 100644 --- a/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java +++ b/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java @@ -11,7 +11,10 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,12 +44,16 @@ public void setUp() { @Test public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null, - "test-model", "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); - when(mockResponse.body()).thenReturn("{\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}]}"); + when(mockResponse.body()).thenReturn("{\"id\": \"msg_01KVZJSawLDxgY8LvDm2w9KP\",\"model\": \"claude-3-5-sonnet-20241022\",\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}], \"usage\":{\"input_tokens\": 48,\"output_tokens\": 70}}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -63,11 +70,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null, - "test-model", "api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); + HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); - when(mockResponse.body()).thenReturn("{\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}]}"); + when(mockResponse.body()).thenReturn("{\"id\": \"msg_01KVZJSawLDxgY8LvDm2w9KP\",\"model\": \"claude-3-5-sonnet-20241022\",\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}], \"usage\":{\"input_tokens\": 48,\"output_tokens\": 70}}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -83,8 +95,13 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt public void testGenerateCompletion_InvalidApiKey() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null, - "test-model", "invalid-api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); + HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(401); @@ -104,8 +121,12 @@ public void testGenerateCompletion_InvalidApiKey() throws IOException, Interrupt public void testGenerateCompletion_HandleException() throws IOException, InterruptedException { // Arrange - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null, - "test-model", "invalid-api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); // Mock the HttpClient to throw an exception when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) diff --git a/src/test/java/com/r7b7/client/DefaultGroqClientTest.java b/src/test/java/com/r7b7/client/DefaultGroqClientTest.java index 88b6fdd..e8b9537 100644 --- a/src/test/java/com/r7b7/client/DefaultGroqClientTest.java +++ b/src/test/java/com/r7b7/client/DefaultGroqClientTest.java @@ -11,7 +11,10 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -40,13 +43,17 @@ public void setUp() { @Test public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); + HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); when(mockResponse.body()).thenReturn( - "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}"); + "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -63,13 +70,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); when(mockResponse.body()).thenReturn( - "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}"); + "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -85,9 +95,12 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt public void testGenerateCompletion_InvalidApiKey() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "invalid-api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(401); @@ -107,9 +120,12 @@ public void testGenerateCompletion_InvalidApiKey() throws IOException, Interrupt public void testGenerateCompletion_HandleException() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "invalid-api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenThrow(new IOException("Mocked IOException")); diff --git a/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java b/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java index dcfbc0d..98c0d74 100644 --- a/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java +++ b/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java @@ -11,7 +11,10 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -40,13 +43,16 @@ public void setUp() { @Test public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", null); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); when(mockResponse.body()).thenReturn( - "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"}}"); + "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"},\"total_duration\":5191566416,\"eval_duration\":4799921000}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -62,13 +68,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte @Test public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", null); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); when(mockResponse.body()).thenReturn( - "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"}}"); + "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"},\"total_duration\":5191566416,\"eval_duration\":4799921000}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -83,9 +92,12 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt @Test public void testGenerateCompletion_HandleException() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", null); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenThrow(new IOException("Mocked IOException")); diff --git a/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java b/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java index 43b0a07..ec84eaa 100644 --- a/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java +++ b/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java @@ -11,7 +11,10 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -40,13 +43,16 @@ public void setUp() { @Test public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); when(mockResponse.body()).thenReturn( - "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}"); + "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -63,13 +69,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(200); when(mockResponse.body()).thenReturn( - "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}"); + "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}"); when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(mockResponse); mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient); @@ -85,9 +94,12 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt public void testGenerateCompletion_InvalidApiKey() throws IOException, InterruptedException { try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "invalid-api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); HttpResponse mockResponse = mock(HttpResponse.class); when(mockResponse.statusCode()).thenReturn(401); @@ -108,9 +120,12 @@ public void testGenerateCompletion_HandleException() throws IOException, Interru try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) { // Arrange - CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), - null, - "test-model", "invalid-api-key"); + Map requestMap = new HashMap<>(); + requestMap.put("model", "test-model"); + List prompt = new ArrayList<>(); + prompt.add(new Message(Role.system, "You are a helpful assistant")); + requestMap.put("messages", prompt); + CompletionRequest request = new CompletionRequest(requestMap, "api-key"); // Mock the HttpClient to throw an exception when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) diff --git a/src/test/java/com/r7b7/service/AnthropicServiceTest.java b/src/test/java/com/r7b7/service/AnthropicServiceTest.java index ab6105f..dea8065 100644 --- a/src/test/java/com/r7b7/service/AnthropicServiceTest.java +++ b/src/test/java/com/r7b7/service/AnthropicServiceTest.java @@ -8,6 +8,7 @@ import static org.mockito.Mockito.verify; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -18,20 +19,18 @@ import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; -import com.r7b7.client.AnthropicClient; -import com.r7b7.client.factory.AnthropicClientFactory; +import com.r7b7.client.IAnthropicClient; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; import com.r7b7.entity.Role; import com.r7b7.model.BaseLLMRequest; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.model.ILLMRequest; public class AnthropicServiceTest { @Mock - private AnthropicClient mockClient; + private IAnthropicClient mockClient; @InjectMocks private AnthropicService anthropicService; @@ -47,19 +46,19 @@ public void setUp() { @Test public void testGenerateResponse_Success() { - try (MockedStatic mockedStatic = mockStatic(AnthropicClientFactory.class)) { - LLMRequest request = createMockLLMRequest("Test prompt"); + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { + ILLMRequest request = createMockLLMRequest("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> AnthropicClientFactory.getClient()) + mockedStatic.when(() -> LLMClientFactory.getAnthropicClient()) .thenReturn(mockClient); - LLMResponse response = anthropicService.generateResponse(request); + CompletionResponse response = anthropicService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("anthropic", metadata.get("provider")); @@ -69,20 +68,20 @@ public void testGenerateResponse_Success() { @Test public void testGenerateResponse_WithParams_Success() { - try (MockedStatic mockedStatic = mockStatic(AnthropicClientFactory.class)) { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { - LLMRequest request = createMockLLMRequestWithParams("Test prompt"); + ILLMRequest request = createMockLLMRequestWithParams("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> AnthropicClientFactory.getClient()) + mockedStatic.when(() -> LLMClientFactory.getAnthropicClient()) .thenReturn(mockClient); - LLMResponse response = anthropicService.generateResponse(request); + CompletionResponse response = anthropicService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("anthropic", metadata.get("provider")); @@ -90,10 +89,33 @@ public void testGenerateResponse_WithParams_Success() { } } - private LLMRequest createMockLLMRequest(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - LLMRequest request = new BaseLLMRequest(messages, null); + @Test + public void testGenerateResponseForSingleQuery_Success() { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { + CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); + doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); + mockedStatic.when(() -> LLMClientFactory.getAnthropicClient()) + .thenReturn(mockClient); + + CompletionResponse response = anthropicService.generateResponse("Single query"); + + assertNotNull(response); + assertEquals("test content", response.messages().get(0).content()); + + Map metadata = response.metaData(); + assertEquals(TEST_MODEL, metadata.get("model")); + assertEquals("anthropic", metadata.get("provider")); + + verify(mockClient).generateCompletion(any(CompletionRequest.class)); + } + } + + private ILLMRequest createMockLLMRequest(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.system, "You are a helpful assistant")); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + ILLMRequest request = new BaseLLMRequest(messages, null); return request; } @@ -101,18 +123,24 @@ private CompletionResponse createMockCompletionResponse(String content) { List messages = new ArrayList<>(); com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content"); messages.add(msg); - CompletionResponse response = new CompletionResponse(messages, null, null); + Map metaData = new HashMap<>(); + metaData.put("model", TEST_MODEL); + metaData.put("provider", "anthropic"); + CompletionResponse response = new CompletionResponse(messages, metaData, null); return response; } - private LLMRequest createMockLLMRequestWithParams(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - Map params = Map.of( - Param.temperature, 0.7, - Param.max_token, 1000); + private ILLMRequest createMockLLMRequestWithParams(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.system, "You are a helpful assistant")); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + + Map params = Map.of( + "temperature", 0.7, + "max_token", 1000); - LLMRequest request = new BaseLLMRequest(messages, params); + ILLMRequest request = new BaseLLMRequest(messages, params); return request; } } diff --git a/src/test/java/com/r7b7/service/GroqServiceTest.java b/src/test/java/com/r7b7/service/GroqServiceTest.java index 57ba2ca..a3f5bca 100644 --- a/src/test/java/com/r7b7/service/GroqServiceTest.java +++ b/src/test/java/com/r7b7/service/GroqServiceTest.java @@ -8,6 +8,7 @@ import static org.mockito.Mockito.verify; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -18,20 +19,18 @@ import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; -import com.r7b7.client.GroqClient; -import com.r7b7.client.factory.GroqClientFactory; +import com.r7b7.client.IGroqClient; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; import com.r7b7.entity.Role; import com.r7b7.model.BaseLLMRequest; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.model.ILLMRequest; public class GroqServiceTest { @Mock - private GroqClient mockClient; + private IGroqClient mockClient; @InjectMocks private GroqService groqService; @@ -47,18 +46,18 @@ public void setUp() { @Test public void testGenerateResponse_Success() { - try (MockedStatic mockedStatic = mockStatic(GroqClientFactory.class);) { - LLMRequest request = createMockLLMRequest("Test prompt"); + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class);) { + ILLMRequest request = createMockLLMRequest("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> GroqClientFactory.getClient()).thenReturn(mockClient); + mockedStatic.when(() -> LLMClientFactory.getGroqClient()).thenReturn(mockClient); - LLMResponse response = groqService.generateResponse(request); + CompletionResponse response = groqService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("grok", metadata.get("provider")); @@ -68,18 +67,18 @@ public void testGenerateResponse_Success() { @Test public void testGenerateResponse_WithParams_Success() { - try (MockedStatic mockedStatic = mockStatic(GroqClientFactory.class);) { - LLMRequest request = createMockLLMRequestWithParams("Test prompt"); + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class);) { + ILLMRequest request = createMockLLMRequestWithParams("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> GroqClientFactory.getClient()).thenReturn(mockClient); + mockedStatic.when(() -> LLMClientFactory.getGroqClient()).thenReturn(mockClient); - LLMResponse response = groqService.generateResponse(request); + CompletionResponse response = groqService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("grok", metadata.get("provider")); @@ -87,10 +86,31 @@ public void testGenerateResponse_WithParams_Success() { } } - private LLMRequest createMockLLMRequest(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - LLMRequest request = new BaseLLMRequest(messages, null); + @Test + public void testGenerateResponseWithSingleQuery_Success() { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class);) { + CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); + doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); + mockedStatic.when(() -> LLMClientFactory.getGroqClient()).thenReturn(mockClient); + + CompletionResponse response = groqService.generateResponse("Single Query"); + + assertNotNull(response); + assertEquals("test content", response.messages().get(0).content()); + + Map metadata = response.metaData(); + assertEquals(TEST_MODEL, metadata.get("model")); + assertEquals("grok", metadata.get("provider")); + + verify(mockClient).generateCompletion(any(CompletionRequest.class)); + } + } + + private ILLMRequest createMockLLMRequest(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + ILLMRequest request = new BaseLLMRequest(messages, null); return request; } @@ -98,18 +118,23 @@ private CompletionResponse createMockCompletionResponse(String content) { List messages = new ArrayList<>(); com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content"); messages.add(msg); - CompletionResponse response = new CompletionResponse(messages, null, null); + Map metaData = new HashMap<>(); + metaData.put("model", TEST_MODEL); + metaData.put("provider", "grok"); + CompletionResponse response = new CompletionResponse(messages, metaData, null); return response; } - private LLMRequest createMockLLMRequestWithParams(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - Map params = Map.of( - Param.temperature, 0.7, - Param.max_token, 1000); + private ILLMRequest createMockLLMRequestWithParams(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + + Map params = Map.of( + "temperature", 0.7, + "max_token", 1000); - LLMRequest request = new BaseLLMRequest(messages, params); + ILLMRequest request = new BaseLLMRequest(messages, params); return request; } } diff --git a/src/test/java/com/r7b7/service/OllamaServiceTest.java b/src/test/java/com/r7b7/service/OllamaServiceTest.java index bb07c1f..b090f73 100644 --- a/src/test/java/com/r7b7/service/OllamaServiceTest.java +++ b/src/test/java/com/r7b7/service/OllamaServiceTest.java @@ -8,6 +8,7 @@ import static org.mockito.Mockito.verify; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -18,20 +19,18 @@ import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; -import com.r7b7.client.OllamaClient; -import com.r7b7.client.factory.OllamaClientFactory; +import com.r7b7.client.IOllamaClient; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; import com.r7b7.entity.Role; import com.r7b7.model.BaseLLMRequest; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.model.ILLMRequest; public class OllamaServiceTest { @Mock - private OllamaClient mockClient; + private IOllamaClient mockClient; @InjectMocks private OllamaService ollamaService; @@ -46,18 +45,18 @@ public void setUp() { @Test public void testGenerateResponse_Success() { - try (MockedStatic mockedStatic = mockStatic(OllamaClientFactory.class)) { - LLMRequest request = createMockLLMRequest("Test prompt"); + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { + ILLMRequest request = createMockLLMRequest("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> OllamaClientFactory.getClient()).thenReturn(mockClient); + mockedStatic.when(() -> LLMClientFactory.getOllamaClient()).thenReturn(mockClient); - LLMResponse response = ollamaService.generateResponse(request); + CompletionResponse response = ollamaService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("Ollama", metadata.get("provider")); @@ -67,19 +66,19 @@ public void testGenerateResponse_Success() { @Test public void testGenerateResponse_WithParams_Success() { - try (MockedStatic mockedStatic = mockStatic(OllamaClientFactory.class)) { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { - LLMRequest request = createMockLLMRequestWithParams("Test prompt"); + ILLMRequest request = createMockLLMRequestWithParams("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> OllamaClientFactory.getClient()).thenReturn(mockClient); + mockedStatic.when(() -> LLMClientFactory.getOllamaClient()).thenReturn(mockClient); - LLMResponse response = ollamaService.generateResponse(request); + CompletionResponse response = ollamaService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("Ollama", metadata.get("provider")); @@ -87,10 +86,32 @@ public void testGenerateResponse_WithParams_Success() { } } - private LLMRequest createMockLLMRequest(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - LLMRequest request = new BaseLLMRequest(messages, null); + @Test + public void testGenerateResponseWithSingleQuery_Success() { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { + CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); + doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); + mockedStatic.when(() -> LLMClientFactory.getOllamaClient()).thenReturn(mockClient); + + CompletionResponse response = ollamaService.generateResponse("Single Query"); + + assertNotNull(response); + assertEquals("test content", response.messages().get(0).content()); + + Map metadata = response.metaData(); + assertEquals(TEST_MODEL, metadata.get("model")); + assertEquals("Ollama", metadata.get("provider")); + + verify(mockClient).generateCompletion(any(CompletionRequest.class)); + } + } + + private ILLMRequest createMockLLMRequest(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + + ILLMRequest request = new BaseLLMRequest(messages, null); return request; } @@ -98,18 +119,23 @@ private CompletionResponse createMockCompletionResponse(String content) { List messages = new ArrayList<>(); com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content"); messages.add(msg); - CompletionResponse response = new CompletionResponse(messages, null, null); + Map metaData = new HashMap<>(); + metaData.put("model", TEST_MODEL); + metaData.put("provider", "Ollama"); + CompletionResponse response = new CompletionResponse(messages, metaData, null); return response; } - private LLMRequest createMockLLMRequestWithParams(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - Map params = Map.of( - Param.temperature, 0.7, - Param.max_token, 1000); + private ILLMRequest createMockLLMRequestWithParams(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + + Map params = Map.of( + "temperature", 0.7, + "max_token", 1000); - LLMRequest request = new BaseLLMRequest(messages, params); + ILLMRequest request = new BaseLLMRequest(messages, params); return request; } } diff --git a/src/test/java/com/r7b7/service/OpenAIServiceTest.java b/src/test/java/com/r7b7/service/OpenAIServiceTest.java index e5691c3..752a2dc 100644 --- a/src/test/java/com/r7b7/service/OpenAIServiceTest.java +++ b/src/test/java/com/r7b7/service/OpenAIServiceTest.java @@ -8,6 +8,7 @@ import static org.mockito.Mockito.verify; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -18,20 +19,18 @@ import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; -import com.r7b7.client.OpenAIClient; -import com.r7b7.client.factory.OpenAIClientFactory; +import com.r7b7.client.IOpenAIClient; +import com.r7b7.client.factory.LLMClientFactory; import com.r7b7.entity.CompletionRequest; import com.r7b7.entity.CompletionResponse; import com.r7b7.entity.Message; -import com.r7b7.entity.Param; import com.r7b7.entity.Role; import com.r7b7.model.BaseLLMRequest; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.model.ILLMRequest; public class OpenAIServiceTest { @Mock - private OpenAIClient mockClient; + private IOpenAIClient mockClient; @InjectMocks private OpenAIService openAIService; @@ -47,18 +46,18 @@ public void setUp() { @Test public void testGenerateResponse_Success() { - try (MockedStatic mockedStatic = mockStatic(OpenAIClientFactory.class)) { - LLMRequest request = createMockLLMRequest("Test prompt"); + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { + ILLMRequest request = createMockLLMRequest("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> OpenAIClientFactory.getClient()).thenReturn(mockClient); + mockedStatic.when(() -> LLMClientFactory.getOpenAIClient()).thenReturn(mockClient); - LLMResponse response = openAIService.generateResponse(request); + CompletionResponse response = openAIService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("openai", metadata.get("provider")); @@ -68,19 +67,19 @@ public void testGenerateResponse_Success() { @Test public void testGenerateResponse_WithParams_Success() { - try (MockedStatic mockedStatic = mockStatic(OpenAIClientFactory.class)) { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { - LLMRequest request = createMockLLMRequestWithParams("Test prompt"); + ILLMRequest request = createMockLLMRequestWithParams("Test prompt"); CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); - mockedStatic.when(() -> OpenAIClientFactory.getClient()).thenReturn(mockClient); + mockedStatic.when(() -> LLMClientFactory.getOpenAIClient()).thenReturn(mockClient); - LLMResponse response = openAIService.generateResponse(request); + CompletionResponse response = openAIService.generateResponse(request); assertNotNull(response); - assertEquals("test content", response.getContent().messages().get(0).content()); + assertEquals("test content", response.messages().get(0).content()); - Map metadata = response.getMetadata(); + Map metadata = response.metaData(); assertEquals(TEST_MODEL, metadata.get("model")); assertEquals("openai", metadata.get("provider")); @@ -88,10 +87,31 @@ public void testGenerateResponse_WithParams_Success() { } } - private LLMRequest createMockLLMRequest(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - LLMRequest request = new BaseLLMRequest(messages, null); + @Test + public void testGenerateResponseWithSingleQuery_Success() { + try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) { + CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response"); + doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any()); + mockedStatic.when(() -> LLMClientFactory.getOpenAIClient()).thenReturn(mockClient); + + CompletionResponse response = openAIService.generateResponse("single query"); + + assertNotNull(response); + assertEquals("test content", response.messages().get(0).content()); + + Map metadata = response.metaData(); + assertEquals(TEST_MODEL, metadata.get("model")); + assertEquals("openai", metadata.get("provider")); + + verify(mockClient).generateCompletion(any(CompletionRequest.class)); + } + } + + private ILLMRequest createMockLLMRequest(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + ILLMRequest request = new BaseLLMRequest(messages, null); return request; } @@ -99,18 +119,24 @@ private CompletionResponse createMockCompletionResponse(String content) { List messages = new ArrayList<>(); com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content"); messages.add(msg); - CompletionResponse response = new CompletionResponse(messages, null, null); + Map metaData = new HashMap<>(); + metaData.put("model", TEST_MODEL); + metaData.put("provider", "openai"); + CompletionResponse response = new CompletionResponse(messages, metaData, null); return response; } - private LLMRequest createMockLLMRequestWithParams(String prompt) { - List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"), - new Message(Role.user, prompt)); - Map params = Map.of( - Param.temperature, 0.7, - Param.max_token, 1000); + private ILLMRequest createMockLLMRequestWithParams(String prompt) { + List messages = new ArrayList<>(); + messages.add(new Message(Role.system, "You are a helpful assistant")); + messages.add(new Message(Role.assistant, "You are a helpful assistant")); + messages.add(new Message(Role.user, prompt)); + + Map params = Map.of( + "temperature", 0.7, + "max_token", 1000); - LLMRequest request = new BaseLLMRequest(messages, params); + ILLMRequest request = new BaseLLMRequest(messages, params); return request; } } diff --git a/src/test/java/com/r7b7/service/PromptEngineTest.java b/src/test/java/com/r7b7/service/PromptEngineTest.java index 5b4cf96..f2b2bd7 100644 --- a/src/test/java/com/r7b7/service/PromptEngineTest.java +++ b/src/test/java/com/r7b7/service/PromptEngineTest.java @@ -2,7 +2,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -12,35 +11,65 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.Mockito; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import com.r7b7.entity.CompletionResponse; -import com.r7b7.model.LLMRequest; -import com.r7b7.model.LLMResponse; +import com.r7b7.entity.Message; +import com.r7b7.entity.Role; +import com.r7b7.model.ILLMRequest; public class PromptEngineTest { - private LLMService mockLlmService; + @Mock + private ILLMService mockLlmService; + + @InjectMocks private PromptEngine promptEngine; @BeforeEach void setUp() { - mockLlmService = Mockito.mock(LLMService.class); - promptEngine = new PromptEngine(mockLlmService); + MockitoAnnotations.openMocks(this); + + // mockLlmService = Mockito.mock(ILLMService.class); } @Test - void testGetResponse() { + void testSendQuery_Text_Input() { String inputQuery = "What is the weather today?"; String expectedContent = "test content"; - LLMResponse mockResponse = mock(LLMResponse.class); - when(mockResponse.getContent()).thenReturn(createMockCompletionResponse("Test prompt")); - when(mockLlmService.generateResponse(any(LLMRequest.class))).thenReturn(mockResponse); + promptEngine = new PromptEngine(mockLlmService); + + when(mockLlmService.generateResponse(inputQuery)) + .thenReturn(createMockCompletionResponse("Test prompt")); + + CompletionResponse response = promptEngine.sendQuery(inputQuery); + + assertEquals(expectedContent, response.messages().get(0).content()); + verify(mockLlmService, times(1)).generateResponse(inputQuery); + } + + @Test + void testSendQuery_Builder_Input() { + String expectedContent = "test content"; + + when(mockLlmService.generateResponse(any(ILLMRequest.class))) + .thenReturn(createMockCompletionResponse("Test prompt")); + + PromptBuilder builder = new PromptBuilder() + .addMessage(new Message(Role.system, "Give output in consistent format")) + .addMessage(new Message(Role.user, "what's the stock symbol of ARCHER Aviation?")) + .addMessage(new Message(Role.assistant, "{\"company\":\"Archer\", \"symbol\":\"ACHR\"}")) + .addMessage(new Message(Role.user, "what's the stock symbol of Palantir technology?")) + .addParam("temperature", 0.7) + .addParam("max_tokens", 150); + promptEngine = builder.build(mockLlmService); - CompletionResponse response = promptEngine.getResponse(inputQuery); + CompletionResponse response = promptEngine.sendQuery(); assertEquals(expectedContent, response.messages().get(0).content()); - verify(mockLlmService, times(1)).generateResponse(any(LLMRequest.class)); + verify(mockLlmService, times(1)).generateResponse(any(ILLMRequest.class)); } private CompletionResponse createMockCompletionResponse(String content) { From 54193d7e2bf9ffccccaa6c0475538284df0a765b Mon Sep 17 00:00:00 2001 From: R7B7 Date: Thu, 5 Dec 2024 13:15:53 -0600 Subject: [PATCH 3/4] Added maven build workflow --- .github/workflows/maven.yml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 .github/workflows/maven.yml diff --git a/.github/workflows/maven.yml b/.github/workflows/maven.yml new file mode 100644 index 0000000..8507375 --- /dev/null +++ b/.github/workflows/maven.yml @@ -0,0 +1,35 @@ +# This workflow will build a Java project with Maven, and cache/restore any dependencies to improve the workflow execution time +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-java-with-maven + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Java CI with Maven + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + java-version: '21' + distribution: 'temurin' + cache: maven + - name: Build with Maven + run: mvn -B package --file pom.xml + + # Optional: Uploads the full dependency graph to GitHub to improve the quality of Dependabot alerts this repository can receive + - name: Update dependency graph + uses: advanced-security/maven-dependency-submission-action@571e99aab1055c2e71a1e2309b9691de18d6b7d6 From 7ecd351a1ae8ae149c94e18b4e9cde65c3586601 Mon Sep 17 00:00:00 2001 From: R7B7 Date: Thu, 5 Dec 2024 13:21:31 -0600 Subject: [PATCH 4/4] Removed dependency graph step from workflow --- .github/workflows/maven.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/maven.yml b/.github/workflows/maven.yml index 8507375..769ab13 100644 --- a/.github/workflows/maven.yml +++ b/.github/workflows/maven.yml @@ -30,6 +30,3 @@ jobs: - name: Build with Maven run: mvn -B package --file pom.xml - # Optional: Uploads the full dependency graph to GitHub to improve the quality of Dependabot alerts this repository can receive - - name: Update dependency graph - uses: advanced-security/maven-dependency-submission-action@571e99aab1055c2e71a1e2309b9691de18d6b7d6