Skip to content

Commit

Permalink
Clear transient header from system context
Browse files Browse the repository at this point in the history
Signed-off-by: Gagan Juneja <[email protected]>
  • Loading branch information
Gagan Juneja committed Jan 4, 2024
1 parent db6395a commit 495fad3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public StoredContext stashContext() {
);
}

final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext);
if (!transientHeaders.isEmpty()) {
threadContextStruct = threadContextStruct.putTransient(transientHeaders);
}
Expand All @@ -182,7 +182,7 @@ public StoredContext stashContext() {
public Writeable captureAsWriteable() {
final ThreadContextStruct context = threadLocal.get();
return out -> {
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
};
}
Expand Down Expand Up @@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);

boolean transientHeadersModified = false;
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext);
if (!transientHeaders.isEmpty()) {
newTransientHeaders.putAll(transientHeaders);
transientHeadersModified = true;
Expand Down Expand Up @@ -322,7 +322,7 @@ public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
@Override
public void writeTo(StreamOutput out) throws IOException {
final ThreadContextStruct context = threadLocal.get();
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
}

Expand Down Expand Up @@ -534,17 +534,7 @@ boolean isDefaultContext() {
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
ThreadContextStruct threadContextStruct = threadLocal.get();
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transientsForSystemContext(threadContextStruct.transientHeaders)));
ThreadContextStruct newThreadContextStruct = new ThreadContextStruct(
threadContextStruct.requestHeaders,
threadContextStruct.responseHeaders,
transients,
threadContextStruct.persistentHeaders,
threadContextStruct.isSystemContext
);
threadLocal.set(newThreadContextStruct.setSystemContext());
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

/**
Expand Down Expand Up @@ -583,15 +573,15 @@ public static Map<String, String> buildDefaultHeaders(Settings settings) {
}
}

private Map<String, Object> propagateTransients(Map<String, Object> source) {
private Map<String, Object> propagateTransients(Map<String, Object> source, boolean isSystemContext) {
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(source)));
propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext)));
return transients;
}

private Map<String, String> propagateHeaders(Map<String, Object> source) {
private Map<String, String> propagateHeaders(Map<String, Object> source, boolean isSystemContext) {
final Map<String, String> headers = new HashMap<>();
propagators.forEach(p -> headers.putAll(p.headers(source)));
propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext)));
return headers;
}

Expand All @@ -613,11 +603,13 @@ private static final class ThreadContextStruct {
// saving current warning headers' size not to recalculate the size with every new warning header
private final long warningHeadersSize;

private ThreadContextStruct setSystemContext() {
private ThreadContextStruct setSystemContext(final List<ThreadContextStatePropagator> propagators) {
if (isSystemContext) {
return this;
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true);
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true)));
return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true);
}

private ThreadContextStruct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.opensearch.common.annotation.PublicApi;

import java.util.HashMap;
import java.util.Map;

/**
Expand All @@ -23,24 +22,19 @@
public interface ThreadContextStatePropagator {
/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
* @param source current context transient headers
*
* @param source current context transient headers
* @param isSystemContext if the propagation is for system context.
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
Map<String, Object> transients(Map<String, Object> source);

/**
* Returns the list of transient headers that need to be propagated to the child system context.
* @param source current context transient headers
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
default Map<String, Object> transientsForSystemContext(Map<String, Object> source) {
return new HashMap<>();
}
Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext);

/**
* Returns the list of request headers that needs to be propagated from current context to request.
* @param source current context headers
*
* @param source current context headers
* @param isSystemContext if the propagation is for system context.
* @return the list of request headers that needs to be propagated from current context to request
*/
Map<String, String> headers(Map<String, Object> source);
Map<String, String> headers(Map<String, Object> source, boolean isSystemContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
*/
public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator {
@Override
public Map<String, Object> transients(Map<String, Object> source) {
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
final Map<String, Object> transients = new HashMap<>();

if (source.containsKey(TASK_ID)) {
Expand All @@ -32,12 +32,7 @@ public Map<String, Object> transients(Map<String, Object> source) {
}

@Override
public Map<String, Object> transientsForSystemContext(Map<String, Object> source) {
return transients(source);
}

@Override
public Map<String, String> headers(Map<String, Object> source) {
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,19 @@ public void put(String key, Span span) {
}

@Override
public Map<String, Object> transients(Map<String, Object> source) {
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
final Map<String, Object> transients = new HashMap<>();

if (source.containsKey(CURRENT_SPAN)) {
if (isSystemContext == false && source.containsKey(CURRENT_SPAN)) {
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
if (current != null) {
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
}
}

return transients;
}

@Override
public Map<String, String> headers(Map<String, Object> source) {
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
final Map<String, String> headers = new HashMap<>();

if (source.containsKey(CURRENT_SPAN)) {
Expand Down

0 comments on commit 495fad3

Please sign in to comment.