From 8b4789d5f079162579c7f5c1d5ce90bed6a61f07 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Sat, 16 Dec 2023 13:42:18 +0200 Subject: [PATCH] Limit scope of OpenAiRestApi providers The previous way they were declared meant that they would affect other REST Clients too. Fixes: #164 --- .../langchain4j/openai/OpenAiRestApi.java | 67 ++++++++++--------- .../openai/test/JsonParsingTest.java | 2 +- 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java index ffdad9b53..1d21f9cfc 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java @@ -15,26 +15,24 @@ import java.util.regex.Pattern; import jakarta.annotation.Priority; -import jakarta.ws.rs.ConstrainedTo; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.Priorities; import jakarta.ws.rs.Produces; -import jakarta.ws.rs.RuntimeType; import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MultivaluedMap; import jakarta.ws.rs.core.Response; import jakarta.ws.rs.ext.MessageBodyWriter; -import jakarta.ws.rs.ext.Provider; import jakarta.ws.rs.ext.ReaderInterceptor; import jakarta.ws.rs.ext.ReaderInterceptorContext; import jakarta.ws.rs.ext.WriterInterceptor; import jakarta.ws.rs.ext.WriterInterceptorContext; import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; import org.jboss.logging.Logger; import org.jboss.resteasy.reactive.RestQuery; import org.jboss.resteasy.reactive.RestStreamElementType; @@ -76,6 +74,10 @@ @ClientHeaderParam(name = "api-key", value = "{apiKey}") // used by AzureAI @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) +@RegisterProvider(OpenAiRestApi.OpenAiRestApiJacksonReader.class) +@RegisterProvider(OpenAiRestApi.OpenAiRestApiJacksonWriter.class) +@RegisterProvider(OpenAiRestApi.OpenAiRestApiReaderInterceptor.class) +@RegisterProvider(OpenAiRestApi.OpenAiRestApiWriterInterceptor.class) public interface OpenAiRestApi { /** @@ -181,57 +183,58 @@ public boolean test(SseEvent event) { } } - /** - * We need a custom version of the Jackson provider because reading SSE values does not work properly with - * {@code @ClientObjectMapper} due to the lack of a complete context in those requests - */ - @Provider - @ConstrainedTo(RuntimeType.CLIENT) - @Priority(Priorities.USER + 100) - class OpenAiRestApiJacksonProvider extends AbstractJsonMessageBodyReader implements MessageBodyWriter { + @Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one + class OpenAiRestApiJacksonWriter implements MessageBodyWriter { - /** - * Normally this is not necessary, but if one uses the 'demo' key, then the response comes back as type text/html - * but the content is still JSON. Go figure... - */ @Override - public boolean isReadable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { + public boolean isWriteable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { return true; } @Override - public Object readFrom(Class type, Type genericType, Annotation[] annotations, MediaType mediaType, - MultivaluedMap httpHeaders, InputStream entityStream) + public void writeTo(Object o, Class type, Type genericType, Annotation[] annotations, MediaType mediaType, + MultivaluedMap httpHeaders, OutputStream entityStream) throws IOException, WebApplicationException { - return ObjectMapperHolder.READER - .forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type)) - .readValue(entityStream); + entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8)); } + } + + @Priority(Priorities.USER - 100) // this priority ensures that our Reader has priority over the standard Jackson one + class OpenAiRestApiJacksonReader extends AbstractJsonMessageBodyReader { + /** + * Normally this is not necessary, but if one uses the 'demo' Langchain4j key, then the response comes back as type + * text/html + * but the content is still JSON. + */ @Override - public boolean isWriteable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { + public boolean isReadable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { return true; } + /** + * We need a custom version of the Jackson provider because reading SSE values does not work properly with + * {@code @ClientObjectMapper} due to the lack of a complete context in those requests + */ @Override - public void writeTo(Object o, Class type, Type genericType, Annotation[] annotations, MediaType mediaType, - MultivaluedMap httpHeaders, OutputStream entityStream) + public Object readFrom(Class type, Type genericType, Annotation[] annotations, MediaType mediaType, + MultivaluedMap httpHeaders, InputStream entityStream) throws IOException, WebApplicationException { - entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8)); + return ObjectMapperHolder.READER + .forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type)) + .readValue(entityStream); } + } - public static class ObjectMapperHolder { - public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER; + public class ObjectMapperHolder { + public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER; - private static final ObjectReader READER = MAPPER.reader(); - } + private static final ObjectReader READER = MAPPER.reader(); } /** * This method validates that the response is not empty, which happens when the API returns an error object */ - @Provider - @ConstrainedTo(RuntimeType.CLIENT) class OpenAiRestApiReaderInterceptor implements ReaderInterceptor { @Override @@ -272,8 +275,6 @@ private Object validateResponse(Object result) { * The point of this is to properly set the {@code stream} value of the request * so users don't have to remember to set it manually */ - @Provider - @ConstrainedTo(RuntimeType.CLIENT) class OpenAiRestApiWriterInterceptor implements WriterInterceptor { @Override public void aroundWriteTo(WriterInterceptorContext context) throws IOException, WebApplicationException { diff --git a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/JsonParsingTest.java b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/JsonParsingTest.java index 7b5c723c3..df577f3dc 100644 --- a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/JsonParsingTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/JsonParsingTest.java @@ -22,7 +22,7 @@ public class JsonParsingTest { @Test void testChatCompletion() throws JsonProcessingException { - ObjectMapper mapperToUse = OpenAiRestApi.OpenAiRestApiJacksonProvider.ObjectMapperHolder.MAPPER; + ObjectMapper mapperToUse = OpenAiRestApi.ObjectMapperHolder.MAPPER; ChatCompletionResponse chatCompletionResponse = mapperToUse.readValue( "{\"id\":\"chatcmpl-8AAeH0Sdve2wfHWhFIXq1gFkkqoIU\",\"object\":\"chat.completion.chunk\",\"created\":1697434905,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"length\"}]}",