-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #844 from dastrobu/streamingChatLanguageModelSupplier
Add `streamingChatLanguageModelSupplier` property to `@RegisterAiService`
- Loading branch information
Showing
10 changed files
with
338 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
.../src/test/java/io/quarkiverse/langchain4j/test/BlockingChatLanguageModelSupplierTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package io.quarkiverse.langchain4j.test; | ||
|
||
import static io.restassured.RestAssured.get; | ||
import static org.hamcrest.Matchers.equalTo; | ||
|
||
import java.util.function.Supplier; | ||
|
||
import jakarta.ws.rs.GET; | ||
import jakarta.ws.rs.Path; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
public class BlockingChatLanguageModelSupplierTest { | ||
@Path("/test") | ||
static class MyResource { | ||
|
||
private final MyService service; | ||
|
||
MyResource(MyService service) { | ||
this.service = service; | ||
} | ||
|
||
@GET | ||
public String blocking() { | ||
return service.chat("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?"); | ||
} | ||
} | ||
|
||
public static class MyModelSupplier implements Supplier<ChatLanguageModel> { | ||
@Override | ||
public ChatLanguageModel get() { | ||
return (messages) -> new Response<>(new AiMessage("42")); | ||
} | ||
} | ||
|
||
@RegisterAiService(chatLanguageModelSupplier = MyModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) | ||
interface MyService { | ||
String chat(String msg); | ||
} | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(MyResource.class, MyService.class, MyModelSupplier.class)); | ||
|
||
@Test | ||
public void testCall() { | ||
get("test") | ||
.then() | ||
.statusCode(200) | ||
.body(equalTo("42")); | ||
} | ||
} |
88 changes: 88 additions & 0 deletions
88
...va/io/quarkiverse/langchain4j/test/StreamingAndBlockingChatLanguageModelSupplierTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package io.quarkiverse.langchain4j.test; | ||
|
||
import static io.restassured.RestAssured.get; | ||
import static org.hamcrest.Matchers.equalTo; | ||
|
||
import java.util.function.Supplier; | ||
|
||
import jakarta.ws.rs.GET; | ||
import jakarta.ws.rs.Path; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
import io.smallrye.mutiny.Multi; | ||
|
||
public class StreamingAndBlockingChatLanguageModelSupplierTest { | ||
@Path("/test") | ||
static class MyResource { | ||
|
||
private final MyService service; | ||
|
||
MyResource(MyService service) { | ||
this.service = service; | ||
} | ||
|
||
@GET | ||
@Path("/blocking") | ||
public String blocking() { | ||
return service.blocking("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?"); | ||
} | ||
|
||
@GET | ||
@Path("/streaming") | ||
public Multi<String> streaming() { | ||
return service.streaming("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?"); | ||
} | ||
} | ||
|
||
public static class MyModelSupplier implements Supplier<ChatLanguageModel> { | ||
@Override | ||
public ChatLanguageModel get() { | ||
return (messages) -> new Response<>(new AiMessage("42")); | ||
} | ||
} | ||
|
||
public static class MyStreamingModelSupplier implements Supplier<StreamingChatLanguageModel> { | ||
@Override | ||
public StreamingChatLanguageModel get() { | ||
return (messages, handler) -> { | ||
handler.onNext("4"); | ||
handler.onNext("2"); | ||
handler.onComplete(new Response<>(new AiMessage(""))); | ||
}; | ||
} | ||
} | ||
|
||
@RegisterAiService(chatLanguageModelSupplier = MyModelSupplier.class, streamingChatLanguageModelSupplier = MyStreamingModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) | ||
interface MyService { | ||
Multi<String> streaming(String msg); | ||
|
||
String blocking(String msg); | ||
} | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(MyResource.class, MyService.class, MyModelSupplier.class, MyStreamingModelSupplier.class)); | ||
|
||
@Test | ||
public void testCalls() { | ||
get("test/blocking") | ||
.then() | ||
.statusCode(200) | ||
.body(equalTo("42")); | ||
get("test/streaming") | ||
.then() | ||
.statusCode(200) | ||
.body(equalTo("42")); | ||
} | ||
} |
Oops, something went wrong.