Skip to content

Commit

Permalink
fix huggingface generate text
Browse files Browse the repository at this point in the history
update openapi.json

Signed-off-by: jitokim <[email protected]>
  • Loading branch information
jitokim committed Nov 13, 2024
1 parent b4e0a45 commit ac97e35
Show file tree
Hide file tree
Showing 2 changed files with 1,668 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
import org.springframework.ai.huggingface.invoker.ApiClient;
import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
import org.springframework.ai.huggingface.model.CompatGenerateRequest;
import org.springframework.ai.huggingface.model.GenerateParameters;
import org.springframework.ai.huggingface.model.GenerateRequest;
import org.springframework.ai.huggingface.model.GenerateResponse;

/**
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
* Endpoints for text generation.
*
* @author Mark Pollack
* @author Jihoon Kim
*/
public class HuggingfaceChatModel implements ChatModel {

Expand Down Expand Up @@ -89,22 +90,24 @@ public HuggingfaceChatModel(final String apiToken, String basePath) {
*/
@Override
public ChatResponse call(Prompt prompt) {
GenerateRequest generateRequest = new GenerateRequest();
generateRequest.setInputs(prompt.getContents());
CompatGenerateRequest compatGenerateRequest = new CompatGenerateRequest();
compatGenerateRequest.setInputs(prompt.getContents());
GenerateParameters generateParameters = new GenerateParameters();
// TODO - need to expose API to set parameters per call.
generateParameters.setMaxNewTokens(this.maxNewTokens);
generateRequest.setParameters(generateParameters);
GenerateResponse generateResponse = this.textGenApi.generate(generateRequest);
String generatedText = generateResponse.getGeneratedText();
compatGenerateRequest.setParameters(generateParameters);
List<GenerateResponse> generateResponses = this.textGenApi.compatGenerate(compatGenerateRequest);
List<Generation> generations = new ArrayList<>();
AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails();
Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails,
new TypeReference<Map<String, Object>>() {

});
Generation generation = new Generation(generatedText, detailsMap);
generations.add(generation);
for (GenerateResponse generateResponse : generateResponses) {
String generatedText = generateResponse.getGeneratedText();
AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails();
Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails,
new TypeReference<Map<String, Object>>() {

});
Generation generation = new Generation(generatedText, detailsMap);
generations.add(generation);
}
return new ChatResponse(generations);
}

Expand Down
Loading

0 comments on commit ac97e35

Please sign in to comment.