Skip to content

Commit

Permalink
Limit scope of OpenAiRestApi providers
Browse files Browse the repository at this point in the history
The previous way they were declared meant
that they would affect other REST Clients
too.

Fixes: #164
  • Loading branch information
geoand committed Dec 20, 2023
1 parent 289ddd8 commit 8b4789d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,24 @@
import java.util.regex.Pattern;

import jakarta.annotation.Priority;
import jakarta.ws.rs.ConstrainedTo;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Priorities;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.RuntimeType;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.ext.MessageBodyWriter;
import jakarta.ws.rs.ext.Provider;
import jakarta.ws.rs.ext.ReaderInterceptor;
import jakarta.ws.rs.ext.ReaderInterceptorContext;
import jakarta.ws.rs.ext.WriterInterceptor;
import jakarta.ws.rs.ext.WriterInterceptorContext;

import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.RestQuery;
import org.jboss.resteasy.reactive.RestStreamElementType;
Expand Down Expand Up @@ -76,6 +74,10 @@
@ClientHeaderParam(name = "api-key", value = "{apiKey}") // used by AzureAI
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiJacksonReader.class)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiJacksonWriter.class)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiReaderInterceptor.class)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiWriterInterceptor.class)
public interface OpenAiRestApi {

/**
Expand Down Expand Up @@ -181,57 +183,58 @@ public boolean test(SseEvent<String> event) {
}
}

/**
* We need a custom version of the Jackson provider because reading SSE values does not work properly with
* {@code @ClientObjectMapper} due to the lack of a complete context in those requests
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
@Priority(Priorities.USER + 100)
class OpenAiRestApiJacksonProvider extends AbstractJsonMessageBodyReader implements MessageBodyWriter<Object> {
@Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one
class OpenAiRestApiJacksonWriter implements MessageBodyWriter<Object> {

/**
* Normally this is not necessary, but if one uses the 'demo' key, then the response comes back as type text/html
* but the content is still JSON. Go figure...
*/
@Override
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
public boolean isWriteable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
return true;
}

@Override
public Object readFrom(Class<Object> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
public void writeTo(Object o, Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, Object> httpHeaders, OutputStream entityStream)
throws IOException, WebApplicationException {
return ObjectMapperHolder.READER
.forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type))
.readValue(entityStream);
entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8));
}
}

@Priority(Priorities.USER - 100) // this priority ensures that our Reader has priority over the standard Jackson one
class OpenAiRestApiJacksonReader extends AbstractJsonMessageBodyReader {

/**
* Normally this is not necessary, but if one uses the 'demo' Langchain4j key, then the response comes back as type
* text/html
* but the content is still JSON.
*/
@Override
public boolean isWriteable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
return true;
}

/**
* We need a custom version of the Jackson provider because reading SSE values does not work properly with
* {@code @ClientObjectMapper} due to the lack of a complete context in those requests
*/
@Override
public void writeTo(Object o, Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, Object> httpHeaders, OutputStream entityStream)
public Object readFrom(Class<Object> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
throws IOException, WebApplicationException {
entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8));
return ObjectMapperHolder.READER
.forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type))
.readValue(entityStream);
}
}

public static class ObjectMapperHolder {
public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;
public class ObjectMapperHolder {
public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;

private static final ObjectReader READER = MAPPER.reader();
}
private static final ObjectReader READER = MAPPER.reader();
}

/**
* This method validates that the response is not empty, which happens when the API returns an error object
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
class OpenAiRestApiReaderInterceptor implements ReaderInterceptor {

@Override
Expand Down Expand Up @@ -272,8 +275,6 @@ private Object validateResponse(Object result) {
* The point of this is to properly set the {@code stream} value of the request
* so users don't have to remember to set it manually
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
class OpenAiRestApiWriterInterceptor implements WriterInterceptor {
@Override
public void aroundWriteTo(WriterInterceptorContext context) throws IOException, WebApplicationException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class JsonParsingTest {

@Test
void testChatCompletion() throws JsonProcessingException {
ObjectMapper mapperToUse = OpenAiRestApi.OpenAiRestApiJacksonProvider.ObjectMapperHolder.MAPPER;
ObjectMapper mapperToUse = OpenAiRestApi.ObjectMapperHolder.MAPPER;

ChatCompletionResponse chatCompletionResponse = mapperToUse.readValue(
"{\"id\":\"chatcmpl-8AAeH0Sdve2wfHWhFIXq1gFkkqoIU\",\"object\":\"chat.completion.chunk\",\"created\":1697434905,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"length\"}]}",
Expand Down

0 comments on commit 8b4789d

Please sign in to comment.