diff --git a/temporal-sdk/build.gradle b/temporal-sdk/build.gradle index aef5da968..df71ca34f 100644 --- a/temporal-sdk/build.gradle +++ b/temporal-sdk/build.gradle @@ -42,11 +42,14 @@ dependencies { api project(':temporal-serviceclient') api group: 'com.google.code.gson', name: 'gson', version: '2.8.6' api group: 'io.micrometer', name: 'micrometer-core', version: '1.6.0' + api group: 'io.opentracing', name: 'opentracing-api', version: '0.33.0' implementation group: 'com.google.guava', name: 'guava', version: '30.0-jre' implementation group: 'com.cronutils', name: 'cron-utils', version: '9.1.1' implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.11.3' implementation group: 'com.fasterxml.jackson.datatype', name: 'jackson-datatype-jsr310', version: '2.11.3' + implementation group: 'io.opentracing', name: 'opentracing-util', version: '0.33.0' + if (!JavaVersion.current().isJava8()) { implementation 'javax.annotation:javax.annotation-api:1.3.2' } @@ -54,6 +57,7 @@ dependencies { testImplementation group: 'ch.qos.logback', name: 'logback-classic', version: '1.2.3' testImplementation group: 'com.googlecode.junit-toolbox', name: 'junit-toolbox', version: '2.4' testImplementation group: 'junit', name: 'junit', version: '4.13.1' + testImplementation group: 'io.opentracing', name: 'opentracing-mock', version: '0.33.0' } configurations.all { diff --git a/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java b/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java index ee618f192..35918849f 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java +++ b/temporal-sdk/src/main/java/io/temporal/common/context/ContextPropagator.java @@ -23,8 +23,12 @@ import java.util.Map; /** - * Context Propagators are used to propagate information from workflow to activity, workflow to - * child workflow, and workflow to child thread (using {@link io.temporal.workflow.Async}). + * Context Propagators are used to propagate information from workflow stub to workflow, workflow to + * activity, workflow to child workflow, and workflow to child thread (using {@link + * io.temporal.workflow.Async}). + * + *

It is important to note that all threads share one ContextPropagator instance, so your + * implementation must be thread-safe and store any state in ThreadLocal variables. * *

A sample ContextPropagator that copies all {@link org.slf4j.MDC} entries starting * with a given prefix along the code path looks like this: @@ -126,4 +130,27 @@ public interface ContextPropagator { /** Sets the current context */ void setCurrentContext(Object context); + + /** + * This is a lifecycle method, called after the context has been propagated to the + * workflow/activity thread but the workflow/activity has not yet started. + */ + default void setUp() { + // No-op + } + + /** This is a lifecycle method, called after the workflow/activity has completed. */ + default void finish() { + // No-op + } + + /** + * This is a lifecycle method, called when the workflow/activity finishes by throwing an unhandled + * exception. {@link #finish()} is called after this method. + * + * @param t The unhandled exception that caused the workflow/activity to terminate + */ + default void onError(Throwable t) { + // No-op + } } diff --git a/temporal-sdk/src/main/java/io/temporal/common/context/OpenTracingContextPropagator.java b/temporal-sdk/src/main/java/io/temporal/common/context/OpenTracingContextPropagator.java new file mode 100644 index 000000000..8d4b155c8 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/common/context/OpenTracingContextPropagator.java @@ -0,0 +1,207 @@ +/* + * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package io.temporal.common.context; + +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.SpanContext; +import io.opentracing.Tracer; +import io.opentracing.log.Fields; +import io.opentracing.propagation.Format; +import io.opentracing.propagation.TextMap; +import io.opentracing.tag.Tags; +import io.opentracing.util.GlobalTracer; +import io.temporal.api.common.v1.Payload; +import io.temporal.common.converter.DataConverter; +import io.temporal.internal.logging.LoggerTag; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +/** Support for OpenTracing spans */ +public class OpenTracingContextPropagator implements ContextPropagator { + + private static final Logger log = LoggerFactory.getLogger(OpenTracingContextPropagator.class); + + private static final ThreadLocal currentOpenTracingSpanContext = new ThreadLocal<>(); + private static final ThreadLocal currentOpenTracingSpan = new ThreadLocal<>(); + private static final ThreadLocal currentOpenTracingScope = new ThreadLocal<>(); + + public static SpanContext getCurrentOpenTracingSpanContext() { + return currentOpenTracingSpanContext.get(); + } + + public static void setCurrentOpenTracingSpanContext(SpanContext ctx) { + if (ctx != null) { + currentOpenTracingSpanContext.set(ctx); + } + } + + @Override + public String getName() { + return "OpenTracing"; + } + + @Override + public Map serializeContext(Object context) { + Map serializedContext = new HashMap<>(); + Map contextMap = (Map) context; + if (contextMap != null) { + for (Map.Entry entry : contextMap.entrySet()) { + serializedContext.put( + entry.getKey(), DataConverter.getDefaultInstance().toPayload(entry.getValue()).get()); + } + } + return serializedContext; + } + + @Override + public Object deserializeContext(Map context) { + Map contextMap = new HashMap<>(); + for (Map.Entry entry : context.entrySet()) { + contextMap.put( + entry.getKey(), + DataConverter.getDefaultInstance() + .fromPayload(entry.getValue(), String.class, String.class)); + } + return contextMap; + } + + @Override + public Object getCurrentContext() { + Tracer currentTracer = GlobalTracer.get(); + Span currentSpan = currentTracer.scopeManager().activeSpan(); + if (currentSpan != null) { + HashMapTextMap contextTextMap = new HashMapTextMap(); + currentTracer.inject(currentSpan.context(), Format.Builtin.TEXT_MAP, contextTextMap); + return contextTextMap.getBackingMap(); + } else { + return null; + } + } + + @Override + public void setCurrentContext(Object context) { + Tracer currentTracer = GlobalTracer.get(); + Map contextAsMap = (Map) context; + if (contextAsMap != null) { + HashMapTextMap contextTextMap = new HashMapTextMap(contextAsMap); + setCurrentOpenTracingSpanContext( + currentTracer.extract(Format.Builtin.TEXT_MAP, contextTextMap)); + } + } + + @Override + public void setUp() { + Tracer openTracingTracer = GlobalTracer.get(); + Tracer.SpanBuilder builder = openTracingTracer.buildSpan("cadence.workflow"); + + if (MDC.getCopyOfContextMap().containsKey(LoggerTag.WORKFLOW_TYPE)) { + builder.withTag("resource.name", MDC.get(LoggerTag.WORKFLOW_TYPE)); + } else { + builder.withTag("resource.name", MDC.get(LoggerTag.ACTIVITY_TYPE)); + } + + if (getCurrentOpenTracingSpanContext() != null) { + builder.asChildOf(getCurrentOpenTracingSpanContext()); + } + Span span = builder.start(); + openTracingTracer.activateSpan(span); + currentOpenTracingSpan.set(span); + Scope scope = openTracingTracer.activateSpan(span); + currentOpenTracingScope.set(scope); + } + + @Override + public void onError(Throwable t) { + Span span = currentOpenTracingSpan.get(); + if (span != null) { + Tags.ERROR.set(span, true); + Map errorData = new HashMap<>(); + errorData.put(Fields.EVENT, "error"); + if (t != null) { + errorData.put(Fields.ERROR_OBJECT, t); + errorData.put(Fields.MESSAGE, t.getMessage()); + } + span.log(errorData); + } + } + + @Override + public void finish() { + Scope currentScope = currentOpenTracingScope.get(); + Span currentSpan = currentOpenTracingSpan.get(); + if (currentScope != null) { + currentScope.close(); + } + if (currentSpan != null) { + currentSpan.finish(); + } + currentOpenTracingScope.remove(); + currentOpenTracingSpan.remove(); + currentOpenTracingSpanContext.remove(); + } + + /** Just check for other instances of the same class */ + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (this == obj) { + return true; + } + return this.getClass().equals(obj.getClass()); + } + + @Override + public int hashCode() { + return this.getClass().hashCode(); + } + + private class HashMapTextMap implements TextMap { + private final HashMap backingMap = new HashMap<>(); + + public HashMapTextMap() { + // Noop + } + + public HashMapTextMap(Map spanData) { + backingMap.putAll(spanData); + } + + @Override + public Iterator> iterator() { + return backingMap.entrySet().iterator(); + } + + @Override + public void put(String key, String value) { + backingMap.put(key, value); + } + + public HashMap getBackingMap() { + return backingMap; + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/AbstractContextThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/AbstractContextThreadLocal.java new file mode 100644 index 000000000..ea948b2c1 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/context/AbstractContextThreadLocal.java @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package io.temporal.internal.context; + +import io.temporal.common.context.ContextPropagator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** This class holds the current set of context propagators */ +public abstract class AbstractContextThreadLocal { + + private static final Logger log = LoggerFactory.getLogger(AbstractContextThreadLocal.class); + + /** + * Returns the context propagators for the current thread + * + * @return + */ + protected abstract List getPropagatorsForThread(); + + /** Sets the context propagators for this thread */ + public abstract void setContextPropagators(List contextPropagators); + + public List getContextPropagators() { + return getPropagatorsForThread(); + } + + public Map getCurrentContextForPropagation() { + Map contextData = new HashMap<>(); + for (ContextPropagator propagator : getPropagatorsForThread()) { + contextData.put(propagator.getName(), propagator.getCurrentContext()); + } + return contextData; + } + + /** + * Injects the context data into the thread for each configured context propagator + * + * @param contextData The context data received from the server + */ + public void propagateContextToCurrentThread(Map contextData) { + if (contextData == null || contextData.isEmpty()) { + return; + } + for (ContextPropagator propagator : getPropagatorsForThread()) { + if (contextData.containsKey(propagator.getName())) { + propagator.setCurrentContext(contextData.get(propagator.getName())); + } + } + } + + /** Calls {@link ContextPropagator#setUp()} for each propagator */ + public void setUpContextPropagators() { + for (ContextPropagator propagator : getPropagatorsForThread()) { + try { + propagator.setUp(); + } catch (Throwable t) { + // Don't let an error in one propagator block the others + log.error("Error calling setUp() on a contextpropagator", t); + } + } + } + + /** + * Calls {@link ContextPropagator#onError(Throwable)} for each propagator + * + * @param t The Throwable that caused the workflow/activity to finish + */ + public void onErrorContextPropagators(Throwable t) { + for (ContextPropagator propagator : getPropagatorsForThread()) { + try { + propagator.onError(t); + } catch (Throwable t1) { + // Don't let an error in one propagator block the others + log.error("Error calling onError() on a contextpropagator", t1); + } + } + } + + /** Calls {@link ContextPropagator#finish()} for each propagator */ + public void finishContextPropagators() { + for (ContextPropagator propagator : getPropagatorsForThread()) { + try { + propagator.finish(); + } catch (Throwable t) { + // Don't let an error in one propagator block the others + log.error("Error calling finish() on a contextpropagator", t); + } + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextActivityThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextActivityThreadLocal.java new file mode 100644 index 000000000..03b2e6359 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextActivityThreadLocal.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package io.temporal.internal.context; + +import io.temporal.common.context.ContextPropagator; +import java.util.ArrayList; +import java.util.List; + +public class ContextActivityThreadLocal extends AbstractContextThreadLocal { + + private static final ThreadLocal> contextPropagators = + ThreadLocal.withInitial(() -> new ArrayList<>()); + + public static ContextActivityThreadLocal getInstance() { + return new ContextActivityThreadLocal(); + } + + @Override + public List getPropagatorsForThread() { + return contextPropagators.get(); + } + + @Override + public void setContextPropagators(List contextPropagators) { + if (contextPropagators == null || contextPropagators.size() == 0) { + return; + } + + this.contextPropagators.set(contextPropagators); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextPropagatorUtils.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextPropagatorUtils.java new file mode 100644 index 000000000..777ed8007 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextPropagatorUtils.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package io.temporal.internal.context; + +import io.temporal.api.common.v1.Header; +import io.temporal.api.common.v1.Payload; +import io.temporal.common.context.ContextPropagator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** Common methods for dealing with context */ +public class ContextPropagatorUtils { + + public static Map extractContextsFromHeaders( + List contextPropagators, Header headers) { + + if (contextPropagators == null || contextPropagators.isEmpty()) { + return new HashMap<>(); + } + + if (headers == null) { + return new HashMap<>(); + } + + Map headerData = new HashMap<>(); + for (Map.Entry pair : headers.getFieldsMap().entrySet()) { + headerData.put(pair.getKey(), pair.getValue()); + } + + Map contextData = new HashMap<>(); + for (ContextPropagator propagator : contextPropagators) { + + // Only send the context propagator the fields that belong to them + // Change the map from MyPropagator:foo -> bar to foo -> bar + Map filteredData = + headerData.entrySet().stream() + .filter(e -> e.getKey().startsWith(propagator.getName())) + .collect( + Collectors.toMap( + e -> e.getKey().substring(propagator.getName().length() + 1), + Map.Entry::getValue)); + contextData.put(propagator.getName(), propagator.deserializeContext(filteredData)); + } + + return contextData; + } + + public static Map extractContextsAndConvertToBytes( + List contextPropagators) { + if (contextPropagators == null) { + return null; + } + Map result = new HashMap<>(); + for (ContextPropagator propagator : contextPropagators) { + // Get the serialized context from the propagator + Map serializedContext = + propagator.serializeContext(propagator.getCurrentContext()); + // Namespace each entry in case of overlaps, so foo -> bar becomes MyPropagator:foo -> bar + Map namespacedSerializedContext = + serializedContext.entrySet().stream() + .collect( + Collectors.toMap( + e -> propagator.getName() + ":" + e.getKey(), Map.Entry::getValue)); + result.putAll(namespacedSerializedContext); + } + return result; + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java deleted file mode 100644 index 39044b773..000000000 --- a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. - * - * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Modifications copyright (C) 2017 Uber Technologies, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not - * use this file except in compliance with the License. A copy of the License is - * located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package io.temporal.internal.context; - -import io.temporal.common.context.ContextPropagator; -import io.temporal.workflow.WorkflowThreadLocal; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; - -/** This class holds the current set of context propagators */ -public class ContextThreadLocal { - - private static final WorkflowThreadLocal> contextPropagators = - WorkflowThreadLocal.withInitial( - new Supplier>() { - @Override - public List get() { - return new ArrayList<>(); - } - }); - - /** Sets the list of context propagators for the thread */ - public static void setContextPropagators(List propagators) { - if (propagators == null || propagators.isEmpty()) { - return; - } - contextPropagators.set(propagators); - } - - public static List getContextPropagators() { - return contextPropagators.get(); - } - - public static Map getCurrentContextForPropagation() { - Map contextData = new HashMap<>(); - for (ContextPropagator propagator : contextPropagators.get()) { - contextData.put(propagator.getName(), propagator.getCurrentContext()); - } - return contextData; - } - - public static void propagateContextToCurrentThread(Map contextData) { - if (contextData == null || contextData.isEmpty()) { - return; - } - for (ContextPropagator propagator : contextPropagators.get()) { - if (contextData.containsKey(propagator.getName())) { - propagator.setCurrentContext(contextData.get(propagator.getName())); - } - } - } -} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextWorkflowThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextWorkflowThreadLocal.java new file mode 100644 index 000000000..562670614 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextWorkflowThreadLocal.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package io.temporal.internal.context; + +import io.temporal.common.context.ContextPropagator; +import io.temporal.workflow.WorkflowThreadLocal; +import java.util.ArrayList; +import java.util.List; + +public class ContextWorkflowThreadLocal extends AbstractContextThreadLocal { + + private static final WorkflowThreadLocal> contextPropagators = + WorkflowThreadLocal.withInitial(() -> new ArrayList<>()); + + public static ContextWorkflowThreadLocal getInstance() { + return new ContextWorkflowThreadLocal(); + } + + @Override + public List getPropagatorsForThread() { + return contextPropagators.get(); + } + + @Override + public void setContextPropagators(List contextPropagators) { + if (contextPropagators == null || contextPropagators.size() == 0) { + return; + } + + this.contextPropagators.set(contextPropagators); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java b/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java index 335b1a41b..87de9a662 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/replay/WorkflowContext.java @@ -21,7 +21,6 @@ import com.google.protobuf.util.Timestamps; import io.temporal.api.command.v1.ContinueAsNewWorkflowExecutionCommandAttributes; -import io.temporal.api.common.v1.Header; import io.temporal.api.common.v1.Payload; import io.temporal.api.common.v1.SearchAttributes; import io.temporal.api.common.v1.WorkflowExecution; @@ -29,8 +28,8 @@ import io.temporal.api.history.v1.WorkflowExecutionStartedEventAttributes; import io.temporal.common.context.ContextPropagator; import io.temporal.internal.common.ProtobufTimeUtils; +import io.temporal.internal.context.ContextPropagatorUtils; import java.time.Duration; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -161,26 +160,8 @@ public List getContextPropagators() { /** Returns a map of propagated context objects, keyed by propagator name */ Map getPropagatedContexts() { - if (contextPropagators == null || contextPropagators.isEmpty()) { - return new HashMap<>(); - } - - Header headers = startedAttributes.getHeader(); - if (headers == null) { - return new HashMap<>(); - } - - Map headerData = new HashMap<>(); - for (Map.Entry pair : headers.getFieldsMap().entrySet()) { - headerData.put(pair.getKey(), pair.getValue()); - } - - Map contextData = new HashMap<>(); - for (ContextPropagator propagator : contextPropagators) { - contextData.put(propagator.getName(), propagator.deserializeContext(headerData)); - } - - return contextData; + return ContextPropagatorUtils.extractContextsFromHeaders( + contextPropagators, startedAttributes.getHeader()); } void mergeSearchAttributes(SearchAttributes searchAttributes) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java index d485416e9..1e388963d 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java @@ -35,7 +35,7 @@ import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptorBase; import io.temporal.failure.CanceledFailure; import io.temporal.internal.common.CheckedExceptionWrapper; -import io.temporal.internal.context.ContextThreadLocal; +import io.temporal.internal.context.ContextWorkflowThreadLocal; import io.temporal.internal.replay.ExecuteActivityParameters; import io.temporal.internal.replay.ExecuteLocalActivityParameters; import io.temporal.internal.replay.ReplayWorkflowContext; @@ -96,6 +96,8 @@ private NamedRunnable(String name, Runnable runnable) { private static final Logger log = LoggerFactory.getLogger(DeterministicRunnerImpl.class); static final String WORKFLOW_ROOT_THREAD_NAME = "workflow-method"; private static final ThreadLocal currentThreadThreadLocal = new ThreadLocal<>(); + private final ContextWorkflowThreadLocal currentWorkflowThreadLocal = + ContextWorkflowThreadLocal.getInstance(); private final Lock lock = new ReentrantLock(); private final ExecutorService threadPool; @@ -528,7 +530,7 @@ void setRunnerLocal(RunnerLocalInternal key, T value) { */ private Map getPropagatedContexts() { if (currentThreadThreadLocal.get() != null) { - return ContextThreadLocal.getCurrentContextForPropagation(); + return currentWorkflowThreadLocal.getCurrentContextForPropagation(); } else { return workflowContext.getContext().getPropagatedContexts(); } @@ -536,7 +538,7 @@ private Map getPropagatedContexts() { private List getContextPropagators() { if (currentThreadThreadLocal.get() != null) { - return ContextThreadLocal.getContextPropagators(); + return currentWorkflowThreadLocal.getContextPropagators(); } else { return workflowContext.getContext().getContextPropagators(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java index b915ce1e6..64c4c376a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java @@ -32,7 +32,6 @@ import io.temporal.api.common.v1.ActivityType; import io.temporal.api.common.v1.Header; import io.temporal.api.common.v1.Memo; -import io.temporal.api.common.v1.Payload; import io.temporal.api.common.v1.Payloads; import io.temporal.api.common.v1.RetryPolicy; import io.temporal.api.common.v1.SearchAttributes; @@ -55,6 +54,7 @@ import io.temporal.internal.common.InternalUtils; import io.temporal.internal.common.OptionsUtils; import io.temporal.internal.common.ProtobufTimeUtils; +import io.temporal.internal.context.ContextPropagatorUtils; import io.temporal.internal.metrics.MetricsType; import io.temporal.internal.replay.ChildWorkflowTaskFailedException; import io.temporal.internal.replay.ExecuteActivityParameters; @@ -298,7 +298,8 @@ private ExecuteActivityParameters constructExecuteActivityParameters( if (propagators == null) { propagators = this.contextPropagators; } - Header header = toHeaderGrpc(extractContextsAndConvertToBytes(propagators)); + Header header = + toHeaderGrpc(ContextPropagatorUtils.extractContextsAndConvertToBytes(propagators)); if (header != null) { attributes.setHeader(header); } @@ -416,7 +417,8 @@ private Promise> executeChildWorkflow( attributes.setRetryPolicy(toRetryPolicy(retryOptions)); } attributes.setCronSchedule(OptionsUtils.safeGet(options.getCronSchedule())); - Header header = toHeaderGrpc(extractContextsAndConvertToBytes(propagators)); + Header header = + toHeaderGrpc(ContextPropagatorUtils.extractContextsAndConvertToBytes(propagators)); if (header != null) { attributes.setHeader(header); } @@ -459,18 +461,6 @@ private Promise> executeChildWorkflow( return result; } - private Map extractContextsAndConvertToBytes( - List contextPropagators) { - if (contextPropagators == null) { - return null; - } - Map result = new HashMap<>(); - for (ContextPropagator propagator : contextPropagators) { - result.putAll(propagator.serializeContext(propagator.getCurrentContext())); - } - return result; - } - private RuntimeException mapChildWorkflowException(Exception failure) { if (failure == null) { return null; diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java index 69128aeea..5e5632f81 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowStubImpl.java @@ -54,7 +54,6 @@ import io.temporal.client.WorkflowServiceException; import io.temporal.client.WorkflowStub; import io.temporal.common.RetryOptions; -import io.temporal.common.context.ContextPropagator; import io.temporal.failure.CanceledFailure; import io.temporal.failure.FailureConverter; import io.temporal.internal.common.CheckedExceptionWrapper; @@ -63,10 +62,9 @@ import io.temporal.internal.common.StatusUtils; import io.temporal.internal.common.WorkflowExecutionFailedException; import io.temporal.internal.common.WorkflowExecutionUtils; +import io.temporal.internal.context.ContextPropagatorUtils; import io.temporal.internal.external.GenericWorkflowClientExternal; import java.lang.reflect.Type; -import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -234,7 +232,8 @@ private StartWorkflowExecutionRequest newStartWorkflowExecutionRequest( .putAllIndexedFields(convertMemoFromObjectToBytes(o.getSearchAttributes()))); } if (o.getContextPropagators() != null && !o.getContextPropagators().isEmpty()) { - Map context = extractContextsAndConvertToBytes(o.getContextPropagators()); + Map context = + ContextPropagatorUtils.extractContextsAndConvertToBytes(o.getContextPropagators()); request.setHeader(Header.newBuilder().putAllFields(context)); } return request.build(); @@ -248,18 +247,6 @@ private Map convertSearchAttributesFromObjectToBytes(Map extractContextsAndConvertToBytes( - List contextPropagators) { - if (contextPropagators == null) { - return null; - } - Map result = new HashMap<>(); - for (ContextPropagator propagator : contextPropagators) { - result.putAll(propagator.serializeContext(propagator.getCurrentContext())); - } - return result; - } - @Override public WorkflowExecution start(Object... args) { if (!options.isPresent()) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java index 89a097023..a5f983dca 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java @@ -22,7 +22,7 @@ import com.google.common.util.concurrent.RateLimiter; import io.temporal.common.context.ContextPropagator; import io.temporal.failure.CanceledFailure; -import io.temporal.internal.context.ContextThreadLocal; +import io.temporal.internal.context.ContextWorkflowThreadLocal; import io.temporal.internal.logging.LoggerTag; import io.temporal.internal.metrics.MetricsType; import io.temporal.internal.replay.ReplayWorkflowContext; @@ -98,8 +98,12 @@ public void run() { MDC.put(LoggerTag.NAMESPACE, replayWorkflowContext.getNamespace()); // Repopulate the context(s) - ContextThreadLocal.setContextPropagators(this.contextPropagators); - ContextThreadLocal.propagateContextToCurrentThread(this.propagatedContexts); + ContextWorkflowThreadLocal contextWorkflowThreadLocal = + ContextWorkflowThreadLocal.getInstance(); + contextWorkflowThreadLocal.setContextPropagators(this.contextPropagators); + contextWorkflowThreadLocal.propagateContextToCurrentThread(this.propagatedContexts); + contextWorkflowThreadLocal.setUpContextPropagators(); + try { // initialYield blocks thread until the first runUntilBlocked is called. // Otherwise r starts executing without control of the sync. @@ -107,20 +111,25 @@ public void run() { cancellationScope.run(); } catch (DestroyWorkflowThreadError e) { if (!threadContext.isDestroyRequested()) { + contextWorkflowThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } } catch (Error e) { + contextWorkflowThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } catch (CanceledFailure e) { if (!isCancelRequested()) { + contextWorkflowThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } if (log.isDebugEnabled()) { log.debug(String.format("Workflow thread \"%s\" run canceled", name)); } } catch (Throwable e) { + contextWorkflowThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } finally { + contextWorkflowThreadLocal.finishContextPropagators(); DeterministicRunnerImpl.setCurrentThreadInternal(null); threadContext.setStatus(Status.DONE); thread.setName(originalName); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java index c821e4ab6..8d726cb7a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java @@ -25,7 +25,6 @@ import com.uber.m3.tally.Stopwatch; import com.uber.m3.util.Duration; import com.uber.m3.util.ImmutableMap; -import io.temporal.api.common.v1.Payload; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.failure.v1.CanceledFailureInfo; import io.temporal.api.failure.v1.Failure; @@ -33,18 +32,18 @@ import io.temporal.api.workflowservice.v1.RespondActivityTaskCanceledRequest; import io.temporal.api.workflowservice.v1.RespondActivityTaskCompletedRequest; import io.temporal.api.workflowservice.v1.RespondActivityTaskFailedRequest; -import io.temporal.common.context.ContextPropagator; +import io.temporal.failure.FailureConverter; import io.temporal.internal.common.GrpcRetryer; import io.temporal.internal.common.ProtobufTimeUtils; import io.temporal.internal.common.RpcRetryOptions; +import io.temporal.internal.context.ContextActivityThreadLocal; +import io.temporal.internal.context.ContextPropagatorUtils; import io.temporal.internal.logging.LoggerTag; import io.temporal.internal.metrics.MetricsType; import io.temporal.internal.replay.FailureWrapperException; import io.temporal.internal.worker.ActivityTaskHandler.Result; import io.temporal.serviceclient.MetricsTag; import io.temporal.serviceclient.WorkflowServiceStubs; -import java.util.HashMap; -import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeUnit; import org.slf4j.MDC; @@ -185,7 +184,13 @@ public void handle(PollActivityTaskQueueResponse task) throws Exception { MDC.put(LoggerTag.WORKFLOW_ID, task.getWorkflowExecution().getWorkflowId()); MDC.put(LoggerTag.RUN_ID, task.getWorkflowExecution().getRunId()); - propagateContext(task); + ContextActivityThreadLocal contextActivityThreadLocal = + ContextActivityThreadLocal.getInstance(); + contextActivityThreadLocal.setContextPropagators(options.getContextPropagators()); + contextActivityThreadLocal.propagateContextToCurrentThread( + ContextPropagatorUtils.extractContextsFromHeaders( + options.getContextPropagators(), task.getHeader())); + contextActivityThreadLocal.setUpContextPropagators(); try { Stopwatch sw = metricsScope.timer(MetricsType.ACTIVITY_EXEC_LATENCY).start(); @@ -215,7 +220,10 @@ public void handle(PollActivityTaskQueueResponse task) throws Exception { new Result(task.getActivityId(), null, null, canceledRequest.build(), null), metricsScope); } + contextActivityThreadLocal.onErrorContextPropagators( + FailureConverter.failureToException(failure, options.getDataConverter())); } finally { + contextActivityThreadLocal.finishContextPropagators(); MDC.remove(LoggerTag.ACTIVITY_ID); MDC.remove(LoggerTag.ACTIVITY_TYPE); MDC.remove(LoggerTag.WORKFLOW_ID); @@ -223,23 +231,6 @@ public void handle(PollActivityTaskQueueResponse task) throws Exception { } } - void propagateContext(PollActivityTaskQueueResponse response) { - if (options.getContextPropagators() == null || options.getContextPropagators().isEmpty()) { - return; - } - - if (!response.hasHeader()) { - return; - } - Map headerData = new HashMap<>(); - for (Map.Entry entry : response.getHeader().getFieldsMap().entrySet()) { - headerData.put(entry.getKey(), entry.getValue()); - } - for (ContextPropagator propagator : options.getContextPropagators()) { - propagator.setCurrentContext(propagator.deserializeContext(headerData)); - } - } - @Override public Throwable wrapFailure(PollActivityTaskQueueResponse task, Throwable failure) { WorkflowExecution execution = task.getWorkflowExecution(); diff --git a/temporal-sdk/src/test/java/io/temporal/common/context/ContextTest.java b/temporal-sdk/src/test/java/io/temporal/common/context/ContextTest.java new file mode 100644 index 000000000..03af0ac48 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/common/context/ContextTest.java @@ -0,0 +1,503 @@ +/* + * Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package io.temporal.common.context; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.Tracer; +import io.opentracing.mock.MockSpan; +import io.opentracing.mock.MockTracer; +import io.opentracing.util.GlobalTracer; +import io.opentracing.util.ThreadLocalScopeManager; +import io.temporal.activity.ActivityOptions; +import io.temporal.api.common.v1.Payload; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.client.WorkflowOptions; +import io.temporal.common.converter.DataConverter; +import io.temporal.failure.ApplicationFailure; +import io.temporal.internal.testing.WorkflowTestingTest; +import io.temporal.testing.TestEnvironmentOptions; +import io.temporal.testing.TestWorkflowEnvironment; +import io.temporal.worker.Worker; +import io.temporal.workflow.Async; +import io.temporal.workflow.ChildWorkflowOptions; +import io.temporal.workflow.Promise; +import io.temporal.workflow.Workflow; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.MDC; + +public class ContextTest { + + private static final String TASK_QUEUE = "test-workflow"; + + private TestWorkflowEnvironment testEnvironment; + private MockTracer mockTracer = + new MockTracer(new ThreadLocalScopeManager(), MockTracer.Propagator.TEXT_MAP); + + @Before + public void setUp() { + TestEnvironmentOptions options = + TestEnvironmentOptions.newBuilder() + .setWorkflowClientOptions( + WorkflowClientOptions.newBuilder() + .setContextPropagators( + Arrays.asList( + new TestContextPropagator(), new OpenTracingContextPropagator())) + .build()) + .build(); + testEnvironment = TestWorkflowEnvironment.newInstance(options); + GlobalTracer.registerIfAbsent(mockTracer); + } + + @After + public void tearDown() { + testEnvironment.close(); + mockTracer.reset(); + } + + public static class TestContextPropagator implements ContextPropagator { + + @Override + public String getName() { + return this.getClass().getName(); + } + + @Override + public Map serializeContext(Object context) { + String testKey = (String) context; + if (testKey != null) { + return Collections.singletonMap( + "test", DataConverter.getDefaultInstance().toPayload(testKey).get()); + } else { + return Collections.emptyMap(); + } + } + + @Override + public Object deserializeContext(Map context) { + if (context.containsKey("test")) { + return DataConverter.getDefaultInstance() + .fromPayload(context.get("test"), String.class, String.class); + + } else { + return null; + } + } + + @Override + public Object getCurrentContext() { + return MDC.get("test"); + } + + @Override + public void setCurrentContext(Object context) { + MDC.put("test", String.valueOf(context)); + } + } + + public static class ContextPropagationWorkflowImpl implements WorkflowTestingTest.TestWorkflow { + + @Override + public String workflow1(String input) { + // The test value should be in the MDC + return MDC.get("test"); + } + } + + @Test + public void testWorkflowContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ContextPropagationWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("testing123", result); + } + + public static class ContextPropagationParentWorkflowImpl + implements WorkflowTestingTest.ParentWorkflow { + + @Override + public String workflow(String input) { + // Get the MDC value + String mdcValue = MDC.get("test"); + + // Fire up a child workflow + ChildWorkflowOptions options = + ChildWorkflowOptions.newBuilder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.ChildWorkflow child = + Workflow.newChildWorkflowStub(WorkflowTestingTest.ChildWorkflow.class, options); + + String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId()); + return result; + } + + @Override + public void signal(String value) {} + } + + public static class ContextPropagationChildWorkflowImpl + implements WorkflowTestingTest.ChildWorkflow { + + @Override + public String workflow(String input, String parentId) { + String mdcValue = MDC.get("test"); + return input + mdcValue; + } + } + + @Test + public void testChildWorkflowContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes( + ContextPropagationParentWorkflowImpl.class, ContextPropagationChildWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.ParentWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.ParentWorkflow.class, options); + String result = workflow.workflow("input1"); + assertEquals("testing123testing123", result); + } + + public static class ContextPropagationThreadWorkflowImpl + implements WorkflowTestingTest.TestWorkflow { + + @Override + public String workflow1(String input) { + Promise asyncPromise = Async.function(this::async); + return asyncPromise.get(); + } + + private String async() { + return "async" + MDC.get("test"); + } + } + + @Test + public void testThreadContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ContextPropagationThreadWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .setTaskQueue(TASK_QUEUE) + .build(); + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("asynctesting123", result); + } + + public static class ContextActivityImpl implements WorkflowTestingTest.TestActivity { + @Override + public String activity1(String input) { + return "activity" + MDC.get("test"); + } + } + + public static class ContextPropagationActivityWorkflowImpl + implements WorkflowTestingTest.TestWorkflow { + @Override + public String workflow1(String input) { + ActivityOptions options = + ActivityOptions.newBuilder() + .setScheduleToCloseTimeout(Duration.ofSeconds(5)) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.TestActivity activity = + Workflow.newActivityStub( + WorkflowTestingTest.TestActivity.class, + ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofHours(1)).build()); + + return activity.activity1("foo"); + } + } + + @Test + public void testActivityContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(ContextPropagationActivityWorkflowImpl.class); + worker.registerActivitiesImplementations(new ContextActivityImpl()); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("activitytesting123", result); + } + + public static class DefaultContextPropagationActivityWorkflowImpl + implements WorkflowTestingTest.TestWorkflow { + @Override + public String workflow1(String input) { + ActivityOptions options = + ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build(); + WorkflowTestingTest.TestActivity activity = + Workflow.newActivityStub(WorkflowTestingTest.TestActivity.class, options); + return activity.activity1("foo"); + } + } + + @Test + public void testDefaultActivityContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(DefaultContextPropagationActivityWorkflowImpl.class); + worker.registerActivitiesImplementations(new ContextActivityImpl()); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("activitytesting123", result); + } + + public static class DefaultContextPropagationParentWorkflowImpl + implements WorkflowTestingTest.ParentWorkflow { + + @Override + public String workflow(String input) { + // Get the MDC value + String mdcValue = MDC.get("test"); + + // Fire up a child workflow + ChildWorkflowOptions options = ChildWorkflowOptions.newBuilder().build(); + WorkflowTestingTest.ChildWorkflow child = + Workflow.newChildWorkflowStub(WorkflowTestingTest.ChildWorkflow.class, options); + + String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId()); + return result; + } + + @Override + public void signal(String value) {} + } + + @Test + public void testDefaultChildWorkflowContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes( + DefaultContextPropagationParentWorkflowImpl.class, + ContextPropagationChildWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + WorkflowTestingTest.ParentWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.ParentWorkflow.class, options); + String result = workflow.workflow("input1"); + assertEquals("testing123testing123", result); + } + + public static class OpenTracingContextPropagationWorkflowImpl + implements WorkflowTestingTest.TestWorkflow { + @Override + public String workflow1(String input) { + Tracer tracer = GlobalTracer.get(); + Span activeSpan = tracer.scopeManager().activeSpan(); + MockSpan mockSpan = (MockSpan) activeSpan; + assertNotNull(activeSpan); + assertEquals("TestWorkflow", mockSpan.tags().get("resource.name")); + assertNotEquals(0, mockSpan.parentId()); + if ("fail".equals(input)) { + throw ApplicationFailure.newFailure("fail", "fail"); + } else { + return activeSpan.getBaggageItem("foo"); + } + } + } + + public static class OpenTracingContextPropagationWithActivityWorkflowImpl + implements WorkflowTestingTest.TestWorkflow { + @Override + public String workflow1(String input) { + ActivityOptions options = + ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build(); + WorkflowTestingTest.TestActivity activity = + Workflow.newActivityStub(WorkflowTestingTest.TestActivity.class, options); + return activity.activity1(input); + } + } + + public static class OpenTracingContextPropagationActivityImpl + implements WorkflowTestingTest.TestActivity { + + @Override + public String activity1(String input) { + Tracer tracer = GlobalTracer.get(); + Span activeSpan = tracer.scopeManager().activeSpan(); + MockSpan mockSpan = (MockSpan) activeSpan; + assertNotNull(activeSpan); + assertEquals("Activity1", mockSpan.tags().get("resource.name")); + assertNotEquals(0, mockSpan.parentId()); + if ("fail".equals(input)) { + throw ApplicationFailure.newFailure("fail", "fail"); + } else { + return activeSpan.getBaggageItem("foo"); + } + } + } + + @Test + public void testOpenTracingContextPropagation() { + Tracer tracer = GlobalTracer.get(); + Span span = tracer.buildSpan("testContextPropagationSuccess").start(); + + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(OpenTracingContextPropagationWorkflowImpl.class); + testEnvironment.start(); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators( + Arrays.asList(new TestContextPropagator(), new OpenTracingContextPropagator())) + .build(); + + try (Scope scope = tracer.scopeManager().activate(span)) { + + Span activeSpan = tracer.scopeManager().activeSpan(); + activeSpan.setBaggageItem("foo", "bar"); + + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + assertEquals("bar", workflow.workflow1("input1")); + + } finally { + span.finish(); + } + } + + @Test + public void testOpenTracingContextPropagationWithFailure() { + Tracer tracer = GlobalTracer.get(); + Span span = tracer.buildSpan("testContextPropagationFailure").start(); + + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes(OpenTracingContextPropagationWorkflowImpl.class); + testEnvironment.start(); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators( + Arrays.asList(new TestContextPropagator(), new OpenTracingContextPropagator())) + .build(); + + try (Scope scope = tracer.scopeManager().activate(span)) { + + Span activeSpan = tracer.scopeManager().activeSpan(); + activeSpan.setBaggageItem("foo", "bar"); + + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + try { + workflow.workflow1("fail"); + fail("Unreachable"); + } catch (ApplicationFailure e) { + // Expected + assertEquals("fail", e.getMessage()); + } catch (Exception e) { + e.printStackTrace(); + } + + } finally { + span.finish(); + } + } + + @Test + public void testOpenTracingContextPropagationToActivity() { + Tracer tracer = GlobalTracer.get(); + Span span = tracer.buildSpan("testContextPropagationWithActivity").start(); + + Worker worker = testEnvironment.newWorker(TASK_QUEUE); + worker.registerWorkflowImplementationTypes( + OpenTracingContextPropagationWithActivityWorkflowImpl.class); + worker.registerActivitiesImplementations(new OpenTracingContextPropagationActivityImpl()); + testEnvironment.start(); + WorkflowClient client = testEnvironment.getWorkflowClient(); + WorkflowOptions options = + WorkflowOptions.newBuilder() + .setTaskQueue(TASK_QUEUE) + .setContextPropagators( + Arrays.asList(new TestContextPropagator(), new OpenTracingContextPropagator())) + .build(); + + try (Scope scope = tracer.scopeManager().activate(span)) { + + Span activeSpan = tracer.scopeManager().activeSpan(); + activeSpan.setBaggageItem("foo", "bar"); + + WorkflowTestingTest.TestWorkflow workflow = + client.newWorkflowStub(WorkflowTestingTest.TestWorkflow.class, options); + assertEquals("bar", workflow.workflow1("input1")); + + } finally { + span.finish(); + } + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java b/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java index 35c793aee..e583990e8 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/testing/WorkflowTestingTest.java @@ -27,7 +27,6 @@ import io.temporal.activity.Activity; import io.temporal.activity.ActivityInterface; import io.temporal.activity.ActivityOptions; -import io.temporal.api.common.v1.Payload; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.enums.v1.TimeoutType; import io.temporal.api.workflow.v1.WorkflowExecutionInfo; @@ -42,8 +41,6 @@ import io.temporal.client.WorkflowOptions; import io.temporal.client.WorkflowStub; import io.temporal.common.RetryOptions; -import io.temporal.common.context.ContextPropagator; -import io.temporal.common.converter.DataConverter; import io.temporal.failure.ActivityFailure; import io.temporal.failure.ApplicationFailure; import io.temporal.failure.CanceledFailure; @@ -53,16 +50,13 @@ import io.temporal.testing.TestWorkflowEnvironment; import io.temporal.worker.Worker; import io.temporal.workflow.Async; -import io.temporal.workflow.ChildWorkflowOptions; import io.temporal.workflow.Promise; import io.temporal.workflow.SignalMethod; import io.temporal.workflow.Workflow; import io.temporal.workflow.WorkflowInterface; import io.temporal.workflow.WorkflowMethod; import java.time.Duration; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -76,7 +70,6 @@ import org.junit.runner.Description; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.slf4j.MDC; public class WorkflowTestingTest { private static final Logger log = LoggerFactory.getLogger(WorkflowTestingTest.class); @@ -98,10 +91,7 @@ protected void failed(Throwable e, Description description) { public void setUp() { TestEnvironmentOptions options = TestEnvironmentOptions.newBuilder() - .setWorkflowClientOptions( - WorkflowClientOptions.newBuilder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build()) + .setWorkflowClientOptions(WorkflowClientOptions.newBuilder().build()) .build(); testEnvironment = TestWorkflowEnvironment.newInstance(options); } @@ -874,257 +864,4 @@ public void testMockedChildSimulatedTimeout() { assertTrue(e.getCause().getCause() instanceof TimeoutFailure); } } - - public static class TestContextPropagator implements ContextPropagator { - - @Override - public String getName() { - return this.getClass().getName(); - } - - @Override - public Map serializeContext(Object context) { - String testKey = (String) context; - if (testKey != null) { - return Collections.singletonMap( - "test", DataConverter.getDefaultInstance().toPayload(testKey).get()); - } else { - return Collections.emptyMap(); - } - } - - @Override - public Object deserializeContext(Map context) { - if (context.containsKey("test")) { - return DataConverter.getDefaultInstance() - .fromPayload(context.get("test"), String.class, String.class); - - } else { - return null; - } - } - - @Override - public Object getCurrentContext() { - return MDC.get("test"); - } - - @Override - public void setCurrentContext(Object context) { - MDC.put("test", String.valueOf(context)); - } - } - - public static class ContextPropagationWorkflowImpl implements TestWorkflow { - - @Override - public String workflow1(String input) { - // The test value should be in the MDC - return MDC.get("test"); - } - } - - @Test - public void testWorkflowContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_QUEUE); - worker.registerWorkflowImplementationTypes(ContextPropagationWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.getWorkflowClient(); - WorkflowOptions options = - WorkflowOptions.newBuilder() - .setTaskQueue(TASK_QUEUE) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("testing123", result); - } - - public static class ContextPropagationParentWorkflowImpl implements ParentWorkflow { - - @Override - public String workflow(String input) { - // Get the MDC value - String mdcValue = MDC.get("test"); - - // Fire up a child workflow - ChildWorkflowOptions options = - ChildWorkflowOptions.newBuilder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options); - - String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId()); - return result; - } - - @Override - public void signal(String value) {} - } - - public static class ContextPropagationChildWorkflowImpl implements ChildWorkflow { - - @Override - public String workflow(String input, String parentId) { - String mdcValue = MDC.get("test"); - return input + mdcValue; - } - } - - @Test - public void testChildWorkflowContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_QUEUE); - worker.registerWorkflowImplementationTypes( - ContextPropagationParentWorkflowImpl.class, ContextPropagationChildWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.getWorkflowClient(); - WorkflowOptions options = - WorkflowOptions.newBuilder() - .setTaskQueue(TASK_QUEUE) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options); - String result = workflow.workflow("input1"); - assertEquals("testing123testing123", result); - } - - public static class ContextPropagationThreadWorkflowImpl implements TestWorkflow { - - @Override - public String workflow1(String input) { - Promise asyncPromise = Async.function(this::async); - return asyncPromise.get(); - } - - private String async() { - return "async" + MDC.get("test"); - } - } - - @Test - public void testThreadContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_QUEUE); - worker.registerWorkflowImplementationTypes(ContextPropagationThreadWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.getWorkflowClient(); - WorkflowOptions options = - WorkflowOptions.newBuilder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .setTaskQueue(TASK_QUEUE) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("asynctesting123", result); - } - - public static class ContextActivityImpl implements TestActivity { - @Override - public String activity1(String input) { - return "activity" + MDC.get("test"); - } - } - - public static class ContextPropagationActivityWorkflowImpl implements TestWorkflow { - @Override - public String workflow1(String input) { - ActivityOptions options = - ActivityOptions.newBuilder() - .setScheduleToCloseTimeout(Duration.ofSeconds(5)) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestActivity activity = - Workflow.newActivityStub( - TestActivity.class, - ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofHours(1)).build()); - - return activity.activity1("foo"); - } - } - - @Test - public void testActivityContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_QUEUE); - worker.registerWorkflowImplementationTypes(ContextPropagationActivityWorkflowImpl.class); - worker.registerActivitiesImplementations(new ContextActivityImpl()); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.getWorkflowClient(); - WorkflowOptions options = - WorkflowOptions.newBuilder() - .setTaskQueue(TASK_QUEUE) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("activitytesting123", result); - } - - public static class DefaultContextPropagationActivityWorkflowImpl implements TestWorkflow { - @Override - public String workflow1(String input) { - ActivityOptions options = - ActivityOptions.newBuilder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build(); - TestActivity activity = Workflow.newActivityStub(TestActivity.class, options); - return activity.activity1("foo"); - } - } - - @Test - public void testDefaultActivityContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_QUEUE); - worker.registerWorkflowImplementationTypes(DefaultContextPropagationActivityWorkflowImpl.class); - worker.registerActivitiesImplementations(new ContextActivityImpl()); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.getWorkflowClient(); - WorkflowOptions options = - WorkflowOptions.newBuilder() - .setTaskQueue(TASK_QUEUE) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("activitytesting123", result); - } - - public static class DefaultContextPropagationParentWorkflowImpl implements ParentWorkflow { - - @Override - public String workflow(String input) { - // Get the MDC value - String mdcValue = MDC.get("test"); - - // Fire up a child workflow - ChildWorkflowOptions options = ChildWorkflowOptions.newBuilder().build(); - ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options); - - String result = child.workflow(mdcValue, Workflow.getInfo().getWorkflowId()); - return result; - } - - @Override - public void signal(String value) {} - } - - @Test - public void testDefaultChildWorkflowContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_QUEUE); - worker.registerWorkflowImplementationTypes( - DefaultContextPropagationParentWorkflowImpl.class, - ContextPropagationChildWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.getWorkflowClient(); - WorkflowOptions options = - WorkflowOptions.newBuilder() - .setTaskQueue(TASK_QUEUE) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options); - String result = workflow.workflow("input1"); - assertEquals("testing123testing123", result); - } }