Skip to content

Commit

Permalink
Merge pull request #372 from andreadimaio/main
Browse files Browse the repository at this point in the history
Enable ModerationModel for the BAM module
  • Loading branch information
geoand authored Mar 15, 2024
2 parents e713beb + 1105e94 commit 0d4ee28
Show file tree
Hide file tree
Showing 45 changed files with 1,815 additions and 419 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MODERATION_MODEL;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;

import java.util.List;
Expand All @@ -15,8 +16,10 @@
import io.quarkiverse.langchain4j.bam.runtime.config.LangChain4jBamConfig;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ModerationModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
Expand All @@ -28,7 +31,6 @@
public class BamProcessor {

private static final String FEATURE = "langchain4j-bam";

private static final String PROVIDER = "bam";

@BuildStep
Expand All @@ -39,6 +41,7 @@ FeatureBuildItem feature() {
@BuildStep
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
BuildProducer<ModerationModelProviderCandidateBuildItem> moderationProducer,
LangChain4jBamBuildConfig config) {

if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
Expand All @@ -48,6 +51,10 @@ public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
}

if (config.moderationModel().enabled().isEmpty() || config.moderationModel().enabled().get()) {
moderationProducer.produce(new ModerationModelProviderCandidateBuildItem(PROVIDER));
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
Expand All @@ -56,6 +63,7 @@ public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem
void generateBeans(BamRecorder recorder,
List<SelectedChatModelProviderBuildItem> selectedChatItem,
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
List<SelectedModerationModelProviderBuildItem> selectedModeration,
LangChain4jBamConfig config,
BuildProducer<SyntheticBeanBuildItem> beanProducer) {

Expand Down Expand Up @@ -97,6 +105,20 @@ void generateBeans(BamRecorder recorder,
beanProducer.produce(builder.done());
}
}

for (var selected : selectedModeration) {
if (PROVIDER.equals(selected.getProvider())) {
String modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(MODERATION_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.moderationModel(config, modelName));
addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}
}

private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,4 @@ public interface ChatModelBuildConfig {
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();

/**
* Embedding embedding model related settings
*/
EmbeddingModelBuildConfig embeddingModel();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@ public interface LangChain4jBamBuildConfig {
* Embedding model related settings
*/
EmbeddingModelBuildConfig embeddingModel();

/**
* Moteration model related settings
*/
ModerationModelBuildConfig moderationModel();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.bam.deployment;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface ModerationModelBuildConfig {

/**
* Whether the model should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;

import java.util.List;

Expand All @@ -18,6 +19,12 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tomakehurst.wiremock.WireMockServer;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
Expand Down Expand Up @@ -65,15 +72,16 @@ interface NewAIService {
@Inject
NewAIService service;

@Inject
ChatLanguageModel chatModel;

@Inject
LangChain4jBamConfig langchain4jBamConfig;

@Test
void chat() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

var modelId = config.chatModel().modelId();

var parameters = Parameters.builder()
.decodingMethod(config.chatModel().decodingMethod())
.temperature(config.chatModel().temperature())
Expand All @@ -87,7 +95,8 @@ void chat() throws Exception {

var body = new TextGenerationRequest(modelId, messages, parameters);

mockServers.mockBuilder(200)
mockServers
.mockBuilder(WireMockUtil.URL_CHAT_API, 200)
.body(mapper.writeValueAsString(body))
.response("""
{
Expand All @@ -106,4 +115,100 @@ void chat() throws Exception {

assertEquals("AI Response", service.chat("Hello"));
}

@Test
void chat_test_generate_1() throws Exception {
var config = langchain4jBamConfig.defaultConfig();
var modelId = config.chatModel().modelId();
var parameters = Parameters.builder()
.decodingMethod(config.chatModel().decodingMethod())
.temperature(config.chatModel().temperature())
.minNewTokens(config.chatModel().minNewTokens())
.maxNewTokens(config.chatModel().maxNewTokens())
.build();

List<Message> messages = List.of(
new Message("user", "Hello"));

var body = new TextGenerationRequest(modelId, messages, parameters);

mockServers
.mockBuilder(WireMockUtil.URL_CHAT_API, 200)
.body(mapper.writeValueAsString(body))
.response("""
{
"results": [
{
"generated_token_count": 20,
"input_token_count": 146,
"stop_reason": "max_tokens",
"seed": 40268626,
"generated_text": "AI Response"
}
]
}
""")
.build();

assertEquals("AI Response", chatModel.generate("Hello"));
}

@Test
void chat_test_generate_2() throws Exception {
var config = langchain4jBamConfig.defaultConfig();
var modelId = config.chatModel().modelId();
var parameters = Parameters.builder()
.decodingMethod(config.chatModel().decodingMethod())
.temperature(config.chatModel().temperature())
.minNewTokens(config.chatModel().minNewTokens())
.maxNewTokens(config.chatModel().maxNewTokens())
.build();

List<Message> messages = List.of(
new Message("system", "This is a systemMessage"),
new Message("user", "This is a userMessage"),
new Message("assistant", "This is a assistantMessage"));

var body = new TextGenerationRequest(modelId, messages, parameters);

mockServers
.mockBuilder(WireMockUtil.URL_CHAT_API, 200)
.body(mapper.writeValueAsString(body))
.response("""
{
"results": [
{
"generated_token_count": 20,
"input_token_count": 146,
"stop_reason": "max_tokens",
"seed": 40268626,
"generated_text": "AI Response"
}
]
}
""")
.build();

var expected = Response.from(AiMessage.from("AI Response"), new TokenUsage(146, 20, 166), FinishReason.LENGTH);
assertEquals(expected, chatModel.generate(List.of(
new dev.langchain4j.data.message.SystemMessage("This is a systemMessage"),
new dev.langchain4j.data.message.UserMessage("This is a userMessage"),
new dev.langchain4j.data.message.AiMessage("This is a assistantMessage"))));
assertEquals(expected, chatModel.generate(
new dev.langchain4j.data.message.SystemMessage("This is a systemMessage"),
new dev.langchain4j.data.message.UserMessage("This is a userMessage"),
new dev.langchain4j.data.message.AiMessage("This is a assistantMessage")));
}

@Test
void chat_test_tool_specification() throws Exception {

assertThrowsExactly(
IllegalArgumentException.class,
() -> chatModel.generate(List.of(), ToolSpecification.builder().build()));

assertThrowsExactly(
IllegalArgumentException.class,
() -> chatModel.generate(List.of(), List.of(ToolSpecification.builder().build())));
}
}
Loading

0 comments on commit 0d4ee28

Please sign in to comment.