Skip to content

Commit

Permalink
Get request (#255)
Browse files Browse the repository at this point in the history
* Adding support for LLama API

* Adding support for GET request

* Refactoring llama endpoint
  • Loading branch information
hemantDwivedi authored Oct 15, 2023
1 parent 1b63602 commit 24cf405
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 83 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.edgechain.lib.endpoint.impl.llm;

import com.edgechain.lib.endpoint.Endpoint;
import com.edgechain.lib.request.ArkRequest;
import com.edgechain.lib.retrofit.Llama2Service;
import com.edgechain.lib.retrofit.client.RetrofitClientInstance;
import com.edgechain.lib.rxjava.retry.RetryPolicy;
import io.reactivex.rxjava3.core.Observable;
import org.modelmapper.ModelMapper;
import retrofit2.Retrofit;

public class LLamaQuickstart extends Endpoint {
private final Retrofit retrofit = RetrofitClientInstance.getInstance();
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class);
private final ModelMapper modelMapper = new ModelMapper();
private String query;
public LLamaQuickstart() {
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy) {
super(url, retryPolicy);
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) {
super(url, retryPolicy);
this.query = query;
}

public String getQuery() {
return query;
}

public void setQuery(String query) {
this.query = query;
}


public Observable<String> chatCompletion(String query, ArkRequest arkRequest) {
LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class);
mapper.setQuery(query);
return chatCompletion(mapper, arkRequest);
}

private Observable<String> chatCompletion(LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) {
return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.edgechain.lib.llama2;

import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.request.LLamaCompletionRequest;
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.rxjava3.core.Observable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.List;

@Service
public class LLamaClient {
@Autowired private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<String>> createChatCompletion(
LLamaCompletionRequest request, LLamaQuickstart endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<LLamaCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response =
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<String> chatCompletionResponse =
objectMapper.readValue(
response, new TypeReference<>() {});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.edgechain.lib.llama2;

import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
Expand All @@ -14,51 +15,77 @@
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.Collections;
import java.util.List;
import java.util.Map;

@Service
public class Llama2Client {
@Autowired private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion(
Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Llama2ChatCompletionRequest llamaRequest = new
// Llama2ChatCompletionRequest();
//
// llamaRequest.setInputs(request.getInputs());
//
// llamaRequest.setParameters(request.getParameters());

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response =
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<Llama2ChatCompletionResponse> chatCompletionResponse =
objectMapper.readValue(
response, new TypeReference<List<Llama2ChatCompletionResponse>>() {});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}
@Autowired
private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion(
Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response =
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<Llama2ChatCompletionResponse> chatCompletionResponse =
objectMapper.readValue(
response, new TypeReference<>() {
});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}

public EdgeChain<String> createGetChatCompletion(LLamaQuickstart endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.set("User-Agent", "insomnia/8.2.0");
HttpEntity<?> entity = new HttpEntity<>(headers);

Map<String, String> param = Collections.singletonMap("query", endpoint.getQuery());

String endpointUrl = endpoint.getUrl() + "?query={query}";

ResponseEntity<String> response = restTemplate.exchange(endpointUrl, HttpMethod.GET, entity, String.class, param);

logger.info("\nRESPONSE DATA {}\n", response.getBody());

emitter.onNext(response.getBody());
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.edgechain.lib.llama2.request;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.StringJoiner;

public class LLamaCompletionRequest {
@JsonProperty("text_inputs")
private String textInputs;
@JsonProperty("return_full_text")
private Boolean returnFullText;
@JsonProperty("top_k")
private Integer topK;
public LLamaCompletionRequest() {}

public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) {
this.textInputs = textInputs;
this.returnFullText = returnFullText;
this.topK = topK;
}

@Override
public String toString() {
return new StringJoiner(", ", LLamaCompletionRequest.class.getSimpleName() + "{", "}")
.add("\"text_inputs:\"" + textInputs)
.add("\"return_full_text:\"" + returnFullText)
.add("\"top_k:\"" + topK)
.toString();
}

public static LlamaSupportChatCompletionRequestBuilder builder() {
return new LlamaSupportChatCompletionRequestBuilder();
}

public String getTextInputs() {
return textInputs;
}

public void setTextInputs(String textInputs) {
this.textInputs = textInputs;
}

public Boolean getReturnFullText() {
return returnFullText;
}

public void setReturnFullText(Boolean returnFullText) {
this.returnFullText = returnFullText;
}

public Integer getTopK() {
return topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

public static class LlamaSupportChatCompletionRequestBuilder {
private String textInputs;
private Boolean returnFullText;
private Integer topK;

private LlamaSupportChatCompletionRequestBuilder() {}

public LlamaSupportChatCompletionRequestBuilder textInputs(String textInputs) {
this.textInputs = textInputs;
return this;
}

public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFullText) {
this.returnFullText = returnFullText;
return this;
}

public LlamaSupportChatCompletionRequestBuilder topK(Integer topK){
this.topK = topK;
return this;
}

public LLamaCompletionRequest build() {
return new LLamaCompletionRequest(textInputs, returnFullText, topK);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package com.edgechain.lib.retrofit;

import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.request.ArkRequest;
import io.reactivex.rxjava3.core.Single;
import retrofit2.http.Body;
import retrofit2.http.GET;
import retrofit2.http.POST;
import retrofit2.http.Query;

import java.util.List;

public interface Llama2Service {
@POST(value = "llama2/chat-completion")
@POST(value = "llama/chat-completion")
Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint);
}
@POST(value = "llama/chat-completion")
Single<String> llamaCompletion(@Body LLamaQuickstart lLamaQuickstart);
}
Loading

0 comments on commit 24cf405

Please sign in to comment.