diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java index 447e19e0f..96ba9912d 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java @@ -2,7 +2,6 @@ import com.edgechain.lib.endpoint.Endpoint; import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; -import com.edgechain.lib.openai.response.ChatCompletionResponse; import com.edgechain.lib.request.ArkRequest; import com.edgechain.lib.retrofit.Llama2Service; import com.edgechain.lib.retrofit.client.RetrofitClientInstance; @@ -17,154 +16,165 @@ import java.util.Objects; public class Llama2Endpoint extends Endpoint { - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); - - private final ModelMapper modelMapper = new ModelMapper(); - - private String inputs; - private JSONObject parameters; - private Double temperature; - @JsonProperty("top_k") - private Integer topK; - @JsonProperty("top_p") - private Double topP; - - @JsonProperty("do_sample") - private Boolean doSample; - @JsonProperty("max_new_tokens") - private Integer maxNewTokens; - @JsonProperty("repetition_penalty") - private Double repetitionPenalty; - private List stop; - private String chainName; - private String callIdentifier; - - public Llama2Endpoint() { - } - - public Llama2Endpoint(String url, RetryPolicy retryPolicy, - Double temperature, Integer topK, Double topP, - Boolean doSample, Integer maxNewTokens, Double repetitionPenalty, - List stop) { - super(url, retryPolicy); - this.temperature = temperature; - this.topK = topK; - this.topP = topP; - this.doSample = doSample; - this.maxNewTokens = maxNewTokens; - this.repetitionPenalty = repetitionPenalty; - this.stop = stop; - } - - public Llama2Endpoint(String url, RetryPolicy retryPolicy) { - super(url, retryPolicy); - this.temperature = 0.7; - this.maxNewTokens = 512; - } - - public String getInputs() { - return inputs; - } - - public void setInputs(String inputs) { - this.inputs = inputs; - } - - public JSONObject getParameters() { - return parameters; - } - - public void setParameters(JSONObject parameters) { - this.parameters = parameters; - } - - public Double getTemperature() { - return temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - public Integer getTopK() { - return topK; - } - - public void setTopK(Integer topK) { - this.topK = topK; - } - - public Double getTopP() { - return topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - public Boolean getDoSample() { - return doSample; - } - - public void setDoSample(Boolean doSample) { - this.doSample = doSample; - } - - public Integer getMaxNewTokens() { - return maxNewTokens; - } - - public void setMaxNewTokens(Integer maxNewTokens) { - this.maxNewTokens = maxNewTokens; - } - - public Double getRepetitionPenalty() { - return repetitionPenalty; - } - - public void setRepetitionPenalty(Double repetitionPenalty) { - this.repetitionPenalty = repetitionPenalty; - } - - public List getStop() { - return stop; - } - - public void setStop(List stop) { - this.stop = stop; - } - - public String getChainName() { - return chainName; - } - - public void setChainName(String chainName) { - this.chainName = chainName; - } - - public String getCallIdentifier() { - return callIdentifier; - } - - public void setCallIdentifier(String callIdentifier) { - this.callIdentifier = callIdentifier; - } - - public Observable> chatCompletion( - String inputs,String chainName, ArkRequest arkRequest) { + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); + + private final ModelMapper modelMapper = new ModelMapper(); + + private String inputs; + private JSONObject parameters; + private Double temperature; + + @JsonProperty("top_k") + private Integer topK; + + @JsonProperty("top_p") + private Double topP; + + @JsonProperty("do_sample") + private Boolean doSample; + + @JsonProperty("max_new_tokens") + private Integer maxNewTokens; + + @JsonProperty("repetition_penalty") + private Double repetitionPenalty; + + private List stop; + private String chainName; + private String callIdentifier; + + public Llama2Endpoint() {} + + public Llama2Endpoint( + String url, + RetryPolicy retryPolicy, + Double temperature, + Integer topK, + Double topP, + Boolean doSample, + Integer maxNewTokens, + Double repetitionPenalty, + List stop) { + super(url, retryPolicy); + this.temperature = temperature; + this.topK = topK; + this.topP = topP; + this.doSample = doSample; + this.maxNewTokens = maxNewTokens; + this.repetitionPenalty = repetitionPenalty; + this.stop = stop; + } + + public Llama2Endpoint(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + this.temperature = 0.7; + this.maxNewTokens = 512; + } + + public String getInputs() { + return inputs; + } + + public void setInputs(String inputs) { + this.inputs = inputs; + } + + public JSONObject getParameters() { + return parameters; + } + + public void setParameters(JSONObject parameters) { + this.parameters = parameters; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Boolean getDoSample() { + return doSample; + } + + public void setDoSample(Boolean doSample) { + this.doSample = doSample; + } + + public Integer getMaxNewTokens() { + return maxNewTokens; + } + + public void setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + } + + public Double getRepetitionPenalty() { + return repetitionPenalty; + } + + public void setRepetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + } + + public List getStop() { + return stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public String getChainName() { + return chainName; + } + + public void setChainName(String chainName) { + this.chainName = chainName; + } + + public String getCallIdentifier() { + return callIdentifier; + } + + public void setCallIdentifier(String callIdentifier) { + this.callIdentifier = callIdentifier; + } + + public Observable> chatCompletion( + String inputs, String chainName, ArkRequest arkRequest) { - Llama2Endpoint mapper = modelMapper.map(this, Llama2Endpoint.class); - mapper.setInputs(inputs); - mapper.setChainName(chainName); - return chatCompletion(mapper, arkRequest); - } + Llama2Endpoint mapper = modelMapper.map(this, Llama2Endpoint.class); + mapper.setInputs(inputs); + mapper.setChainName(chainName); + return chatCompletion(mapper, arkRequest); + } - private Observable> chatCompletion(Llama2Endpoint mapper, ArkRequest arkRequest) { + private Observable> chatCompletion( + Llama2Endpoint mapper, ArkRequest arkRequest) { - if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); - else mapper.setCallIdentifier("URI wasn't provided"); + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); - return Observable.fromSingle(this.llama2Service.chatCompletion(mapper)); - } + return Observable.fromSingle(this.llama2Service.chatCompletion(mapper)); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java index c8a73edf6..e1a389199 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java @@ -1,15 +1,12 @@ package com.edgechain.lib.llama2; - import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; -import com.edgechain.lib.utils.JsonUtils; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.reactivex.rxjava3.core.Observable; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -18,50 +15,50 @@ import org.springframework.web.client.RestTemplate; import java.util.List; -import java.util.Objects; @Service public class Llama2Client { - @Autowired - private ObjectMapper objectMapper; - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final RestTemplate restTemplate = new RestTemplate(); - public EdgeChain> createChatCompletion( - Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - - logger.info("Logging ChatCompletion...."); - - logger.info("==============REQUEST DATA================"); - logger.info(request.toString()); + @Autowired private ObjectMapper objectMapper; + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); -// Llama2ChatCompletionRequest llamaRequest = new Llama2ChatCompletionRequest(); -// -// llamaRequest.setInputs(request.getInputs()); -// llamaRequest.setParameters(request.getParameters()); + public EdgeChain> createChatCompletion( + Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + logger.info("Logging ChatCompletion...."); + logger.info("==============REQUEST DATA================"); + logger.info(request.toString()); - // Create headers - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - HttpEntity entity = new HttpEntity<>(request, headers); -// - String response = restTemplate.postForObject(endpoint.getUrl(), entity, String.class); + // Llama2ChatCompletionRequest llamaRequest = new + // Llama2ChatCompletionRequest(); + // + // llamaRequest.setInputs(request.getInputs()); + // + // llamaRequest.setParameters(request.getParameters()); - List chatCompletionResponse = - objectMapper.readValue(response, new TypeReference>() {}); - emitter.onNext(chatCompletionResponse); - emitter.onComplete(); + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(request, headers); + // + String response = + restTemplate.postForObject(endpoint.getUrl(), entity, String.class); - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } + List chatCompletionResponse = + objectMapper.readValue( + response, new TypeReference>() {}); + emitter.onNext(chatCompletionResponse); + emitter.onComplete(); + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java index 7dc22a280..17f190e93 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java @@ -1,80 +1,67 @@ package com.edgechain.lib.llama2.request; -import com.edgechain.lib.openai.request.ChatCompletionRequest; -import com.edgechain.lib.openai.request.ChatMessage; -import com.fasterxml.jackson.annotation.JsonProperty; import org.json.JSONObject; -import java.util.Collections; -import java.util.List; -import java.util.Map; import java.util.StringJoiner; public class Llama2ChatCompletionRequest { - private String inputs; - private JSONObject parameters; + private String inputs; + private JSONObject parameters; - public Llama2ChatCompletionRequest() { - } + public Llama2ChatCompletionRequest() {} - public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) { - this.inputs = inputs; - this.parameters = parameters; - } + public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) { + this.inputs = inputs; + this.parameters = parameters; + } - @Override - public String toString() { - return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]") - .add("\"inputs:\"" + inputs) - .add("\"parameters:\"" + parameters) - .toString(); - } + @Override + public String toString() { + return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]") + .add("\"inputs:\"" + inputs) + .add("\"parameters:\"" + parameters) + .toString(); + } - public static Llama2ChatCompletionRequestBuilder builder() { - return new Llama2ChatCompletionRequestBuilder(); - } + public static Llama2ChatCompletionRequestBuilder builder() { + return new Llama2ChatCompletionRequestBuilder(); + } + public String getInputs() { + return inputs; + } - public String getInputs() { - return inputs; - } + public void setInputs(String inputs) { + this.inputs = inputs; + } - public void setInputs(String inputs) { - this.inputs = inputs; - } + public JSONObject getParameters() { + return parameters; + } - public JSONObject getParameters() { - return parameters; - } + public void setParameters(JSONObject parameters) { + this.parameters = parameters; + } - public void setParameters(JSONObject parameters) { - this.parameters = parameters; - } - - - public static class Llama2ChatCompletionRequestBuilder { - private String inputs; - private JSONObject parameters; - - private Llama2ChatCompletionRequestBuilder() { - } + public static class Llama2ChatCompletionRequestBuilder { + private String inputs; + private JSONObject parameters; + private Llama2ChatCompletionRequestBuilder() {} - public Llama2ChatCompletionRequestBuilder inputs(String inputs) { - this.inputs = inputs; - return this; - } + public Llama2ChatCompletionRequestBuilder inputs(String inputs) { + this.inputs = inputs; + return this; + } - public Llama2ChatCompletionRequestBuilder parameters(JSONObject parameters) { - this.parameters = parameters; - return this; - } + public Llama2ChatCompletionRequestBuilder parameters(JSONObject parameters) { + this.parameters = parameters; + return this; + } - public Llama2ChatCompletionRequest build() { - return new Llama2ChatCompletionRequest( - inputs, - parameters); - } + public Llama2ChatCompletionRequest build() { + return new Llama2ChatCompletionRequest(inputs, parameters); } -} \ No newline at end of file + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java index a563a37c5..033d78286 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java @@ -3,16 +3,16 @@ import com.fasterxml.jackson.annotation.JsonProperty; public class Llama2ChatCompletionResponse { - @JsonProperty("generated_text") - private String generatedText; + @JsonProperty("generated_text") + private String generatedText; - public Llama2ChatCompletionResponse() {} + public Llama2ChatCompletionResponse() {} - public String getGeneratedText() { - return generatedText; - } + public String getGeneratedText() { + return generatedText; + } - public void setGeneratedText(String generatedText) { - this.generatedText = generatedText; - } + public void setGeneratedText(String generatedText) { + this.generatedText = generatedText; + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java index bd66511d7..988a409ff 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java @@ -9,6 +9,6 @@ import java.util.List; public interface Llama2Service { - @POST(value = "llama2/chat-completion") - Single> chatCompletion(@Body Llama2Endpoint llama2Endpoint); + @POST(value = "llama2/chat-completion") + Single> chatCompletion(@Body Llama2Endpoint llama2Endpoint); } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java index 05b744344..f4cee2566 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java @@ -23,36 +23,41 @@ @RestController("Service Llama2Controller") @RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama2") public class Llama2Controller { - @Autowired - private ChatCompletionLogService chatCompletionLogService; + @Autowired private ChatCompletionLogService chatCompletionLogService; - @Autowired private JsonnetLogService jsonnetLogService; + @Autowired private JsonnetLogService jsonnetLogService; - @Autowired private Environment env; - @Autowired private Llama2Client llama2Client; + @Autowired private Environment env; + @Autowired private Llama2Client llama2Client; - @PostMapping(value = "/chat-completion") - public Single> chatCompletion(@RequestBody Llama2Endpoint llama2Endpoint) { + @PostMapping(value = "/chat-completion") + public Single> chatCompletion( + @RequestBody Llama2Endpoint llama2Endpoint) { - System.out.println("\nI'm in controller class\n"); + System.out.println("\nI'm in controller class\n"); - JSONObject parameters = new JSONObject(); - parameters.put("do_sample", llama2Endpoint.getDoSample()); - parameters.put("top_p", llama2Endpoint.getTopP()); - parameters.put("temperature", llama2Endpoint.getTemperature()); - parameters.put("top_k", llama2Endpoint.getTopK()); - parameters.put("max_new_tokens", llama2Endpoint.getMaxNewTokens()); - parameters.put("repetition_penalty", llama2Endpoint.getRepetitionPenalty()); - parameters.put("stop", llama2Endpoint.getStop() != null ? llama2Endpoint.getStop() : Collections.emptyList()); + JSONObject parameters = new JSONObject(); + parameters.put("do_sample", llama2Endpoint.getDoSample()); + parameters.put("top_p", llama2Endpoint.getTopP()); + parameters.put("temperature", llama2Endpoint.getTemperature()); + parameters.put("top_k", llama2Endpoint.getTopK()); + parameters.put("max_new_tokens", llama2Endpoint.getMaxNewTokens()); + parameters.put("repetition_penalty", llama2Endpoint.getRepetitionPenalty()); + parameters.put( + "stop", + llama2Endpoint.getStop() != null ? llama2Endpoint.getStop() : Collections.emptyList()); - System.out.println("\nI'm in controller class after json object\n"); + System.out.println("\nI'm in controller class after json object\n"); - Llama2ChatCompletionRequest llama2ChatCompletionRequest = - Llama2ChatCompletionRequest.builder().inputs(llama2Endpoint.getInputs()).parameters(parameters).build(); + Llama2ChatCompletionRequest llama2ChatCompletionRequest = + Llama2ChatCompletionRequest.builder() + .inputs(llama2Endpoint.getInputs()) + .parameters(parameters) + .build(); - EdgeChain> edgeChain = - llama2Client.createChatCompletion(llama2ChatCompletionRequest, llama2Endpoint); + EdgeChain> edgeChain = + llama2Client.createChatCompletion(llama2ChatCompletionRequest, llama2Endpoint); - return edgeChain.toSingle(); - } + return edgeChain.toSingle(); + } }