Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Fix loss of context in the inference API for streaming APIs #118999

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/118999.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 118999
summary: Fix loss of context in the inference API for streaming APIs
area: Machine Learning
type: bug
issues:
- 119000
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -341,31 +342,44 @@ protected Map<String, Object> infer(String modelId, List<String> input) throws I
return inferInternal(endpoint, input, null, Map.of());
}

protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
protected Deque<ServerSentEvent> streamInferOnMockService(
String modelId,
TaskType taskType,
List<String> input,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
return callAsync(endpoint, input);
return callAsync(endpoint, input, responseConsumerCallback);
}

protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List<String> input)
throws Exception {
protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
String modelId,
TaskType taskType,
List<String> input,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId);
return callAsyncUnified(endpoint, input, "user");
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
}

private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input, @Nullable Consumer<Response> responseConsumerCallback)
throws Exception {
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input, null));

return execAsyncCall(request);
return execAsyncCall(request, responseConsumerCallback);
}

private Deque<ServerSentEvent> execAsyncCall(Request request) throws Exception {
private Deque<ServerSentEvent> execAsyncCall(Request request, @Nullable Consumer<Response> responseConsumerCallback) throws Exception {
var responseConsumer = new AsyncInferenceResponseConsumer();
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
var latch = new CountDownLatch(1);
client().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
if (responseConsumerCallback != null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a way to get the response so we can check the headers

responseConsumerCallback.accept(response);
}
latch.countDown();
}

Expand All @@ -378,11 +392,16 @@ public void onFailure(Exception exception) {
return responseConsumer.events();
}

private Deque<ServerSentEvent> callAsyncUnified(String endpoint, List<String> input, String role) throws Exception {
private Deque<ServerSentEvent> callAsyncUnified(
String endpoint,
List<String> input,
String role,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var request = new Request("POST", endpoint);

request.setJsonEntity(createUnifiedJsonBody(input, role));
return execAsyncCall(request);
return execAsyncCall(request, responseConsumerCallback);
}

private String createUnifiedJsonBody(List<String> input, String role) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.xpack.inference;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -28,6 +29,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand All @@ -37,9 +39,15 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

public class InferenceCrudIT extends InferenceBaseRestTest {

private static final Consumer<Response> VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat(
r.getHeader("X-elastic-product"),
is("Elasticsearch")
);

@SuppressWarnings("unchecked")
public void testCRUD() throws IOException {
for (int i = 0; i < 5; i++) {
Expand Down Expand Up @@ -442,7 +450,7 @@ public void testUnsupportedStream() throws Exception {
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));

try {
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()));
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null);
assertThat(events.size(), equalTo(2));
events.forEach(event -> {
switch (event.name()) {
Expand All @@ -469,7 +477,7 @@ public void testSupportedStream() throws Exception {

var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
try {
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input);
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, even if I remove my change this test still passes. I opted for keeping these tests just in case something really goes wrong in the future but they wouldn't have caught the original issue 😞


var expectedResponses = Stream.concat(
input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
Expand All @@ -496,7 +504,7 @@ public void testUnifiedCompletionInference() throws Exception {

var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
try {
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input);
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
var expectedResponses = expectedResultsIterator(input);
assertThat(events.size(), equalTo((input.size() + 1) * 2));
events.forEach(event -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.EntityUtils;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
Expand All @@ -43,6 +44,7 @@
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
Expand All @@ -52,6 +54,7 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -96,6 +99,14 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
}

public static class StreamingPlugin extends Plugin implements ActionPlugin {
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();

@Override
public Collection<?> createComponents(PluginServices services) {
threadPool.set(services.threadPool());
return Collections.emptyList();
}

@Override
public Collection<RestHandler> getRestHandlers(
Settings settings,
Expand All @@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
var publisher = new RandomPublisher(requestCount, withError);
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
}
}, new RestHandler() {
@Override
Expand All @@ -132,7 +143,7 @@ public List<Route> routes() {

@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
new ServerSentEventsRestActionListener(channel).onFailure(expectedException);
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
}
}, new RestHandler() {
@Override
Expand All @@ -143,7 +154,7 @@ public List<Route> routes() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
Expand Down Expand Up @@ -154,6 +155,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<HttpRequestSender.Factory> httpFactory = new SetOnce<>();
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
// not being initialized yet
private final SetOnce<ThreadPool> threadPoolSetOnce = new SetOnce<>();
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
Expand Down Expand Up @@ -197,9 +201,11 @@ public List<RestHandler> getRestHandlers(
Supplier<DiscoveryNodes> nodesInCluster,
Predicate<NodeFeature> clusterSupportsFeature
) {
assert serviceComponents.get() != null : "serviceComponents must be set before retrieving the rest handlers";

var availableRestActions = List.of(
new RestInferenceAction(),
new RestStreamInferenceAction(),
new RestStreamInferenceAction(threadPoolSetOnce),
new RestGetInferenceModelAction(),
new RestPutInferenceModelAction(),
new RestUpdateInferenceModelAction(),
Expand All @@ -208,7 +214,7 @@ public List<RestHandler> getRestHandlers(
new RestGetInferenceServicesAction()
);
List<RestHandler> conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled()
? List.of(new RestUnifiedCompletionInferenceAction())
? List.of(new RestUnifiedCompletionInferenceAction(threadPoolSetOnce))
: List.of();

return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList();
Expand All @@ -219,6 +225,7 @@ public Collection<?> createComponents(PluginServices services) {
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
threadPoolSetOnce.set(services.threadPool());

var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager);
var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,30 @@

package org.elasticsearch.xpack.inference.rest;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH;

@ServerlessScope(Scope.PUBLIC)
public class RestStreamInferenceAction extends BaseInferenceAction {
private final SetOnce<ThreadPool> threadPool;

public RestStreamInferenceAction(SetOnce<ThreadPool> threadPool) {
super();
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
public String getName() {
return "stream_inference_action";
Expand All @@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques

@Override
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
return new ServerSentEventsRestActionListener(channel);
return new ServerSentEventsRestActionListener(channel, threadPool);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,32 @@

package org.elasticsearch.xpack.inference.rest;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH;

@ServerlessScope(Scope.PUBLIC)
public class RestUnifiedCompletionInferenceAction extends BaseRestHandler {
private final SetOnce<ThreadPool> threadPool;

public RestUnifiedCompletionInferenceAction(SetOnce<ThreadPool> threadPool) {
super();
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
public String getName() {
return "unified_inference_action";
Expand All @@ -44,6 +54,10 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser);
}

return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel));
return channel -> client.execute(
UnifiedCompletionAction.INSTANCE,
request,
new ServerSentEventsRestActionListener(channel, threadPool)
);
}
}
Loading