Skip to content

Commit

Permalink
Llama Implementationn. (#248)
Browse files Browse the repository at this point in the history
* 1

* 2

* 3

* Implementing llama, added endpoint, client, completion request, response,and controller
  • Loading branch information
hemantDwivedi authored Oct 6, 2023
1 parent f6f813d commit e24c63b
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package com.edgechain.lib.endpoint.impl.llm;

import com.edgechain.lib.endpoint.Endpoint;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.openai.response.ChatCompletionResponse;
import com.edgechain.lib.request.ArkRequest;
import com.edgechain.lib.retrofit.Llama2Service;
import com.edgechain.lib.retrofit.client.RetrofitClientInstance;
import com.edgechain.lib.rxjava.retry.RetryPolicy;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.reactivex.rxjava3.core.Observable;
import org.json.JSONObject;
import org.modelmapper.ModelMapper;
import retrofit2.Retrofit;

import java.util.List;
import java.util.Objects;

public class Llama2Endpoint extends Endpoint {
private final Retrofit retrofit = RetrofitClientInstance.getInstance();
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class);

private final ModelMapper modelMapper = new ModelMapper();

private String inputs;
private JSONObject parameters;
private Double temperature;
@JsonProperty("top_k")
private Integer topK;
@JsonProperty("top_p")
private Double topP;

@JsonProperty("do_sample")
private Boolean doSample;
@JsonProperty("max_new_tokens")
private Integer maxNewTokens;
@JsonProperty("repetition_penalty")
private Double repetitionPenalty;
private List<String> stop;
private String chainName;
private String callIdentifier;

public Llama2Endpoint() {
}

public Llama2Endpoint(String url, RetryPolicy retryPolicy,
Double temperature, Integer topK, Double topP,
Boolean doSample, Integer maxNewTokens, Double repetitionPenalty,
List<String> stop) {
super(url, retryPolicy);
this.temperature = temperature;
this.topK = topK;
this.topP = topP;
this.doSample = doSample;
this.maxNewTokens = maxNewTokens;
this.repetitionPenalty = repetitionPenalty;
this.stop = stop;
}

public Llama2Endpoint(String url, RetryPolicy retryPolicy) {
super(url, retryPolicy);
this.temperature = 0.7;
this.maxNewTokens = 512;
}

public String getInputs() {
return inputs;
}

public void setInputs(String inputs) {
this.inputs = inputs;
}

public JSONObject getParameters() {
return parameters;
}

public void setParameters(JSONObject parameters) {
this.parameters = parameters;
}

public Double getTemperature() {
return temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}

public Integer getTopK() {
return topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

public Double getTopP() {
return topP;
}

public void setTopP(Double topP) {
this.topP = topP;
}

public Boolean getDoSample() {
return doSample;
}

public void setDoSample(Boolean doSample) {
this.doSample = doSample;
}

public Integer getMaxNewTokens() {
return maxNewTokens;
}

public void setMaxNewTokens(Integer maxNewTokens) {
this.maxNewTokens = maxNewTokens;
}

public Double getRepetitionPenalty() {
return repetitionPenalty;
}

public void setRepetitionPenalty(Double repetitionPenalty) {
this.repetitionPenalty = repetitionPenalty;
}

public List<String> getStop() {
return stop;
}

public void setStop(List<String> stop) {
this.stop = stop;
}

public String getChainName() {
return chainName;
}

public void setChainName(String chainName) {
this.chainName = chainName;
}

public String getCallIdentifier() {
return callIdentifier;
}

public void setCallIdentifier(String callIdentifier) {
this.callIdentifier = callIdentifier;
}

public Observable<List<Llama2ChatCompletionResponse>> chatCompletion(
String inputs,String chainName, ArkRequest arkRequest) {

Llama2Endpoint mapper = modelMapper.map(this, Llama2Endpoint.class);
mapper.setInputs(inputs);
mapper.setChainName(chainName);
return chatCompletion(mapper, arkRequest);
}

private Observable<List<Llama2ChatCompletionResponse>> chatCompletion(Llama2Endpoint mapper, ArkRequest arkRequest) {

if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI());
else mapper.setCallIdentifier("URI wasn't provided");

return Observable.fromSingle(this.llama2Service.chatCompletion(mapper));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.edgechain.lib.llama2;


import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import com.edgechain.lib.utils.JsonUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.rxjava3.core.Observable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.*;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.List;
import java.util.Objects;

@Service
public class Llama2Client {
@Autowired
private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();
public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion(
Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Llama2ChatCompletionRequest llamaRequest = new Llama2ChatCompletionRequest();
//
// llamaRequest.setInputs(request.getInputs());
// llamaRequest.setParameters(request.getParameters());



// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response = restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<Llama2ChatCompletionResponse> chatCompletionResponse =
objectMapper.readValue(response, new TypeReference<List<Llama2ChatCompletionResponse>>() {});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.edgechain.lib.llama2.request;

import com.edgechain.lib.openai.request.ChatCompletionRequest;
import com.edgechain.lib.openai.request.ChatMessage;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.json.JSONObject;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;

public class Llama2ChatCompletionRequest {

private String inputs;
private JSONObject parameters;

public Llama2ChatCompletionRequest() {
}

public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) {
this.inputs = inputs;
this.parameters = parameters;
}

@Override
public String toString() {
return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]")
.add("\"inputs:\"" + inputs)
.add("\"parameters:\"" + parameters)
.toString();
}

public static Llama2ChatCompletionRequestBuilder builder() {
return new Llama2ChatCompletionRequestBuilder();
}


public String getInputs() {
return inputs;
}

public void setInputs(String inputs) {
this.inputs = inputs;
}

public JSONObject getParameters() {
return parameters;
}

public void setParameters(JSONObject parameters) {
this.parameters = parameters;
}


public static class Llama2ChatCompletionRequestBuilder {
private String inputs;
private JSONObject parameters;

private Llama2ChatCompletionRequestBuilder() {
}


public Llama2ChatCompletionRequestBuilder inputs(String inputs) {
this.inputs = inputs;
return this;
}

public Llama2ChatCompletionRequestBuilder parameters(JSONObject parameters) {
this.parameters = parameters;
return this;
}

public Llama2ChatCompletionRequest build() {
return new Llama2ChatCompletionRequest(
inputs,
parameters);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.edgechain.lib.llama2.response;

import com.fasterxml.jackson.annotation.JsonProperty;

public class Llama2ChatCompletionResponse {
@JsonProperty("generated_text")
private String generatedText;

public Llama2ChatCompletionResponse() {}

public String getGeneratedText() {
return generatedText;
}

public void setGeneratedText(String generatedText) {
this.generatedText = generatedText;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.edgechain.lib.retrofit;

import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import io.reactivex.rxjava3.core.Single;
import retrofit2.http.Body;
import retrofit2.http.POST;

import java.util.List;

public interface Llama2Service {
@POST(value = "llama2/chat-completion")
Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint);
}
Loading

0 comments on commit e24c63b

Please sign in to comment.