From 27d998afc01ed17ed7dfee421f995a42727d3974 Mon Sep 17 00:00:00 2001
From: Quinn Klassen <klassenq@gmail.com>
Date: Thu, 24 Oct 2024 09:18:27 -0700
Subject: [PATCH] Add failure_reason to nexus_task_execution_failed (#2274)

Add failure_reason to nexus_task_execution_failed
---
 .../internal/nexus/NexusTaskHandlerImpl.java  |   4 +-
 .../temporal/internal/worker/NexusWorker.java |  31 +-
 .../nexus/OperationFailMetricTest.java        | 303 ++++++++++++++++++
 .../nexus/SyncClientOperationTest.java        |   7 +-
 .../workflow/nexus/SyncOperationFailTest.java |  24 ++
 5 files changed, 362 insertions(+), 7 deletions(-)
 create mode 100644 temporal-sdk/src/test/java/io/temporal/workflow/nexus/OperationFailMetricTest.java

diff --git a/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java
index 0714d472e..5ef69329c 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java
@@ -223,7 +223,9 @@ private void convertKnownFailures(Throwable e) {
     if (failure instanceof Error) {
       throw (Error) failure;
     }
-    throw new RuntimeException(failure);
+    throw failure instanceof RuntimeException
+        ? (RuntimeException) failure
+        : new RuntimeException(failure);
   }
 
   private OperationStartResult<HandlerResultContent> startOperation(
diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java
index 13d0e7b0a..d4a904a4a 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java
@@ -21,6 +21,7 @@
 package io.temporal.internal.worker;
 
 import static io.temporal.serviceclient.MetricsTag.METRICS_TAGS_CALL_OPTIONS_KEY;
+import static io.temporal.serviceclient.MetricsTag.TASK_FAILURE_TYPE;
 
 import com.google.protobuf.ByteString;
 import com.uber.m3.tally.Scope;
@@ -40,6 +41,7 @@
 import io.temporal.worker.MetricsType;
 import io.temporal.worker.WorkerMetricsTag;
 import io.temporal.worker.tuning.*;
+import java.util.Collections;
 import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
@@ -278,16 +280,35 @@ private void handleNexusTask(NexusTask task, Scope metricsScope) {
       Stopwatch sw = metricsScope.timer(MetricsType.NEXUS_EXEC_LATENCY).start();
       try {
         result = handler.handle(task, metricsScope);
-        if (result.getHandlerError() != null
-            || (result.getResponse().hasStartOperation()
-                && result.getResponse().getStartOperation().hasOperationError())) {
-          metricsScope.counter(MetricsType.NEXUS_EXEC_FAILED_COUNTER).inc(1);
+        if (result.getHandlerError() != null) {
+          metricsScope
+              .tagged(
+                  Collections.singletonMap(
+                      TASK_FAILURE_TYPE,
+                      "handler_error_" + result.getHandlerError().getErrorType()))
+              .counter(MetricsType.NEXUS_EXEC_FAILED_COUNTER)
+              .inc(1);
+        } else if (result.getResponse().hasStartOperation()
+            && result.getResponse().getStartOperation().hasOperationError()) {
+          String operationState =
+              result.getResponse().getStartOperation().getOperationError().getOperationState();
+          metricsScope
+              .tagged(Collections.singletonMap(TASK_FAILURE_TYPE, "operation_" + operationState))
+              .counter(MetricsType.NEXUS_EXEC_FAILED_COUNTER)
+              .inc(1);
         }
       } catch (TimeoutException e) {
         log.warn("Nexus task timed out while processing", e);
-        metricsScope.counter(MetricsType.NEXUS_EXEC_FAILED_COUNTER).inc(1);
+        metricsScope
+            .tagged(Collections.singletonMap(TASK_FAILURE_TYPE, "timeout"))
+            .counter(MetricsType.NEXUS_EXEC_FAILED_COUNTER)
+            .inc(1);
         return;
       } catch (Throwable e) {
+        metricsScope
+            .tagged(Collections.singletonMap(TASK_FAILURE_TYPE, "internal_sdk_error"))
+            .counter(MetricsType.NEXUS_EXEC_FAILED_COUNTER)
+            .inc(1);
         // handler.handle if expected to never throw an exception and return result
         // that can be used for a workflow callback if this method throws, it's a bug.
         log.error("[BUG] Code that expected to never throw an exception threw an exception", e);
diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/nexus/OperationFailMetricTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/OperationFailMetricTest.java
new file mode 100644
index 000000000..3b25c45bd
--- /dev/null
+++ b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/OperationFailMetricTest.java
@@ -0,0 +1,303 @@
+/*
+ * Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved.
+ *
+ * Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Modifications copyright (C) 2017 Uber Technologies, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this material except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.temporal.workflow.nexus;
+
+import static io.temporal.testing.internal.SDKTestWorkflowRule.NAMESPACE;
+
+import com.google.common.collect.ImmutableMap;
+import com.uber.m3.tally.RootScopeBuilder;
+import io.nexusrpc.OperationUnsuccessfulException;
+import io.nexusrpc.handler.OperationHandler;
+import io.nexusrpc.handler.OperationHandlerException;
+import io.nexusrpc.handler.OperationImpl;
+import io.nexusrpc.handler.ServiceImpl;
+import io.temporal.api.common.v1.WorkflowExecution;
+import io.temporal.client.WorkflowExecutionAlreadyStarted;
+import io.temporal.client.WorkflowFailedException;
+import io.temporal.common.reporter.TestStatsReporter;
+import io.temporal.failure.ApplicationFailure;
+import io.temporal.serviceclient.MetricsTag;
+import io.temporal.testUtils.Eventually;
+import io.temporal.testing.internal.SDKTestWorkflowRule;
+import io.temporal.worker.MetricsType;
+import io.temporal.worker.WorkerMetricsTag;
+import io.temporal.workflow.*;
+import io.temporal.workflow.shared.TestNexusServices;
+import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
+import java.time.Duration;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+
+public class OperationFailMetricTest {
+  private final TestStatsReporter reporter = new TestStatsReporter();
+
+  @Rule
+  public SDKTestWorkflowRule testWorkflowRule =
+      SDKTestWorkflowRule.newBuilder()
+          .setWorkflowTypes(TestNexus.class)
+          .setNexusServiceImplementation(new TestNexusServiceImpl())
+          .setMetricsScope(
+              new RootScopeBuilder()
+                  .reporter(reporter)
+                  .reportEvery(com.uber.m3.util.Duration.ofMillis(10)))
+          .build();
+
+  private ImmutableMap.Builder<String, String> getBaseTags() {
+    return ImmutableMap.<String, String>builder()
+        .putAll(MetricsTag.defaultTags(NAMESPACE))
+        .put(MetricsTag.WORKER_TYPE, WorkerMetricsTag.WorkerType.NEXUS_WORKER.getValue())
+        .put(MetricsTag.TASK_QUEUE, testWorkflowRule.getTaskQueue());
+  }
+
+  private ImmutableMap.Builder<String, String> getOperationTags() {
+    return getBaseTags()
+        .put(MetricsTag.NEXUS_SERVICE, "TestNexusService1")
+        .put(MetricsTag.NEXUS_OPERATION, "operation");
+  }
+
+  @Test
+  public void failOperationMetrics() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+
+    Assert.assertThrows(WorkflowFailedException.class, () -> workflowStub.execute("fail"));
+
+    Map<String, String> execFailedTags =
+        getOperationTags().put(MetricsTag.TASK_FAILURE_TYPE, "operation_failed").buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_TASK_E2E_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 1);
+        });
+  }
+
+  @Test
+  public void failHandlerBadRequestMetrics() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    Assert.assertThrows(WorkflowFailedException.class, () -> workflowStub.execute("handlererror"));
+
+    Map<String, String> execFailedTags =
+        getOperationTags()
+            .put(MetricsTag.TASK_FAILURE_TYPE, "handler_error_BAD_REQUEST")
+            .buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_TASK_E2E_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 1);
+        });
+  }
+
+  @Test
+  public void failHandlerAlreadyStartedMetrics() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    Assert.assertThrows(
+        WorkflowFailedException.class, () -> workflowStub.execute("already-started"));
+
+    Map<String, String> execFailedTags =
+        getOperationTags()
+            .put(MetricsTag.TASK_FAILURE_TYPE, "handler_error_BAD_REQUEST")
+            .buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_TASK_E2E_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 1);
+        });
+  }
+
+  @Test
+  public void failHandlerRetryableApplicationFailureMetrics() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    Assert.assertThrows(
+        WorkflowFailedException.class, () -> workflowStub.execute("retryable-application-failure"));
+
+    Map<String, String> execFailedTags =
+        getOperationTags()
+            .put(MetricsTag.TASK_FAILURE_TYPE, "handler_error_INTERNAL")
+            .buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_TASK_E2E_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(
+              MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, c -> c >= 1);
+        });
+  }
+
+  @Test
+  public void failHandlerNonRetryableApplicationFailureMetrics() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    Assert.assertThrows(
+        WorkflowFailedException.class,
+        () -> workflowStub.execute("non-retryable-application-failure"));
+
+    Map<String, String> execFailedTags =
+        getOperationTags()
+            .put(MetricsTag.TASK_FAILURE_TYPE, "handler_error_BAD_REQUEST")
+            .buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_TASK_E2E_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 1);
+        });
+  }
+
+  @Test(timeout = 20000)
+  public void failHandlerSleepMetrics() throws InterruptedException {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    Assert.assertThrows(WorkflowFailedException.class, () -> workflowStub.execute("sleep"));
+
+    Map<String, String> execFailedTags =
+        getOperationTags().put(MetricsTag.TASK_FAILURE_TYPE, "timeout").buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 1);
+        });
+  }
+
+  @Test
+  public void failHandlerErrorMetrics() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    Assert.assertThrows(WorkflowFailedException.class, () -> workflowStub.execute("error"));
+    Map<String, String> execFailedTags =
+        getOperationTags()
+            .put(MetricsTag.TASK_FAILURE_TYPE, "handler_error_INTERNAL")
+            .buildKeepingLast();
+    Eventually.assertEventually(
+        Duration.ofSeconds(3),
+        () -> {
+          reporter.assertTimer(
+              MetricsType.NEXUS_SCHEDULE_TO_START_LATENCY, getBaseTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_EXEC_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertTimer(
+              MetricsType.NEXUS_TASK_E2E_LATENCY, getOperationTags().buildKeepingLast());
+          reporter.assertCounter(
+              MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, c -> c >= 1);
+        });
+  }
+
+  public static class TestNexus implements TestWorkflow1 {
+    @Override
+    public String execute(String operation) {
+      TestNexusServices.TestNexusService1 testNexusService =
+          Workflow.newNexusServiceStub(
+              TestNexusServices.TestNexusService1.class,
+              NexusServiceOptions.newBuilder()
+                  .setOperationOptions(
+                      NexusOperationOptions.newBuilder()
+                          .setScheduleToCloseTimeout(Duration.ofSeconds(10))
+                          .build())
+                  .build());
+      return testNexusService.operation(operation);
+    }
+  }
+
+  @ServiceImpl(service = TestNexusServices.TestNexusService1.class)
+  public class TestNexusServiceImpl {
+    Map<String, Integer> invocationCount = new ConcurrentHashMap<>();
+
+    @OperationImpl
+    public OperationHandler<String, String> operation() {
+      // Implemented inline
+      return OperationHandler.sync(
+          (ctx, details, operation) -> {
+            invocationCount.put(
+                details.getRequestId(),
+                invocationCount.getOrDefault(details.getRequestId(), 0) + 1);
+            if (invocationCount.get(details.getRequestId()) > 1) {
+              throw new OperationUnsuccessfulException("exceeded invocation count");
+            }
+            switch (operation) {
+              case "success":
+                return operation;
+              case "fail":
+                throw new OperationUnsuccessfulException("fail");
+              case "handlererror":
+                throw new OperationHandlerException(
+                    OperationHandlerException.ErrorType.BAD_REQUEST, "handlererror");
+              case "already-started":
+                throw new WorkflowExecutionAlreadyStarted(
+                    WorkflowExecution.getDefaultInstance(), "TestWorkflowType", null);
+              case "retryable-application-failure":
+                throw ApplicationFailure.newFailure("fail", "TestFailure");
+              case "non-retryable-application-failure":
+                throw ApplicationFailure.newNonRetryableFailure("fail", "TestFailure");
+              case "sleep":
+                try {
+                  Thread.sleep(11000);
+                } catch (InterruptedException e) {
+                  throw new RuntimeException(e);
+                }
+                return operation;
+              case "error":
+                throw new Error("error");
+              default:
+                // Should never happen
+                Assert.fail();
+            }
+            return operation;
+          });
+    }
+  }
+}
diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncClientOperationTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncClientOperationTest.java
index 875f4ac83..adbb2c055 100644
--- a/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncClientOperationTest.java
+++ b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncClientOperationTest.java
@@ -106,7 +106,12 @@ public void syncClientOperationFail() {
             .buildKeepingLast();
     reporter.assertTimer(MetricsType.NEXUS_EXEC_LATENCY, operationTags);
     reporter.assertTimer(MetricsType.NEXUS_TASK_E2E_LATENCY, operationTags);
-    reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, operationTags, 1);
+    Map<String, String> execFailedTags =
+        ImmutableMap.<String, String>builder()
+            .putAll(operationTags)
+            .put(MetricsTag.TASK_FAILURE_TYPE, "handler_error_BAD_REQUEST")
+            .buildKeepingLast();
+    reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 1);
   }
 
   @WorkflowInterface
diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncOperationFailTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncOperationFailTest.java
index b6ec13642..24ad56bc6 100644
--- a/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncOperationFailTest.java
+++ b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/SyncOperationFailTest.java
@@ -20,27 +20,41 @@
 
 package io.temporal.workflow.nexus;
 
+import static io.temporal.testing.internal.SDKTestWorkflowRule.NAMESPACE;
+
+import com.google.common.collect.ImmutableMap;
+import com.uber.m3.tally.RootScopeBuilder;
 import io.nexusrpc.OperationUnsuccessfulException;
 import io.nexusrpc.handler.OperationHandler;
 import io.nexusrpc.handler.OperationImpl;
 import io.nexusrpc.handler.ServiceImpl;
 import io.temporal.client.WorkflowFailedException;
+import io.temporal.common.reporter.TestStatsReporter;
 import io.temporal.failure.ApplicationFailure;
 import io.temporal.failure.NexusOperationFailure;
+import io.temporal.serviceclient.MetricsTag;
 import io.temporal.testing.internal.SDKTestWorkflowRule;
+import io.temporal.worker.MetricsType;
+import io.temporal.worker.WorkerMetricsTag;
 import io.temporal.workflow.*;
 import io.temporal.workflow.shared.TestNexusServices;
 import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
 import java.time.Duration;
+import java.util.Map;
 import org.junit.*;
 
 public class SyncOperationFailTest {
+  private final TestStatsReporter reporter = new TestStatsReporter();
 
   @Rule
   public SDKTestWorkflowRule testWorkflowRule =
       SDKTestWorkflowRule.newBuilder()
           .setWorkflowTypes(TestNexus.class)
           .setNexusServiceImplementation(new TestNexusServiceImpl())
+          .setMetricsScope(
+              new RootScopeBuilder()
+                  .reporter(reporter)
+                  .reportEvery(com.uber.m3.util.Duration.ofMillis(10)))
           .build();
 
   @Test
@@ -54,6 +68,16 @@ public void failSyncOperation() {
     Assert.assertTrue(nexusFailure.getCause() instanceof ApplicationFailure);
     ApplicationFailure applicationFailure = (ApplicationFailure) nexusFailure.getCause();
     Assert.assertEquals("failed to call operation", applicationFailure.getOriginalMessage());
+    Map<String, String> execFailedTags =
+        ImmutableMap.<String, String>builder()
+            .putAll(MetricsTag.defaultTags(NAMESPACE))
+            .put(MetricsTag.WORKER_TYPE, WorkerMetricsTag.WorkerType.NEXUS_WORKER.getValue())
+            .put(MetricsTag.TASK_QUEUE, testWorkflowRule.getTaskQueue())
+            .put(MetricsTag.NEXUS_SERVICE, "TestNexusService1")
+            .put(MetricsTag.NEXUS_OPERATION, "operation")
+            .put(MetricsTag.TASK_FAILURE_TYPE, "operation_failed")
+            .buildKeepingLast();
+    reporter.assertCounter(MetricsType.NEXUS_EXEC_FAILED_COUNTER, execFailedTags, 4);
   }
 
   public static class TestNexus implements TestWorkflow1 {