diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index ff871c96e1ad9..6580b0e0085ef 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -161,7 +161,7 @@ public StoredContext stashContext() { ); } - final Map transientHeaders = propagateTransients(context.transientHeaders); + final Map transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext); if (!transientHeaders.isEmpty()) { threadContextStruct = threadContextStruct.putTransient(transientHeaders); } @@ -182,7 +182,7 @@ public StoredContext stashContext() { public Writeable captureAsWriteable() { final ThreadContextStruct context = threadLocal.get(); return out -> { - final Map propagatedHeaders = propagateHeaders(context.transientHeaders); + final Map propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext); context.writeTo(out, defaultHeader, propagatedHeaders); }; } @@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio final Map newTransientHeaders = new HashMap<>(originalContext.transientHeaders); boolean transientHeadersModified = false; - final Map transientHeaders = propagateTransients(originalContext.transientHeaders); + final Map transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext); if (!transientHeaders.isEmpty()) { newTransientHeaders.putAll(transientHeaders); transientHeadersModified = true; @@ -322,7 +322,7 @@ public Supplier wrapRestorable(StoredContext storedContext) { @Override public void writeTo(StreamOutput out) throws IOException { final ThreadContextStruct context = threadLocal.get(); - final Map propagatedHeaders = propagateHeaders(context.transientHeaders); + final Map propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext); context.writeTo(out, defaultHeader, propagatedHeaders); } @@ -534,17 +534,7 @@ boolean isDefaultContext() { * by the system itself rather than by a user action. */ public void markAsSystemContext() { - ThreadContextStruct threadContextStruct = threadLocal.get(); - final Map transients = new HashMap<>(); - propagators.forEach(p -> transients.putAll(p.transientsForSystemContext(threadContextStruct.transientHeaders))); - ThreadContextStruct newThreadContextStruct = new ThreadContextStruct( - threadContextStruct.requestHeaders, - threadContextStruct.responseHeaders, - transients, - threadContextStruct.persistentHeaders, - threadContextStruct.isSystemContext - ); - threadLocal.set(newThreadContextStruct.setSystemContext()); + threadLocal.set(threadLocal.get().setSystemContext(propagators)); } /** @@ -583,15 +573,15 @@ public static Map buildDefaultHeaders(Settings settings) { } } - private Map propagateTransients(Map source) { + private Map propagateTransients(Map source, boolean isSystemContext) { final Map transients = new HashMap<>(); - propagators.forEach(p -> transients.putAll(p.transients(source))); + propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext))); return transients; } - private Map propagateHeaders(Map source) { + private Map propagateHeaders(Map source, boolean isSystemContext) { final Map headers = new HashMap<>(); - propagators.forEach(p -> headers.putAll(p.headers(source))); + propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext))); return headers; } @@ -613,11 +603,13 @@ private static final class ThreadContextStruct { // saving current warning headers' size not to recalculate the size with every new warning header private final long warningHeadersSize; - private ThreadContextStruct setSystemContext() { + private ThreadContextStruct setSystemContext(final List propagators) { if (isSystemContext) { return this; } - return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true); + final Map transients = new HashMap<>(); + propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true))); + return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true); } private ThreadContextStruct( diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java index 73346cdca256b..d6b65b4c0bdf9 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java @@ -10,7 +10,6 @@ import org.opensearch.common.annotation.PublicApi; -import java.util.HashMap; import java.util.Map; /** @@ -23,24 +22,19 @@ public interface ThreadContextStatePropagator { /** * Returns the list of transient headers that needs to be propagated from current context to new thread context. - * @param source current context transient headers + * + * @param source current context transient headers + * @param isSystemContext if the propagation is for system context. * @return the list of transient headers that needs to be propagated from current context to new thread context */ - Map transients(Map source); - - /** - * Returns the list of transient headers that need to be propagated to the child system context. - * @param source current context transient headers - * @return the list of transient headers that needs to be propagated from current context to new thread context - */ - default Map transientsForSystemContext(Map source) { - return new HashMap<>(); - } + Map transients(Map source, boolean isSystemContext); /** * Returns the list of request headers that needs to be propagated from current context to request. - * @param source current context headers + * + * @param source current context headers + * @param isSystemContext if the propagation is for system context. * @return the list of request headers that needs to be propagated from current context to request */ - Map headers(Map source); + Map headers(Map source, boolean isSystemContext); } diff --git a/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java b/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java index 73b9123df3cc3..b7ed5e620b73b 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java +++ b/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java @@ -21,7 +21,7 @@ */ public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator { @Override - public Map transients(Map source) { + public Map transients(Map source, boolean isSystemContext) { final Map transients = new HashMap<>(); if (source.containsKey(TASK_ID)) { @@ -32,12 +32,7 @@ public Map transients(Map source) { } @Override - public Map transientsForSystemContext(Map source) { - return transients(source); - } - - @Override - public Map headers(Map source) { + public Map headers(Map source, boolean isSystemContext) { return Collections.emptyMap(); } } diff --git a/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java b/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java index 863f56d9fbe94..756d364cfb80d 100644 --- a/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java +++ b/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java @@ -50,21 +50,19 @@ public void put(String key, Span span) { } @Override - public Map transients(Map source) { + public Map transients(Map source, boolean isSystemContext) { final Map transients = new HashMap<>(); - - if (source.containsKey(CURRENT_SPAN)) { + if (isSystemContext == false && source.containsKey(CURRENT_SPAN)) { final SpanReference current = (SpanReference) source.get(CURRENT_SPAN); if (current != null) { transients.put(CURRENT_SPAN, new SpanReference(current.getSpan())); } } - return transients; } @Override - public Map headers(Map source) { + public Map headers(Map source, boolean isSystemContext) { final Map headers = new HashMap<>(); if (source.containsKey(CURRENT_SPAN)) {