Skip to content

Commit

Permalink
Extend Hugging Face configuration with doSample, top-p, top-k and rep…
Browse files Browse the repository at this point in the history
…etition penalty
  • Loading branch information
cescoffier committed Nov 22, 2023
1 parent 9143bef commit 87fce09
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import java.net.URL;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.OptionalInt;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
Expand Down Expand Up @@ -34,9 +37,13 @@ public class QuarkusHuggingFaceChatModel implements ChatLanguageModel {
private final Integer maxNewTokens;
private final Boolean returnFullText;
private final Boolean waitForModel;
private final Optional<Boolean> doSample;
private final OptionalDouble topP;
private final OptionalInt topK;
private final OptionalDouble repetitionPenalty;

private QuarkusHuggingFaceChatModel(Builder builder) {
this.client = CLIENT_FACTORY.create(new HuggingFaceClientFactory.Input() {
this.client = CLIENT_FACTORY.create(builder, new HuggingFaceClientFactory.Input() {
@Override
public String apiKey() {
return builder.accessToken;
Expand All @@ -56,6 +63,10 @@ public Duration timeout() {
this.maxNewTokens = builder.maxNewTokens;
this.returnFullText = builder.returnFullText;
this.waitForModel = builder.waitForModel;
this.doSample = builder.doSample;
this.topP = builder.topP;
this.topK = builder.topK;
this.repetitionPenalty = builder.repetitionPenalty;
}

public static Builder builder() {
Expand All @@ -65,15 +76,23 @@ public static Builder builder() {
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {

Parameters.Builder builder = Parameters.builder()
.temperature(temperature)
.maxNewTokens(maxNewTokens)
.returnFullText(returnFullText);

doSample.ifPresent(builder::doSample);
topK.ifPresent(builder::topK);
topP.ifPresent(builder::topP);
repetitionPenalty.ifPresent(builder::repetitionPenalty);

Parameters parameters = builder
.build();
TextGenerationRequest request = TextGenerationRequest.builder()
.inputs(messages.stream()
.map(ChatMessage::text)
.collect(joining("\n")))
.parameters(Parameters.builder()
.temperature(temperature)
.maxNewTokens(maxNewTokens)
.returnFullText(returnFullText)
.build())
.parameters(parameters)
.options(Options.builder()
.waitForModel(waitForModel)
.build())
Expand Down Expand Up @@ -103,6 +122,14 @@ public static final class Builder {
private Boolean returnFullText;
private Boolean waitForModel = true;
private URI url;
private Optional<Boolean> doSample;

private OptionalInt topK;
private OptionalDouble topP;

private OptionalDouble repetitionPenalty;
public boolean logResponses;
public boolean logRequests;

public Builder accessToken(String accessToken) {
this.accessToken = accessToken;
Expand Down Expand Up @@ -143,8 +170,38 @@ public Builder waitForModel(Boolean waitForModel) {
return this;
}

public Builder doSample(Optional<Boolean> doSample) {
this.doSample = doSample;
return this;
}

public Builder topK(OptionalInt topK) {
this.topK = topK;
return this;
}

public Builder topP(OptionalDouble topP) {
this.topP = topP;
return this;
}

public Builder repetitionPenalty(OptionalDouble repetitionPenalty) {
this.repetitionPenalty = repetitionPenalty;
return this;
}

public QuarkusHuggingFaceChatModel build() {
return new QuarkusHuggingFaceChatModel(this);
}

public Builder logRequests(boolean logRequests) {
this.logRequests = logRequests;
return this;
}

public Builder logResponses(boolean logResponses) {
this.logResponses = logResponses;
return this;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
package io.quarkiverse.langchain4j.huggingface;

import static java.util.stream.Collectors.joining;
import static java.util.stream.StreamSupport.stream;

import java.net.URI;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.client.api.LoggingScope;

import dev.langchain4j.model.huggingface.client.EmbeddingRequest;
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
import dev.langchain4j.model.huggingface.client.TextGenerationRequest;
import dev.langchain4j.model.huggingface.client.TextGenerationResponse;
import dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;

public class QuarkusHuggingFaceClientFactory implements HuggingFaceClientFactory {

Expand All @@ -18,12 +32,21 @@ public HuggingFaceClient create(Input input) {
throw new UnsupportedOperationException("Should not be called");
}

public HuggingFaceClient create(Input input, URI url) {
HuggingFaceRestApi restApi = QuarkusRestClientBuilder.newBuilder()
public HuggingFaceClient create(QuarkusHuggingFaceChatModel.Builder config, Input input, URI url) {
QuarkusRestClientBuilder builder = QuarkusRestClientBuilder.newBuilder()
.baseUri(url)
.connectTimeout(input.timeout().toSeconds(), TimeUnit.SECONDS)
.readTimeout(input.timeout().toSeconds(), TimeUnit.SECONDS)
.readTimeout(input.timeout().toSeconds(), TimeUnit.SECONDS);

if (config != null && (config.logRequests || config.logResponses)) {
builder.loggingScope(LoggingScope.REQUEST_RESPONSE);
builder.clientLogger(new HuggingFaceClientLogger(config.logRequests,
config.logResponses));
}

HuggingFaceRestApi restApi = builder
.build(HuggingFaceRestApi.class);

return new QuarkusHuggingFaceClient(restApi, input.apiKey());
}

Expand Down Expand Up @@ -61,4 +84,116 @@ public List<float[]> embed(EmbeddingRequest request) {
return restApi.embed(request, token);
}
}

/**
* Introduce a custom logger as the stock one logs at the DEBUG level by default...
*/
class HuggingFaceClientLogger implements ClientLogger {
private static final Logger log = Logger.getLogger(HuggingFaceClientLogger.class);

private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s*sk-)(\\w{2})(\\w+)(\\w{2})");

private final boolean logRequests;
private final boolean logResponses;

public HuggingFaceClientLogger(boolean logRequests, boolean logResponses) {
this.logRequests = logRequests;
this.logResponses = logResponses;
}

@Override
public void setBodySize(int bodySize) {
// ignore
}

@Override
public void logRequest(HttpClientRequest request, Buffer body, boolean omitBody) {
if (!logRequests || !log.isInfoEnabled()) {
return;
}
try {
log.infof("Request:\n- method: %s\n- url: %s\n- headers: %s\n- body: %s",
request.getMethod(),
request.absoluteURI(),
inOneLine(request.headers()),
bodyToString(body));
} catch (Exception e) {
log.warn("Failed to log request", e);
}
}

@Override
public void logResponse(HttpClientResponse response, boolean redirect) {
if (!logResponses || !log.isInfoEnabled()) {
return;
}
response.bodyHandler(new Handler<>() {
@Override
public void handle(Buffer body) {
try {
log.infof(
"Response:\n- status code: %s\n- headers: %s\n- body: %s",
response.statusCode(),
inOneLine(response.headers()),
bodyToString(body));
} catch (Exception e) {
log.warn("Failed to log response", e);
}
}
});
}

private String bodyToString(Buffer body) {
if (body == null) {
return "";
}
return body.toString();
}

private String inOneLine(MultiMap headers) {

return stream(headers.spliterator(), false)
.map(header -> {
String headerKey = header.getKey();
String headerValue = header.getValue();
if (headerKey.equals("Authorization")) {
headerValue = maskAuthorizationHeaderValue(headerValue);
} else if (headerKey.equals("api-key")) {
headerValue = maskApiKeyHeaderValue(headerValue);
}
return String.format("[%s: %s]", headerKey, headerValue);
})
.collect(joining(", "));
}

private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) {
try {

Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue);

StringBuilder sb = new StringBuilder();
while (matcher.find()) {
matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4));
}
matcher.appendTail(sb);

return sb.toString();
} catch (Exception e) {
return "Failed to mask the API key.";
}
}

private static String maskApiKeyHeaderValue(String apiKeyHeaderValue) {
try {
if (apiKeyHeaderValue.length() <= 4) {
return apiKeyHeaderValue;
}
return apiKeyHeaderValue.substring(0, 2)
+ "..."
+ apiKeyHeaderValue.substring(apiKeyHeaderValue.length() - 2);
} catch (Exception e) {
return "Failed to mask the API key.";
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class QuarkusHuggingFaceEmbeddingModel implements EmbeddingModel {
private final boolean waitForModel;

private QuarkusHuggingFaceEmbeddingModel(Builder builder) {
this.client = CLIENT_FACTORY.create(new HuggingFaceClientFactory.Input() {
this.client = CLIENT_FACTORY.create(null, new HuggingFaceClientFactory.Input() {
@Override
public String apiKey() {
return builder.accessToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ public Supplier<?> chatModel(Langchain4jHuggingFaceConfig runtimeConfig) {
.url(url)
.timeout(runtimeConfig.timeout())
.temperature(chatModelConfig.temperature())
.waitForModel(chatModelConfig.waitForModel());
.waitForModel(chatModelConfig.waitForModel())
.doSample(chatModelConfig.doSample())
.topP(chatModelConfig.topP())
.topK(chatModelConfig.topK())
.repetitionPenalty(chatModelConfig.repetitionPenalty())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses());

if (apiKeyOpt.isPresent()) {
builder.accessToken(apiKeyOpt.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import java.net.URL;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.OptionalInt;

import io.quarkus.runtime.annotations.ConfigGroup;
import io.smallrye.config.WithDefault;
Expand Down Expand Up @@ -50,4 +52,27 @@ public interface ChatModelConfig {
*/
@WithDefault("true")
Boolean waitForModel();

/**
* Whether or not to use sampling ; use greedy decoding otherwise.
*/
Optional<Boolean> doSample();

/**
* The number of highest probability vocabulary tokens to keep for top-k-filtering.
*/
OptionalInt topK();

/**
* If set to less than {@code 1}, only the most probable tokens with probabilities that add up to {@code top_p} or
* higher are kept for generation.
*/
OptionalDouble topP();

/**
* The parameter for repetition penalty. 1.0 means no penalty.
* See <a href="https://arxiv.org/pdf/1909.05858.pdf">this paper</a> for more details.
*/
OptionalDouble repetitionPenalty();

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,16 @@ public interface Langchain4jHuggingFaceConfig {
* Embedding model related settings
*/
EmbeddingModelConfig embeddingModel();

/**
* Whether the OpenAI client should log requests
*/
@WithDefault("false")
Boolean logRequests();

/**
* Whether the OpenAI client should log responses
*/
@WithDefault("false")
Boolean logResponses();
}

0 comments on commit 87fce09

Please sign in to comment.