Skip to content

Commit

Permalink
Merge pull request #799 from quarkiverse/ollama-observability
Browse files Browse the repository at this point in the history
Introduce observability into Ollama ChatLanguageModel
  • Loading branch information
geoand authored Jul 31, 2024
2 parents 874dc37 + 3bf35b6 commit 1a828a3
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.deployment.DotNames;
import io.quarkiverse.langchain4j.deployment.devservice.Langchain4jDevServicesEnabled;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.DevServicesChatModelRequiredBuildItem;
Expand Down Expand Up @@ -110,7 +114,9 @@ void generateBeans(OllamaRecorder recorder,
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config, fixedRuntimeConfig, configName));
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(recorder.chatModel(config, fixedRuntimeConfig, configName));
addQualifierIfNecessary(builder, configName);
beanProducer.produce(builder.done());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.jboss.logging.Logger;
import org.jetbrains.annotations.NotNull;

import com.fasterxml.jackson.core.JsonProcessingException;

Expand All @@ -17,23 +22,33 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;

public class OllamaChatLanguageModel implements ChatLanguageModel {

private static final Logger log = Logger.getLogger(OllamaChatLanguageModel.class);

private final OllamaClient client;
private final String model;
private final String format;
private final Options options;
private final List<ChatModelListener> listeners;

private OllamaChatLanguageModel(Builder builder) {
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses,
builder.configName);
model = builder.model;
format = builder.format;
options = builder.options;
this.listeners = builder.listeners;
}

public static Builder builder() {
Expand Down Expand Up @@ -64,11 +79,62 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
.stream(false)
.build();

ChatResponse response = client.chat(request);
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});

try {
ChatResponse chatResponse = client.chat(request);
Response<AiMessage> response = toResponse(chatResponse);

ChatModelResponse modelListenerResponse = createModelListenerResponse(
null,
chatResponse.model(),
response);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});

return response;
} catch (RuntimeException e) {
ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
null,
attributes);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});

throw e;
}
}

private static @NotNull Response<AiMessage> toResponse(ChatResponse response) {
Response<AiMessage> result;
List<ToolCall> toolCalls = response.message().toolCalls();
if ((toolCalls == null) || toolCalls.isEmpty()) {
return Response.from(
result = Response.from(
AiMessage.from(response.message().content()),
new TokenUsage(response.promptEvalCount(), response.evalCount()));
} else {
Expand All @@ -86,12 +152,45 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
.build());
}

return Response.from(aiMessage(toolExecutionRequests),
result = Response.from(aiMessage(toolExecutionRequests),
new TokenUsage(response.promptEvalCount(), response.evalCount()));
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to parse tool call response", e);
}
}
return result;
}

private ChatModelRequest createModelListenerRequest(ChatRequest request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
Options options = request.options();
var builder = ChatModelRequest.builder()
.model(request.model())
.messages(messages)
.toolSpecifications(toolSpecifications);
if (options != null) {
builder.temperature(options.temperature())
.topP(options.topP())
.maxTokens(options.numPredict());
}
return builder.build();
}

private ChatModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
if (response == null) {
return null;
}

return ChatModelResponse.builder()
.id(responseId)
.model(responseModel)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build();
}

public static final class Builder {
Expand All @@ -104,6 +203,7 @@ public static final class Builder {
private boolean logRequests = false;
private boolean logResponses = false;
private String configName;
private List<ChatModelListener> listeners = Collections.emptyList();

private Builder() {
}
Expand Down Expand Up @@ -148,6 +248,11 @@ public Builder configName(String configName) {
return this;
}

public Builder listeners(List<ChatModelListener> listeners) {
this.listeners = listeners;
return this;
}

public OllamaChatLanguageModel build() {
return new OllamaChatLanguageModel(this);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package io.quarkiverse.langchain4j.ollama.runtime;

import java.time.Duration;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.util.TypeLiteral;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.DisabledChatLanguageModel;
import dev.langchain4j.model.chat.DisabledStreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.embedding.DisabledEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel;
Expand All @@ -18,14 +24,19 @@
import io.quarkiverse.langchain4j.ollama.runtime.config.LangChain4jOllamaConfig;
import io.quarkiverse.langchain4j.ollama.runtime.config.LangChain4jOllamaFixedRuntimeConfig;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.annotations.Recorder;

@Recorder
public class OllamaRecorder {

private static final String DEFAULT_BASE_URL = "http://localhost:11434";

public Supplier<ChatLanguageModel> chatModel(LangChain4jOllamaConfig runtimeConfig,
private static final TypeLiteral<Instance<ChatModelListener>> CHAT_MODEL_LISTENER_TYPE_LITERAL = new TypeLiteral<>() {
};

public Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel> chatModel(
LangChain4jOllamaConfig runtimeConfig,
LangChain4jOllamaFixedRuntimeConfig fixedRuntimeConfig, String configName) {
LangChain4jOllamaConfig.OllamaConfig ollamaConfig = correspondingOllamaConfig(runtimeConfig, configName);
LangChain4jOllamaFixedRuntimeConfig.OllamaConfig ollamaFixedConfig = correspondingOllamaFixedConfig(fixedRuntimeConfig,
Expand Down Expand Up @@ -58,16 +69,18 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jOllamaConfig runtimeConf
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.options(optionsBuilder.build());

return new Supplier<>() {
return new Function<>() {
@Override
public ChatLanguageModel get() {
public ChatLanguageModel apply(SyntheticCreationalContext<ChatLanguageModel> context) {
builder.listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream()
.collect(Collectors.toList()));
return builder.build();
}
};
} else {
return new Supplier<>() {
return new Function<>() {
@Override
public ChatLanguageModel get() {
public ChatLanguageModel apply(SyntheticCreationalContext<ChatLanguageModel> context) {
return new DisabledChatLanguageModel();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void setupMocks() {

@Test
void disabledChatModel() {
assertThat(recorder.chatModel(config, fixedConfig, NamedConfigUtil.DEFAULT_NAME).get())
assertThat(recorder.chatModel(config, fixedConfig, NamedConfigUtil.DEFAULT_NAME).apply(null))
.isNotNull()
.isExactlyInstanceOf(DisabledChatLanguageModel.class);
}
Expand Down

0 comments on commit 1a828a3

Please sign in to comment.