Skip to content

Commit

Permalink
Merge pull request #109 from quarkiverse/proper-sse-event-streaming
Browse files Browse the repository at this point in the history
Use @SseEventFilter instead of deserialization hack
  • Loading branch information
geoand authored Dec 8, 2023
2 parents 685203f + 3658959 commit 5eb489b
Showing 1 changed file with 18 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -19,7 +20,6 @@
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Priorities;
import jakarta.ws.rs.ProcessingException;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.RuntimeType;
import jakarta.ws.rs.WebApplicationException;
Expand All @@ -38,12 +38,13 @@
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.RestQuery;
import org.jboss.resteasy.reactive.RestStreamElementType;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.exc.MismatchedInputException;

import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
Expand Down Expand Up @@ -99,6 +100,7 @@ CompletionResponse blockingCompletion(CompletionRequest request, @NotBody String
@Path("chat/completions")
@POST
@RestStreamElementType(MediaType.APPLICATION_JSON)
@SseEventFilter(DoneFilter.class)
Multi<CompletionResponse> streamingCompletion(CompletionRequest request, @NotBody String apiKey,
@RestQuery("api-version") String apiVersion);

Expand All @@ -124,6 +126,7 @@ ChatCompletionResponse blockingChatCompletion(ChatCompletionRequest request, @No
@Path("chat/completions")
@POST
@RestStreamElementType(MediaType.APPLICATION_JSON)
@SseEventFilter(DoneFilter.class)
Multi<ChatCompletionResponse> streamingChatCompletion(ChatCompletionRequest request, @NotBody String apiKey,
@RestQuery("api-version") String apiVersion);

Expand Down Expand Up @@ -167,6 +170,17 @@ static RuntimeException toException(Response response) {
return null;
}

/**
* Ensures that the terminal event sent by OpenAI is not processed (as it is not a valid json event)
*/
class DoneFilter implements Predicate<SseEvent<String>> {

@Override
public boolean test(SseEvent<String> event) {
return !"[DONE]".equals(event.data());
}
}

/**
* We need a custom version of the Jackson provider because reading SSE values does not work properly with
* {@code @ClientObjectMapper} due to the lack of a complete context in those requests
Expand Down Expand Up @@ -205,35 +219,15 @@ public static class ObjectMapperHolder {
}

/**
* This method does two things:
* <p>
* First, it returns {@code null} instead of throwing an exception when last streaming API result comes back and.
* This result is a "[DONE]" message, so it cannot map onto the domain.
* Second, it validates that the response is not empty, which happens when the API returns an error object
* This method validates that the response is not empty, which happens when the API returns an error object
*/
@Provider
@ConstrainedTo(RuntimeType.CLIENT)
class OpenAiRestApiReaderInterceptor implements ReaderInterceptor {

@Override
public Object aroundReadFrom(ReaderInterceptorContext context) throws IOException, WebApplicationException {
try {
return validateResponse(context.proceed());
} catch (ProcessingException e) {
Throwable cause = e.getCause();
if (cause instanceof MismatchedInputException) {
Class<?> targetType = ((MismatchedInputException) cause).getTargetType();
if (ChatCompletionResponse.Builder.class.equals(targetType)
|| CompletionResponse.Builder.class.equals(targetType)) {
if (cause.getMessage().contains("DONE") || cause.getMessage()
.contains("JsonToken.START_ARRAY")) {
return null;
}
}
}

throw e;
}
return validateResponse(context.proceed());
}

/**
Expand Down

0 comments on commit 5eb489b

Please sign in to comment.