diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..5bedb1c
Binary files /dev/null and b/.DS_Store differ
diff --git a/pom.xml b/pom.xml
index 09ec83f..ad20907 100644
--- a/pom.xml
+++ b/pom.xml
@@ -25,11 +25,11 @@
2.18.1
- org.junit.jupiter
- junit-jupiter-api
- 5.11.3
- test
-
+ org.junit.jupiter
+ junit-jupiter-api
+ 5.11.3
+ test
+
org.mockito
mockito-core
@@ -55,5 +55,11 @@
3.0.0
+
+
+ src/main/resources
+ true
+
+
\ No newline at end of file
diff --git a/src/main/java/com/r7b7/client/DefaultAnthropicClient.java b/src/main/java/com/r7b7/client/DefaultAnthropicClient.java
index 52735db..3b4555f 100644
--- a/src/main/java/com/r7b7/client/DefaultAnthropicClient.java
+++ b/src/main/java/com/r7b7/client/DefaultAnthropicClient.java
@@ -6,51 +6,42 @@
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
+import java.util.Properties;
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.node.ArrayNode;
-import com.fasterxml.jackson.databind.node.ObjectNode;
import com.r7b7.client.model.AnthropicResponse;
import com.r7b7.client.model.Message;
+import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;
-public class DefaultAnthropicClient implements AnthropicClient {
- private String ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages";
- private String ANTHROPIC_VERSION = "2023-06-01";
- private Integer MAX_TOKENS = 1024;
+public class DefaultAnthropicClient implements IAnthropicClient {
+ private String ANTHROPIC_API_URL;
+ private String ANTHROPIC_VERSION;
- public DefaultAnthropicClient() {}
+ public DefaultAnthropicClient() {
+ try {
+ Properties properties = PropertyConfig.loadConfig();
+ ANTHROPIC_API_URL = properties.getProperty("hospai.anthropic.url");
+ ANTHROPIC_VERSION = properties.getProperty("hospai.anthropic.version");
+ } catch (Exception ex) {
+ throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
+ }
+ }
@Override
- public CompletionResponse generateCompletion(CompletionRequest request){
+ public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();
- ArrayNode arrayNode = objectMapper.valueToTree(request.messages());
- ObjectNode requestBody = objectMapper.createObjectNode();
- requestBody.put("model", request.model());
- requestBody.set("messages", arrayNode);
-
- if (null != request.params()) {
- for (Map.Entry entry : request.params().entrySet()) {
- Object value = entry.getValue();
- if (value instanceof String) {
- requestBody.put(entry.getKey(), (String) value);
- } else if (value instanceof Integer) {
- requestBody.put(entry.getKey(), (Integer) value);
- }
- }
- } else {
- requestBody.put("max_tokens", MAX_TOKENS);
- }
+ String jsonRequest = objectMapper.writeValueAsString(request.requestBody());
HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.ANTHROPIC_API_URL))
.header("Content-Type", "application/json")
.header("x-api-key", request.apiKey())
.header("anthropic-version", ANTHROPIC_VERSION)
- .POST(HttpRequest.BodyPublishers.ofString(requestBody.toString()))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonRequest))
.build();
HttpResponse response = HttpClient.newHttpClient().send(httpRequest,
@@ -58,8 +49,8 @@ public CompletionResponse generateCompletion(CompletionRequest request){
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
- return new CompletionResponse(null, response, new ErrorResponse(
- "Request sent to LLM failed with status code " + response, null));
+ return new CompletionResponse(null, null, new ErrorResponse(
+ "Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
@@ -71,13 +62,21 @@ private CompletionResponse extractResponseText(String responseBody) {
List msgs = null;
AnthropicResponse response = null;
ErrorResponse error = null;
+ Map metadata = null;
+
try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, AnthropicResponse.class);
msgs = response.content().stream().map(content -> new Message(content.type(), content.text())).toList();
+ metadata = Map.of(
+ "id", response.id(),
+ "model", response.model(),
+ "provider", "Anthropic",
+ "input_tokens", response.usage().inputTokens(),
+ "output_tokens", response.usage().outputTokens());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
- return new CompletionResponse(msgs, response, error);
+ return new CompletionResponse(msgs, metadata, error);
}
}
diff --git a/src/main/java/com/r7b7/client/DefaultGroqClient.java b/src/main/java/com/r7b7/client/DefaultGroqClient.java
index feaddb1..963ff1e 100644
--- a/src/main/java/com/r7b7/client/DefaultGroqClient.java
+++ b/src/main/java/com/r7b7/client/DefaultGroqClient.java
@@ -6,54 +6,47 @@
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
+import java.util.Properties;
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.node.ArrayNode;
-import com.fasterxml.jackson.databind.node.ObjectNode;
import com.r7b7.client.model.Message;
import com.r7b7.client.model.OpenAIResponse;
+import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;
-public class DefaultGroqClient implements GroqClient {
- private String GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions";
+public class DefaultGroqClient implements IGroqClient {
+ private String GROQ_API_URL;
- public DefaultGroqClient(){}
+ public DefaultGroqClient() {
+ try {
+ Properties properties = PropertyConfig.loadConfig();
+ GROQ_API_URL = properties.getProperty("hospai.groq.url");
+ } catch (Exception ex) {
+ throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
+ }
+ }
@Override
public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();
- ArrayNode arrayNode = objectMapper.valueToTree(request.messages());
- ObjectNode requestBody = objectMapper.createObjectNode();
- requestBody.put("model", request.model());
- requestBody.set("messages", arrayNode);
+ String jsonRequest = objectMapper.writeValueAsString(request.requestBody());
- if (null != request.params()) {
- for (Map.Entry entry : request.params().entrySet()) {
- Object value = entry.getValue();
- if (value instanceof String) {
- requestBody.put(entry.getKey(), (String) value);
- } else if (value instanceof Integer) {
- requestBody.put(entry.getKey(), (Integer) value);
- }
- }
- }
-
HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.GROQ_API_URL))
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + request.apiKey())
- .POST(HttpRequest.BodyPublishers.ofString(requestBody.toString()))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonRequest))
.build();
HttpResponse response = HttpClient.newHttpClient().send(httpRequest,
HttpResponse.BodyHandlers.ofString());
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
- return new CompletionResponse(null, response, new ErrorResponse(
- "Request sent to LLM failed with status code " + response.statusCode(), null));
+ return new CompletionResponse(null, null, new ErrorResponse(
+ "Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
@@ -64,16 +57,23 @@ private CompletionResponse extractResponseText(String responseBody) {
List msgs = null;
OpenAIResponse response = null;
ErrorResponse error = null;
+ Map metadata = null;
try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OpenAIResponse.class);
msgs = response.choices().stream()
.map(choice -> new Message(choice.message().role(), choice.message().content())).toList();
+ metadata = Map.of(
+ "id", response.id(),
+ "model", response.model(),
+ "provider", "Groq",
+ "prompt_tokens", response.usage().promptTokens(),
+ "completion_tokens", response.usage().completionTokens(),
+ "total_tokens", response.usage().totalTokens());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
- return new CompletionResponse(msgs, response, error);
+ return new CompletionResponse(msgs, metadata, error);
}
-
}
diff --git a/src/main/java/com/r7b7/client/DefaultOllamaClient.java b/src/main/java/com/r7b7/client/DefaultOllamaClient.java
index ed918f5..cacad98 100644
--- a/src/main/java/com/r7b7/client/DefaultOllamaClient.java
+++ b/src/main/java/com/r7b7/client/DefaultOllamaClient.java
@@ -4,43 +4,35 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Properties;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.r7b7.client.model.Message;
import com.r7b7.client.model.OllamaResponse;
+import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;
-public class DefaultOllamaClient implements OllamaClient {
- private String OLLAMA_API_URL = "http://localhost:11434/api/chat";
+public class DefaultOllamaClient implements IOllamaClient {
+ private String OLLAMA_API_URL;
public DefaultOllamaClient() {
+ try {
+ Properties properties = PropertyConfig.loadConfig();
+ OLLAMA_API_URL = properties.getProperty("hospai.ollama.url");
+ } catch (Exception ex) {
+ throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
+ }
}
@Override
public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();
-
- Map requestMap = new HashMap<>();
- requestMap.put("model", request.model());
- requestMap.put("messages", request.messages());
- requestMap.put("stream", false);
-
- Map optionsMap = new HashMap<>();
- if (null != request.params() && request.params().get("temperature") != null) {
- optionsMap.put("temperature", request.params().get("temperature"));
- }
- if (null != request.params() && request.params().get("seed") != null) {
- optionsMap.put("seed", request.params().get("seed"));
- }
-
- requestMap.put("options", optionsMap);
- String jsonRequest = objectMapper.writeValueAsString(requestMap);
+ String jsonRequest = objectMapper.writeValueAsString(request.requestBody());
HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.OLLAMA_API_URL))
@@ -53,8 +45,8 @@ public CompletionResponse generateCompletion(CompletionRequest request) {
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
- return new CompletionResponse(null, response, new ErrorResponse(
- "Request sent to LLM failed with status code " + response.statusCode(), null));
+ return new CompletionResponse(null, null, new ErrorResponse(
+ "Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
@@ -65,13 +57,20 @@ private CompletionResponse extractResponseText(String responseBody) {
List msgs = null;
OllamaResponse response = null;
ErrorResponse error = null;
+ Map metadata = null;
+
try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OllamaResponse.class);
msgs = List.of(response.message());
+ metadata = Map.of(
+ "model", response.model(),
+ "provider", "Ollama",
+ "total_duration", response.total_duration(),
+ "eval_duration", response.eval_duration());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
- return new CompletionResponse(msgs, response, error);
+ return new CompletionResponse(msgs, metadata, error);
}
}
diff --git a/src/main/java/com/r7b7/client/DefaultOpenAIClient.java b/src/main/java/com/r7b7/client/DefaultOpenAIClient.java
index beeeb0e..0b13b82 100644
--- a/src/main/java/com/r7b7/client/DefaultOpenAIClient.java
+++ b/src/main/java/com/r7b7/client/DefaultOpenAIClient.java
@@ -6,55 +6,47 @@
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
+import java.util.Properties;
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.node.ArrayNode;
-import com.fasterxml.jackson.databind.node.ObjectNode;
import com.r7b7.client.model.Message;
import com.r7b7.client.model.OpenAIResponse;
+import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;
-public class DefaultOpenAIClient implements OpenAIClient {
- private String OPENAI_API_URL = "https://api.openai.com/v1/chat/completions";
+public class DefaultOpenAIClient implements IOpenAIClient {
+ private String OPENAI_API_URL;
public DefaultOpenAIClient() {
+ try {
+ Properties properties = PropertyConfig.loadConfig();
+ OPENAI_API_URL = properties.getProperty("hospai.openai.url");
+ } catch (Exception ex) {
+ throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
+ }
}
@Override
public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();
- ArrayNode arrayNode = objectMapper.valueToTree(request.messages());
- ObjectNode requestBody = objectMapper.createObjectNode();
- requestBody.put("model", request.model());
- requestBody.set("messages", arrayNode);
-
- if (null != request.params()) {
- for (Map.Entry entry : request.params().entrySet()) {
- Object value = entry.getValue();
- if (value instanceof String) {
- requestBody.put(entry.getKey(), (String) value);
- } else if (value instanceof Integer) {
- requestBody.put(entry.getKey(), (Integer) value);
- }
- }
- }
+ String jsonRequest = objectMapper.writeValueAsString(request.requestBody());
HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.OPENAI_API_URL))
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + request.apiKey())
- .POST(HttpRequest.BodyPublishers.ofString(requestBody.toString()))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonRequest))
.build();
HttpResponse response = HttpClient.newHttpClient().send(httpRequest,
HttpResponse.BodyHandlers.ofString());
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
- return new CompletionResponse(null, response, new ErrorResponse(
- "Request sent to LLM failed with status code " + response.statusCode(), null));
+ return new CompletionResponse(null, null, new ErrorResponse(
+ "Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
@@ -65,15 +57,23 @@ private CompletionResponse extractResponseText(String responseBody) {
List msgs = null;
OpenAIResponse response = null;
ErrorResponse error = null;
+ Map metadata = null;
try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OpenAIResponse.class);
msgs = response.choices().stream()
.map(choice -> new Message(choice.message().role(), choice.message().content())).toList();
+ metadata = Map.of(
+ "id", response.id(),
+ "model", response.model(),
+ "provider", "OpenAi",
+ "prompt_tokens", response.usage().promptTokens(),
+ "completion_tokens", response.usage().completionTokens(),
+ "total_tokens", response.usage().totalTokens());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
- return new CompletionResponse(msgs, response, error);
+ return new CompletionResponse(msgs, metadata, error);
}
}
diff --git a/src/main/java/com/r7b7/client/AnthropicClient.java b/src/main/java/com/r7b7/client/IAnthropicClient.java
similarity index 83%
rename from src/main/java/com/r7b7/client/AnthropicClient.java
rename to src/main/java/com/r7b7/client/IAnthropicClient.java
index 3aa4d6b..0ffab7a 100644
--- a/src/main/java/com/r7b7/client/AnthropicClient.java
+++ b/src/main/java/com/r7b7/client/IAnthropicClient.java
@@ -3,6 +3,6 @@
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-public interface AnthropicClient {
+public interface IAnthropicClient {
CompletionResponse generateCompletion(CompletionRequest request);
}
diff --git a/src/main/java/com/r7b7/client/GroqClient.java b/src/main/java/com/r7b7/client/IGroqClient.java
similarity index 85%
rename from src/main/java/com/r7b7/client/GroqClient.java
rename to src/main/java/com/r7b7/client/IGroqClient.java
index b0125ab..88c3f66 100644
--- a/src/main/java/com/r7b7/client/GroqClient.java
+++ b/src/main/java/com/r7b7/client/IGroqClient.java
@@ -3,6 +3,6 @@
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-public interface GroqClient {
+public interface IGroqClient {
CompletionResponse generateCompletion(CompletionRequest request);
}
diff --git a/src/main/java/com/r7b7/client/OllamaClient.java b/src/main/java/com/r7b7/client/IOllamaClient.java
similarity index 84%
rename from src/main/java/com/r7b7/client/OllamaClient.java
rename to src/main/java/com/r7b7/client/IOllamaClient.java
index dac4f1b..3a8924d 100644
--- a/src/main/java/com/r7b7/client/OllamaClient.java
+++ b/src/main/java/com/r7b7/client/IOllamaClient.java
@@ -3,6 +3,6 @@
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-public interface OllamaClient {
+public interface IOllamaClient {
CompletionResponse generateCompletion(CompletionRequest request);
}
diff --git a/src/main/java/com/r7b7/client/OpenAIClient.java b/src/main/java/com/r7b7/client/IOpenAIClient.java
similarity index 84%
rename from src/main/java/com/r7b7/client/OpenAIClient.java
rename to src/main/java/com/r7b7/client/IOpenAIClient.java
index 5dc1858..77a449a 100644
--- a/src/main/java/com/r7b7/client/OpenAIClient.java
+++ b/src/main/java/com/r7b7/client/IOpenAIClient.java
@@ -3,6 +3,6 @@
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-public interface OpenAIClient {
+public interface IOpenAIClient {
CompletionResponse generateCompletion(CompletionRequest request);
}
diff --git a/src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java b/src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java
deleted file mode 100644
index 70ba673..0000000
--- a/src/main/java/com/r7b7/client/factory/AnthropicClientFactory.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package com.r7b7.client.factory;
-
-import com.r7b7.client.AnthropicClient;
-import com.r7b7.client.DefaultAnthropicClient;
-
-public class AnthropicClientFactory {
- private static AnthropicClient currentClient;
-
- public static AnthropicClient createDefaultClient() {
- DefaultAnthropicClient client = new DefaultAnthropicClient();
- return client;
- }
-
- public static void setClient(AnthropicClient client) {
- currentClient = client;
- }
-
- public static AnthropicClient getClient() {
- if (null == currentClient) {
- currentClient = createDefaultClient();
- }
- return currentClient;
- }
-}
diff --git a/src/main/java/com/r7b7/client/factory/GroqClientFactory.java b/src/main/java/com/r7b7/client/factory/GroqClientFactory.java
deleted file mode 100644
index 3d44887..0000000
--- a/src/main/java/com/r7b7/client/factory/GroqClientFactory.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package com.r7b7.client.factory;
-
-import com.r7b7.client.DefaultGroqClient;
-import com.r7b7.client.GroqClient;
-
-public class GroqClientFactory {
- private static GroqClient currentClient;
-
- public static GroqClient createDefaultClient() {
- DefaultGroqClient client = new DefaultGroqClient();
- return client;
- }
-
- public static void setClient(GroqClient client) {
- currentClient = client;
- }
-
- public static GroqClient getClient() {
- if (null == currentClient) {
- currentClient = createDefaultClient();
- }
- return currentClient;
- }
-}
diff --git a/src/main/java/com/r7b7/client/factory/LLMClientFactory.java b/src/main/java/com/r7b7/client/factory/LLMClientFactory.java
new file mode 100644
index 0000000..9207f2a
--- /dev/null
+++ b/src/main/java/com/r7b7/client/factory/LLMClientFactory.java
@@ -0,0 +1,85 @@
+package com.r7b7.client.factory;
+
+import com.r7b7.client.IAnthropicClient;
+import com.r7b7.client.DefaultAnthropicClient;
+import com.r7b7.client.DefaultGroqClient;
+import com.r7b7.client.DefaultOllamaClient;
+import com.r7b7.client.DefaultOpenAIClient;
+import com.r7b7.client.IGroqClient;
+import com.r7b7.client.IOllamaClient;
+import com.r7b7.client.IOpenAIClient;
+
+public class LLMClientFactory {
+ private static IAnthropicClient currentAnthropicClient;
+ private static IGroqClient currentGroqClient;
+ private static IOllamaClient currentOllamaClient;
+ private static IOpenAIClient currentOpenAIClient;
+
+ // Open AI Client
+ public static IOpenAIClient createDefaultOpenAIClient() {
+ DefaultOpenAIClient client = new DefaultOpenAIClient();
+ return client;
+ }
+
+ public static void setOpenAIClient(IOpenAIClient client) {
+ currentOpenAIClient = client;
+ }
+
+ public static IOpenAIClient getOpenAIClient() {
+ if (null == currentOpenAIClient) {
+ currentOpenAIClient = createDefaultOpenAIClient();
+ }
+ return currentOpenAIClient;
+ }
+
+ // Anthropic Client
+ public static IAnthropicClient createDefaultAnthropicClient() {
+ DefaultAnthropicClient client = new DefaultAnthropicClient();
+ return client;
+ }
+
+ public static void setAnthropicClient(IAnthropicClient client) {
+ currentAnthropicClient = client;
+ }
+
+ public static IAnthropicClient getAnthropicClient() {
+ if (null == currentAnthropicClient) {
+ currentAnthropicClient = createDefaultAnthropicClient();
+ }
+ return currentAnthropicClient;
+ }
+
+ // Groq Client
+ public static IGroqClient createDefaultGroqClient() {
+ DefaultGroqClient client = new DefaultGroqClient();
+ return client;
+ }
+
+ public static void setGroqClient(IGroqClient client) {
+ currentGroqClient = client;
+ }
+
+ public static IGroqClient getGroqClient() {
+ if (null == currentGroqClient) {
+ currentGroqClient = createDefaultGroqClient();
+ }
+ return currentGroqClient;
+ }
+
+ // Ollama Client
+ public static IOllamaClient createDefaultOllamaClient() {
+ DefaultOllamaClient client = new DefaultOllamaClient();
+ return client;
+ }
+
+ public static void setOllamaClient(IOllamaClient client) {
+ currentOllamaClient = client;
+ }
+
+ public static IOllamaClient getOllamaClient() {
+ if (null == currentOllamaClient) {
+ currentOllamaClient = createDefaultOllamaClient();
+ }
+ return currentOllamaClient;
+ }
+}
diff --git a/src/main/java/com/r7b7/client/factory/OllamaClientFactory.java b/src/main/java/com/r7b7/client/factory/OllamaClientFactory.java
deleted file mode 100644
index 20165a0..0000000
--- a/src/main/java/com/r7b7/client/factory/OllamaClientFactory.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package com.r7b7.client.factory;
-
-import com.r7b7.client.DefaultOllamaClient;
-import com.r7b7.client.OllamaClient;
-
-public class OllamaClientFactory {
- private static OllamaClient currentClient;
-
- public static OllamaClient createDefaultClient() {
- DefaultOllamaClient client = new DefaultOllamaClient();
- return client;
- }
-
- public static void setClient(OllamaClient client) {
- currentClient = client;
- }
-
- public static OllamaClient getClient() {
- if (null == currentClient) {
- currentClient = createDefaultClient();
- }
- return currentClient;
- }
-}
diff --git a/src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java b/src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java
deleted file mode 100644
index 3fa4cb7..0000000
--- a/src/main/java/com/r7b7/client/factory/OpenAIClientFactory.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package com.r7b7.client.factory;
-
-import com.r7b7.client.DefaultOpenAIClient;
-import com.r7b7.client.OpenAIClient;
-
-public class OpenAIClientFactory {
- private static OpenAIClient currentClient;
-
- public static OpenAIClient createDefaultClient() {
- DefaultOpenAIClient client = new DefaultOpenAIClient();
- return client;
- }
-
- public static void setClient(OpenAIClient client) {
- currentClient = client;
- }
-
- public static OpenAIClient getClient() {
- if (null == currentClient) {
- currentClient = createDefaultClient();
- }
- return currentClient;
- }
-}
diff --git a/src/main/java/com/r7b7/config/PropertyConfig.java b/src/main/java/com/r7b7/config/PropertyConfig.java
new file mode 100644
index 0000000..4ac905f
--- /dev/null
+++ b/src/main/java/com/r7b7/config/PropertyConfig.java
@@ -0,0 +1,20 @@
+package com.r7b7.config;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Properties;
+
+public class PropertyConfig {
+ private static Properties properties;
+
+ public static Properties loadConfig() throws IOException {
+ if (null == properties) {
+ properties = new Properties();
+ try (InputStream input = PropertyConfig.class.getClassLoader()
+ .getResourceAsStream("application.properties")) {
+ properties.load(input);
+ }
+ }
+ return properties;
+ }
+}
diff --git a/src/main/java/com/r7b7/entity/CompletionRequest.java b/src/main/java/com/r7b7/entity/CompletionRequest.java
index a58716b..1b76f62 100644
--- a/src/main/java/com/r7b7/entity/CompletionRequest.java
+++ b/src/main/java/com/r7b7/entity/CompletionRequest.java
@@ -1,7 +1,6 @@
package com.r7b7.entity;
-import java.util.List;
import java.util.Map;
-public record CompletionRequest(List messages, Map params, String model, String apiKey) {
+public record CompletionRequest(Map requestBody, String apiKey) {
}
diff --git a/src/main/java/com/r7b7/entity/CompletionResponse.java b/src/main/java/com/r7b7/entity/CompletionResponse.java
index 6825360..65c5ded 100644
--- a/src/main/java/com/r7b7/entity/CompletionResponse.java
+++ b/src/main/java/com/r7b7/entity/CompletionResponse.java
@@ -1,7 +1,8 @@
package com.r7b7.entity;
import java.util.List;
+import java.util.Map;
-public record CompletionResponse (List messages, Object completeResponse, ErrorResponse error){
+public record CompletionResponse (List messages, Map metaData, ErrorResponse error){
}
diff --git a/src/main/java/com/r7b7/entity/Param.java b/src/main/java/com/r7b7/entity/Param.java
deleted file mode 100644
index 4691f60..0000000
--- a/src/main/java/com/r7b7/entity/Param.java
+++ /dev/null
@@ -1,5 +0,0 @@
-package com.r7b7.entity;
-
-public enum Param {
- max_token, n, temperature, seed
-}
diff --git a/src/main/java/com/r7b7/entity/Role.java b/src/main/java/com/r7b7/entity/Role.java
index 197a46c..7e1f5f3 100644
--- a/src/main/java/com/r7b7/entity/Role.java
+++ b/src/main/java/com/r7b7/entity/Role.java
@@ -1,5 +1,5 @@
package com.r7b7.entity;
public enum Role {
- user, assistant
+ user, assistant, system
}
diff --git a/src/main/java/com/r7b7/model/BaseLLMRequest.java b/src/main/java/com/r7b7/model/BaseLLMRequest.java
index a8d85d8..d0612b2 100644
--- a/src/main/java/com/r7b7/model/BaseLLMRequest.java
+++ b/src/main/java/com/r7b7/model/BaseLLMRequest.java
@@ -4,13 +4,12 @@
import java.util.Map;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
-public class BaseLLMRequest implements LLMRequest {
+public class BaseLLMRequest implements ILLMRequest {
private final List messages;
- private final Map parameters;
+ private final Map parameters;
- public BaseLLMRequest(List messages, Map parameters) {
+ public BaseLLMRequest(List messages, Map parameters) {
this.messages = messages;
this.parameters = parameters;
}
@@ -21,7 +20,7 @@ public List getPrompt() {
}
@Override
- public Map getParameters() {
+ public Map getParameters() {
return parameters;
}
}
diff --git a/src/main/java/com/r7b7/model/BaseLLMResponse.java b/src/main/java/com/r7b7/model/BaseLLMResponse.java
deleted file mode 100644
index 033dc62..0000000
--- a/src/main/java/com/r7b7/model/BaseLLMResponse.java
+++ /dev/null
@@ -1,25 +0,0 @@
-package com.r7b7.model;
-
-import java.util.Map;
-
-import com.r7b7.entity.CompletionResponse;
-
-public class BaseLLMResponse implements LLMResponse {
- private final CompletionResponse content;
- private final Map metadata;
-
- public BaseLLMResponse(CompletionResponse content, Map metadata) {
- this.content = content;
- this.metadata = metadata;
- }
-
- @Override
- public CompletionResponse getContent() {
- return content;
- }
-
- @Override
- public Map getMetadata() {
- return metadata;
- }
-}
diff --git a/src/main/java/com/r7b7/model/LLMRequest.java b/src/main/java/com/r7b7/model/ILLMRequest.java
similarity index 57%
rename from src/main/java/com/r7b7/model/LLMRequest.java
rename to src/main/java/com/r7b7/model/ILLMRequest.java
index 1f0b1eb..ccf1bd7 100644
--- a/src/main/java/com/r7b7/model/LLMRequest.java
+++ b/src/main/java/com/r7b7/model/ILLMRequest.java
@@ -4,9 +4,8 @@
import java.util.Map;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
-public interface LLMRequest {
+public interface ILLMRequest {
List getPrompt();
- Map getParameters();
+ Map getParameters();
}
diff --git a/src/main/java/com/r7b7/model/LLMResponse.java b/src/main/java/com/r7b7/model/LLMResponse.java
deleted file mode 100644
index cea4217..0000000
--- a/src/main/java/com/r7b7/model/LLMResponse.java
+++ /dev/null
@@ -1,10 +0,0 @@
-package com.r7b7.model;
-
-import java.util.Map;
-
-import com.r7b7.entity.CompletionResponse;
-
-public interface LLMResponse {
- CompletionResponse getContent();
- Map getMetadata();
-}
diff --git a/src/main/java/com/r7b7/service/AnthropicService.java b/src/main/java/com/r7b7/service/AnthropicService.java
index 53ff208..39ad7b0 100644
--- a/src/main/java/com/r7b7/service/AnthropicService.java
+++ b/src/main/java/com/r7b7/service/AnthropicService.java
@@ -1,19 +1,21 @@
package com.r7b7.service;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.stream.Collectors;
-import com.r7b7.client.AnthropicClient;
-import com.r7b7.client.factory.AnthropicClientFactory;
+import com.r7b7.client.IAnthropicClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-import com.r7b7.entity.Param;
-import com.r7b7.model.BaseLLMResponse;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.entity.Message;
+import com.r7b7.entity.Role;
+import com.r7b7.model.ILLMRequest;
+import com.r7b7.util.StringUtility;
-public class AnthropicService implements LLMService {
+public class AnthropicService implements ILLMService {
private final String apiKey;
private final String model;
@@ -23,38 +25,65 @@ public AnthropicService(String apiKey, String model) {
}
@Override
- public LLMResponse generateResponse(LLMRequest request) {
- AnthropicClient client = AnthropicClientFactory.getClient();
- Map platformAllignedParams = null;
-
- platformAllignedParams = getPlatformAllignedParams(request);
- CompletionResponse response = client.generateCompletion(
- new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey));
-
- Map metadata = Map.of(
- "model", model,
- "provider", "anthropic");
- return new BaseLLMResponse(response, metadata);
+ public CompletionResponse generateResponse(ILLMRequest request) {
+ IAnthropicClient client = LLMClientFactory.getAnthropicClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ String systemMessage = getSystemMessage(request);
+ if (!StringUtility.isNullOrEmpty(systemMessage)) {
+ requestMap.put("system", systemMessage);
+ }
+ requestMap.put("messages", request.getPrompt());
+ // set mandatory param if not set explicitly
+ requestMap.put("max_tokens", 1024);
+ if (null != request.getParameters()) {
+ for (Map.Entry entry : request.getParameters().entrySet()) {
+ requestMap.put(entry.getKey(), entry.getValue());
+ }
+ }
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey));
+ return response;
}
@Override
- public CompletableFuture generateResponseAsync(LLMRequest request) {
+ public CompletableFuture generateResponseAsync(ILLMRequest request) {
return CompletableFuture.supplyAsync(() -> generateResponse(request));
}
- private Map getPlatformAllignedParams(LLMRequest request) {
- Map platformAllignedParams = null;
- if (null != request.getParameters()) {
- Map keyMapping = Map.of(
- Param.max_token, "max_tokens",
- Param.temperature, "temperature");
-
- platformAllignedParams = request.getParameters().entrySet().stream()
- .filter(entry -> keyMapping.containsKey(entry.getKey()))
- .collect(Collectors.toMap(
- entry -> keyMapping.get(entry.getKey()),
- Map.Entry::getValue));
- }
- return platformAllignedParams;
+ @Override
+ public CompletionResponse generateResponse(String inputQuery) {
+ IAnthropicClient client = LLMClientFactory.getAnthropicClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.user, inputQuery));
+ requestMap.put("messages", prompt);
+ // set mandatory param
+ requestMap.put("max_tokens", 1024);
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey));
+ return response;
+ }
+
+ @Override
+ public CompletableFuture generateResponseAsync(String inputQuery) {
+ return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery));
+ }
+
+ private String getSystemMessage(ILLMRequest request) {
+ String systemMessage = request.getPrompt().stream()
+ .filter(msg -> msg.role() == Role.system)
+ .findFirst()
+ .map(msg -> {
+ request.getPrompt().remove(msg);
+ return msg.content();
+ })
+ .orElse(null);
+ return systemMessage;
}
}
diff --git a/src/main/java/com/r7b7/service/GroqService.java b/src/main/java/com/r7b7/service/GroqService.java
index 3e49771..c376e88 100644
--- a/src/main/java/com/r7b7/service/GroqService.java
+++ b/src/main/java/com/r7b7/service/GroqService.java
@@ -1,19 +1,20 @@
package com.r7b7.service;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.stream.Collectors;
-import com.r7b7.client.GroqClient;
-import com.r7b7.client.factory.GroqClientFactory;
+import com.r7b7.client.IGroqClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-import com.r7b7.entity.Param;
-import com.r7b7.model.BaseLLMResponse;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.entity.Message;
+import com.r7b7.entity.Role;
+import com.r7b7.model.ILLMRequest;
-public class GroqService implements LLMService {
+public class GroqService implements ILLMService {
private final String apiKey;
private final String model;
@@ -23,39 +24,45 @@ public GroqService(String apiKey, String model) {
}
@Override
- public LLMResponse generateResponse(LLMRequest request) {
- CompletionResponse response = null;
- GroqClient client = GroqClientFactory.getClient();
- Map platformAllignedParams = getPlatformAllignedParams(request);
-
- response = client
- .generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey));
- Map metadata = Map.of(
- "model", model,
- "provider", "grok");
- return new BaseLLMResponse(response, metadata);
+ public CompletionResponse generateResponse(ILLMRequest request) {
+ IGroqClient client = LLMClientFactory.getGroqClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ requestMap.put("messages", request.getPrompt());
+ if (null != request.getParameters()) {
+ for (Map.Entry entry : request.getParameters().entrySet()) {
+ requestMap.put(entry.getKey(), entry.getValue());
+ }
+ }
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey));
+ return response;
}
@Override
- public CompletableFuture generateResponseAsync(LLMRequest request) {
+ public CompletableFuture generateResponseAsync(ILLMRequest request) {
return CompletableFuture.supplyAsync(() -> generateResponse(request));
}
- private Map getPlatformAllignedParams(LLMRequest request) {
- Map platformAllignedParams = null;
- if (null != request.getParameters()) {
- Map keyMapping = Map.of(
- Param.max_token, "max_tokens",
- Param.n, "n",
- Param.temperature, "temperature",
- Param.seed, "seed");
-
- platformAllignedParams = request.getParameters().entrySet().stream()
- .filter(entry -> keyMapping.containsKey(entry.getKey()))
- .collect(Collectors.toMap(
- entry -> keyMapping.get(entry.getKey()),
- Map.Entry::getValue));
- }
- return platformAllignedParams;
+ @Override
+ public CompletionResponse generateResponse(String inputQuery) {
+ IGroqClient client = LLMClientFactory.getGroqClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.user, inputQuery));
+ requestMap.put("messages", prompt);
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey));
+ return response;
+ }
+
+ @Override
+ public CompletableFuture generateResponseAsync(String inputQuery) {
+ return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery));
}
}
diff --git a/src/main/java/com/r7b7/service/ILLMService.java b/src/main/java/com/r7b7/service/ILLMService.java
new file mode 100644
index 0000000..501dc50
--- /dev/null
+++ b/src/main/java/com/r7b7/service/ILLMService.java
@@ -0,0 +1,16 @@
+package com.r7b7.service;
+
+import java.util.concurrent.CompletableFuture;
+
+import com.r7b7.entity.CompletionResponse;
+import com.r7b7.model.ILLMRequest;
+
+public interface ILLMService {
+ CompletionResponse generateResponse(ILLMRequest request);
+
+ CompletionResponse generateResponse(String inputQuery);
+
+ CompletableFuture generateResponseAsync(ILLMRequest request);
+
+ CompletableFuture generateResponseAsync(String inputQuery);
+}
diff --git a/src/main/java/com/r7b7/service/LLMService.java b/src/main/java/com/r7b7/service/LLMService.java
deleted file mode 100644
index 0922606..0000000
--- a/src/main/java/com/r7b7/service/LLMService.java
+++ /dev/null
@@ -1,11 +0,0 @@
-package com.r7b7.service;
-
-import java.util.concurrent.CompletableFuture;
-
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
-
-public interface LLMService {
- LLMResponse generateResponse(LLMRequest request);
- CompletableFuture generateResponseAsync(LLMRequest request);
-}
diff --git a/src/main/java/com/r7b7/service/LLMServiceFactory.java b/src/main/java/com/r7b7/service/LLMServiceFactory.java
index bca90c0..641d4ed 100644
--- a/src/main/java/com/r7b7/service/LLMServiceFactory.java
+++ b/src/main/java/com/r7b7/service/LLMServiceFactory.java
@@ -3,7 +3,7 @@
import com.r7b7.entity.Provider;
public class LLMServiceFactory {
- public static LLMService createService(Provider provider, String apiKey, String model) {
+ public static ILLMService createService(Provider provider, String apiKey, String model) {
return switch (provider) {
case Provider.OPENAI -> new OpenAIService(apiKey, model);
case Provider.ANTHROPIC -> new AnthropicService(apiKey, model);
@@ -13,7 +13,7 @@ public static LLMService createService(Provider provider, String apiKey, String
};
}
- public static LLMService createService(Provider provider, String model) {
+ public static ILLMService createService(Provider provider, String model) {
return createService(provider, null, model);
}
}
diff --git a/src/main/java/com/r7b7/service/OllamaService.java b/src/main/java/com/r7b7/service/OllamaService.java
index 0c2e3a5..03d499b 100644
--- a/src/main/java/com/r7b7/service/OllamaService.java
+++ b/src/main/java/com/r7b7/service/OllamaService.java
@@ -1,19 +1,20 @@
package com.r7b7.service;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.stream.Collectors;
-import com.r7b7.client.OllamaClient;
-import com.r7b7.client.factory.OllamaClientFactory;
+import com.r7b7.client.IOllamaClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-import com.r7b7.entity.Param;
-import com.r7b7.model.BaseLLMResponse;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.entity.Message;
+import com.r7b7.entity.Role;
+import com.r7b7.model.ILLMRequest;
-public class OllamaService implements LLMService {
+public class OllamaService implements ILLMService {
private final String model;
public OllamaService(String model) {
@@ -21,37 +22,48 @@ public OllamaService(String model) {
}
@Override
- public LLMResponse generateResponse(LLMRequest request) {
- OllamaClient client = OllamaClientFactory.getClient();
- Map platformAllignedParams = getPlatformAllignedParams(request);
+ public CompletionResponse generateResponse(ILLMRequest request) {
+ IOllamaClient client = LLMClientFactory.getOllamaClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ requestMap.put("messages", request.getPrompt());
- CompletionResponse response = client
- .generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, null));
- Map metadata = Map.of(
- "model", model,
- "provider", "Ollama");
+ Map optionsMap = new HashMap<>();
+ if (null != request.getParameters()) {
+ for (Map.Entry entry : request.getParameters().entrySet()) {
+ optionsMap.put(entry.getKey(), entry.getValue());
+ }
+ }
+ requestMap.put("options", optionsMap);
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
- return new BaseLLMResponse(response, metadata);
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, null));
+ return response;
}
@Override
- public CompletableFuture generateResponseAsync(LLMRequest request) {
+ public CompletableFuture generateResponseAsync(ILLMRequest request) {
return CompletableFuture.supplyAsync(() -> generateResponse(request));
}
- private Map getPlatformAllignedParams(LLMRequest request) {
- Map platformAllignedParams = null;
- if (null != request.getParameters()) {
- Map keyMapping = Map.of(
- Param.seed, "seed",
- Param.temperature, "temperature");
-
- platformAllignedParams = request.getParameters().entrySet().stream()
- .filter(entry -> keyMapping.containsKey(entry.getKey()))
- .collect(Collectors.toMap(
- entry -> keyMapping.get(entry.getKey()),
- Map.Entry::getValue));
- }
- return platformAllignedParams;
+ @Override
+ public CompletionResponse generateResponse(String inputQuery) {
+ IOllamaClient client = LLMClientFactory.getOllamaClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.user, inputQuery));
+ requestMap.put("messages", prompt);
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, null));
+ return response;
+ }
+
+ @Override
+ public CompletableFuture generateResponseAsync(String inputQuery) {
+ return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery));
}
}
diff --git a/src/main/java/com/r7b7/service/OpenAIService.java b/src/main/java/com/r7b7/service/OpenAIService.java
index 4eeb1d8..2a74aa3 100644
--- a/src/main/java/com/r7b7/service/OpenAIService.java
+++ b/src/main/java/com/r7b7/service/OpenAIService.java
@@ -1,19 +1,20 @@
package com.r7b7.service;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.stream.Collectors;
-import com.r7b7.client.OpenAIClient;
-import com.r7b7.client.factory.OpenAIClientFactory;
+import com.r7b7.client.IOpenAIClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
-import com.r7b7.entity.Param;
-import com.r7b7.model.BaseLLMResponse;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.entity.Message;
+import com.r7b7.entity.Role;
+import com.r7b7.model.ILLMRequest;
-public class OpenAIService implements LLMService {
+public class OpenAIService implements ILLMService {
private final String apiKey;
private final String model;
@@ -23,37 +24,46 @@ public OpenAIService(String apiKey, String model) {
}
@Override
- public LLMResponse generateResponse(LLMRequest request) {
- CompletionResponse response = null;
- OpenAIClient client = OpenAIClientFactory.getClient();
- Map platformAllignedParams = getPlatformAllignedParams(request);
-
- response = client.generateCompletion(new CompletionRequest(request.getPrompt(), platformAllignedParams, model, apiKey));
- Map metadata = Map.of(
- "model", model,
- "provider", "openai");
- return new BaseLLMResponse(response, metadata);
+ public CompletionResponse generateResponse(ILLMRequest request) {
+ IOpenAIClient client = LLMClientFactory.getOpenAIClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ requestMap.put("messages", request.getPrompt());
+ if (null != request.getParameters()) {
+ for (Map.Entry entry : request.getParameters().entrySet()) {
+ requestMap.put(entry.getKey(), entry.getValue());
+ }
+ }
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey));
+ return response;
}
@Override
- public CompletableFuture generateResponseAsync(LLMRequest request) {
+ public CompletableFuture generateResponseAsync(ILLMRequest request) {
return CompletableFuture.supplyAsync(() -> generateResponse(request));
}
- private Map getPlatformAllignedParams(LLMRequest request) {
- Map platformAllignedParams = null;
- if (null != request.getParameters()) {
- Map keyMapping = Map.of(
- Param.max_token, "max_completion_tokens",
- Param.n, "n",
- Param.temperature, "temperature");
-
- platformAllignedParams = request.getParameters().entrySet().stream()
- .filter(entry -> keyMapping.containsKey(entry.getKey()))
- .collect(Collectors.toMap(
- entry -> keyMapping.get(entry.getKey()),
- Map.Entry::getValue));
- }
- return platformAllignedParams;
+ @Override
+ public CompletionResponse generateResponse(String inputQuery) {
+ IOpenAIClient client = LLMClientFactory.getOpenAIClient();
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", this.model);
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.user, inputQuery));
+ requestMap.put("messages", prompt);
+
+ // override disabled features if set dynamically
+ requestMap.put("stream", false);
+
+ CompletionResponse response = client.generateCompletion(new CompletionRequest(requestMap, this.apiKey));
+ return response;
+ }
+
+ @Override
+ public CompletableFuture generateResponseAsync(String inputQuery) {
+ return CompletableFuture.supplyAsync(() -> generateResponse(inputQuery));
}
}
diff --git a/src/main/java/com/r7b7/service/PromptBuilder.java b/src/main/java/com/r7b7/service/PromptBuilder.java
new file mode 100644
index 0000000..197ccc0
--- /dev/null
+++ b/src/main/java/com/r7b7/service/PromptBuilder.java
@@ -0,0 +1,35 @@
+package com.r7b7.service;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.r7b7.entity.Message;
+
+public class PromptBuilder {
+ private List messages = new ArrayList<>();
+ private Map params = new HashMap<>();
+
+ public PromptBuilder addMessage(Message message) {
+ messages.add(message);
+ return this;
+ }
+
+ public PromptBuilder addParam(String key, Object value) {
+ params.put(key, value);
+ return this;
+ }
+
+ public List getMessages() {
+ return messages;
+ }
+
+ public Map getParams() {
+ return params;
+ }
+
+ public PromptEngine build(ILLMService service) {
+ return new PromptEngine(service, params, messages);
+ }
+}
diff --git a/src/main/java/com/r7b7/service/PromptEngine.java b/src/main/java/com/r7b7/service/PromptEngine.java
index c275e34..7c3d5c3 100644
--- a/src/main/java/com/r7b7/service/PromptEngine.java
+++ b/src/main/java/com/r7b7/service/PromptEngine.java
@@ -6,42 +6,41 @@
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
-import com.r7b7.entity.Role;
import com.r7b7.model.BaseLLMRequest;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.model.ILLMRequest;
public class PromptEngine {
- private final LLMService llmService;
- private final Message assistantMessage;
- private final Map params;
+ private final ILLMService llmService;
+ private final Map params;
+ private final List messages;
- public PromptEngine(LLMService llmService) {
- this(llmService, new Message(Role.assistant, "You are a helpful assistant"), null);
+ public PromptEngine(ILLMService llmService) {
+ this(llmService, null, null);
}
- public PromptEngine(LLMService llmService, Message assistantMessage, Map params) {
+ public PromptEngine(ILLMService llmService, Map params, List messages) {
this.llmService = llmService;
- this.assistantMessage = assistantMessage;
this.params = params;
+ this.messages = messages;
}
- public CompletionResponse getResponse(String inputQuery) {
- Message userMsg = new Message(Role.user, inputQuery);
- List messages = List.of(assistantMessage, userMsg);
- LLMRequest request = new BaseLLMRequest(messages, params);
- LLMResponse response = llmService.generateResponse(request);
- return response.getContent();
+ public CompletionResponse sendQuery() {
+ ILLMRequest request = new BaseLLMRequest(this.messages, this.params);
+ CompletionResponse response = this.llmService.generateResponse(request);
+ return response;
}
- public CompletableFuture getResponseAsync(String inputQuery) {
- Message userMsg = new Message(Role.user, inputQuery);
- List messages = List.of(assistantMessage, userMsg);
- LLMRequest request = new BaseLLMRequest(messages, params);
+ public CompletableFuture sendQueryAsync() {
+ ILLMRequest request = new BaseLLMRequest(this.messages, this.params);
+ return this.llmService.generateResponseAsync(request);
+ }
- return llmService.generateResponseAsync(request)
- .thenApply(LLMResponse::getContent);
+ public CompletionResponse sendQuery(String inputQuery) {
+ CompletionResponse response = this.llmService.generateResponse(inputQuery);
+ return response;
}
+ public CompletableFuture sendQueryAsync(String inputQuery) {
+ return this.llmService.generateResponseAsync(inputQuery);
+ }
}
diff --git a/src/main/java/com/r7b7/util/StringUtility.java b/src/main/java/com/r7b7/util/StringUtility.java
index ddd5703..28ef137 100644
--- a/src/main/java/com/r7b7/util/StringUtility.java
+++ b/src/main/java/com/r7b7/util/StringUtility.java
@@ -3,7 +3,7 @@
import java.net.URI;
public final class StringUtility {
-
+
public static boolean isNullOrEmpty(String str) {
return str == null || str.isEmpty();
}
@@ -13,7 +13,7 @@ public static boolean isValidHttpOrHttpsUrl(String url) {
return false;
}
try {
- URI.create(url).toURL();
+ URI.create(url).toURL();
return true;
} catch (Exception e) {
return false;
diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties
new file mode 100644
index 0000000..0ce3686
--- /dev/null
+++ b/src/main/resources/application.properties
@@ -0,0 +1,5 @@
+hospai.openai.url=https://api.openai.com/v1/chat/completions
+hospai.anthropic.url=https://api.anthropic.com/v1/messages
+hospai.anthropic.version=2023-06-01
+hospai.ollama.url=http://localhost:11434/api/chat
+hospai.groq.url=https://api.groq.com/openai/v1/chat/completions
\ No newline at end of file
diff --git a/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java b/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java
index 2fc3291..9ff4f22 100644
--- a/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java
+++ b/src/test/java/com/r7b7/client/DefaultAnthropicClientTest.java
@@ -11,7 +11,10 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -41,12 +44,16 @@ public void setUp() {
@Test
public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null,
- "test-model", "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
- when(mockResponse.body()).thenReturn("{\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}]}");
+ when(mockResponse.body()).thenReturn("{\"id\": \"msg_01KVZJSawLDxgY8LvDm2w9KP\",\"model\": \"claude-3-5-sonnet-20241022\",\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}], \"usage\":{\"input_tokens\": 48,\"output_tokens\": 70}}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -63,11 +70,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte
public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null,
- "test-model", "api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
+
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
- when(mockResponse.body()).thenReturn("{\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}]}");
+ when(mockResponse.body()).thenReturn("{\"id\": \"msg_01KVZJSawLDxgY8LvDm2w9KP\",\"model\": \"claude-3-5-sonnet-20241022\",\"content\": [{\"type\": \"assistant\", \"text\": \"Hi there!\"}], \"usage\":{\"input_tokens\": 48,\"output_tokens\": 70}}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -83,8 +95,13 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt
public void testGenerateCompletion_InvalidApiKey() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null,
- "test-model", "invalid-api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
+
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(401);
@@ -104,8 +121,12 @@ public void testGenerateCompletion_InvalidApiKey() throws IOException, Interrupt
public void testGenerateCompletion_HandleException() throws IOException, InterruptedException {
// Arrange
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")), null,
- "test-model", "invalid-api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
// Mock the HttpClient to throw an exception
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
diff --git a/src/test/java/com/r7b7/client/DefaultGroqClientTest.java b/src/test/java/com/r7b7/client/DefaultGroqClientTest.java
index 88b6fdd..e8b9537 100644
--- a/src/test/java/com/r7b7/client/DefaultGroqClientTest.java
+++ b/src/test/java/com/r7b7/client/DefaultGroqClientTest.java
@@ -11,7 +11,10 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -40,13 +43,17 @@ public void setUp() {
@Test
public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
+
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
when(mockResponse.body()).thenReturn(
- "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}");
+ "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -63,13 +70,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte
public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
when(mockResponse.body()).thenReturn(
- "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}");
+ "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -85,9 +95,12 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt
public void testGenerateCompletion_InvalidApiKey() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "invalid-api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(401);
@@ -107,9 +120,12 @@ public void testGenerateCompletion_InvalidApiKey() throws IOException, Interrupt
public void testGenerateCompletion_HandleException() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "invalid-api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenThrow(new IOException("Mocked IOException"));
diff --git a/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java b/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java
index dcfbc0d..98c0d74 100644
--- a/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java
+++ b/src/test/java/com/r7b7/client/DefaultOllamaClientTest.java
@@ -11,7 +11,10 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -40,13 +43,16 @@ public void setUp() {
@Test
public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", null);
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
when(mockResponse.body()).thenReturn(
- "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"}}");
+ "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"},\"total_duration\":5191566416,\"eval_duration\":4799921000}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -62,13 +68,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte
@Test
public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", null);
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
when(mockResponse.body()).thenReturn(
- "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"}}");
+ "{\"model\":\"test\",\"message\":{\"role\": \"assistant\", \"content\": \"Hi there!\"},\"total_duration\":5191566416,\"eval_duration\":4799921000}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -83,9 +92,12 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt
@Test
public void testGenerateCompletion_HandleException() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", null);
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenThrow(new IOException("Mocked IOException"));
diff --git a/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java b/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java
index 43b0a07..ec84eaa 100644
--- a/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java
+++ b/src/test/java/com/r7b7/client/DefaultOpenAIClientTest.java
@@ -11,7 +11,10 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -40,13 +43,16 @@ public void setUp() {
@Test
public void testGenerateCompletion_ValidRequest() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
when(mockResponse.body()).thenReturn(
- "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}");
+ "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -63,13 +69,16 @@ public void testGenerateCompletion_ValidRequest() throws IOException, Interrupte
public void testGenerateCompletion_WithoutParams() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(200);
when(mockResponse.body()).thenReturn(
- "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}]}");
+ "{\"id\":\"xxx\",\"model\":\"test\",\"choices\":[{\"index\":0, \"message\":{\"type\": \"assistant\", \"text\": \"Hi there!\"}}],\"usage\":{\"prompt_tokens\": 55,\"completion_tokens\": 12,\"total_tokens\": 67}}");
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(mockResponse);
mockedStatic.when(HttpClient::newHttpClient).thenReturn(mockHttpClient);
@@ -85,9 +94,12 @@ public void testGenerateCompletion_WithoutParams() throws IOException, Interrupt
public void testGenerateCompletion_InvalidApiKey() throws IOException, InterruptedException {
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "invalid-api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
HttpResponse mockResponse = mock(HttpResponse.class);
when(mockResponse.statusCode()).thenReturn(401);
@@ -108,9 +120,12 @@ public void testGenerateCompletion_HandleException() throws IOException, Interru
try (MockedStatic mockedStatic = mockStatic(HttpClient.class)) {
// Arrange
- CompletionRequest request = new CompletionRequest(List.of(new Message(Role.assistant, "Hello")),
- null,
- "test-model", "invalid-api-key");
+ Map requestMap = new HashMap<>();
+ requestMap.put("model", "test-model");
+ List prompt = new ArrayList<>();
+ prompt.add(new Message(Role.system, "You are a helpful assistant"));
+ requestMap.put("messages", prompt);
+ CompletionRequest request = new CompletionRequest(requestMap, "api-key");
// Mock the HttpClient to throw an exception
when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
diff --git a/src/test/java/com/r7b7/service/AnthropicServiceTest.java b/src/test/java/com/r7b7/service/AnthropicServiceTest.java
index ab6105f..dea8065 100644
--- a/src/test/java/com/r7b7/service/AnthropicServiceTest.java
+++ b/src/test/java/com/r7b7/service/AnthropicServiceTest.java
@@ -8,6 +8,7 @@
import static org.mockito.Mockito.verify;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -18,20 +19,18 @@
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
-import com.r7b7.client.AnthropicClient;
-import com.r7b7.client.factory.AnthropicClientFactory;
+import com.r7b7.client.IAnthropicClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
import com.r7b7.entity.Role;
import com.r7b7.model.BaseLLMRequest;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.model.ILLMRequest;
public class AnthropicServiceTest {
@Mock
- private AnthropicClient mockClient;
+ private IAnthropicClient mockClient;
@InjectMocks
private AnthropicService anthropicService;
@@ -47,19 +46,19 @@ public void setUp() {
@Test
public void testGenerateResponse_Success() {
- try (MockedStatic mockedStatic = mockStatic(AnthropicClientFactory.class)) {
- LLMRequest request = createMockLLMRequest("Test prompt");
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
+ ILLMRequest request = createMockLLMRequest("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> AnthropicClientFactory.getClient())
+ mockedStatic.when(() -> LLMClientFactory.getAnthropicClient())
.thenReturn(mockClient);
- LLMResponse response = anthropicService.generateResponse(request);
+ CompletionResponse response = anthropicService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("anthropic", metadata.get("provider"));
@@ -69,20 +68,20 @@ public void testGenerateResponse_Success() {
@Test
public void testGenerateResponse_WithParams_Success() {
- try (MockedStatic mockedStatic = mockStatic(AnthropicClientFactory.class)) {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
- LLMRequest request = createMockLLMRequestWithParams("Test prompt");
+ ILLMRequest request = createMockLLMRequestWithParams("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> AnthropicClientFactory.getClient())
+ mockedStatic.when(() -> LLMClientFactory.getAnthropicClient())
.thenReturn(mockClient);
- LLMResponse response = anthropicService.generateResponse(request);
+ CompletionResponse response = anthropicService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("anthropic", metadata.get("provider"));
@@ -90,10 +89,33 @@ public void testGenerateResponse_WithParams_Success() {
}
}
- private LLMRequest createMockLLMRequest(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- LLMRequest request = new BaseLLMRequest(messages, null);
+ @Test
+ public void testGenerateResponseForSingleQuery_Success() {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
+ CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
+ doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
+ mockedStatic.when(() -> LLMClientFactory.getAnthropicClient())
+ .thenReturn(mockClient);
+
+ CompletionResponse response = anthropicService.generateResponse("Single query");
+
+ assertNotNull(response);
+ assertEquals("test content", response.messages().get(0).content());
+
+ Map metadata = response.metaData();
+ assertEquals(TEST_MODEL, metadata.get("model"));
+ assertEquals("anthropic", metadata.get("provider"));
+
+ verify(mockClient).generateCompletion(any(CompletionRequest.class));
+ }
+ }
+
+ private ILLMRequest createMockLLMRequest(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.system, "You are a helpful assistant"));
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+ ILLMRequest request = new BaseLLMRequest(messages, null);
return request;
}
@@ -101,18 +123,24 @@ private CompletionResponse createMockCompletionResponse(String content) {
List messages = new ArrayList<>();
com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content");
messages.add(msg);
- CompletionResponse response = new CompletionResponse(messages, null, null);
+ Map metaData = new HashMap<>();
+ metaData.put("model", TEST_MODEL);
+ metaData.put("provider", "anthropic");
+ CompletionResponse response = new CompletionResponse(messages, metaData, null);
return response;
}
- private LLMRequest createMockLLMRequestWithParams(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- Map params = Map.of(
- Param.temperature, 0.7,
- Param.max_token, 1000);
+ private ILLMRequest createMockLLMRequestWithParams(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.system, "You are a helpful assistant"));
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+
+ Map params = Map.of(
+ "temperature", 0.7,
+ "max_token", 1000);
- LLMRequest request = new BaseLLMRequest(messages, params);
+ ILLMRequest request = new BaseLLMRequest(messages, params);
return request;
}
}
diff --git a/src/test/java/com/r7b7/service/GroqServiceTest.java b/src/test/java/com/r7b7/service/GroqServiceTest.java
index 57ba2ca..a3f5bca 100644
--- a/src/test/java/com/r7b7/service/GroqServiceTest.java
+++ b/src/test/java/com/r7b7/service/GroqServiceTest.java
@@ -8,6 +8,7 @@
import static org.mockito.Mockito.verify;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -18,20 +19,18 @@
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
-import com.r7b7.client.GroqClient;
-import com.r7b7.client.factory.GroqClientFactory;
+import com.r7b7.client.IGroqClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
import com.r7b7.entity.Role;
import com.r7b7.model.BaseLLMRequest;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.model.ILLMRequest;
public class GroqServiceTest {
@Mock
- private GroqClient mockClient;
+ private IGroqClient mockClient;
@InjectMocks
private GroqService groqService;
@@ -47,18 +46,18 @@ public void setUp() {
@Test
public void testGenerateResponse_Success() {
- try (MockedStatic mockedStatic = mockStatic(GroqClientFactory.class);) {
- LLMRequest request = createMockLLMRequest("Test prompt");
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class);) {
+ ILLMRequest request = createMockLLMRequest("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> GroqClientFactory.getClient()).thenReturn(mockClient);
+ mockedStatic.when(() -> LLMClientFactory.getGroqClient()).thenReturn(mockClient);
- LLMResponse response = groqService.generateResponse(request);
+ CompletionResponse response = groqService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("grok", metadata.get("provider"));
@@ -68,18 +67,18 @@ public void testGenerateResponse_Success() {
@Test
public void testGenerateResponse_WithParams_Success() {
- try (MockedStatic mockedStatic = mockStatic(GroqClientFactory.class);) {
- LLMRequest request = createMockLLMRequestWithParams("Test prompt");
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class);) {
+ ILLMRequest request = createMockLLMRequestWithParams("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> GroqClientFactory.getClient()).thenReturn(mockClient);
+ mockedStatic.when(() -> LLMClientFactory.getGroqClient()).thenReturn(mockClient);
- LLMResponse response = groqService.generateResponse(request);
+ CompletionResponse response = groqService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("grok", metadata.get("provider"));
@@ -87,10 +86,31 @@ public void testGenerateResponse_WithParams_Success() {
}
}
- private LLMRequest createMockLLMRequest(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- LLMRequest request = new BaseLLMRequest(messages, null);
+ @Test
+ public void testGenerateResponseWithSingleQuery_Success() {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class);) {
+ CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
+ doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
+ mockedStatic.when(() -> LLMClientFactory.getGroqClient()).thenReturn(mockClient);
+
+ CompletionResponse response = groqService.generateResponse("Single Query");
+
+ assertNotNull(response);
+ assertEquals("test content", response.messages().get(0).content());
+
+ Map metadata = response.metaData();
+ assertEquals(TEST_MODEL, metadata.get("model"));
+ assertEquals("grok", metadata.get("provider"));
+
+ verify(mockClient).generateCompletion(any(CompletionRequest.class));
+ }
+ }
+
+ private ILLMRequest createMockLLMRequest(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+ ILLMRequest request = new BaseLLMRequest(messages, null);
return request;
}
@@ -98,18 +118,23 @@ private CompletionResponse createMockCompletionResponse(String content) {
List messages = new ArrayList<>();
com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content");
messages.add(msg);
- CompletionResponse response = new CompletionResponse(messages, null, null);
+ Map metaData = new HashMap<>();
+ metaData.put("model", TEST_MODEL);
+ metaData.put("provider", "grok");
+ CompletionResponse response = new CompletionResponse(messages, metaData, null);
return response;
}
- private LLMRequest createMockLLMRequestWithParams(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- Map params = Map.of(
- Param.temperature, 0.7,
- Param.max_token, 1000);
+ private ILLMRequest createMockLLMRequestWithParams(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+
+ Map params = Map.of(
+ "temperature", 0.7,
+ "max_token", 1000);
- LLMRequest request = new BaseLLMRequest(messages, params);
+ ILLMRequest request = new BaseLLMRequest(messages, params);
return request;
}
}
diff --git a/src/test/java/com/r7b7/service/OllamaServiceTest.java b/src/test/java/com/r7b7/service/OllamaServiceTest.java
index bb07c1f..b090f73 100644
--- a/src/test/java/com/r7b7/service/OllamaServiceTest.java
+++ b/src/test/java/com/r7b7/service/OllamaServiceTest.java
@@ -8,6 +8,7 @@
import static org.mockito.Mockito.verify;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -18,20 +19,18 @@
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
-import com.r7b7.client.OllamaClient;
-import com.r7b7.client.factory.OllamaClientFactory;
+import com.r7b7.client.IOllamaClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
import com.r7b7.entity.Role;
import com.r7b7.model.BaseLLMRequest;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.model.ILLMRequest;
public class OllamaServiceTest {
@Mock
- private OllamaClient mockClient;
+ private IOllamaClient mockClient;
@InjectMocks
private OllamaService ollamaService;
@@ -46,18 +45,18 @@ public void setUp() {
@Test
public void testGenerateResponse_Success() {
- try (MockedStatic mockedStatic = mockStatic(OllamaClientFactory.class)) {
- LLMRequest request = createMockLLMRequest("Test prompt");
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
+ ILLMRequest request = createMockLLMRequest("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> OllamaClientFactory.getClient()).thenReturn(mockClient);
+ mockedStatic.when(() -> LLMClientFactory.getOllamaClient()).thenReturn(mockClient);
- LLMResponse response = ollamaService.generateResponse(request);
+ CompletionResponse response = ollamaService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("Ollama", metadata.get("provider"));
@@ -67,19 +66,19 @@ public void testGenerateResponse_Success() {
@Test
public void testGenerateResponse_WithParams_Success() {
- try (MockedStatic mockedStatic = mockStatic(OllamaClientFactory.class)) {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
- LLMRequest request = createMockLLMRequestWithParams("Test prompt");
+ ILLMRequest request = createMockLLMRequestWithParams("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> OllamaClientFactory.getClient()).thenReturn(mockClient);
+ mockedStatic.when(() -> LLMClientFactory.getOllamaClient()).thenReturn(mockClient);
- LLMResponse response = ollamaService.generateResponse(request);
+ CompletionResponse response = ollamaService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("Ollama", metadata.get("provider"));
@@ -87,10 +86,32 @@ public void testGenerateResponse_WithParams_Success() {
}
}
- private LLMRequest createMockLLMRequest(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- LLMRequest request = new BaseLLMRequest(messages, null);
+ @Test
+ public void testGenerateResponseWithSingleQuery_Success() {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
+ CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
+ doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
+ mockedStatic.when(() -> LLMClientFactory.getOllamaClient()).thenReturn(mockClient);
+
+ CompletionResponse response = ollamaService.generateResponse("Single Query");
+
+ assertNotNull(response);
+ assertEquals("test content", response.messages().get(0).content());
+
+ Map metadata = response.metaData();
+ assertEquals(TEST_MODEL, metadata.get("model"));
+ assertEquals("Ollama", metadata.get("provider"));
+
+ verify(mockClient).generateCompletion(any(CompletionRequest.class));
+ }
+ }
+
+ private ILLMRequest createMockLLMRequest(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+
+ ILLMRequest request = new BaseLLMRequest(messages, null);
return request;
}
@@ -98,18 +119,23 @@ private CompletionResponse createMockCompletionResponse(String content) {
List messages = new ArrayList<>();
com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content");
messages.add(msg);
- CompletionResponse response = new CompletionResponse(messages, null, null);
+ Map metaData = new HashMap<>();
+ metaData.put("model", TEST_MODEL);
+ metaData.put("provider", "Ollama");
+ CompletionResponse response = new CompletionResponse(messages, metaData, null);
return response;
}
- private LLMRequest createMockLLMRequestWithParams(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- Map params = Map.of(
- Param.temperature, 0.7,
- Param.max_token, 1000);
+ private ILLMRequest createMockLLMRequestWithParams(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+
+ Map params = Map.of(
+ "temperature", 0.7,
+ "max_token", 1000);
- LLMRequest request = new BaseLLMRequest(messages, params);
+ ILLMRequest request = new BaseLLMRequest(messages, params);
return request;
}
}
diff --git a/src/test/java/com/r7b7/service/OpenAIServiceTest.java b/src/test/java/com/r7b7/service/OpenAIServiceTest.java
index e5691c3..752a2dc 100644
--- a/src/test/java/com/r7b7/service/OpenAIServiceTest.java
+++ b/src/test/java/com/r7b7/service/OpenAIServiceTest.java
@@ -8,6 +8,7 @@
import static org.mockito.Mockito.verify;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -18,20 +19,18 @@
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
-import com.r7b7.client.OpenAIClient;
-import com.r7b7.client.factory.OpenAIClientFactory;
+import com.r7b7.client.IOpenAIClient;
+import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
-import com.r7b7.entity.Param;
import com.r7b7.entity.Role;
import com.r7b7.model.BaseLLMRequest;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.model.ILLMRequest;
public class OpenAIServiceTest {
@Mock
- private OpenAIClient mockClient;
+ private IOpenAIClient mockClient;
@InjectMocks
private OpenAIService openAIService;
@@ -47,18 +46,18 @@ public void setUp() {
@Test
public void testGenerateResponse_Success() {
- try (MockedStatic mockedStatic = mockStatic(OpenAIClientFactory.class)) {
- LLMRequest request = createMockLLMRequest("Test prompt");
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
+ ILLMRequest request = createMockLLMRequest("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> OpenAIClientFactory.getClient()).thenReturn(mockClient);
+ mockedStatic.when(() -> LLMClientFactory.getOpenAIClient()).thenReturn(mockClient);
- LLMResponse response = openAIService.generateResponse(request);
+ CompletionResponse response = openAIService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("openai", metadata.get("provider"));
@@ -68,19 +67,19 @@ public void testGenerateResponse_Success() {
@Test
public void testGenerateResponse_WithParams_Success() {
- try (MockedStatic mockedStatic = mockStatic(OpenAIClientFactory.class)) {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
- LLMRequest request = createMockLLMRequestWithParams("Test prompt");
+ ILLMRequest request = createMockLLMRequestWithParams("Test prompt");
CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
- mockedStatic.when(() -> OpenAIClientFactory.getClient()).thenReturn(mockClient);
+ mockedStatic.when(() -> LLMClientFactory.getOpenAIClient()).thenReturn(mockClient);
- LLMResponse response = openAIService.generateResponse(request);
+ CompletionResponse response = openAIService.generateResponse(request);
assertNotNull(response);
- assertEquals("test content", response.getContent().messages().get(0).content());
+ assertEquals("test content", response.messages().get(0).content());
- Map metadata = response.getMetadata();
+ Map metadata = response.metaData();
assertEquals(TEST_MODEL, metadata.get("model"));
assertEquals("openai", metadata.get("provider"));
@@ -88,10 +87,31 @@ public void testGenerateResponse_WithParams_Success() {
}
}
- private LLMRequest createMockLLMRequest(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- LLMRequest request = new BaseLLMRequest(messages, null);
+ @Test
+ public void testGenerateResponseWithSingleQuery_Success() {
+ try (MockedStatic mockedStatic = mockStatic(LLMClientFactory.class)) {
+ CompletionResponse mockCompletionResponse = createMockCompletionResponse("Test response");
+ doReturn(mockCompletionResponse).when(mockClient).generateCompletion(any());
+ mockedStatic.when(() -> LLMClientFactory.getOpenAIClient()).thenReturn(mockClient);
+
+ CompletionResponse response = openAIService.generateResponse("single query");
+
+ assertNotNull(response);
+ assertEquals("test content", response.messages().get(0).content());
+
+ Map metadata = response.metaData();
+ assertEquals(TEST_MODEL, metadata.get("model"));
+ assertEquals("openai", metadata.get("provider"));
+
+ verify(mockClient).generateCompletion(any(CompletionRequest.class));
+ }
+ }
+
+ private ILLMRequest createMockLLMRequest(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+ ILLMRequest request = new BaseLLMRequest(messages, null);
return request;
}
@@ -99,18 +119,24 @@ private CompletionResponse createMockCompletionResponse(String content) {
List messages = new ArrayList<>();
com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content");
messages.add(msg);
- CompletionResponse response = new CompletionResponse(messages, null, null);
+ Map metaData = new HashMap<>();
+ metaData.put("model", TEST_MODEL);
+ metaData.put("provider", "openai");
+ CompletionResponse response = new CompletionResponse(messages, metaData, null);
return response;
}
- private LLMRequest createMockLLMRequestWithParams(String prompt) {
- List messages = List.of(new Message(Role.assistant, "You are a helpful assistant"),
- new Message(Role.user, prompt));
- Map params = Map.of(
- Param.temperature, 0.7,
- Param.max_token, 1000);
+ private ILLMRequest createMockLLMRequestWithParams(String prompt) {
+ List messages = new ArrayList<>();
+ messages.add(new Message(Role.system, "You are a helpful assistant"));
+ messages.add(new Message(Role.assistant, "You are a helpful assistant"));
+ messages.add(new Message(Role.user, prompt));
+
+ Map params = Map.of(
+ "temperature", 0.7,
+ "max_token", 1000);
- LLMRequest request = new BaseLLMRequest(messages, params);
+ ILLMRequest request = new BaseLLMRequest(messages, params);
return request;
}
}
diff --git a/src/test/java/com/r7b7/service/PromptEngineTest.java b/src/test/java/com/r7b7/service/PromptEngineTest.java
index 5b4cf96..f2b2bd7 100644
--- a/src/test/java/com/r7b7/service/PromptEngineTest.java
+++ b/src/test/java/com/r7b7/service/PromptEngineTest.java
@@ -2,7 +2,6 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -12,35 +11,65 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.mockito.Mockito;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
import com.r7b7.entity.CompletionResponse;
-import com.r7b7.model.LLMRequest;
-import com.r7b7.model.LLMResponse;
+import com.r7b7.entity.Message;
+import com.r7b7.entity.Role;
+import com.r7b7.model.ILLMRequest;
public class PromptEngineTest {
- private LLMService mockLlmService;
+ @Mock
+ private ILLMService mockLlmService;
+
+ @InjectMocks
private PromptEngine promptEngine;
@BeforeEach
void setUp() {
- mockLlmService = Mockito.mock(LLMService.class);
- promptEngine = new PromptEngine(mockLlmService);
+ MockitoAnnotations.openMocks(this);
+
+ // mockLlmService = Mockito.mock(ILLMService.class);
}
@Test
- void testGetResponse() {
+ void testSendQuery_Text_Input() {
String inputQuery = "What is the weather today?";
String expectedContent = "test content";
- LLMResponse mockResponse = mock(LLMResponse.class);
- when(mockResponse.getContent()).thenReturn(createMockCompletionResponse("Test prompt"));
- when(mockLlmService.generateResponse(any(LLMRequest.class))).thenReturn(mockResponse);
+ promptEngine = new PromptEngine(mockLlmService);
+
+ when(mockLlmService.generateResponse(inputQuery))
+ .thenReturn(createMockCompletionResponse("Test prompt"));
+
+ CompletionResponse response = promptEngine.sendQuery(inputQuery);
+
+ assertEquals(expectedContent, response.messages().get(0).content());
+ verify(mockLlmService, times(1)).generateResponse(inputQuery);
+ }
+
+ @Test
+ void testSendQuery_Builder_Input() {
+ String expectedContent = "test content";
+
+ when(mockLlmService.generateResponse(any(ILLMRequest.class)))
+ .thenReturn(createMockCompletionResponse("Test prompt"));
+
+ PromptBuilder builder = new PromptBuilder()
+ .addMessage(new Message(Role.system, "Give output in consistent format"))
+ .addMessage(new Message(Role.user, "what's the stock symbol of ARCHER Aviation?"))
+ .addMessage(new Message(Role.assistant, "{\"company\":\"Archer\", \"symbol\":\"ACHR\"}"))
+ .addMessage(new Message(Role.user, "what's the stock symbol of Palantir technology?"))
+ .addParam("temperature", 0.7)
+ .addParam("max_tokens", 150);
+ promptEngine = builder.build(mockLlmService);
- CompletionResponse response = promptEngine.getResponse(inputQuery);
+ CompletionResponse response = promptEngine.sendQuery();
assertEquals(expectedContent, response.messages().get(0).content());
- verify(mockLlmService, times(1)).generateResponse(any(LLMRequest.class));
+ verify(mockLlmService, times(1)).generateResponse(any(ILLMRequest.class));
}
private CompletionResponse createMockCompletionResponse(String content) {