-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adding support for LLama API * Adding support for GET request * Refactoring llama endpoint
- Loading branch information
1 parent
1b63602
commit 24cf405
Showing
6 changed files
with
277 additions
and
83 deletions.
There are no files selected for viewing
47 changes: 47 additions & 0 deletions
47
...ring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package com.edgechain.lib.endpoint.impl.llm; | ||
|
||
import com.edgechain.lib.endpoint.Endpoint; | ||
import com.edgechain.lib.request.ArkRequest; | ||
import com.edgechain.lib.retrofit.Llama2Service; | ||
import com.edgechain.lib.retrofit.client.RetrofitClientInstance; | ||
import com.edgechain.lib.rxjava.retry.RetryPolicy; | ||
import io.reactivex.rxjava3.core.Observable; | ||
import org.modelmapper.ModelMapper; | ||
import retrofit2.Retrofit; | ||
|
||
public class LLamaQuickstart extends Endpoint { | ||
private final Retrofit retrofit = RetrofitClientInstance.getInstance(); | ||
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); | ||
private final ModelMapper modelMapper = new ModelMapper(); | ||
private String query; | ||
public LLamaQuickstart() { | ||
} | ||
|
||
public LLamaQuickstart(String url, RetryPolicy retryPolicy) { | ||
super(url, retryPolicy); | ||
} | ||
|
||
public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) { | ||
super(url, retryPolicy); | ||
this.query = query; | ||
} | ||
|
||
public String getQuery() { | ||
return query; | ||
} | ||
|
||
public void setQuery(String query) { | ||
this.query = query; | ||
} | ||
|
||
|
||
public Observable<String> chatCompletion(String query, ArkRequest arkRequest) { | ||
LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class); | ||
mapper.setQuery(query); | ||
return chatCompletion(mapper, arkRequest); | ||
} | ||
|
||
private Observable<String> chatCompletion(LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) { | ||
return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart)); | ||
} | ||
} |
61 changes: 61 additions & 0 deletions
61
FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package com.edgechain.lib.llama2; | ||
|
||
import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; | ||
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; | ||
import com.edgechain.lib.llama2.request.LLamaCompletionRequest; | ||
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; | ||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import io.reactivex.rxjava3.core.Observable; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.beans.factory.annotation.Autowired; | ||
import org.springframework.http.HttpEntity; | ||
import org.springframework.http.HttpHeaders; | ||
import org.springframework.http.MediaType; | ||
import org.springframework.stereotype.Service; | ||
import org.springframework.web.client.RestTemplate; | ||
|
||
import java.util.List; | ||
|
||
@Service | ||
public class LLamaClient { | ||
@Autowired private ObjectMapper objectMapper; | ||
private final Logger logger = LoggerFactory.getLogger(getClass()); | ||
private final RestTemplate restTemplate = new RestTemplate(); | ||
|
||
public EdgeChain<List<String>> createChatCompletion( | ||
LLamaCompletionRequest request, LLamaQuickstart endpoint) { | ||
return new EdgeChain<>( | ||
Observable.create( | ||
emitter -> { | ||
try { | ||
|
||
logger.info("Logging ChatCompletion...."); | ||
|
||
logger.info("==============REQUEST DATA================"); | ||
logger.info(request.toString()); | ||
|
||
// Create headers | ||
HttpHeaders headers = new HttpHeaders(); | ||
headers.setContentType(MediaType.APPLICATION_JSON); | ||
HttpEntity<LLamaCompletionRequest> entity = new HttpEntity<>(request, headers); | ||
// | ||
String response = | ||
restTemplate.postForObject(endpoint.getUrl(), entity, String.class); | ||
|
||
List<String> chatCompletionResponse = | ||
objectMapper.readValue( | ||
response, new TypeReference<>() {}); | ||
emitter.onNext(chatCompletionResponse); | ||
emitter.onComplete(); | ||
|
||
} catch (final Exception e) { | ||
emitter.onError(e); | ||
} | ||
}), | ||
endpoint); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
.../edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package com.edgechain.lib.llama2.request; | ||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
import java.util.StringJoiner; | ||
|
||
public class LLamaCompletionRequest { | ||
@JsonProperty("text_inputs") | ||
private String textInputs; | ||
@JsonProperty("return_full_text") | ||
private Boolean returnFullText; | ||
@JsonProperty("top_k") | ||
private Integer topK; | ||
public LLamaCompletionRequest() {} | ||
|
||
public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) { | ||
this.textInputs = textInputs; | ||
this.returnFullText = returnFullText; | ||
this.topK = topK; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return new StringJoiner(", ", LLamaCompletionRequest.class.getSimpleName() + "{", "}") | ||
.add("\"text_inputs:\"" + textInputs) | ||
.add("\"return_full_text:\"" + returnFullText) | ||
.add("\"top_k:\"" + topK) | ||
.toString(); | ||
} | ||
|
||
public static LlamaSupportChatCompletionRequestBuilder builder() { | ||
return new LlamaSupportChatCompletionRequestBuilder(); | ||
} | ||
|
||
public String getTextInputs() { | ||
return textInputs; | ||
} | ||
|
||
public void setTextInputs(String textInputs) { | ||
this.textInputs = textInputs; | ||
} | ||
|
||
public Boolean getReturnFullText() { | ||
return returnFullText; | ||
} | ||
|
||
public void setReturnFullText(Boolean returnFullText) { | ||
this.returnFullText = returnFullText; | ||
} | ||
|
||
public Integer getTopK() { | ||
return topK; | ||
} | ||
|
||
public void setTopK(Integer topK) { | ||
this.topK = topK; | ||
} | ||
|
||
public static class LlamaSupportChatCompletionRequestBuilder { | ||
private String textInputs; | ||
private Boolean returnFullText; | ||
private Integer topK; | ||
|
||
private LlamaSupportChatCompletionRequestBuilder() {} | ||
|
||
public LlamaSupportChatCompletionRequestBuilder textInputs(String textInputs) { | ||
this.textInputs = textInputs; | ||
return this; | ||
} | ||
|
||
public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFullText) { | ||
this.returnFullText = returnFullText; | ||
return this; | ||
} | ||
|
||
public LlamaSupportChatCompletionRequestBuilder topK(Integer topK){ | ||
this.topK = topK; | ||
return this; | ||
} | ||
|
||
public LLamaCompletionRequest build() { | ||
return new LLamaCompletionRequest(textInputs, returnFullText, topK); | ||
} | ||
} | ||
} |
10 changes: 8 additions & 2 deletions
10
FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,20 @@ | ||
package com.edgechain.lib.retrofit; | ||
|
||
import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; | ||
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import com.edgechain.lib.request.ArkRequest; | ||
import io.reactivex.rxjava3.core.Single; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.GET; | ||
import retrofit2.http.POST; | ||
import retrofit2.http.Query; | ||
|
||
import java.util.List; | ||
|
||
public interface Llama2Service { | ||
@POST(value = "llama2/chat-completion") | ||
@POST(value = "llama/chat-completion") | ||
Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint); | ||
} | ||
@POST(value = "llama/chat-completion") | ||
Single<String> llamaCompletion(@Body LLamaQuickstart lLamaQuickstart); | ||
} |
Oops, something went wrong.