diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java new file mode 100644 index 000000000..447e19e0f --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java @@ -0,0 +1,170 @@ +package com.edgechain.lib.endpoint.impl.llm; + +import com.edgechain.lib.endpoint.Endpoint; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import com.edgechain.lib.openai.response.ChatCompletionResponse; +import com.edgechain.lib.request.ArkRequest; +import com.edgechain.lib.retrofit.Llama2Service; +import com.edgechain.lib.retrofit.client.RetrofitClientInstance; +import com.edgechain.lib.rxjava.retry.RetryPolicy; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.reactivex.rxjava3.core.Observable; +import org.json.JSONObject; +import org.modelmapper.ModelMapper; +import retrofit2.Retrofit; + +import java.util.List; +import java.util.Objects; + +public class Llama2Endpoint extends Endpoint { + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); + + private final ModelMapper modelMapper = new ModelMapper(); + + private String inputs; + private JSONObject parameters; + private Double temperature; + @JsonProperty("top_k") + private Integer topK; + @JsonProperty("top_p") + private Double topP; + + @JsonProperty("do_sample") + private Boolean doSample; + @JsonProperty("max_new_tokens") + private Integer maxNewTokens; + @JsonProperty("repetition_penalty") + private Double repetitionPenalty; + private List<String> stop; + private String chainName; + private String callIdentifier; + + public Llama2Endpoint() { + } + + public Llama2Endpoint(String url, RetryPolicy retryPolicy, + Double temperature, Integer topK, Double topP, + Boolean doSample, Integer maxNewTokens, Double repetitionPenalty, + List<String> stop) { + super(url, retryPolicy); + this.temperature = temperature; + this.topK = topK; + this.topP = topP; + this.doSample = doSample; + this.maxNewTokens = maxNewTokens; + this.repetitionPenalty = repetitionPenalty; + this.stop = stop; + } + + public Llama2Endpoint(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + this.temperature = 0.7; + this.maxNewTokens = 512; + } + + public String getInputs() { + return inputs; + } + + public void setInputs(String inputs) { + this.inputs = inputs; + } + + public JSONObject getParameters() { + return parameters; + } + + public void setParameters(JSONObject parameters) { + this.parameters = parameters; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Boolean getDoSample() { + return doSample; + } + + public void setDoSample(Boolean doSample) { + this.doSample = doSample; + } + + public Integer getMaxNewTokens() { + return maxNewTokens; + } + + public void setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + } + + public Double getRepetitionPenalty() { + return repetitionPenalty; + } + + public void setRepetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + } + + public List<String> getStop() { + return stop; + } + + public void setStop(List<String> stop) { + this.stop = stop; + } + + public String getChainName() { + return chainName; + } + + public void setChainName(String chainName) { + this.chainName = chainName; + } + + public String getCallIdentifier() { + return callIdentifier; + } + + public void setCallIdentifier(String callIdentifier) { + this.callIdentifier = callIdentifier; + } + + public Observable<List<Llama2ChatCompletionResponse>> chatCompletion( + String inputs,String chainName, ArkRequest arkRequest) { + + Llama2Endpoint mapper = modelMapper.map(this, Llama2Endpoint.class); + mapper.setInputs(inputs); + mapper.setChainName(chainName); + return chatCompletion(mapper, arkRequest); + } + + private Observable<List<Llama2ChatCompletionResponse>> chatCompletion(Llama2Endpoint mapper, ArkRequest arkRequest) { + + if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); + else mapper.setCallIdentifier("URI wasn't provided"); + + return Observable.fromSingle(this.llama2Service.chatCompletion(mapper)); + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java new file mode 100644 index 000000000..c8a73edf6 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java @@ -0,0 +1,67 @@ +package com.edgechain.lib.llama2; + + +import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; +import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import com.edgechain.lib.utils.JsonUtils; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.reactivex.rxjava3.core.Observable; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.*; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import java.util.List; +import java.util.Objects; + +@Service +public class Llama2Client { + @Autowired + private ObjectMapper objectMapper; + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); + public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion( + Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + logger.info("Logging ChatCompletion...."); + + logger.info("==============REQUEST DATA================"); + logger.info(request.toString()); + +// Llama2ChatCompletionRequest llamaRequest = new Llama2ChatCompletionRequest(); +// +// llamaRequest.setInputs(request.getInputs()); +// llamaRequest.setParameters(request.getParameters()); + + + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers); +// + String response = restTemplate.postForObject(endpoint.getUrl(), entity, String.class); + + List<Llama2ChatCompletionResponse> chatCompletionResponse = + objectMapper.readValue(response, new TypeReference<List<Llama2ChatCompletionResponse>>() {}); + emitter.onNext(chatCompletionResponse); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java new file mode 100644 index 000000000..7dc22a280 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java @@ -0,0 +1,80 @@ +package com.edgechain.lib.llama2.request; + +import com.edgechain.lib.openai.request.ChatCompletionRequest; +import com.edgechain.lib.openai.request.ChatMessage; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.json.JSONObject; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; + +public class Llama2ChatCompletionRequest { + + private String inputs; + private JSONObject parameters; + + public Llama2ChatCompletionRequest() { + } + + public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) { + this.inputs = inputs; + this.parameters = parameters; + } + + @Override + public String toString() { + return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]") + .add("\"inputs:\"" + inputs) + .add("\"parameters:\"" + parameters) + .toString(); + } + + public static Llama2ChatCompletionRequestBuilder builder() { + return new Llama2ChatCompletionRequestBuilder(); + } + + + public String getInputs() { + return inputs; + } + + public void setInputs(String inputs) { + this.inputs = inputs; + } + + public JSONObject getParameters() { + return parameters; + } + + public void setParameters(JSONObject parameters) { + this.parameters = parameters; + } + + + public static class Llama2ChatCompletionRequestBuilder { + private String inputs; + private JSONObject parameters; + + private Llama2ChatCompletionRequestBuilder() { + } + + + public Llama2ChatCompletionRequestBuilder inputs(String inputs) { + this.inputs = inputs; + return this; + } + + public Llama2ChatCompletionRequestBuilder parameters(JSONObject parameters) { + this.parameters = parameters; + return this; + } + + public Llama2ChatCompletionRequest build() { + return new Llama2ChatCompletionRequest( + inputs, + parameters); + } + } +} \ No newline at end of file diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java new file mode 100644 index 000000000..a563a37c5 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java @@ -0,0 +1,18 @@ +package com.edgechain.lib.llama2.response; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class Llama2ChatCompletionResponse { + @JsonProperty("generated_text") + private String generatedText; + + public Llama2ChatCompletionResponse() {} + + public String getGeneratedText() { + return generatedText; + } + + public void setGeneratedText(String generatedText) { + this.generatedText = generatedText; + } +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java new file mode 100644 index 000000000..bd66511d7 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java @@ -0,0 +1,14 @@ +package com.edgechain.lib.retrofit; + +import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import io.reactivex.rxjava3.core.Single; +import retrofit2.http.Body; +import retrofit2.http.POST; + +import java.util.List; + +public interface Llama2Service { + @POST(value = "llama2/chat-completion") + Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint); +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java new file mode 100644 index 000000000..05b744344 --- /dev/null +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java @@ -0,0 +1,58 @@ +package com.edgechain.service.controllers.llama2; + +import com.edgechain.lib.configuration.WebConfiguration; +import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; +import com.edgechain.lib.llama2.Llama2Client; +import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; +import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; +import com.edgechain.lib.logger.services.ChatCompletionLogService; +import com.edgechain.lib.logger.services.JsonnetLogService; +import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; +import io.reactivex.rxjava3.core.Single; +import org.json.JSONObject; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.env.Environment; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.Collections; +import java.util.List; + +@RestController("Service Llama2Controller") +@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama2") +public class Llama2Controller { + @Autowired + private ChatCompletionLogService chatCompletionLogService; + + @Autowired private JsonnetLogService jsonnetLogService; + + @Autowired private Environment env; + @Autowired private Llama2Client llama2Client; + + @PostMapping(value = "/chat-completion") + public Single<List<Llama2ChatCompletionResponse>> chatCompletion(@RequestBody Llama2Endpoint llama2Endpoint) { + + System.out.println("\nI'm in controller class\n"); + + JSONObject parameters = new JSONObject(); + parameters.put("do_sample", llama2Endpoint.getDoSample()); + parameters.put("top_p", llama2Endpoint.getTopP()); + parameters.put("temperature", llama2Endpoint.getTemperature()); + parameters.put("top_k", llama2Endpoint.getTopK()); + parameters.put("max_new_tokens", llama2Endpoint.getMaxNewTokens()); + parameters.put("repetition_penalty", llama2Endpoint.getRepetitionPenalty()); + parameters.put("stop", llama2Endpoint.getStop() != null ? llama2Endpoint.getStop() : Collections.emptyList()); + + System.out.println("\nI'm in controller class after json object\n"); + + Llama2ChatCompletionRequest llama2ChatCompletionRequest = + Llama2ChatCompletionRequest.builder().inputs(llama2Endpoint.getInputs()).parameters(parameters).build(); + + EdgeChain<List<Llama2ChatCompletionResponse>> edgeChain = + llama2Client.createChatCompletion(llama2ChatCompletionRequest, llama2Endpoint); + + return edgeChain.toSingle(); + } +}