Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit scope of OpenAiRestApi providers #165

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,24 @@
import java.util.regex.Pattern;

import jakarta.annotation.Priority;
import jakarta.ws.rs.ConstrainedTo;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Priorities;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.RuntimeType;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.ext.MessageBodyWriter;
import jakarta.ws.rs.ext.Provider;
import jakarta.ws.rs.ext.ReaderInterceptor;
import jakarta.ws.rs.ext.ReaderInterceptorContext;
import jakarta.ws.rs.ext.WriterInterceptor;
import jakarta.ws.rs.ext.WriterInterceptorContext;

import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.RestQuery;
import org.jboss.resteasy.reactive.RestStreamElementType;
Expand Down Expand Up @@ -76,6 +74,10 @@
@ClientHeaderParam(name = "api-key", value = "{apiKey}") // used by AzureAI
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiJacksonReader.class)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiJacksonWriter.class)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiReaderInterceptor.class)
@RegisterProvider(OpenAiRestApi.OpenAiRestApiWriterInterceptor.class)
public interface OpenAiRestApi {

/**
Expand Down Expand Up @@ -181,57 +183,58 @@ public boolean test(SseEvent<String> event) {
}
}

/**
* We need a custom version of the Jackson provider because reading SSE values does not work properly with
* {@code @ClientObjectMapper} due to the lack of a complete context in those requests
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
@Priority(Priorities.USER + 100)
class OpenAiRestApiJacksonProvider extends AbstractJsonMessageBodyReader implements MessageBodyWriter<Object> {
@Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one
class OpenAiRestApiJacksonWriter implements MessageBodyWriter<Object> {

/**
* Normally this is not necessary, but if one uses the 'demo' key, then the response comes back as type text/html
* but the content is still JSON. Go figure...
*/
@Override
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
public boolean isWriteable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
return true;
}

@Override
public Object readFrom(Class<Object> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
public void writeTo(Object o, Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, Object> httpHeaders, OutputStream entityStream)
throws IOException, WebApplicationException {
return ObjectMapperHolder.READER
.forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type))
.readValue(entityStream);
entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8));
}
}

@Priority(Priorities.USER - 100) // this priority ensures that our Reader has priority over the standard Jackson one
class OpenAiRestApiJacksonReader extends AbstractJsonMessageBodyReader {

/**
* Normally this is not necessary, but if one uses the 'demo' Langchain4j key, then the response comes back as type
* text/html
* but the content is still JSON.
*/
@Override
public boolean isWriteable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
return true;
}

/**
* We need a custom version of the Jackson provider because reading SSE values does not work properly with
* {@code @ClientObjectMapper} due to the lack of a complete context in those requests
*/
@Override
public void writeTo(Object o, Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, Object> httpHeaders, OutputStream entityStream)
public Object readFrom(Class<Object> type, Type genericType, Annotation[] annotations, MediaType mediaType,
MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
throws IOException, WebApplicationException {
entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8));
return ObjectMapperHolder.READER
.forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type))
.readValue(entityStream);
}
}

public static class ObjectMapperHolder {
public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;
public class ObjectMapperHolder {
public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;

private static final ObjectReader READER = MAPPER.reader();
}
private static final ObjectReader READER = MAPPER.reader();
}

/**
* This method validates that the response is not empty, which happens when the API returns an error object
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
class OpenAiRestApiReaderInterceptor implements ReaderInterceptor {

@Override
Expand Down Expand Up @@ -272,8 +275,6 @@ private Object validateResponse(Object result) {
* The point of this is to properly set the {@code stream} value of the request
* so users don't have to remember to set it manually
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
class OpenAiRestApiWriterInterceptor implements WriterInterceptor {
@Override
public void aroundWriteTo(WriterInterceptorContext context) throws IOException, WebApplicationException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class JsonParsingTest {

@Test
void testChatCompletion() throws JsonProcessingException {
ObjectMapper mapperToUse = OpenAiRestApi.OpenAiRestApiJacksonProvider.ObjectMapperHolder.MAPPER;
ObjectMapper mapperToUse = OpenAiRestApi.ObjectMapperHolder.MAPPER;

ChatCompletionResponse chatCompletionResponse = mapperToUse.readValue(
"{\"id\":\"chatcmpl-8AAeH0Sdve2wfHWhFIXq1gFkkqoIU\",\"object\":\"chat.completion.chunk\",\"created\":1697434905,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"length\"}]}",
Expand Down
Loading