From 365895993af7239ee791f00bde8fdd214480104c Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Fri, 8 Dec 2023 10:15:36 +0200 Subject: [PATCH] Use @SseEventFilter instead of deserialization hack --- .../langchain4j/openai/OpenAiRestApi.java | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java index 42005d21c..e4860d2ff 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java @@ -10,6 +10,7 @@ import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -19,7 +20,6 @@ import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.Priorities; -import jakarta.ws.rs.ProcessingException; import jakarta.ws.rs.Produces; import jakarta.ws.rs.RuntimeType; import jakarta.ws.rs.WebApplicationException; @@ -38,12 +38,13 @@ import org.jboss.logging.Logger; import org.jboss.resteasy.reactive.RestQuery; import org.jboss.resteasy.reactive.RestStreamElementType; +import org.jboss.resteasy.reactive.client.SseEvent; +import org.jboss.resteasy.reactive.client.SseEventFilter; import org.jboss.resteasy.reactive.client.api.ClientLogger; import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; -import com.fasterxml.jackson.databind.exc.MismatchedInputException; import dev.ai4j.openai4j.OpenAiHttpException; import dev.ai4j.openai4j.chat.ChatCompletionRequest; @@ -99,6 +100,7 @@ CompletionResponse blockingCompletion(CompletionRequest request, @NotBody String @Path("chat/completions") @POST @RestStreamElementType(MediaType.APPLICATION_JSON) + @SseEventFilter(DoneFilter.class) Multi streamingCompletion(CompletionRequest request, @NotBody String apiKey, @RestQuery("api-version") String apiVersion); @@ -124,6 +126,7 @@ ChatCompletionResponse blockingChatCompletion(ChatCompletionRequest request, @No @Path("chat/completions") @POST @RestStreamElementType(MediaType.APPLICATION_JSON) + @SseEventFilter(DoneFilter.class) Multi streamingChatCompletion(ChatCompletionRequest request, @NotBody String apiKey, @RestQuery("api-version") String apiVersion); @@ -167,6 +170,17 @@ static RuntimeException toException(Response response) { return null; } + /** + * Ensures that the terminal event sent by OpenAI is not processed (as it is not a valid json event) + */ + class DoneFilter implements Predicate> { + + @Override + public boolean test(SseEvent event) { + return !"[DONE]".equals(event.data()); + } + } + /** * We need a custom version of the Jackson provider because reading SSE values does not work properly with * {@code @ClientObjectMapper} due to the lack of a complete context in those requests @@ -205,11 +219,7 @@ public static class ObjectMapperHolder { } /** - * This method does two things: - *

- * First, it returns {@code null} instead of throwing an exception when last streaming API result comes back and. - * This result is a "[DONE]" message, so it cannot map onto the domain. - * Second, it validates that the response is not empty, which happens when the API returns an error object + * This method validates that the response is not empty, which happens when the API returns an error object */ @Provider @ConstrainedTo(RuntimeType.CLIENT) @@ -217,23 +227,7 @@ class OpenAiRestApiReaderInterceptor implements ReaderInterceptor { @Override public Object aroundReadFrom(ReaderInterceptorContext context) throws IOException, WebApplicationException { - try { - return validateResponse(context.proceed()); - } catch (ProcessingException e) { - Throwable cause = e.getCause(); - if (cause instanceof MismatchedInputException) { - Class targetType = ((MismatchedInputException) cause).getTargetType(); - if (ChatCompletionResponse.Builder.class.equals(targetType) - || CompletionResponse.Builder.class.equals(targetType)) { - if (cause.getMessage().contains("DONE") || cause.getMessage() - .contains("JsonToken.START_ARRAY")) { - return null; - } - } - } - - throw e; - } + return validateResponse(context.proceed()); } /**