Skip to content

Commit

Permalink
Merge pull request #774 from andreadimaio/main
Browse files Browse the repository at this point in the history
Polish watsonx code
  • Loading branch information
geoand authored Jul 24, 2024
2 parents 16200ce + f9c1122 commit b52c841
Show file tree
Hide file tree
Showing 16 changed files with 523 additions and 518 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

Expand Down Expand Up @@ -39,7 +40,7 @@ public class AiChatServiceTest {
LangChain4jWatsonxConfig langchain4jWatsonConfig;

@Inject
ChatLanguageModel model;
ChatLanguageModel chatModel;

static WireMockUtil mockServers;

Expand Down Expand Up @@ -70,6 +71,12 @@ static void afterAll() {
iamServer.stop();
}

@BeforeEach
void beforeEach() {
watsonxServer.resetAll();
iamServer.resetAll();
}

@RegisterAiService
@Singleton
interface NewAIService {
Expand Down Expand Up @@ -109,21 +116,7 @@ void chat() throws Exception {

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
.body(mapper.writeValueAsString(body))
.response("""
{
"model_id": "meta-llama/llama-2-70b-chat",
"created_at": "2024-01-21T17:06:14.052Z",
"results": [
{
"generated_text": "AI Response",
"generated_token_count": 5,
"input_token_count": 50,
"stop_reason": "eos_token",
"seed": 2123876088
}
]
}
""")
.response(WireMockUtil.RESPONSE_WATSONX_CHAT_API)
.build();

assertEquals("AI Response", service.chat("Hello"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

Expand Down Expand Up @@ -70,6 +71,12 @@ static void afterAll() {
iamServer.stop();
}

@BeforeEach
void beforeEach() {
watsonxServer.resetAll();
iamServer.resetAll();
}

@Inject
EmbeddingModel embeddingModel;

Expand Down Expand Up @@ -180,22 +187,7 @@ private List<Float> mockEmbeddingServer(String input) throws Exception {

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200)
.body(mapper.writeValueAsString(request))
.response("""
{
"model_id": "%s",
"results": [
{
"embedding": [
-0.006929283,
-0.005336422,
-0.024047505
]
}
],
"created_at": "2024-02-21T17:32:28Z",
"input_token_count": 10
}
""".formatted(WireMockUtil.DEFAULT_EMBEDDING_MODEL))
.response(WireMockUtil.RESPONSE_WATSONX_EMBEDDING_API)
.build();

return List.of(-0.006929283f, -0.005336422f, -0.024047505f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.time.Duration;
import java.util.Date;
Expand All @@ -18,21 +18,25 @@
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tomakehurst.wiremock.WireMockServer;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
import io.quarkiverse.langchain4j.watsonx.bean.Parameters;
import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi;
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig;
import io.quarkus.test.QuarkusUnitTest;
Expand All @@ -55,6 +59,9 @@ public class AllPropertiesTest {
@Inject
EmbeddingModel embeddingModel;

@Inject
TokenCountEstimator tokenCountEstimator;

static WireMockUtil mockServers;

@RegisterExtension
Expand Down Expand Up @@ -84,6 +91,7 @@ public class AllPropertiesTest {
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.truncate-input-tokens", "0")
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.include-stop-sequence", "false")
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-joiner", "@")
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model")
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@BeforeAll
Expand All @@ -99,30 +107,36 @@ static void beforeAll() {
mockServers = new WireMockUtil(watsonxServer, iamServer);
}

@BeforeEach
void beforeEach() {
watsonxServer.resetAll();
iamServer.resetAll();
mockServers.mockIAMBuilder(200)
.grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType())
.response(WireMockUtil.BEARER_TOKEN, new Date())
.build();
}

@AfterAll
static void afterAll() {
watsonxServer.stop();
iamServer.stop();
}

static Parameters parameters;

static {
parameters = Parameters.builder()
.minNewTokens(10)
.maxNewTokens(200)
.decodingMethod("greedy")
.lengthPenalty(new LengthPenalty(1.1, 0))
.randomSeed(2)
.stopSequences(List.of("\n", "\n\n"))
.temperature(1.5)
.topK(90)
.topP(0.5)
.repetitionPenalty(2.0)
.truncateInputTokens(0)
.includeStopSequence(false)
.build();
}
static Parameters parameters = Parameters.builder()
.minNewTokens(10)
.maxNewTokens(200)
.decodingMethod("greedy")
.lengthPenalty(new LengthPenalty(1.1, 0))
.randomSeed(2)
.stopSequences(List.of("\n", "\n\n"))
.temperature(1.5)
.topK(90)
.topP(0.5)
.repetitionPenalty(2.0)
.truncateInputTokens(0)
.includeStopSequence(false)
.build();

@Test
void check_config() throws Exception {
Expand Down Expand Up @@ -152,140 +166,82 @@ void check_config() throws Exception {
assertEquals(0, config.chatModel().truncateInputTokens().get());
assertEquals(false, config.chatModel().includeStopSequence().get());
assertEquals("@", config.chatModel().promptJoiner().get());
assertEquals("my_super_embedding_model", config.embeddingModel().modelId());
}

@Test
void check_chat_model_config() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.chatModel().modelId();
String projectId = config.projectId();
var parameters = Parameters.builder()
.minNewTokens(10)
.maxNewTokens(200)
.decodingMethod("greedy")
.lengthPenalty(new LengthPenalty(1.1, 0))
.randomSeed(2)
.stopSequences(List.of("\n", "\n\n"))
.temperature(1.5)
.topK(90)
.topP(0.5)
.repetitionPenalty(2.0)
.truncateInputTokens(0)
.includeStopSequence(false)
.build();

TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters);

mockServers.mockIAMBuilder(200)
.grantType(config.iam().grantType())
.response(WireMockUtil.BEARER_TOKEN, new Date())
.build();

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200, "aaaa-mm-dd")
.body(mapper.writeValueAsString(body))
.response("""
{
"model_id": "meta-llama/llama-2-70b-chat",
"created_at": "2024-01-21T17:06:14.052Z",
"results": [
{
"generated_text": "Response!",
"generated_token_count": 5,
"input_token_count": 50,
"stop_reason": "eos_token",
"seed": 2123876088
}
]
}
""")
.response(WireMockUtil.RESPONSE_WATSONX_CHAT_API)
.build();

assertEquals("Response!", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"),
assertEquals("AI Response", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"),
dev.langchain4j.data.message.UserMessage.from("UserMessage")).content().text());
}

@Test
void check_chat_streaming_model_config() throws Exception {
void check_embedding_model() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.chatModel().modelId();
String modelId = config.embeddingModel().modelId();
String projectId = config.projectId();
var parameters = Parameters.builder()
.minNewTokens(10)
.maxNewTokens(200)
.decodingMethod("greedy")
.lengthPenalty(new LengthPenalty(1.1, 0))
.randomSeed(2)
.stopSequences(List.of("\n", "\n\n"))
.temperature(1.5)
.topK(90)
.topP(0.5)
.repetitionPenalty(2.0)
.truncateInputTokens(0)
.includeStopSequence(false)
.build();

TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters);
EmbeddingRequest request = new EmbeddingRequest(modelId, projectId,
List.of("Embedding THIS!"));

mockServers.mockIAMBuilder(200)
.grantType(config.iam().grantType())
.response(WireMockUtil.BEARER_TOKEN, new Date())
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200, "aaaa-mm-dd")
.body(mapper.writeValueAsString(request))
.response(WireMockUtil.RESPONSE_WATSONX_EMBEDDING_API.formatted(modelId))
.build();

String eventStreamResponse = """
id: 1
event: message
data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.162Z","results":[{"generated_text":"","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]}
Response<Embedding> response = embeddingModel.embed("Embedding THIS!");
assertNotNull(response);
assertNotNull(response.content());
}

id: 2
event: message
data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.203Z","results":[{"generated_text":". ","generated_token_count":2,"input_token_count":0,"stop_reason":"not_finished"}]}
@Test
void check_token_count_estimator() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.chatModel().modelId();
String projectId = config.projectId();

id: 3
event: message
data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.223Z","results":[{"generated_text":"I'","generated_token_count":3,"input_token_count":0,"stop_reason":"not_finished"}]}
var body = new TokenizationRequest(modelId, "test", projectId);

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_TOKENIZER_API, 200, "aaaa-mm-dd")
.body(mapper.writeValueAsString(body))
.response(WireMockUtil.RESPONSE_WATSONX_TOKENIZER_API.formatted(modelId))
.build();

id: 4
event: message
data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.243Z","results":[{"generated_text":"m ","generated_token_count":4,"input_token_count":0,"stop_reason":"not_finished"}]}
assertEquals(11, tokenCountEstimator.estimateTokenCount("test"));
}

id: 5
event: message
data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.262Z","results":[{"generated_text":"a beginner","generated_token_count":5,"input_token_count":0,"stop_reason":"max_tokens"}]}
@Test
void check_chat_streaming_model_config() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.chatModel().modelId();
String projectId = config.projectId();

id: 5
event: close
data: {}}
""";
TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters);

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200, "aaaa-mm-dd")
.body(mapper.writeValueAsString(body))
.responseMediaType(MediaType.SERVER_SENT_EVENTS)
.response(eventStreamResponse)
.response(WireMockUtil.RESPONSE_WATSONX_STREAMING_API)
.build();

var messages = List.of(
dev.langchain4j.data.message.SystemMessage.from("SystemMessage"),
dev.langchain4j.data.message.UserMessage.from("UserMessage"));

var streamingResponse = new AtomicReference<AiMessage>();
streamingChatModel.generate(messages, new StreamingResponseHandler<>() {
@Override
public void onNext(String token) {
}

@Override
public void onError(Throwable error) {
fail("Streaming failed: %s".formatted(error.getMessage()), error);
}

@Override
public void onComplete(Response<AiMessage> response) {
streamingResponse.set(response.content());
}
});

await()
.atMost(Duration.ofMinutes(1))
streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse));

await().atMost(Duration.ofMinutes(1))
.pollInterval(Duration.ofSeconds(2))
.until(() -> streamingResponse.get() != null);

Expand Down
Loading

0 comments on commit b52c841

Please sign in to comment.