Skip to content

Commit

Permalink
Merge pull request #1 from r7b7/refactor-client-factory-branch
Browse files Browse the repository at this point in the history
Refactor client factory branch
  • Loading branch information
r7b7 authored Dec 5, 2024
2 parents 497cf04 + 7ecd351 commit f5cb4c6
Show file tree
Hide file tree
Showing 44 changed files with 884 additions and 613 deletions.
Binary file added .DS_Store
Binary file not shown.
16 changes: 11 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
<version>2.18.1</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.11.3</version>
<scope>test</scope>
</dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.11.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand All @@ -55,5 +55,11 @@
<version>3.0.0</version>
</plugin>
</plugins>
<resources>
<resource>
<directory>src/main/resources</directory>
<filtering>true</filtering>
</resource>
</resources>
</build>
</project>
57 changes: 28 additions & 29 deletions src/main/java/com/r7b7/client/DefaultAnthropicClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,51 @@
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.r7b7.client.model.AnthropicResponse;
import com.r7b7.client.model.Message;
import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;

public class DefaultAnthropicClient implements AnthropicClient {
private String ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages";
private String ANTHROPIC_VERSION = "2023-06-01";
private Integer MAX_TOKENS = 1024;
public class DefaultAnthropicClient implements IAnthropicClient {
private String ANTHROPIC_API_URL;
private String ANTHROPIC_VERSION;

public DefaultAnthropicClient() {}
public DefaultAnthropicClient() {
try {
Properties properties = PropertyConfig.loadConfig();
ANTHROPIC_API_URL = properties.getProperty("hospai.anthropic.url");
ANTHROPIC_VERSION = properties.getProperty("hospai.anthropic.version");
} catch (Exception ex) {
throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
}
}

@Override
public CompletionResponse generateCompletion(CompletionRequest request){
public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();
ArrayNode arrayNode = objectMapper.valueToTree(request.messages());
ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("model", request.model());
requestBody.set("messages", arrayNode);

if (null != request.params()) {
for (Map.Entry<String, Object> entry : request.params().entrySet()) {
Object value = entry.getValue();
if (value instanceof String) {
requestBody.put(entry.getKey(), (String) value);
} else if (value instanceof Integer) {
requestBody.put(entry.getKey(), (Integer) value);
}
}
} else {
requestBody.put("max_tokens", MAX_TOKENS);
}
String jsonRequest = objectMapper.writeValueAsString(request.requestBody());

HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.ANTHROPIC_API_URL))
.header("Content-Type", "application/json")
.header("x-api-key", request.apiKey())
.header("anthropic-version", ANTHROPIC_VERSION)
.POST(HttpRequest.BodyPublishers.ofString(requestBody.toString()))
.POST(HttpRequest.BodyPublishers.ofString(jsonRequest))
.build();

HttpResponse<String> response = HttpClient.newHttpClient().send(httpRequest,
HttpResponse.BodyHandlers.ofString());
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
return new CompletionResponse(null, response, new ErrorResponse(
"Request sent to LLM failed with status code " + response, null));
return new CompletionResponse(null, null, new ErrorResponse(
"Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
Expand All @@ -71,13 +62,21 @@ private CompletionResponse extractResponseText(String responseBody) {
List<Message> msgs = null;
AnthropicResponse response = null;
ErrorResponse error = null;
Map<String, Object> metadata = null;

try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, AnthropicResponse.class);
msgs = response.content().stream().map(content -> new Message(content.type(), content.text())).toList();
metadata = Map.of(
"id", response.id(),
"model", response.model(),
"provider", "Anthropic",
"input_tokens", response.usage().inputTokens(),
"output_tokens", response.usage().outputTokens());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
return new CompletionResponse(msgs, response, error);
return new CompletionResponse(msgs, metadata, error);
}
}
50 changes: 25 additions & 25 deletions src/main/java/com/r7b7/client/DefaultGroqClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,47 @@
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.r7b7.client.model.Message;
import com.r7b7.client.model.OpenAIResponse;
import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;

public class DefaultGroqClient implements GroqClient {
private String GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions";
public class DefaultGroqClient implements IGroqClient {
private String GROQ_API_URL;

public DefaultGroqClient(){}
public DefaultGroqClient() {
try {
Properties properties = PropertyConfig.loadConfig();
GROQ_API_URL = properties.getProperty("hospai.groq.url");
} catch (Exception ex) {
throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
}
}

@Override
public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();
ArrayNode arrayNode = objectMapper.valueToTree(request.messages());
ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("model", request.model());
requestBody.set("messages", arrayNode);
String jsonRequest = objectMapper.writeValueAsString(request.requestBody());

if (null != request.params()) {
for (Map.Entry<String, Object> entry : request.params().entrySet()) {
Object value = entry.getValue();
if (value instanceof String) {
requestBody.put(entry.getKey(), (String) value);
} else if (value instanceof Integer) {
requestBody.put(entry.getKey(), (Integer) value);
}
}
}

HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.GROQ_API_URL))
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + request.apiKey())
.POST(HttpRequest.BodyPublishers.ofString(requestBody.toString()))
.POST(HttpRequest.BodyPublishers.ofString(jsonRequest))
.build();
HttpResponse<String> response = HttpClient.newHttpClient().send(httpRequest,
HttpResponse.BodyHandlers.ofString());
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
return new CompletionResponse(null, response, new ErrorResponse(
"Request sent to LLM failed with status code " + response.statusCode(), null));
return new CompletionResponse(null, null, new ErrorResponse(
"Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
Expand All @@ -64,16 +57,23 @@ private CompletionResponse extractResponseText(String responseBody) {
List<Message> msgs = null;
OpenAIResponse response = null;
ErrorResponse error = null;
Map<String, Object> metadata = null;

try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OpenAIResponse.class);
msgs = response.choices().stream()
.map(choice -> new Message(choice.message().role(), choice.message().content())).toList();
metadata = Map.of(
"id", response.id(),
"model", response.model(),
"provider", "Groq",
"prompt_tokens", response.usage().promptTokens(),
"completion_tokens", response.usage().completionTokens(),
"total_tokens", response.usage().totalTokens());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
return new CompletionResponse(msgs, response, error);
return new CompletionResponse(msgs, metadata, error);
}

}
43 changes: 21 additions & 22 deletions src/main/java/com/r7b7/client/DefaultOllamaClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,35 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.r7b7.client.model.Message;
import com.r7b7.client.model.OllamaResponse;
import com.r7b7.config.PropertyConfig;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.ErrorResponse;

public class DefaultOllamaClient implements OllamaClient {
private String OLLAMA_API_URL = "http://localhost:11434/api/chat";
public class DefaultOllamaClient implements IOllamaClient {
private String OLLAMA_API_URL;

public DefaultOllamaClient() {
try {
Properties properties = PropertyConfig.loadConfig();
OLLAMA_API_URL = properties.getProperty("hospai.ollama.url");
} catch (Exception ex) {
throw new IllegalStateException("Critical configuration missing: CRITICAL_PROPERTY");
}
}

@Override
public CompletionResponse generateCompletion(CompletionRequest request) {
try {
ObjectMapper objectMapper = new ObjectMapper();

Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", request.model());
requestMap.put("messages", request.messages());
requestMap.put("stream", false);

Map<String, Object> optionsMap = new HashMap<>();
if (null != request.params() && request.params().get("temperature") != null) {
optionsMap.put("temperature", request.params().get("temperature"));
}
if (null != request.params() && request.params().get("seed") != null) {
optionsMap.put("seed", request.params().get("seed"));
}

requestMap.put("options", optionsMap);
String jsonRequest = objectMapper.writeValueAsString(requestMap);
String jsonRequest = objectMapper.writeValueAsString(request.requestBody());

HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(this.OLLAMA_API_URL))
Expand All @@ -53,8 +45,8 @@ public CompletionResponse generateCompletion(CompletionRequest request) {
if (response.statusCode() == 200) {
return extractResponseText(response.body());
} else {
return new CompletionResponse(null, response, new ErrorResponse(
"Request sent to LLM failed with status code " + response.statusCode(), null));
return new CompletionResponse(null, null, new ErrorResponse(
"Request sent to LLM failed: " + response.statusCode() + response.body(), null));
}
} catch (Exception ex) {
return new CompletionResponse(null, null, new ErrorResponse("Request processing failed", ex));
Expand All @@ -65,13 +57,20 @@ private CompletionResponse extractResponseText(String responseBody) {
List<Message> msgs = null;
OllamaResponse response = null;
ErrorResponse error = null;
Map<String, Object> metadata = null;

try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OllamaResponse.class);
msgs = List.of(response.message());
metadata = Map.of(
"model", response.model(),
"provider", "Ollama",
"total_duration", response.total_duration(),
"eval_duration", response.eval_duration());
} catch (Exception ex) {
error = new ErrorResponse("Exception occurred in extracting response", ex);
}
return new CompletionResponse(msgs, response, error);
return new CompletionResponse(msgs, metadata, error);
}
}
Loading

0 comments on commit f5cb4c6

Please sign in to comment.