From b0f2b4f45b57d35da61d4c9a3dbeeb7b58299490 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 18 Dec 2024 13:03:30 -0500 Subject: [PATCH 1/5] Adding context preserving fix --- ...rverSentEventsRestActionListenerTests.java | 17 ++++++++++++++--- .../xpack/inference/InferencePlugin.java | 6 ++++-- .../rest/RestStreamInferenceAction.java | 11 ++++++++++- .../RestUnifiedCompletionInferenceAction.java | 15 ++++++++++++++- .../ServerSentEventsRestActionListener.java | 19 +++++++++++++++---- .../rest/RestStreamInferenceActionTests.java | 3 ++- ...UnifiedCompletionInferenceActionTests.java | 3 ++- 7 files changed, 61 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java index ab3f466f3c11f..1e1f6934d95bf 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java @@ -17,6 +17,7 @@ import org.apache.http.nio.util.SimpleInputBuffer; import org.apache.http.protocol.HttpContext; import org.apache.http.util.EntityUtils; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; @@ -43,6 +44,7 @@ import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; @@ -52,6 +54,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; @@ -96,6 +99,14 @@ protected Collection> nodePlugins() { } public static class StreamingPlugin extends Plugin implements ActionPlugin { + private final SetOnce threadPool = new SetOnce<>(); + + @Override + public Collection createComponents(PluginServices services) { + threadPool.set(services.threadPool()); + return Collections.emptyList(); + } + @Override public Collection getRestHandlers( Settings settings, @@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c var publisher = new RandomPublisher(requestCount, withError); var inferenceServiceResults = new StreamingInferenceServiceResults(publisher); var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher()); - new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse); } }, new RestHandler() { @Override @@ -132,7 +143,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { - new ServerSentEventsRestActionListener(channel).onFailure(expectedException); + new ServerSentEventsRestActionListener(channel, threadPool.get()).onFailure(expectedException); } }, new RestHandler() { @Override @@ -143,7 +154,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults()); - new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse); } }); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 72fa840ad19b0..0d253e9385119 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -197,9 +197,11 @@ public List getRestHandlers( Supplier nodesInCluster, Predicate clusterSupportsFeature ) { + assert serviceComponents.get() != null : "serviceComponents must be set before retrieving the rest handlers"; + var availableRestActions = List.of( new RestInferenceAction(), - new RestStreamInferenceAction(), + new RestStreamInferenceAction(serviceComponents.get().threadPool()), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), new RestUpdateInferenceModelAction(), @@ -208,7 +210,7 @@ public List getRestHandlers( new RestGetInferenceServicesAction() ); List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() - ? List.of(new RestUnifiedCompletionInferenceAction()) + ? List.of(new RestUnifiedCompletionInferenceAction(serviceComponents.get().threadPool())) : List.of(); return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 875c288da52bd..04061c3595761 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -11,9 +11,11 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.util.List; +import java.util.Objects; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH; @@ -21,6 +23,13 @@ @ServerlessScope(Scope.PUBLIC) public class RestStreamInferenceAction extends BaseInferenceAction { + private final ThreadPool threadPool; + + public RestStreamInferenceAction(ThreadPool threadPool) { + super(); + this.threadPool = Objects.requireNonNull(threadPool); + } + @Override public String getName() { return "stream_inference_action"; @@ -38,6 +47,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques @Override protected ActionListener listener(RestChannel channel) { - return new ServerSentEventsRestActionListener(channel); + return new ServerSentEventsRestActionListener(channel, threadPool); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java index 5c71b560a6b9d..194ee2e31e461 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -12,10 +12,12 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import java.io.IOException; import java.util.List; +import java.util.Objects; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; @@ -23,6 +25,13 @@ @ServerlessScope(Scope.PUBLIC) public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { + private final ThreadPool threadPool; + + public RestUnifiedCompletionInferenceAction(ThreadPool threadPool) { + super(); + this.threadPool = Objects.requireNonNull(threadPool); + } + @Override public String getName() { return "unified_inference_action"; @@ -44,6 +53,10 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); } - return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); + return channel -> client.execute( + UnifiedCompletionAction.INSTANCE, + request, + new ServerSentEventsRestActionListener(channel, threadPool) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index bf94f072b6e04..784ea0e7a4b1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -13,6 +13,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.BytesStream; @@ -29,6 +30,7 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -38,6 +40,7 @@ import java.nio.charset.StandardCharsets; import java.util.Iterator; import java.util.Map; +import java.util.Objects; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; @@ -55,6 +58,7 @@ public class ServerSentEventsRestActionListener implements ActionListener nextBodyPartListener; - public ServerSentEventsRestActionListener(RestChannel channel) { - this(channel, channel.request()); + public ServerSentEventsRestActionListener(RestChannel channel, ThreadPool threadPool) { + this(channel, channel.request(), threadPool); } - public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) { + public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, ThreadPool threadPool) { this.channel = channel; this.params = params; + this.threadPool = Objects.requireNonNull(threadPool); } @Override @@ -99,7 +104,7 @@ protected void ensureOpen() { } private void initializeStream(InferenceAction.Response response) { - nextBodyPartListener = ActionListener.wrap(bodyPart -> { + ActionListener chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> { // this is the first response, so we need to send the RestResponse to open the stream // all subsequent bytes will be delivered through the nextBodyPartListener channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release)); @@ -115,6 +120,12 @@ private void initializeStream(InferenceAction.Response response) { ) ); }); + + nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext( + chunkedResponseBodyActionListener, + threadPool.getThreadContext() + ); + // subscribe will call onSubscribe, which requests the first chunk response.publisher().subscribe(subscriber); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java index b999e2c9b72f0..26219f509be4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.junit.Before; @@ -25,7 +26,7 @@ public class RestStreamInferenceActionTests extends RestActionTestCase { @Before public void setUpAction() { - controller().registerHandler(new RestStreamInferenceAction()); + controller().registerHandler(new RestStreamInferenceAction(new TestThreadPool(getTestName()))); } public void testStreamIsTrue() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java index 5acfe67b175df..4c40129e856ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.junit.Before; @@ -30,7 +31,7 @@ public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCas @Before public void setUpAction() { - controller().registerHandler(new RestUnifiedCompletionInferenceAction()); + controller().registerHandler(new RestUnifiedCompletionInferenceAction(new TestThreadPool(getTestName()))); } public void testStreamIsTrue() { From dc42a3d8c1f6aede6afd57d0d4a4b7c3b17a0b0e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:08:16 -0500 Subject: [PATCH 2/5] Update docs/changelog/118999.yaml --- docs/changelog/118999.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/118999.yaml diff --git a/docs/changelog/118999.yaml b/docs/changelog/118999.yaml new file mode 100644 index 0000000000000..2af79a95a7ee6 --- /dev/null +++ b/docs/changelog/118999.yaml @@ -0,0 +1,5 @@ +pr: 118999 +summary: Fix lose of context in the inference API for streaming APIs +area: Machine Learning +type: bug +issues: [] From 6623daae2948ce8a8d1ab5bf53bc4ec6df444b66 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:12:43 -0500 Subject: [PATCH 3/5] Update docs/changelog/118999.yaml --- docs/changelog/118999.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/changelog/118999.yaml b/docs/changelog/118999.yaml index 2af79a95a7ee6..01a0493c1ea5c 100644 --- a/docs/changelog/118999.yaml +++ b/docs/changelog/118999.yaml @@ -2,4 +2,5 @@ pr: 118999 summary: Fix lose of context in the inference API for streaming APIs area: Machine Learning type: bug -issues: [] +issues: + - 119000 From 6a986d87886dd8c4906605edcf5b880da4868ba5 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 19 Dec 2024 11:20:58 -0500 Subject: [PATCH 4/5] Using a setonce and adding a test --- .../inference/InferenceBaseRestTest.java | 39 ++++++++++++++----- .../xpack/inference/InferenceCrudIT.java | 14 +++++-- ...rverSentEventsRestActionListenerTests.java | 6 +-- .../xpack/inference/InferencePlugin.java | 9 ++++- .../rest/RestStreamInferenceAction.java | 5 ++- .../RestUnifiedCompletionInferenceAction.java | 5 ++- .../ServerSentEventsRestActionListener.java | 9 +++-- .../rest/RestStreamInferenceActionTests.java | 12 +++++- ...UnifiedCompletionInferenceActionTests.java | 11 +++++- 9 files changed, 82 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 5e6c4d53f4c58..cdc6d9b2dff5f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -341,31 +342,44 @@ protected Map infer(String modelId, List input) throws I return inferInternal(endpoint, input, null, Map.of()); } - protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { + protected Deque streamInferOnMockService( + String modelId, + TaskType taskType, + List input, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); - return callAsync(endpoint, input); + return callAsync(endpoint, input, responseConsumerCallback); } - protected Deque unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List input) - throws Exception { + protected Deque unifiedCompletionInferOnMockService( + String modelId, + TaskType taskType, + List input, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); - return callAsyncUnified(endpoint, input, "user"); + return callAsyncUnified(endpoint, input, "user", responseConsumerCallback); } - private Deque callAsync(String endpoint, List input) throws Exception { + private Deque callAsync(String endpoint, List input, @Nullable Consumer responseConsumerCallback) + throws Exception { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input, null)); - return execAsyncCall(request); + return execAsyncCall(request, responseConsumerCallback); } - private Deque execAsyncCall(Request request) throws Exception { + private Deque execAsyncCall(Request request, @Nullable Consumer responseConsumerCallback) throws Exception { var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @Override public void onSuccess(Response response) { + if (responseConsumerCallback != null) { + responseConsumerCallback.accept(response); + } latch.countDown(); } @@ -378,11 +392,16 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } - private Deque callAsyncUnified(String endpoint, List input, String role) throws Exception { + private Deque callAsyncUnified( + String endpoint, + List input, + String role, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var request = new Request("POST", endpoint); request.setJsonEntity(createUnifiedJsonBody(input, role)); - return execAsyncCall(request); + return execAsyncCall(request, responseConsumerCallback); } private String createUnifiedJsonBody(List input, String role) throws IOException { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index fc593a6a8b0fa..49fce930cd726 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference; import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; @@ -28,6 +29,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -37,9 +39,15 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; public class InferenceCrudIT extends InferenceBaseRestTest { + private static final Consumer VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat( + r.getHeader("X-elastic-product"), + is("Elasticsearch") + ); + @SuppressWarnings("unchecked") public void testCRUD() throws IOException { for (int i = 0; i < 5; i++) { @@ -442,7 +450,7 @@ public void testUnsupportedStream() throws Exception { assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type")); try { - var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID())); + var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null); assertThat(events.size(), equalTo(2)); events.forEach(event -> { switch (event.name()) { @@ -469,7 +477,7 @@ public void testSupportedStream() throws Exception { var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList(); try { - var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input); + var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); var expectedResponses = Stream.concat( input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"), @@ -496,7 +504,7 @@ public void testUnifiedCompletionInference() throws Exception { var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList(); try { - var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input); + var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); var expectedResponses = expectedResultsIterator(input); assertThat(events.size(), equalTo((input.size() + 1) * 2)); events.forEach(event -> { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java index 1e1f6934d95bf..b993cf36cb875 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java @@ -133,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c var publisher = new RandomPublisher(requestCount, withError); var inferenceServiceResults = new StreamingInferenceServiceResults(publisher); var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher()); - new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse); } }, new RestHandler() { @Override @@ -143,7 +143,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { - new ServerSentEventsRestActionListener(channel, threadPool.get()).onFailure(expectedException); + new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException); } }, new RestHandler() { @Override @@ -154,7 +154,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults()); - new ServerSentEventsRestActionListener(channel, threadPool.get()).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse); } }); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 0d253e9385119..f98a7ebdee34d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -43,6 +43,7 @@ import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; @@ -154,6 +155,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce httpFactory = new SetOnce<>(); private final SetOnce amazonBedrockFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); + // This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it + // not being initialized yet + private final SetOnce threadPoolSetOnce = new SetOnce<>(); private final SetOnce elasticInferenceServiceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); @@ -201,7 +205,7 @@ public List getRestHandlers( var availableRestActions = List.of( new RestInferenceAction(), - new RestStreamInferenceAction(serviceComponents.get().threadPool()), + new RestStreamInferenceAction(threadPoolSetOnce), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), new RestUpdateInferenceModelAction(), @@ -210,7 +214,7 @@ public List getRestHandlers( new RestGetInferenceServicesAction() ); List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() - ? List.of(new RestUnifiedCompletionInferenceAction(serviceComponents.get().threadPool())) + ? List.of(new RestUnifiedCompletionInferenceAction(threadPoolSetOnce)) : List.of(); return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); @@ -221,6 +225,7 @@ public Collection createComponents(PluginServices services) { var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); var truncator = new Truncator(settings, services.clusterService()); serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); + threadPoolSetOnce.set(services.threadPool()); var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager); var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 04061c3595761..881af435b29b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.rest; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.Scope; @@ -23,9 +24,9 @@ @ServerlessScope(Scope.PUBLIC) public class RestStreamInferenceAction extends BaseInferenceAction { - private final ThreadPool threadPool; + private final SetOnce threadPool; - public RestStreamInferenceAction(ThreadPool threadPool) { + public RestStreamInferenceAction(SetOnce threadPool) { super(); this.threadPool = Objects.requireNonNull(threadPool); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java index 194ee2e31e461..51f1bc48c8306 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.rest; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; @@ -25,9 +26,9 @@ @ServerlessScope(Scope.PUBLIC) public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { - private final ThreadPool threadPool; + private final SetOnce threadPool; - public RestUnifiedCompletionInferenceAction(ThreadPool threadPool) { + public RestUnifiedCompletionInferenceAction(SetOnce threadPool) { super(); this.threadPool = Objects.requireNonNull(threadPool); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index 784ea0e7a4b1b..042c8b8a8346d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; @@ -58,7 +59,7 @@ public class ServerSentEventsRestActionListener implements ActionListener threadPool; /** * A listener for the first part of the next entry to become available for transmission. @@ -70,11 +71,11 @@ public class ServerSentEventsRestActionListener implements ActionListener nextBodyPartListener; - public ServerSentEventsRestActionListener(RestChannel channel, ThreadPool threadPool) { + public ServerSentEventsRestActionListener(RestChannel channel, SetOnce threadPool) { this(channel, channel.request(), threadPool); } - public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, ThreadPool threadPool) { + public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce threadPool) { this.channel = channel; this.params = params; this.threadPool = Objects.requireNonNull(threadPool); @@ -123,7 +124,7 @@ private void initializeStream(InferenceAction.Response response) { nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext( chunkedResponseBodyActionListener, - threadPool.getThreadContext() + threadPool.get().getThreadContext() ); // subscribe will call onSubscribe, which requests the first chunk diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java index 26219f509be4a..f67680ef6b625 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -13,8 +13,10 @@ import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.junit.After; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -23,10 +25,18 @@ import static org.hamcrest.Matchers.instanceOf; public class RestStreamInferenceActionTests extends RestActionTestCase { + private final SetOnce threadPool = new SetOnce<>(); @Before public void setUpAction() { - controller().registerHandler(new RestStreamInferenceAction(new TestThreadPool(getTestName()))); + threadPool.set(new TestThreadPool(getTestName())); + controller().registerHandler(new RestStreamInferenceAction(threadPool)); + } + + @After + public void tearDownAction() { + terminate(threadPool.get()); + } public void testStreamIsTrue() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java index 4c40129e856ba..9dc23c890c14d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -18,8 +18,10 @@ import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.junit.After; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -28,10 +30,17 @@ import static org.hamcrest.Matchers.instanceOf; public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { + private final SetOnce threadPool = new SetOnce<>(); @Before public void setUpAction() { - controller().registerHandler(new RestUnifiedCompletionInferenceAction(new TestThreadPool(getTestName()))); + threadPool.set(new TestThreadPool(getTestName())); + controller().registerHandler(new RestUnifiedCompletionInferenceAction(threadPool)); + } + + @After + public void tearDownAction() { + terminate(threadPool.get()); } public void testStreamIsTrue() { From 118397b839e2a817dea76abc9b9c839d8f629be7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 19 Dec 2024 11:23:47 -0500 Subject: [PATCH 5/5] Updating the changelog --- docs/changelog/118999.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/118999.yaml b/docs/changelog/118999.yaml index 01a0493c1ea5c..0188cebbd7685 100644 --- a/docs/changelog/118999.yaml +++ b/docs/changelog/118999.yaml @@ -1,5 +1,5 @@ pr: 118999 -summary: Fix lose of context in the inference API for streaming APIs +summary: Fix loss of context in the inference API for streaming APIs area: Machine Learning type: bug issues: