Skip to content

Commit

Permalink
Adds unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Gagan Juneja <[email protected]>
  • Loading branch information
Gagan Juneja committed Jan 5, 2024
1 parent 1d0e11c commit 04867e8
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,73 @@ public void testMarkAsSystemContext() throws IOException {
assertFalse(threadContext.isSystemContext());
}

public void testSystemContextWithPropagator() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.registerThreadContextStatePropagator(createDummyPropagator(("test_transient_propagation_key")));
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", 1);
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("bar", threadContext.getHeader("foo"));
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
threadContext.markAsSystemContext();
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

public void testSerializeSystemContext() throws IOException {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.registerThreadContextStatePropagator(createDummyPropagator(("test_transient_propagation_key")));
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", "test");
BytesStreamOutput out = new BytesStreamOutput();
BytesStreamOutput outFromSystemContext = new BytesStreamOutput();
threadContext.writeTo(out);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.markAsSystemContext();
threadContext.writeTo(outFromSystemContext);
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(outFromSystemContext.bytes().streamInput());
assertNull(threadContext.getHeader("test_transient_propagation_key"));
}
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(out.bytes().streamInput());
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("test", threadContext.getHeader("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

private ThreadContextStatePropagator createDummyPropagator(final String key) {
return new ThreadContextStatePropagator() {
@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
Map<String, Object> transients = new HashMap<>();
if (isSystemContext == false && source.containsKey(key)) {
transients.put(key, source.get(key));
}
return transients;
}

@Override
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
Map<String, String> headers = new HashMap<>();
if (isSystemContext == false && source.containsKey(key)) {
headers.put(key, (String) source.get(key));
}
return headers;
}
};
}

public void testPutHeaders() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.tasks;

import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

public class TaskThreadContextStatePropagatorTests extends OpenSearchTestCase {
private final TaskThreadContextStatePropagator taskThreadContextStatePropagator = new TaskThreadContextStatePropagator();

public void testTransient() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, false);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}

public void testTransientForSystemContext() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, true);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,20 @@ public void run() {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}

public void testSpanNotPropagatedToChildSystemThreadContext() {
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));

try (SpanScope scope = tracer.withSpanInScope(span)) {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}

assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}

0 comments on commit 04867e8

Please sign in to comment.