From ff81b70f886b160997da8cb5c5c26fd06e8a127f Mon Sep 17 00:00:00 2001 From: Gagan Juneja Date: Wed, 6 Dec 2023 18:55:36 +0530 Subject: [PATCH] Add support for conditional Transient header propagation Signed-off-by: Gagan Juneja --- .../common/util/concurrent/ThreadContext.java | 31 ++++++++++++++++--- .../ThreadContextStatePropagator.java | 8 +++-- .../TaskThreadContextStatePropagator.java | 2 +- ...hreadContextBasedTracerContextStorage.java | 15 +++++---- .../org/opensearch/threadpool/ThreadPool.java | 10 +++--- 5 files changed, 48 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 3da21a6777456..2e5ac54291e2a 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -161,7 +161,7 @@ public StoredContext stashContext() { ); } - final Map transientHeaders = propagateTransients(context.transientHeaders); + final Map transientHeaders = propagateTransients(context.transientHeaders, false); if (!transientHeaders.isEmpty()) { threadContextStruct = threadContextStruct.putTransient(transientHeaders); } @@ -230,7 +230,11 @@ public StoredContext stashAndMergeHeaders(Map headers) { * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. */ public StoredContext newStoredContext(boolean preserveResponseHeaders) { - return newStoredContext(preserveResponseHeaders, Collections.emptyList()); + return newStoredContext(preserveResponseHeaders, Collections.emptyList(), false); + } + + public StoredContext newStoredContext(boolean preserveResponseHeaders, boolean forIndependentTask) { + return newStoredContext(preserveResponseHeaders, Collections.emptyList(), forIndependentTask); } /** @@ -241,11 +245,28 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders) { * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. */ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collection transientHeadersToClear) { + return newStoredContext(preserveResponseHeaders, transientHeadersToClear, false); + } + + /** + * Just like {@link #stashContext()} but no default context is set. Instead, the {@code transientHeadersToClear} argument can be used + * to clear specific transient headers in the new context. All headers (with the possible exception of {@code responseHeaders}) are + * restored by closing the returned {@link StoredContext}. + * + * It passes the forIndependentTask parameter to the transient header propagators to decide whether the header needs to + * be propagated or not. + * @param preserveResponseHeaders if set to true the response headers of the restore thread will be preserved. + */ + public StoredContext newStoredContext( + boolean preserveResponseHeaders, + Collection transientHeadersToClear, + boolean forIndependentTask + ) { final ThreadContextStruct originalContext = threadLocal.get(); final Map newTransientHeaders = new HashMap<>(originalContext.transientHeaders); boolean transientHeadersModified = false; - final Map transientHeaders = propagateTransients(originalContext.transientHeaders); + final Map transientHeaders = propagateTransients(originalContext.transientHeaders, forIndependentTask); if (!transientHeaders.isEmpty()) { newTransientHeaders.putAll(transientHeaders); transientHeadersModified = true; @@ -573,9 +594,9 @@ public static Map buildDefaultHeaders(Settings settings) { } } - private Map propagateTransients(Map source) { + private Map propagateTransients(Map source, boolean forIndependentTask) { final Map transients = new HashMap<>(); - propagators.forEach(p -> transients.putAll(p.transients(source))); + propagators.forEach(p -> transients.putAll(p.transients(source, forIndependentTask))); return transients; } diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java index dac70b0e8124e..3385fe862fcbf 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java @@ -22,10 +22,14 @@ public interface ThreadContextStatePropagator { /** * Returns the list of transient headers that needs to be propagated from current context to new thread context. - * @param source current context transient headers + * + * @param source current context transient headers + * @param forIndependentTask Helps in deciding whether transient header needs to be propagated or not for the + * scenarios where the new independent/background/scheduled task is being spawned from the + * current thread's context. * @return the list of transient headers that needs to be propagated from current context to new thread context */ - Map transients(Map source); + Map transients(Map source, boolean forIndependentTask); /** * Returns the list of request headers that needs to be propagated from current context to request. diff --git a/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java b/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java index ed111b34f048f..d0d0b0dd9f801 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java +++ b/server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java @@ -21,7 +21,7 @@ */ public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator { @Override - public Map transients(Map source) { + public Map transients(Map source, boolean forIndependentTask) { final Map transients = new HashMap<>(); if (source.containsKey(TASK_ID)) { diff --git a/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java b/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java index 863f56d9fbe94..78fa6238357f9 100644 --- a/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java +++ b/server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java @@ -50,14 +50,17 @@ public void put(String key, Span span) { } @Override - public Map transients(Map source) { + public Map transients(Map source, boolean forIndependentTask) { final Map transients = new HashMap<>(); - - if (source.containsKey(CURRENT_SPAN)) { - final SpanReference current = (SpanReference) source.get(CURRENT_SPAN); - if (current != null) { - transients.put(CURRENT_SPAN, new SpanReference(current.getSpan())); + if (forIndependentTask == false) { + if (source.containsKey(CURRENT_SPAN)) { + final SpanReference current = (SpanReference) source.get(CURRENT_SPAN); + if (current != null) { + transients.put(CURRENT_SPAN, new SpanReference(current.getSpan())); + } } + } else { + transients.put(CURRENT_SPAN, null); } return transients; diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index c825ecc8abe9f..0ab3c8f518396 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -460,11 +460,13 @@ public ExecutorService executor(String name) { */ @Override public ScheduledCancellable schedule(Runnable command, TimeValue delay, String executor) { - command = threadContext.preserveContext(command); - if (!Names.SAME.equals(executor)) { - command = new ThreadedRunnable(command, executor(executor)); + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(false, true)) { + command = threadContext.preserveContext(command); + if (!Names.SAME.equals(executor)) { + command = new ThreadedRunnable(command, executor(executor)); + } + return new ScheduledCancellableAdapter(scheduler.schedule(command, delay.millis(), TimeUnit.MILLISECONDS)); } - return new ScheduledCancellableAdapter(scheduler.schedule(command, delay.millis(), TimeUnit.MILLISECONDS)); } public void scheduleUnlessShuttingDown(TimeValue delay, String executor, Runnable command) {