From 71f1fabe149fd0777edf44502ace4a8f0911feeb Mon Sep 17 00:00:00 2001 From: Gagan Juneja Date: Tue, 9 Jan 2024 22:54:02 +0530 Subject: [PATCH] Add support for conditional Transient header propagation (#11490) * Clear transient header from system context Signed-off-by: Gagan Juneja * Clear transient header from system context Signed-off-by: Gagan Juneja * Adds changelog Signed-off-by: Gagan Juneja * Update CHANGELOG.md Co-authored-by: Andriy Redko Signed-off-by: Gagan Juneja * Adds unit tests Signed-off-by: Gagan Juneja * Refactor code Signed-off-by: Gagan Juneja * Refactor code Signed-off-by: Gagan Juneja * Refactor code Signed-off-by: Gagan Juneja * Supress warning Signed-off-by: Gagan Juneja * Refactor code Signed-off-by: Gagan Juneja --------- Signed-off-by: Gagan Juneja Signed-off-by: Gagan Juneja Co-authored-by: Gagan Juneja Co-authored-by: Andriy Redko --- CHANGELOG.md | 1 + .../common/util/concurrent/ThreadContext.java | 24 ++++--- .../ThreadContextStatePropagator.java | 30 ++++++++- .../TaskThreadContextStatePropagator.java | 13 ++++ ...hreadContextBasedTracerContextStorage.java | 19 +++++- .../transport/TransportService.java | 17 ++--- .../util/concurrent/ThreadContextTests.java | 67 +++++++++++++++++++ ...TaskThreadContextStatePropagatorTests.java | 34 ++++++++++ ...ContextBasedTracerContextStorageTests.java | 16 +++++ 9 files changed, 193 insertions(+), 28 deletions(-) create mode 100644 server/src/test/java/org/opensearch/tasks/TaskThreadContextStatePropagatorTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 45ab53b0d4246..6da048c65d8d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -214,6 +214,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix template setting override for replication type ([#11417](https://github.com/opensearch-project/OpenSearch/pull/11417)) - Fix Automatic addition of protocol broken in #11512 ([#11609](https://github.com/opensearch-project/OpenSearch/pull/11609)) - Fix issue when calling Delete PIT endpoint and no PITs exist ([#11711](https://github.com/opensearch-project/OpenSearch/pull/11711)) +- Fix tracing context propagation for local transport instrumentation ([#11490](https://github.com/opensearch-project/OpenSearch/pull/11490)) ### Security diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 3da21a6777456..6580b0e0085ef 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -161,7 +161,7 @@ public StoredContext stashContext() { ); } - final Map transientHeaders = propagateTransients(context.transientHeaders); + final Map transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext); if (!transientHeaders.isEmpty()) { threadContextStruct = threadContextStruct.putTransient(transientHeaders); } @@ -182,7 +182,7 @@ public StoredContext stashContext() { public Writeable captureAsWriteable() { final ThreadContextStruct context = threadLocal.get(); return out -> { - final Map propagatedHeaders = propagateHeaders(context.transientHeaders); + final Map propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext); context.writeTo(out, defaultHeader, propagatedHeaders); }; } @@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio final Map newTransientHeaders = new HashMap<>(originalContext.transientHeaders); boolean transientHeadersModified = false; - final Map transientHeaders = propagateTransients(originalContext.transientHeaders); + final Map transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext); if (!transientHeaders.isEmpty()) { newTransientHeaders.putAll(transientHeaders); transientHeadersModified = true; @@ -322,7 +322,7 @@ public Supplier wrapRestorable(StoredContext storedContext) { @Override public void writeTo(StreamOutput out) throws IOException { final ThreadContextStruct context = threadLocal.get(); - final Map propagatedHeaders = propagateHeaders(context.transientHeaders); + final Map propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext); context.writeTo(out, defaultHeader, propagatedHeaders); } @@ -534,7 +534,7 @@ boolean isDefaultContext() { * by the system itself rather than by a user action. */ public void markAsSystemContext() { - threadLocal.set(threadLocal.get().setSystemContext()); + threadLocal.set(threadLocal.get().setSystemContext(propagators)); } /** @@ -573,15 +573,15 @@ public static Map buildDefaultHeaders(Settings settings) { } } - private Map propagateTransients(Map source) { + private Map propagateTransients(Map source, boolean isSystemContext) { final Map transients = new HashMap<>(); - propagators.forEach(p -> transients.putAll(p.transients(source))); + propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext))); return transients; } - private Map propagateHeaders(Map source) { + private Map propagateHeaders(Map source, boolean isSystemContext) { final Map headers = new HashMap<>(); - propagators.forEach(p -> headers.putAll(p.headers(source))); + propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext))); return headers; } @@ -603,11 +603,13 @@ private static final class ThreadContextStruct { // saving current warning headers' size not to recalculate the size with every new warning header private final long warningHeadersSize; - private ThreadContextStruct setSystemContext() { + private ThreadContextStruct setSystemContext(final List propagators) { if (isSystemContext) { return this; } - return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true); + final Map transients = new HashMap<>(); + propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true))); + return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true); } private ThreadContextStruct( diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java index dac70b0e8124e..e8c12ae13d5eb 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java @@ -22,15 +22,41 @@ public interface ThreadContextStatePropagator { /** * Returns the list of transient headers that needs to be propagated from current context to new thread context. - * @param source current context transient headers + * + * @param source current context transient headers * @return the list of transient headers that needs to be propagated from current context to new thread context */ + @Deprecated(since = "2.12.0", forRemoval = true) Map transients(Map source); + /** + * Returns the list of transient headers that needs to be propagated from current context to new thread context. + * + * @param source current context transient headers + * @param isSystemContext if the propagation is for system context. + * @return the list of transient headers that needs to be propagated from current context to new thread context + */ + default Map transients(Map source, boolean isSystemContext) { + return transients(source); + }; + /** * Returns the list of request headers that needs to be propagated from current context to request. - * @param source current context headers + * + * @param source current context headers * @return the list of request headers that needs to be propagated from current context to request */ + @Deprecated(since = "2.12.0", forRemoval = true) Map headers(Map source); + + /** + * Returns the list of request headers that needs to be propagated from current context to request. + * + * @param source current context headers + * @param isSystemContext if the propagation is for system context. + * @return the list of request headers that needs to be propagated from current context to request + */ + default Map headers(Map source, boolean isSystemContext) { + return headers(source); + } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java b/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java index ed111b34f048f..99559e45aaaee 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java +++ b/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java @@ -20,7 +20,9 @@ * Propagates TASK_ID across thread contexts */ public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator { + @Override + @SuppressWarnings("removal") public Map transients(Map source) { final Map transients = new HashMap<>(); @@ -32,7 +34,18 @@ public Map transients(Map source) { } @Override + public Map transients(Map source, boolean isSystemContext) { + return transients(source); + } + + @Override + @SuppressWarnings("removal") public Map headers(Map source) { return Collections.emptyMap(); } + + @Override + public Map headers(Map source, boolean isSystemContext) { + return headers(source); + } } diff --git a/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java b/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java index 863f56d9fbe94..908164d1935a7 100644 --- a/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java +++ b/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java @@ -12,6 +12,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.concurrent.ThreadContextStatePropagator; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -50,20 +51,29 @@ public void put(String key, Span span) { } @Override + @SuppressWarnings("removal") public Map transients(Map source) { final Map transients = new HashMap<>(); - if (source.containsKey(CURRENT_SPAN)) { final SpanReference current = (SpanReference) source.get(CURRENT_SPAN); if (current != null) { transients.put(CURRENT_SPAN, new SpanReference(current.getSpan())); } } - return transients; } @Override + public Map transients(Map source, boolean isSystemContext) { + if (isSystemContext == true) { + return Collections.emptyMap(); + } else { + return transients(source); + } + } + + @Override + @SuppressWarnings("removal") public Map headers(Map source) { final Map headers = new HashMap<>(); @@ -77,6 +87,11 @@ public Map headers(Map source) { return headers; } + @Override + public Map headers(Map source, boolean isSystemContext) { + return headers(source); + } + Span getCurrentSpan(String key) { SpanReference currentSpanRef = threadContext.getTransient(key); return (currentSpanRef == null) ? null : currentSpanRef.getSpan(); diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index a1697b1898eeb..d50266d8c9e4a 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -868,19 +868,10 @@ public final void sendRequest( final TransportRequestOptions options, final TransportResponseHandler handler ) { - if (connection == localNodeConnection) { - // See please https://github.com/opensearch-project/OpenSearch/issues/10291 - sendRequestAsync(connection, action, request, options, handler); - } else { - final Span span = tracer.startSpan(SpanBuilder.from(action, connection)); - try (SpanScope spanScope = tracer.withSpanInScope(span)) { - TransportResponseHandler traceableTransportResponseHandler = TraceableTransportResponseHandler.create( - handler, - span, - tracer - ); - sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler); - } + final Span span = tracer.startSpan(SpanBuilder.from(action, connection)); + try (SpanScope spanScope = tracer.withSpanInScope(span)) { + TransportResponseHandler traceableTransportResponseHandler = TraceableTransportResponseHandler.create(handler, span, tracer); + sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler); } } diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index a0531c76bf897..10669ca1a805b 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -44,6 +44,8 @@ import java.util.Map; import java.util.function.Supplier; +import org.mockito.Mockito; + import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -740,6 +742,71 @@ public void testMarkAsSystemContext() throws IOException { assertFalse(threadContext.isSystemContext()); } + public void testSystemContextWithPropagator() { + Settings build = Settings.builder().put("request.headers.default", "1").build(); + Map transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test"); + Map transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test"); + Map headerMap = Collections.singletonMap("test_transient_propagation_key", "test"); + Map headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test"); + ThreadContext threadContext = new ThreadContext(build); + ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class); + Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap()); + Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap); + + Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap); + Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap); + threadContext.registerThreadContextStatePropagator(mockPropagator); + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("test_transient_propagation_key", 1); + assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key")); + assertEquals("bar", threadContext.getHeader("foo")); + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + threadContext.markAsSystemContext(); + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("test_transient_propagation_key")); + assertEquals("1", threadContext.getHeader("default")); + } + + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key")); + assertEquals("1", threadContext.getHeader("default")); + } + + public void testSerializeSystemContext() throws IOException { + Settings build = Settings.builder().put("request.headers.default", "1").build(); + Map transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test"); + Map transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test"); + Map headerMap = Collections.singletonMap("test_transient_propagation_key", "test"); + Map headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test"); + ThreadContext threadContext = new ThreadContext(build); + ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class); + Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap()); + Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap); + + Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap); + Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap); + threadContext.registerThreadContextStatePropagator(mockPropagator); + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("test_transient_propagation_key", "test"); + BytesStreamOutput out = new BytesStreamOutput(); + BytesStreamOutput outFromSystemContext = new BytesStreamOutput(); + threadContext.writeTo(out); + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + assertEquals("test", threadContext.getTransient("test_transient_propagation_key")); + threadContext.markAsSystemContext(); + threadContext.writeTo(outFromSystemContext); + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("test_transient_propagation_key")); + threadContext.readHeaders(outFromSystemContext.bytes().streamInput()); + assertNull(threadContext.getHeader("test_transient_propagation_key")); + } + assertEquals("test", threadContext.getTransient("test_transient_propagation_key")); + threadContext.readHeaders(out.bytes().streamInput()); + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals("test", threadContext.getHeader("test_transient_propagation_key")); + assertEquals("1", threadContext.getHeader("default")); + } + public void testPutHeaders() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); diff --git a/server/src/test/java/org/opensearch/tasks/TaskThreadContextStatePropagatorTests.java b/server/src/test/java/org/opensearch/tasks/TaskThreadContextStatePropagatorTests.java new file mode 100644 index 0000000000000..bfa0d566aabd7 --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskThreadContextStatePropagatorTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +public class TaskThreadContextStatePropagatorTests extends OpenSearchTestCase { + private final TaskThreadContextStatePropagator taskThreadContextStatePropagator = new TaskThreadContextStatePropagator(); + + public void testTransient() { + Map transientHeader = new HashMap<>(); + transientHeader.put(TASK_ID, "t_1"); + Map transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, false); + assertEquals("t_1", transientPropagatedHeader.get(TASK_ID)); + } + + public void testTransientForSystemContext() { + Map transientHeader = new HashMap<>(); + transientHeader.put(TASK_ID, "t_1"); + Map transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, true); + assertEquals("t_1", transientPropagatedHeader.get(TASK_ID)); + } +} diff --git a/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java b/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java index ee816aa5f596d..bf11bcaf39a96 100644 --- a/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java +++ b/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java @@ -252,4 +252,20 @@ public void run() { assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue()))); assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue())); } + + public void testSpanNotPropagatedToChildSystemThreadContext() { + final Span span = tracer.startSpan(SpanCreationContext.internal().name("test")); + + try (SpanScope scope = tracer.withSpanInScope(span)) { + try (StoredContext ignored = threadContext.stashContext()) { + assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue()))); + assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span)); + threadContext.markAsSystemContext(); + assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue())); + } + } + + assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue()))); + assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue())); + } }