diff --git a/temporal-sdk/build.gradle b/temporal-sdk/build.gradle
index aef5da968..df71ca34f 100644
--- a/temporal-sdk/build.gradle
+++ b/temporal-sdk/build.gradle
@@ -42,11 +42,14 @@ dependencies {
api project(':temporal-serviceclient')
api group: 'com.google.code.gson', name: 'gson', version: '2.8.6'
api group: 'io.micrometer', name: 'micrometer-core', version: '1.6.0'
+ api group: 'io.opentracing', name: 'opentracing-api', version: '0.33.0'
implementation group: 'com.google.guava', name: 'guava', version: '30.0-jre'
implementation group: 'com.cronutils', name: 'cron-utils', version: '9.1.1'
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.11.3'
implementation group: 'com.fasterxml.jackson.datatype', name: 'jackson-datatype-jsr310', version: '2.11.3'
+ implementation group: 'io.opentracing', name: 'opentracing-util', version: '0.33.0'
+
if (!JavaVersion.current().isJava8()) {
implementation 'javax.annotation:javax.annotation-api:1.3.2'
}
@@ -54,6 +57,7 @@ dependencies {
testImplementation group: 'ch.qos.logback', name: 'logback-classic', version: '1.2.3'
testImplementation group: 'com.googlecode.junit-toolbox', name: 'junit-toolbox', version: '2.4'
testImplementation group: 'junit', name: 'junit', version: '4.13.1'
+ testImplementation group: 'io.opentracing', name: 'opentracing-mock', version: '0.33.0'
}
configurations.all {
diff --git a/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java b/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java
index ee618f192..35918849f 100644
--- a/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java
+++ b/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java
@@ -23,8 +23,12 @@
import java.util.Map;
/**
- * Context Propagators are used to propagate information from workflow to activity, workflow to
- * child workflow, and workflow to child thread (using {@link io.temporal.workflow.Async}).
+ * Context Propagators are used to propagate information from workflow stub to workflow, workflow to
+ * activity, workflow to child workflow, and workflow to child thread (using {@link
+ * io.temporal.workflow.Async}).
+ *
+ *
It is important to note that all threads share one ContextPropagator instance, so your
+ * implementation must be thread-safe and store any state in ThreadLocal variables.
*
*
A sample ContextPropagator
that copies all {@link org.slf4j.MDC} entries starting
* with a given prefix along the code path looks like this:
@@ -126,4 +130,27 @@ public interface ContextPropagator {
/** Sets the current context */
void setCurrentContext(Object context);
+
+ /**
+ * This is a lifecycle method, called after the context has been propagated to the
+ * workflow/activity thread but the workflow/activity has not yet started.
+ */
+ default void setUp() {
+ // No-op
+ }
+
+ /** This is a lifecycle method, called after the workflow/activity has completed. */
+ default void finish() {
+ // No-op
+ }
+
+ /**
+ * This is a lifecycle method, called when the workflow/activity finishes by throwing an unhandled
+ * exception. {@link #finish()} is called after this method.
+ *
+ * @param t The unhandled exception that caused the workflow/activity to terminate
+ */
+ default void onError(Throwable t) {
+ // No-op
+ }
}
diff --git a/temporal-sdk/src/main/java/io/temporal/common/context/OpenTracingContextPropagator.java b/temporal-sdk/src/main/java/io/temporal/common/context/OpenTracingContextPropagator.java
new file mode 100644
index 000000000..8d4b155c8
--- /dev/null
+++ b/temporal-sdk/src/main/java/io/temporal/common/context/OpenTracingContextPropagator.java
@@ -0,0 +1,207 @@
+/*
+ * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License is
+ * located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package io.temporal.common.context;
+
+import io.opentracing.Scope;
+import io.opentracing.Span;
+import io.opentracing.SpanContext;
+import io.opentracing.Tracer;
+import io.opentracing.log.Fields;
+import io.opentracing.propagation.Format;
+import io.opentracing.propagation.TextMap;
+import io.opentracing.tag.Tags;
+import io.opentracing.util.GlobalTracer;
+import io.temporal.api.common.v1.Payload;
+import io.temporal.common.converter.DataConverter;
+import io.temporal.internal.logging.LoggerTag;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.slf4j.MDC;
+
+/** Support for OpenTracing spans */
+public class OpenTracingContextPropagator implements ContextPropagator {
+
+ private static final Logger log = LoggerFactory.getLogger(OpenTracingContextPropagator.class);
+
+ private static final ThreadLocal currentOpenTracingSpanContext = new ThreadLocal<>();
+ private static final ThreadLocal currentOpenTracingSpan = new ThreadLocal<>();
+ private static final ThreadLocal currentOpenTracingScope = new ThreadLocal<>();
+
+ public static SpanContext getCurrentOpenTracingSpanContext() {
+ return currentOpenTracingSpanContext.get();
+ }
+
+ public static void setCurrentOpenTracingSpanContext(SpanContext ctx) {
+ if (ctx != null) {
+ currentOpenTracingSpanContext.set(ctx);
+ }
+ }
+
+ @Override
+ public String getName() {
+ return "OpenTracing";
+ }
+
+ @Override
+ public Map serializeContext(Object context) {
+ Map serializedContext = new HashMap<>();
+ Map contextMap = (Map) context;
+ if (contextMap != null) {
+ for (Map.Entry entry : contextMap.entrySet()) {
+ serializedContext.put(
+ entry.getKey(), DataConverter.getDefaultInstance().toPayload(entry.getValue()).get());
+ }
+ }
+ return serializedContext;
+ }
+
+ @Override
+ public Object deserializeContext(Map context) {
+ Map contextMap = new HashMap<>();
+ for (Map.Entry entry : context.entrySet()) {
+ contextMap.put(
+ entry.getKey(),
+ DataConverter.getDefaultInstance()
+ .fromPayload(entry.getValue(), String.class, String.class));
+ }
+ return contextMap;
+ }
+
+ @Override
+ public Object getCurrentContext() {
+ Tracer currentTracer = GlobalTracer.get();
+ Span currentSpan = currentTracer.scopeManager().activeSpan();
+ if (currentSpan != null) {
+ HashMapTextMap contextTextMap = new HashMapTextMap();
+ currentTracer.inject(currentSpan.context(), Format.Builtin.TEXT_MAP, contextTextMap);
+ return contextTextMap.getBackingMap();
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public void setCurrentContext(Object context) {
+ Tracer currentTracer = GlobalTracer.get();
+ Map contextAsMap = (Map) context;
+ if (contextAsMap != null) {
+ HashMapTextMap contextTextMap = new HashMapTextMap(contextAsMap);
+ setCurrentOpenTracingSpanContext(
+ currentTracer.extract(Format.Builtin.TEXT_MAP, contextTextMap));
+ }
+ }
+
+ @Override
+ public void setUp() {
+ Tracer openTracingTracer = GlobalTracer.get();
+ Tracer.SpanBuilder builder = openTracingTracer.buildSpan("cadence.workflow");
+
+ if (MDC.getCopyOfContextMap().containsKey(LoggerTag.WORKFLOW_TYPE)) {
+ builder.withTag("resource.name", MDC.get(LoggerTag.WORKFLOW_TYPE));
+ } else {
+ builder.withTag("resource.name", MDC.get(LoggerTag.ACTIVITY_TYPE));
+ }
+
+ if (getCurrentOpenTracingSpanContext() != null) {
+ builder.asChildOf(getCurrentOpenTracingSpanContext());
+ }
+ Span span = builder.start();
+ openTracingTracer.activateSpan(span);
+ currentOpenTracingSpan.set(span);
+ Scope scope = openTracingTracer.activateSpan(span);
+ currentOpenTracingScope.set(scope);
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ Span span = currentOpenTracingSpan.get();
+ if (span != null) {
+ Tags.ERROR.set(span, true);
+ Map errorData = new HashMap<>();
+ errorData.put(Fields.EVENT, "error");
+ if (t != null) {
+ errorData.put(Fields.ERROR_OBJECT, t);
+ errorData.put(Fields.MESSAGE, t.getMessage());
+ }
+ span.log(errorData);
+ }
+ }
+
+ @Override
+ public void finish() {
+ Scope currentScope = currentOpenTracingScope.get();
+ Span currentSpan = currentOpenTracingSpan.get();
+ if (currentScope != null) {
+ currentScope.close();
+ }
+ if (currentSpan != null) {
+ currentSpan.finish();
+ }
+ currentOpenTracingScope.remove();
+ currentOpenTracingSpan.remove();
+ currentOpenTracingSpanContext.remove();
+ }
+
+ /** Just check for other instances of the same class */
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (this == obj) {
+ return true;
+ }
+ return this.getClass().equals(obj.getClass());
+ }
+
+ @Override
+ public int hashCode() {
+ return this.getClass().hashCode();
+ }
+
+ private class HashMapTextMap implements TextMap {
+ private final HashMap backingMap = new HashMap<>();
+
+ public HashMapTextMap() {
+ // Noop
+ }
+
+ public HashMapTextMap(Map spanData) {
+ backingMap.putAll(spanData);
+ }
+
+ @Override
+ public Iterator> iterator() {
+ return backingMap.entrySet().iterator();
+ }
+
+ @Override
+ public void put(String key, String value) {
+ backingMap.put(key, value);
+ }
+
+ public HashMap getBackingMap() {
+ return backingMap;
+ }
+ }
+}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/AbstractContextThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/AbstractContextThreadLocal.java
new file mode 100644
index 000000000..ea948b2c1
--- /dev/null
+++ b/temporal-sdk/src/main/java/io/temporal/internal/context/AbstractContextThreadLocal.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License is
+ * located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package io.temporal.internal.context;
+
+import io.temporal.common.context.ContextPropagator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** This class holds the current set of context propagators */
+public abstract class AbstractContextThreadLocal {
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractContextThreadLocal.class);
+
+ /**
+ * Returns the context propagators for the current thread
+ *
+ * @return
+ */
+ protected abstract List getPropagatorsForThread();
+
+ /** Sets the context propagators for this thread */
+ public abstract void setContextPropagators(List contextPropagators);
+
+ public List getContextPropagators() {
+ return getPropagatorsForThread();
+ }
+
+ public Map getCurrentContextForPropagation() {
+ Map contextData = new HashMap<>();
+ for (ContextPropagator propagator : getPropagatorsForThread()) {
+ contextData.put(propagator.getName(), propagator.getCurrentContext());
+ }
+ return contextData;
+ }
+
+ /**
+ * Injects the context data into the thread for each configured context propagator
+ *
+ * @param contextData The context data received from the server
+ */
+ public void propagateContextToCurrentThread(Map contextData) {
+ if (contextData == null || contextData.isEmpty()) {
+ return;
+ }
+ for (ContextPropagator propagator : getPropagatorsForThread()) {
+ if (contextData.containsKey(propagator.getName())) {
+ propagator.setCurrentContext(contextData.get(propagator.getName()));
+ }
+ }
+ }
+
+ /** Calls {@link ContextPropagator#setUp()} for each propagator */
+ public void setUpContextPropagators() {
+ for (ContextPropagator propagator : getPropagatorsForThread()) {
+ try {
+ propagator.setUp();
+ } catch (Throwable t) {
+ // Don't let an error in one propagator block the others
+ log.error("Error calling setUp() on a contextpropagator", t);
+ }
+ }
+ }
+
+ /**
+ * Calls {@link ContextPropagator#onError(Throwable)} for each propagator
+ *
+ * @param t The Throwable that caused the workflow/activity to finish
+ */
+ public void onErrorContextPropagators(Throwable t) {
+ for (ContextPropagator propagator : getPropagatorsForThread()) {
+ try {
+ propagator.onError(t);
+ } catch (Throwable t1) {
+ // Don't let an error in one propagator block the others
+ log.error("Error calling onError() on a contextpropagator", t1);
+ }
+ }
+ }
+
+ /** Calls {@link ContextPropagator#finish()} for each propagator */
+ public void finishContextPropagators() {
+ for (ContextPropagator propagator : getPropagatorsForThread()) {
+ try {
+ propagator.finish();
+ } catch (Throwable t) {
+ // Don't let an error in one propagator block the others
+ log.error("Error calling finish() on a contextpropagator", t);
+ }
+ }
+ }
+}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextActivityThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextActivityThreadLocal.java
new file mode 100644
index 000000000..03b2e6359
--- /dev/null
+++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextActivityThreadLocal.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License is
+ * located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package io.temporal.internal.context;
+
+import io.temporal.common.context.ContextPropagator;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ContextActivityThreadLocal extends AbstractContextThreadLocal {
+
+ private static final ThreadLocal> contextPropagators =
+ ThreadLocal.withInitial(() -> new ArrayList<>());
+
+ public static ContextActivityThreadLocal getInstance() {
+ return new ContextActivityThreadLocal();
+ }
+
+ @Override
+ public List getPropagatorsForThread() {
+ return contextPropagators.get();
+ }
+
+ @Override
+ public void setContextPropagators(List contextPropagators) {
+ if (contextPropagators == null || contextPropagators.size() == 0) {
+ return;
+ }
+
+ this.contextPropagators.set(contextPropagators);
+ }
+}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextPropagatorUtils.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextPropagatorUtils.java
new file mode 100644
index 000000000..777ed8007
--- /dev/null
+++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextPropagatorUtils.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License is
+ * located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package io.temporal.internal.context;
+
+import io.temporal.api.common.v1.Header;
+import io.temporal.api.common.v1.Payload;
+import io.temporal.common.context.ContextPropagator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/** Common methods for dealing with context */
+public class ContextPropagatorUtils {
+
+ public static Map extractContextsFromHeaders(
+ List contextPropagators, Header headers) {
+
+ if (contextPropagators == null || contextPropagators.isEmpty()) {
+ return new HashMap<>();
+ }
+
+ if (headers == null) {
+ return new HashMap<>();
+ }
+
+ Map headerData = new HashMap<>();
+ for (Map.Entry pair : headers.getFieldsMap().entrySet()) {
+ headerData.put(pair.getKey(), pair.getValue());
+ }
+
+ Map contextData = new HashMap<>();
+ for (ContextPropagator propagator : contextPropagators) {
+
+ // Only send the context propagator the fields that belong to them
+ // Change the map from MyPropagator:foo -> bar to foo -> bar
+ Map filteredData =
+ headerData.entrySet().stream()
+ .filter(e -> e.getKey().startsWith(propagator.getName()))
+ .collect(
+ Collectors.toMap(
+ e -> e.getKey().substring(propagator.getName().length() + 1),
+ Map.Entry::getValue));
+ contextData.put(propagator.getName(), propagator.deserializeContext(filteredData));
+ }
+
+ return contextData;
+ }
+
+ public static Map extractContextsAndConvertToBytes(
+ List contextPropagators) {
+ if (contextPropagators == null) {
+ return null;
+ }
+ Map result = new HashMap<>();
+ for (ContextPropagator propagator : contextPropagators) {
+ // Get the serialized context from the propagator
+ Map serializedContext =
+ propagator.serializeContext(propagator.getCurrentContext());
+ // Namespace each entry in case of overlaps, so foo -> bar becomes MyPropagator:foo -> bar
+ Map namespacedSerializedContext =
+ serializedContext.entrySet().stream()
+ .collect(
+ Collectors.toMap(
+ e -> propagator.getName() + ":" + e.getKey(), Map.Entry::getValue));
+ result.putAll(namespacedSerializedContext);
+ }
+ return result;
+ }
+}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java
deleted file mode 100644
index 39044b773..000000000
--- a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
- *
- * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Modifications copyright (C) 2017 Uber Technologies, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License"). You may not
- * use this file except in compliance with the License. A copy of the License is
- * located at
- *
- * http://aws.amazon.com/apache2.0
- *
- * or in the "license" file accompanying this file. This file is distributed on
- * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- */
-
-package io.temporal.internal.context;
-
-import io.temporal.common.context.ContextPropagator;
-import io.temporal.workflow.WorkflowThreadLocal;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.function.Supplier;
-
-/** This class holds the current set of context propagators */
-public class ContextThreadLocal {
-
- private static final WorkflowThreadLocal> contextPropagators =
- WorkflowThreadLocal.withInitial(
- new Supplier>() {
- @Override
- public List get() {
- return new ArrayList<>();
- }
- });
-
- /** Sets the list of context propagators for the thread */
- public static void setContextPropagators(List propagators) {
- if (propagators == null || propagators.isEmpty()) {
- return;
- }
- contextPropagators.set(propagators);
- }
-
- public static List getContextPropagators() {
- return contextPropagators.get();
- }
-
- public static Map getCurrentContextForPropagation() {
- Map contextData = new HashMap<>();
- for (ContextPropagator propagator : contextPropagators.get()) {
- contextData.put(propagator.getName(), propagator.getCurrentContext());
- }
- return contextData;
- }
-
- public static void propagateContextToCurrentThread(Map contextData) {
- if (contextData == null || contextData.isEmpty()) {
- return;
- }
- for (ContextPropagator propagator : contextPropagators.get()) {
- if (contextData.containsKey(propagator.getName())) {
- propagator.setCurrentContext(contextData.get(propagator.getName()));
- }
- }
- }
-}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextWorkflowThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextWorkflowThreadLocal.java
new file mode 100644
index 000000000..562670614
--- /dev/null
+++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextWorkflowThreadLocal.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License is
+ * located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package io.temporal.internal.context;
+
+import io.temporal.common.context.ContextPropagator;
+import io.temporal.workflow.WorkflowThreadLocal;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ContextWorkflowThreadLocal extends AbstractContextThreadLocal {
+
+ private static final WorkflowThreadLocal> contextPropagators =
+ WorkflowThreadLocal.withInitial(() -> new ArrayList<>());
+
+ public static ContextWorkflowThreadLocal getInstance() {
+ return new ContextWorkflowThreadLocal();
+ }
+
+ @Override
+ public List getPropagatorsForThread() {
+ return contextPropagators.get();
+ }
+
+ @Override
+ public void setContextPropagators(List contextPropagators) {
+ if (contextPropagators == null || contextPropagators.size() == 0) {
+ return;
+ }
+
+ this.contextPropagators.set(contextPropagators);
+ }
+}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java b/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java
index 335b1a41b..87de9a662 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java
@@ -21,7 +21,6 @@
import com.google.protobuf.util.Timestamps;
import io.temporal.api.command.v1.ContinueAsNewWorkflowExecutionCommandAttributes;
-import io.temporal.api.common.v1.Header;
import io.temporal.api.common.v1.Payload;
import io.temporal.api.common.v1.SearchAttributes;
import io.temporal.api.common.v1.WorkflowExecution;
@@ -29,8 +28,8 @@
import io.temporal.api.history.v1.WorkflowExecutionStartedEventAttributes;
import io.temporal.common.context.ContextPropagator;
import io.temporal.internal.common.ProtobufTimeUtils;
+import io.temporal.internal.context.ContextPropagatorUtils;
import java.time.Duration;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -161,26 +160,8 @@ public List getContextPropagators() {
/** Returns a map of propagated context objects, keyed by propagator name */
Map getPropagatedContexts() {
- if (contextPropagators == null || contextPropagators.isEmpty()) {
- return new HashMap<>();
- }
-
- Header headers = startedAttributes.getHeader();
- if (headers == null) {
- return new HashMap<>();
- }
-
- Map headerData = new HashMap<>();
- for (Map.Entry pair : headers.getFieldsMap().entrySet()) {
- headerData.put(pair.getKey(), pair.getValue());
- }
-
- Map contextData = new HashMap<>();
- for (ContextPropagator propagator : contextPropagators) {
- contextData.put(propagator.getName(), propagator.deserializeContext(headerData));
- }
-
- return contextData;
+ return ContextPropagatorUtils.extractContextsFromHeaders(
+ contextPropagators, startedAttributes.getHeader());
}
void mergeSearchAttributes(SearchAttributes searchAttributes) {
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java
index d485416e9..1e388963d 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java
@@ -35,7 +35,7 @@
import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptorBase;
import io.temporal.failure.CanceledFailure;
import io.temporal.internal.common.CheckedExceptionWrapper;
-import io.temporal.internal.context.ContextThreadLocal;
+import io.temporal.internal.context.ContextWorkflowThreadLocal;
import io.temporal.internal.replay.ExecuteActivityParameters;
import io.temporal.internal.replay.ExecuteLocalActivityParameters;
import io.temporal.internal.replay.ReplayWorkflowContext;
@@ -96,6 +96,8 @@ private NamedRunnable(String name, Runnable runnable) {
private static final Logger log = LoggerFactory.getLogger(DeterministicRunnerImpl.class);
static final String WORKFLOW_ROOT_THREAD_NAME = "workflow-method";
private static final ThreadLocal currentThreadThreadLocal = new ThreadLocal<>();
+ private final ContextWorkflowThreadLocal currentWorkflowThreadLocal =
+ ContextWorkflowThreadLocal.getInstance();
private final Lock lock = new ReentrantLock();
private final ExecutorService threadPool;
@@ -528,7 +530,7 @@ void setRunnerLocal(RunnerLocalInternal key, T value) {
*/
private Map getPropagatedContexts() {
if (currentThreadThreadLocal.get() != null) {
- return ContextThreadLocal.getCurrentContextForPropagation();
+ return currentWorkflowThreadLocal.getCurrentContextForPropagation();
} else {
return workflowContext.getContext().getPropagatedContexts();
}
@@ -536,7 +538,7 @@ private Map getPropagatedContexts() {
private List getContextPropagators() {
if (currentThreadThreadLocal.get() != null) {
- return ContextThreadLocal.getContextPropagators();
+ return currentWorkflowThreadLocal.getContextPropagators();
} else {
return workflowContext.getContext().getContextPropagators();
}
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java
index b915ce1e6..64c4c376a 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java
@@ -32,7 +32,6 @@
import io.temporal.api.common.v1.ActivityType;
import io.temporal.api.common.v1.Header;
import io.temporal.api.common.v1.Memo;
-import io.temporal.api.common.v1.Payload;
import io.temporal.api.common.v1.Payloads;
import io.temporal.api.common.v1.RetryPolicy;
import io.temporal.api.common.v1.SearchAttributes;
@@ -55,6 +54,7 @@
import io.temporal.internal.common.InternalUtils;
import io.temporal.internal.common.OptionsUtils;
import io.temporal.internal.common.ProtobufTimeUtils;
+import io.temporal.internal.context.ContextPropagatorUtils;
import io.temporal.internal.metrics.MetricsType;
import io.temporal.internal.replay.ChildWorkflowTaskFailedException;
import io.temporal.internal.replay.ExecuteActivityParameters;
@@ -298,7 +298,8 @@ private ExecuteActivityParameters constructExecuteActivityParameters(
if (propagators == null) {
propagators = this.contextPropagators;
}
- Header header = toHeaderGrpc(extractContextsAndConvertToBytes(propagators));
+ Header header =
+ toHeaderGrpc(ContextPropagatorUtils.extractContextsAndConvertToBytes(propagators));
if (header != null) {
attributes.setHeader(header);
}
@@ -416,7 +417,8 @@ private Promise> executeChildWorkflow(
attributes.setRetryPolicy(toRetryPolicy(retryOptions));
}
attributes.setCronSchedule(OptionsUtils.safeGet(options.getCronSchedule()));
- Header header = toHeaderGrpc(extractContextsAndConvertToBytes(propagators));
+ Header header =
+ toHeaderGrpc(ContextPropagatorUtils.extractContextsAndConvertToBytes(propagators));
if (header != null) {
attributes.setHeader(header);
}
@@ -459,18 +461,6 @@ private Promise> executeChildWorkflow(
return result;
}
- private Map extractContextsAndConvertToBytes(
- List contextPropagators) {
- if (contextPropagators == null) {
- return null;
- }
- Map result = new HashMap<>();
- for (ContextPropagator propagator : contextPropagators) {
- result.putAll(propagator.serializeContext(propagator.getCurrentContext()));
- }
- return result;
- }
-
private RuntimeException mapChildWorkflowException(Exception failure) {
if (failure == null) {
return null;
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java
index 69128aeea..5e5632f81 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java
@@ -54,7 +54,6 @@
import io.temporal.client.WorkflowServiceException;
import io.temporal.client.WorkflowStub;
import io.temporal.common.RetryOptions;
-import io.temporal.common.context.ContextPropagator;
import io.temporal.failure.CanceledFailure;
import io.temporal.failure.FailureConverter;
import io.temporal.internal.common.CheckedExceptionWrapper;
@@ -63,10 +62,9 @@
import io.temporal.internal.common.StatusUtils;
import io.temporal.internal.common.WorkflowExecutionFailedException;
import io.temporal.internal.common.WorkflowExecutionUtils;
+import io.temporal.internal.context.ContextPropagatorUtils;
import io.temporal.internal.external.GenericWorkflowClientExternal;
import java.lang.reflect.Type;
-import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
@@ -234,7 +232,8 @@ private StartWorkflowExecutionRequest newStartWorkflowExecutionRequest(
.putAllIndexedFields(convertMemoFromObjectToBytes(o.getSearchAttributes())));
}
if (o.getContextPropagators() != null && !o.getContextPropagators().isEmpty()) {
- Map context = extractContextsAndConvertToBytes(o.getContextPropagators());
+ Map context =
+ ContextPropagatorUtils.extractContextsAndConvertToBytes(o.getContextPropagators());
request.setHeader(Header.newBuilder().putAllFields(context));
}
return request.build();
@@ -248,18 +247,6 @@ private Map convertSearchAttributesFromObjectToBytes(Map extractContextsAndConvertToBytes(
- List contextPropagators) {
- if (contextPropagators == null) {
- return null;
- }
- Map result = new HashMap<>();
- for (ContextPropagator propagator : contextPropagators) {
- result.putAll(propagator.serializeContext(propagator.getCurrentContext()));
- }
- return result;
- }
-
@Override
public WorkflowExecution start(Object... args) {
if (!options.isPresent()) {
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java
index 89a097023..a5f983dca 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java
@@ -22,7 +22,7 @@
import com.google.common.util.concurrent.RateLimiter;
import io.temporal.common.context.ContextPropagator;
import io.temporal.failure.CanceledFailure;
-import io.temporal.internal.context.ContextThreadLocal;
+import io.temporal.internal.context.ContextWorkflowThreadLocal;
import io.temporal.internal.logging.LoggerTag;
import io.temporal.internal.metrics.MetricsType;
import io.temporal.internal.replay.ReplayWorkflowContext;
@@ -98,8 +98,12 @@ public void run() {
MDC.put(LoggerTag.NAMESPACE, replayWorkflowContext.getNamespace());
// Repopulate the context(s)
- ContextThreadLocal.setContextPropagators(this.contextPropagators);
- ContextThreadLocal.propagateContextToCurrentThread(this.propagatedContexts);
+ ContextWorkflowThreadLocal contextWorkflowThreadLocal =
+ ContextWorkflowThreadLocal.getInstance();
+ contextWorkflowThreadLocal.setContextPropagators(this.contextPropagators);
+ contextWorkflowThreadLocal.propagateContextToCurrentThread(this.propagatedContexts);
+ contextWorkflowThreadLocal.setUpContextPropagators();
+
try {
// initialYield blocks thread until the first runUntilBlocked is called.
// Otherwise r starts executing without control of the sync.
@@ -107,20 +111,25 @@ public void run() {
cancellationScope.run();
} catch (DestroyWorkflowThreadError e) {
if (!threadContext.isDestroyRequested()) {
+ contextWorkflowThreadLocal.onErrorContextPropagators(e);
threadContext.setUnhandledException(e);
}
} catch (Error e) {
+ contextWorkflowThreadLocal.onErrorContextPropagators(e);
threadContext.setUnhandledException(e);
} catch (CanceledFailure e) {
if (!isCancelRequested()) {
+ contextWorkflowThreadLocal.onErrorContextPropagators(e);
threadContext.setUnhandledException(e);
}
if (log.isDebugEnabled()) {
log.debug(String.format("Workflow thread \"%s\" run canceled", name));
}
} catch (Throwable e) {
+ contextWorkflowThreadLocal.onErrorContextPropagators(e);
threadContext.setUnhandledException(e);
} finally {
+ contextWorkflowThreadLocal.finishContextPropagators();
DeterministicRunnerImpl.setCurrentThreadInternal(null);
threadContext.setStatus(Status.DONE);
thread.setName(originalName);
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java
index c821e4ab6..8d726cb7a 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java
@@ -25,7 +25,6 @@
import com.uber.m3.tally.Stopwatch;
import com.uber.m3.util.Duration;
import com.uber.m3.util.ImmutableMap;
-import io.temporal.api.common.v1.Payload;
import io.temporal.api.common.v1.WorkflowExecution;
import io.temporal.api.failure.v1.CanceledFailureInfo;
import io.temporal.api.failure.v1.Failure;
@@ -33,18 +32,18 @@
import io.temporal.api.workflowservice.v1.RespondActivityTaskCanceledRequest;
import io.temporal.api.workflowservice.v1.RespondActivityTaskCompletedRequest;
import io.temporal.api.workflowservice.v1.RespondActivityTaskFailedRequest;
-import io.temporal.common.context.ContextPropagator;
+import io.temporal.failure.FailureConverter;
import io.temporal.internal.common.GrpcRetryer;
import io.temporal.internal.common.ProtobufTimeUtils;
import io.temporal.internal.common.RpcRetryOptions;
+import io.temporal.internal.context.ContextActivityThreadLocal;
+import io.temporal.internal.context.ContextPropagatorUtils;
import io.temporal.internal.logging.LoggerTag;
import io.temporal.internal.metrics.MetricsType;
import io.temporal.internal.replay.FailureWrapperException;
import io.temporal.internal.worker.ActivityTaskHandler.Result;
import io.temporal.serviceclient.MetricsTag;
import io.temporal.serviceclient.WorkflowServiceStubs;
-import java.util.HashMap;
-import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.slf4j.MDC;
@@ -185,7 +184,13 @@ public void handle(PollActivityTaskQueueResponse task) throws Exception {
MDC.put(LoggerTag.WORKFLOW_ID, task.getWorkflowExecution().getWorkflowId());
MDC.put(LoggerTag.RUN_ID, task.getWorkflowExecution().getRunId());
- propagateContext(task);
+ ContextActivityThreadLocal contextActivityThreadLocal =
+ ContextActivityThreadLocal.getInstance();
+ contextActivityThreadLocal.setContextPropagators(options.getContextPropagators());
+ contextActivityThreadLocal.propagateContextToCurrentThread(
+ ContextPropagatorUtils.extractContextsFromHeaders(
+ options.getContextPropagators(), task.getHeader()));
+ contextActivityThreadLocal.setUpContextPropagators();
try {
Stopwatch sw = metricsScope.timer(MetricsType.ACTIVITY_EXEC_LATENCY).start();
@@ -215,7 +220,10 @@ public void handle(PollActivityTaskQueueResponse task) throws Exception {
new Result(task.getActivityId(), null, null, canceledRequest.build(), null),
metricsScope);
}
+ contextActivityThreadLocal.onErrorContextPropagators(
+ FailureConverter.failureToException(failure, options.getDataConverter()));
} finally {
+ contextActivityThreadLocal.finishContextPropagators();
MDC.remove(LoggerTag.ACTIVITY_ID);
MDC.remove(LoggerTag.ACTIVITY_TYPE);
MDC.remove(LoggerTag.WORKFLOW_ID);
@@ -223,23 +231,6 @@ public void handle(PollActivityTaskQueueResponse task) throws Exception {
}
}
- void propagateContext(PollActivityTaskQueueResponse response) {
- if (options.getContextPropagators() == null || options.getContextPropagators().isEmpty()) {
- return;
- }
-
- if (!response.hasHeader()) {
- return;
- }
- Map headerData = new HashMap<>();
- for (Map.Entry entry : response.getHeader().getFieldsMap().entrySet()) {
- headerData.put(entry.getKey(), entry.getValue());
- }
- for (ContextPropagator propagator : options.getContextPropagators()) {
- propagator.setCurrentContext(propagator.deserializeContext(headerData));
- }
- }
-
@Override
public Throwable wrapFailure(PollActivityTaskQueueResponse task, Throwable failure) {
WorkflowExecution execution = task.getWorkflowExecution();
diff --git a/temporal-sdk/src/test/java/io/temporal/common/context/ContextTest.java b/temporal-sdk/src/test/java/io/temporal/common/context/ContextTest.java
new file mode 100644
index 000000000..03af0ac48
--- /dev/null
+++ b/temporal-sdk/src/test/java/io/temporal/common/context/ContextTest.java
@@ -0,0 +1,503 @@
+/*
+ * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License is
+ * located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package io.temporal.common.context;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.fail;
+
+import io.opentracing.Scope;
+import io.opentracing.Span;
+import io.opentracing.Tracer;
+import io.opentracing.mock.MockSpan;
+import io.opentracing.mock.MockTracer;
+import io.opentracing.util.GlobalTracer;
+import io.opentracing.util.ThreadLocalScopeManager;
+import io.temporal.activity.ActivityOptions;
+import io.temporal.api.common.v1.Payload;
+import io.temporal.client.WorkflowClient;
+import io.temporal.client.WorkflowClientOptions;
+import io.temporal.client.WorkflowOptions;
+import io.temporal.common.converter.DataConverter;
+import io.temporal.failure.ApplicationFailure;
+import io.temporal.internal.testing.WorkflowTestingTest;
+import io.temporal.testing.TestEnvironmentOptions;
+import io.temporal.testing.TestWorkflowEnvironment;
+import io.temporal.worker.Worker;
+import io.temporal.workflow.Async;
+import io.temporal.workflow.ChildWorkflowOptions;
+import io.temporal.workflow.Promise;
+import io.temporal.workflow.Workflow;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.MDC;
+
+public class ContextTest {
+
+ private static final String TASK_QUEUE = "test-workflow";
+
+ private TestWorkflowEnvironment testEnvironment;
+ private MockTracer mockTracer =
+ new MockTracer(new ThreadLocalScopeManager(), MockTracer.Propagator.TEXT_MAP);
+
+ @Before
+ public void setUp() {
+ TestEnvironmentOptions options =
+ TestEnvironmentOptions.newBuilder()
+ .setWorkflowClientOptions(
+ WorkflowClientOptions.newBuilder()
+ .setContextPropagators(
+ Arrays.asList(
+ new TestContextPropagator(), new OpenTracingContextPropagator()))
+ .build())
+ .build();
+ testEnvironment = TestWorkflowEnvironment.newInstance(options);
+ GlobalTracer.registerIfAbsent(mockTracer);
+ }
+
+ @After
+ public void tearDown() {
+ testEnvironment.close();
+ mockTracer.reset();
+ }
+
+ public static class TestContextPropagator implements ContextPropagator {
+
+ @Override
+ public String getName() {
+ return this.getClass().getName();
+ }
+
+ @Override
+ public Map serializeContext(Object context) {
+ String testKey = (String) context;
+ if (testKey != null) {
+ return Collections.singletonMap(
+ "test", DataConverter.getDefaultInstance().toPayload(testKey).get());
+ } else {
+ return Collections.emptyMap();
+ }
+ }
+
+ @Override
+ public Object deserializeContext(Map context) {
+ if (context.containsKey("test")) {
+ return DataConverter.getDefaultInstance()
+ .fromPayload(context.get("test"), String.class, String.class);
+
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public Object getCurrentContext() {
+ return MDC.get("test");
+ }
+
+ @Override
+ public void setCurrentContext(Object context) {
+ MDC.put("test", String.valueOf(context));
+ }
+ }
+
+ public static class ContextPropagationWorkflowImpl implements WorkflowTestingTest.TestWorkflow {
+
+ @Override
+ public String workflow1(String input) {
+ // The test value should be in the MDC
+ return MDC.get("test");
+ }
+ }
+
+ @Test
+ public void testWorkflowContextPropagation() {
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(ContextPropagationWorkflowImpl.class);
+ testEnvironment.start();
+ MDC.put("test", "testing123");
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ String result = workflow.workflow1("input1");
+ assertEquals("testing123", result);
+ }
+
+ public static class ContextPropagationParentWorkflowImpl
+ implements WorkflowTestingTest.ParentWorkflow {
+
+ @Override
+ public String workflow(String input) {
+ // Get the MDC value
+ String mdcValue = MDC.get("test");
+
+ // Fire up a child workflow
+ ChildWorkflowOptions options =
+ ChildWorkflowOptions.newBuilder()
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.ChildWorkflow child =
+ Workflow.newChildWorkflowStub(WorkflowTestingTest.ChildWorkflow.class, options);
+
+ String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId());
+ return result;
+ }
+
+ @Override
+ public void signal(String value) {}
+ }
+
+ public static class ContextPropagationChildWorkflowImpl
+ implements WorkflowTestingTest.ChildWorkflow {
+
+ @Override
+ public String workflow(String input, String parentId) {
+ String mdcValue = MDC.get("test");
+ return input + mdcValue;
+ }
+ }
+
+ @Test
+ public void testChildWorkflowContextPropagation() {
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(
+ ContextPropagationParentWorkflowImpl.class, ContextPropagationChildWorkflowImpl.class);
+ testEnvironment.start();
+ MDC.put("test", "testing123");
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.ParentWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.ParentWorkflow.class, options);
+ String result = workflow.workflow("input1");
+ assertEquals("testing123testing123", result);
+ }
+
+ public static class ContextPropagationThreadWorkflowImpl
+ implements WorkflowTestingTest.TestWorkflow {
+
+ @Override
+ public String workflow1(String input) {
+ Promise asyncPromise = Async.function(this::async);
+ return asyncPromise.get();
+ }
+
+ private String async() {
+ return "async" + MDC.get("test");
+ }
+ }
+
+ @Test
+ public void testThreadContextPropagation() {
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(ContextPropagationThreadWorkflowImpl.class);
+ testEnvironment.start();
+ MDC.put("test", "testing123");
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .setTaskQueue(TASK_QUEUE)
+ .build();
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ String result = workflow.workflow1("input1");
+ assertEquals("asynctesting123", result);
+ }
+
+ public static class ContextActivityImpl implements WorkflowTestingTest.TestActivity {
+ @Override
+ public String activity1(String input) {
+ return "activity" + MDC.get("test");
+ }
+ }
+
+ public static class ContextPropagationActivityWorkflowImpl
+ implements WorkflowTestingTest.TestWorkflow {
+ @Override
+ public String workflow1(String input) {
+ ActivityOptions options =
+ ActivityOptions.newBuilder()
+ .setScheduleToCloseTimeout(Duration.ofSeconds(5))
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.TestActivity activity =
+ Workflow.newActivityStub(
+ WorkflowTestingTest.TestActivity.class,
+ ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofHours(1)).build());
+
+ return activity.activity1("foo");
+ }
+ }
+
+ @Test
+ public void testActivityContextPropagation() {
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(ContextPropagationActivityWorkflowImpl.class);
+ worker.registerActivitiesImplementations(new ContextActivityImpl());
+ testEnvironment.start();
+ MDC.put("test", "testing123");
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ String result = workflow.workflow1("input1");
+ assertEquals("activitytesting123", result);
+ }
+
+ public static class DefaultContextPropagationActivityWorkflowImpl
+ implements WorkflowTestingTest.TestWorkflow {
+ @Override
+ public String workflow1(String input) {
+ ActivityOptions options =
+ ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build();
+ WorkflowTestingTest.TestActivity activity =
+ Workflow.newActivityStub(WorkflowTestingTest.TestActivity.class, options);
+ return activity.activity1("foo");
+ }
+ }
+
+ @Test
+ public void testDefaultActivityContextPropagation() {
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(DefaultContextPropagationActivityWorkflowImpl.class);
+ worker.registerActivitiesImplementations(new ContextActivityImpl());
+ testEnvironment.start();
+ MDC.put("test", "testing123");
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ String result = workflow.workflow1("input1");
+ assertEquals("activitytesting123", result);
+ }
+
+ public static class DefaultContextPropagationParentWorkflowImpl
+ implements WorkflowTestingTest.ParentWorkflow {
+
+ @Override
+ public String workflow(String input) {
+ // Get the MDC value
+ String mdcValue = MDC.get("test");
+
+ // Fire up a child workflow
+ ChildWorkflowOptions options = ChildWorkflowOptions.newBuilder().build();
+ WorkflowTestingTest.ChildWorkflow child =
+ Workflow.newChildWorkflowStub(WorkflowTestingTest.ChildWorkflow.class, options);
+
+ String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId());
+ return result;
+ }
+
+ @Override
+ public void signal(String value) {}
+ }
+
+ @Test
+ public void testDefaultChildWorkflowContextPropagation() {
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(
+ DefaultContextPropagationParentWorkflowImpl.class,
+ ContextPropagationChildWorkflowImpl.class);
+ testEnvironment.start();
+ MDC.put("test", "testing123");
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
+ .build();
+ WorkflowTestingTest.ParentWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.ParentWorkflow.class, options);
+ String result = workflow.workflow("input1");
+ assertEquals("testing123testing123", result);
+ }
+
+ public static class OpenTracingContextPropagationWorkflowImpl
+ implements WorkflowTestingTest.TestWorkflow {
+ @Override
+ public String workflow1(String input) {
+ Tracer tracer = GlobalTracer.get();
+ Span activeSpan = tracer.scopeManager().activeSpan();
+ MockSpan mockSpan = (MockSpan) activeSpan;
+ assertNotNull(activeSpan);
+ assertEquals("TestWorkflow", mockSpan.tags().get("resource.name"));
+ assertNotEquals(0, mockSpan.parentId());
+ if ("fail".equals(input)) {
+ throw ApplicationFailure.newFailure("fail", "fail");
+ } else {
+ return activeSpan.getBaggageItem("foo");
+ }
+ }
+ }
+
+ public static class OpenTracingContextPropagationWithActivityWorkflowImpl
+ implements WorkflowTestingTest.TestWorkflow {
+ @Override
+ public String workflow1(String input) {
+ ActivityOptions options =
+ ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build();
+ WorkflowTestingTest.TestActivity activity =
+ Workflow.newActivityStub(WorkflowTestingTest.TestActivity.class, options);
+ return activity.activity1(input);
+ }
+ }
+
+ public static class OpenTracingContextPropagationActivityImpl
+ implements WorkflowTestingTest.TestActivity {
+
+ @Override
+ public String activity1(String input) {
+ Tracer tracer = GlobalTracer.get();
+ Span activeSpan = tracer.scopeManager().activeSpan();
+ MockSpan mockSpan = (MockSpan) activeSpan;
+ assertNotNull(activeSpan);
+ assertEquals("Activity1", mockSpan.tags().get("resource.name"));
+ assertNotEquals(0, mockSpan.parentId());
+ if ("fail".equals(input)) {
+ throw ApplicationFailure.newFailure("fail", "fail");
+ } else {
+ return activeSpan.getBaggageItem("foo");
+ }
+ }
+ }
+
+ @Test
+ public void testOpenTracingContextPropagation() {
+ Tracer tracer = GlobalTracer.get();
+ Span span = tracer.buildSpan("testContextPropagationSuccess").start();
+
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(OpenTracingContextPropagationWorkflowImpl.class);
+ testEnvironment.start();
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(
+ Arrays.asList(new TestContextPropagator(), new OpenTracingContextPropagator()))
+ .build();
+
+ try (Scope scope = tracer.scopeManager().activate(span)) {
+
+ Span activeSpan = tracer.scopeManager().activeSpan();
+ activeSpan.setBaggageItem("foo", "bar");
+
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ assertEquals("bar", workflow.workflow1("input1"));
+
+ } finally {
+ span.finish();
+ }
+ }
+
+ @Test
+ public void testOpenTracingContextPropagationWithFailure() {
+ Tracer tracer = GlobalTracer.get();
+ Span span = tracer.buildSpan("testContextPropagationFailure").start();
+
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(OpenTracingContextPropagationWorkflowImpl.class);
+ testEnvironment.start();
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(
+ Arrays.asList(new TestContextPropagator(), new OpenTracingContextPropagator()))
+ .build();
+
+ try (Scope scope = tracer.scopeManager().activate(span)) {
+
+ Span activeSpan = tracer.scopeManager().activeSpan();
+ activeSpan.setBaggageItem("foo", "bar");
+
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ try {
+ workflow.workflow1("fail");
+ fail("Unreachable");
+ } catch (ApplicationFailure e) {
+ // Expected
+ assertEquals("fail", e.getMessage());
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ } finally {
+ span.finish();
+ }
+ }
+
+ @Test
+ public void testOpenTracingContextPropagationToActivity() {
+ Tracer tracer = GlobalTracer.get();
+ Span span = tracer.buildSpan("testContextPropagationWithActivity").start();
+
+ Worker worker = testEnvironment.newWorker(TASK_QUEUE);
+ worker.registerWorkflowImplementationTypes(
+ OpenTracingContextPropagationWithActivityWorkflowImpl.class);
+ worker.registerActivitiesImplementations(new OpenTracingContextPropagationActivityImpl());
+ testEnvironment.start();
+ WorkflowClient client = testEnvironment.getWorkflowClient();
+ WorkflowOptions options =
+ WorkflowOptions.newBuilder()
+ .setTaskQueue(TASK_QUEUE)
+ .setContextPropagators(
+ Arrays.asList(new TestContextPropagator(), new OpenTracingContextPropagator()))
+ .build();
+
+ try (Scope scope = tracer.scopeManager().activate(span)) {
+
+ Span activeSpan = tracer.scopeManager().activeSpan();
+ activeSpan.setBaggageItem("foo", "bar");
+
+ WorkflowTestingTest.TestWorkflow workflow =
+ client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options);
+ assertEquals("bar", workflow.workflow1("input1"));
+
+ } finally {
+ span.finish();
+ }
+ }
+}
diff --git a/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java b/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java
index 35c793aee..e583990e8 100644
--- a/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java
+++ b/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java
@@ -27,7 +27,6 @@
import io.temporal.activity.Activity;
import io.temporal.activity.ActivityInterface;
import io.temporal.activity.ActivityOptions;
-import io.temporal.api.common.v1.Payload;
import io.temporal.api.common.v1.WorkflowExecution;
import io.temporal.api.enums.v1.TimeoutType;
import io.temporal.api.workflow.v1.WorkflowExecutionInfo;
@@ -42,8 +41,6 @@
import io.temporal.client.WorkflowOptions;
import io.temporal.client.WorkflowStub;
import io.temporal.common.RetryOptions;
-import io.temporal.common.context.ContextPropagator;
-import io.temporal.common.converter.DataConverter;
import io.temporal.failure.ActivityFailure;
import io.temporal.failure.ApplicationFailure;
import io.temporal.failure.CanceledFailure;
@@ -53,16 +50,13 @@
import io.temporal.testing.TestWorkflowEnvironment;
import io.temporal.worker.Worker;
import io.temporal.workflow.Async;
-import io.temporal.workflow.ChildWorkflowOptions;
import io.temporal.workflow.Promise;
import io.temporal.workflow.SignalMethod;
import io.temporal.workflow.Workflow;
import io.temporal.workflow.WorkflowInterface;
import io.temporal.workflow.WorkflowMethod;
import java.time.Duration;
-import java.util.Collections;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
@@ -76,7 +70,6 @@
import org.junit.runner.Description;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.slf4j.MDC;
public class WorkflowTestingTest {
private static final Logger log = LoggerFactory.getLogger(WorkflowTestingTest.class);
@@ -98,10 +91,7 @@ protected void failed(Throwable e, Description description) {
public void setUp() {
TestEnvironmentOptions options =
TestEnvironmentOptions.newBuilder()
- .setWorkflowClientOptions(
- WorkflowClientOptions.newBuilder()
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build())
+ .setWorkflowClientOptions(WorkflowClientOptions.newBuilder().build())
.build();
testEnvironment = TestWorkflowEnvironment.newInstance(options);
}
@@ -874,257 +864,4 @@ public void testMockedChildSimulatedTimeout() {
assertTrue(e.getCause().getCause() instanceof TimeoutFailure);
}
}
-
- public static class TestContextPropagator implements ContextPropagator {
-
- @Override
- public String getName() {
- return this.getClass().getName();
- }
-
- @Override
- public Map serializeContext(Object context) {
- String testKey = (String) context;
- if (testKey != null) {
- return Collections.singletonMap(
- "test", DataConverter.getDefaultInstance().toPayload(testKey).get());
- } else {
- return Collections.emptyMap();
- }
- }
-
- @Override
- public Object deserializeContext(Map context) {
- if (context.containsKey("test")) {
- return DataConverter.getDefaultInstance()
- .fromPayload(context.get("test"), String.class, String.class);
-
- } else {
- return null;
- }
- }
-
- @Override
- public Object getCurrentContext() {
- return MDC.get("test");
- }
-
- @Override
- public void setCurrentContext(Object context) {
- MDC.put("test", String.valueOf(context));
- }
- }
-
- public static class ContextPropagationWorkflowImpl implements TestWorkflow {
-
- @Override
- public String workflow1(String input) {
- // The test value should be in the MDC
- return MDC.get("test");
- }
- }
-
- @Test
- public void testWorkflowContextPropagation() {
- Worker worker = testEnvironment.newWorker(TASK_QUEUE);
- worker.registerWorkflowImplementationTypes(ContextPropagationWorkflowImpl.class);
- testEnvironment.start();
- MDC.put("test", "testing123");
- WorkflowClient client = testEnvironment.getWorkflowClient();
- WorkflowOptions options =
- WorkflowOptions.newBuilder()
- .setTaskQueue(TASK_QUEUE)
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options);
- String result = workflow.workflow1("input1");
- assertEquals("testing123", result);
- }
-
- public static class ContextPropagationParentWorkflowImpl implements ParentWorkflow {
-
- @Override
- public String workflow(String input) {
- // Get the MDC value
- String mdcValue = MDC.get("test");
-
- // Fire up a child workflow
- ChildWorkflowOptions options =
- ChildWorkflowOptions.newBuilder()
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options);
-
- String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId());
- return result;
- }
-
- @Override
- public void signal(String value) {}
- }
-
- public static class ContextPropagationChildWorkflowImpl implements ChildWorkflow {
-
- @Override
- public String workflow(String input, String parentId) {
- String mdcValue = MDC.get("test");
- return input + mdcValue;
- }
- }
-
- @Test
- public void testChildWorkflowContextPropagation() {
- Worker worker = testEnvironment.newWorker(TASK_QUEUE);
- worker.registerWorkflowImplementationTypes(
- ContextPropagationParentWorkflowImpl.class, ContextPropagationChildWorkflowImpl.class);
- testEnvironment.start();
- MDC.put("test", "testing123");
- WorkflowClient client = testEnvironment.getWorkflowClient();
- WorkflowOptions options =
- WorkflowOptions.newBuilder()
- .setTaskQueue(TASK_QUEUE)
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options);
- String result = workflow.workflow("input1");
- assertEquals("testing123testing123", result);
- }
-
- public static class ContextPropagationThreadWorkflowImpl implements TestWorkflow {
-
- @Override
- public String workflow1(String input) {
- Promise asyncPromise = Async.function(this::async);
- return asyncPromise.get();
- }
-
- private String async() {
- return "async" + MDC.get("test");
- }
- }
-
- @Test
- public void testThreadContextPropagation() {
- Worker worker = testEnvironment.newWorker(TASK_QUEUE);
- worker.registerWorkflowImplementationTypes(ContextPropagationThreadWorkflowImpl.class);
- testEnvironment.start();
- MDC.put("test", "testing123");
- WorkflowClient client = testEnvironment.getWorkflowClient();
- WorkflowOptions options =
- WorkflowOptions.newBuilder()
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .setTaskQueue(TASK_QUEUE)
- .build();
- TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options);
- String result = workflow.workflow1("input1");
- assertEquals("asynctesting123", result);
- }
-
- public static class ContextActivityImpl implements TestActivity {
- @Override
- public String activity1(String input) {
- return "activity" + MDC.get("test");
- }
- }
-
- public static class ContextPropagationActivityWorkflowImpl implements TestWorkflow {
- @Override
- public String workflow1(String input) {
- ActivityOptions options =
- ActivityOptions.newBuilder()
- .setScheduleToCloseTimeout(Duration.ofSeconds(5))
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- TestActivity activity =
- Workflow.newActivityStub(
- TestActivity.class,
- ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofHours(1)).build());
-
- return activity.activity1("foo");
- }
- }
-
- @Test
- public void testActivityContextPropagation() {
- Worker worker = testEnvironment.newWorker(TASK_QUEUE);
- worker.registerWorkflowImplementationTypes(ContextPropagationActivityWorkflowImpl.class);
- worker.registerActivitiesImplementations(new ContextActivityImpl());
- testEnvironment.start();
- MDC.put("test", "testing123");
- WorkflowClient client = testEnvironment.getWorkflowClient();
- WorkflowOptions options =
- WorkflowOptions.newBuilder()
- .setTaskQueue(TASK_QUEUE)
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options);
- String result = workflow.workflow1("input1");
- assertEquals("activitytesting123", result);
- }
-
- public static class DefaultContextPropagationActivityWorkflowImpl implements TestWorkflow {
- @Override
- public String workflow1(String input) {
- ActivityOptions options =
- ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build();
- TestActivity activity = Workflow.newActivityStub(TestActivity.class, options);
- return activity.activity1("foo");
- }
- }
-
- @Test
- public void testDefaultActivityContextPropagation() {
- Worker worker = testEnvironment.newWorker(TASK_QUEUE);
- worker.registerWorkflowImplementationTypes(DefaultContextPropagationActivityWorkflowImpl.class);
- worker.registerActivitiesImplementations(new ContextActivityImpl());
- testEnvironment.start();
- MDC.put("test", "testing123");
- WorkflowClient client = testEnvironment.getWorkflowClient();
- WorkflowOptions options =
- WorkflowOptions.newBuilder()
- .setTaskQueue(TASK_QUEUE)
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options);
- String result = workflow.workflow1("input1");
- assertEquals("activitytesting123", result);
- }
-
- public static class DefaultContextPropagationParentWorkflowImpl implements ParentWorkflow {
-
- @Override
- public String workflow(String input) {
- // Get the MDC value
- String mdcValue = MDC.get("test");
-
- // Fire up a child workflow
- ChildWorkflowOptions options = ChildWorkflowOptions.newBuilder().build();
- ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options);
-
- String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId());
- return result;
- }
-
- @Override
- public void signal(String value) {}
- }
-
- @Test
- public void testDefaultChildWorkflowContextPropagation() {
- Worker worker = testEnvironment.newWorker(TASK_QUEUE);
- worker.registerWorkflowImplementationTypes(
- DefaultContextPropagationParentWorkflowImpl.class,
- ContextPropagationChildWorkflowImpl.class);
- testEnvironment.start();
- MDC.put("test", "testing123");
- WorkflowClient client = testEnvironment.getWorkflowClient();
- WorkflowOptions options =
- WorkflowOptions.newBuilder()
- .setTaskQueue(TASK_QUEUE)
- .setContextPropagators(Collections.singletonList(new TestContextPropagator()))
- .build();
- ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options);
- String result = workflow.workflow("input1");
- assertEquals("testing123testing123", result);
- }
}