diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index 6590aaa04a..1dba95acd6 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -346,6 +346,8 @@ type MutableNodeStatus interface { GetArrayNodeStatus() MutableArrayNodeStatus GetOrCreateArrayNodeStatus() MutableArrayNodeStatus ClearArrayNodeStatus() + + ClearExecutionError() } type ExecutionTimeInfo interface { diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go index cdf3f1b6ab..b0eb198303 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -30,6 +30,11 @@ func (_m *ExecutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() } +// ClearExecutionError provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearExecutionError() { + _m.Called() +} + // ClearGateNodeStatus provides a mock function with given fields: func (_m *ExecutableNodeStatus) ClearGateNodeStatus() { _m.Called() diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go index 3f103bc2ec..328b732463 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -28,6 +28,11 @@ func (_m *MutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() } +// ClearExecutionError provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearExecutionError() { + _m.Called() +} + // ClearGateNodeStatus provides a mock function with given fields: func (_m *MutableNodeStatus) ClearGateNodeStatus() { _m.Called() diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go index aab034224d..98a8941b13 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -477,6 +477,10 @@ func (in *NodeStatus) ClearArrayNodeStatus() { in.SetDirty() } +func (in *NodeStatus) ClearExecutionError() { + in.Error = nil +} + func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { return in.LastUpdatedAt } @@ -632,6 +636,7 @@ func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason st if p == NodePhaseSucceeded || p == NodePhaseSkipped || !enableCRDebugMetadata { // Clear most status related fields after reaching a terminal state. This keeps the CR state small to avoid // etcd size limits. Importantly we keep Phase, StoppedAt and Error which will be needed further. + // Errors will still be needed but it will be cleaned up when possible because they can be very large. in.Message = "" in.QueuedAt = nil in.StartedAt = nil diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 698883fb48..dbf0e69a4b 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -214,7 +214,7 @@ type NodeConfig struct { InterruptibleFailureThreshold int32 `json:"interruptible-failure-threshold" pflag:"1,number of failures for a node to be still considered interruptible. Negative numbers are treated as complementary (ex. -1 means last attempt is non-interruptible).'"` DefaultMaxAttempts int32 `json:"default-max-attempts" pflag:"3,Default maximum number of attempts for a node"` IgnoreRetryCause bool `json:"ignore-retry-cause" pflag:",Ignore retry cause and count all attempts toward a node's max attempts"` - EnableCRDebugMetadata bool `json:"enable-cr-debug-metadata" pflag:",Collapse node on any terminal state, not just successful terminations. This is useful to reduce the size of workflow state in etcd."` + EnableCRDebugMetadata bool `json:"enable-cr-debug-metadata" pflag:",By default node state gets cleared after flytepropeller will no longer need it. This is useful to reduce the size of workflow state in etcd. Consider enabling this to keep this state for debugging purposes."` } // DefaultDeadlines contains default values for timeouts diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index c60f724ee2..321754a6b8 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -96,7 +96,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int32(fmt.Sprintf("%v%v", prefix, "node-config.interruptible-failure-threshold"), defaultConfig.NodeConfig.InterruptibleFailureThreshold, "number of failures for a node to be still considered interruptible. Negative numbers are treated as complementary (ex. -1 means last attempt is non-interruptible).'") cmdFlags.Int32(fmt.Sprintf("%v%v", prefix, "node-config.default-max-attempts"), defaultConfig.NodeConfig.DefaultMaxAttempts, "Default maximum number of attempts for a node") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "node-config.ignore-retry-cause"), defaultConfig.NodeConfig.IgnoreRetryCause, "Ignore retry cause and count all attempts toward a node's max attempts") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "node-config.enable-cr-debug-metadata"), defaultConfig.NodeConfig.EnableCRDebugMetadata, "Collapse node on any terminal state, not just successful terminations. This is useful to reduce the size of workflow state in etcd.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "node-config.enable-cr-debug-metadata"), defaultConfig.NodeConfig.EnableCRDebugMetadata, "By default node state gets cleared after flytepropeller will no longer need it. This is useful to reduce the size of workflow state in etcd. Consider enabling this to keep this state for debugging purposes.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-streak-length"), defaultConfig.MaxStreakLength, "Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "event-config.raw-output-policy"), defaultConfig.EventConfig.RawOutputPolicy, "How output data should be passed along in execution events.") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "event-config.fallback-to-output-reference"), defaultConfig.EventConfig.FallbackToOutputReference, "Whether output data should be sent by reference when it is too large to be sent inline in execution events.") diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 6ddde14c71..442e113447 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -277,6 +277,7 @@ func (c *recursiveNodeExecutor) handleDownstream(ctx context.Context, execContex partialNodeCompletion := false onFailurePolicy := execContext.GetOnFailurePolicy() stateOnComplete := interfaces.NodeStatusComplete + var executableNodeStatusOnComplete v1alpha1.ExecutableNodeStatus for _, downstreamNodeName := range downstreamNodes { downstreamNode, ok := nl.GetNode(downstreamNodeName) if !ok { @@ -298,6 +299,10 @@ func (c *recursiveNodeExecutor) handleDownstream(ctx context.Context, execContex // If the failure policy allows other nodes to continue running, do not exit the loop, // Keep track of the last failed state in the loop since it'll be the one to return. // TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one. + if executableNodeStatusOnComplete != nil { + c.nodeExecutor.Clear(executableNodeStatusOnComplete) + } + executableNodeStatusOnComplete = nl.GetNodeExecutionStatus(ctx, downstreamNode.GetID()) stateOnComplete = state } else { return state, nil @@ -863,6 +868,12 @@ func (c *nodeExecutor) execute(ctx context.Context, h interfaces.NodeHandler, nC return phase, nil } +func (c *nodeExecutor) Clear(executableNodeStatus v1alpha1.ExecutableNodeStatus) { + if !c.enableCRDebugMetadata { + executableNodeStatus.ClearExecutionError() + } +} + func (c *nodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCtx interfaces.NodeExecutionContext, reason string, finalTransition bool) error { logger.Debugf(ctx, "Calling aborting & finalize") if err := h.Abort(ctx, nCtx, reason); err != nil { diff --git a/flytepropeller/pkg/controller/nodes/interfaces/handler.go b/flytepropeller/pkg/controller/nodes/interfaces/handler.go index c5fe9d6321..0f5f6384ab 100644 --- a/flytepropeller/pkg/controller/nodes/interfaces/handler.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/handler.go @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) //go:generate mockery -all -case=underscore @@ -16,6 +17,7 @@ type NodeExecutor interface { HandleNode(ctx context.Context, dag executors.DAGStructure, nCtx NodeExecutionContext, h NodeHandler) (NodeStatus, error) Abort(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext, reason string, finalTransition bool) error Finalize(ctx context.Context, h NodeHandler, nCtx NodeExecutionContext) error + Clear(executableNodeStatus v1alpha1.ExecutableNodeStatus) } // Interface that should be implemented for a node type. diff --git a/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go index e01d13d7d4..f3b8c007d1 100644 --- a/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go +++ b/flytepropeller/pkg/controller/nodes/interfaces/mocks/node_executor.go @@ -9,6 +9,8 @@ import ( interfaces "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) // NodeExecutor is an autogenerated mock type for the NodeExecutor type @@ -48,6 +50,11 @@ func (_m *NodeExecutor) Abort(ctx context.Context, h interfaces.NodeHandler, nCt return r0 } +// Clear provides a mock function with given fields: executableNodeStatus +func (_m *NodeExecutor) Clear(executableNodeStatus v1alpha1.ExecutableNodeStatus) { + _m.Called(executableNodeStatus) +} + type NodeExecutor_Finalize struct { *mock.Call } diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index cc9910abc3..9d2ffe6e62 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -72,6 +72,10 @@ type fakeRemoteWritePlugin struct { t assert.TestingT } +type fakeNodeExecContext interface { + Node() v1alpha1.ExecutableNode +} + func (f fakeRemoteWritePlugin) Handle(ctx context.Context, tCtx pluginCore.TaskExecutionContext) (pluginCore.Transition, error) { logger.Infof(ctx, "----------------------------------------------------------------------------------------------") logger.Infof(ctx, "Handle called for %s", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) @@ -224,6 +228,26 @@ func createTaskExecutorErrorInCheck(t assert.TestingT) pluginCore.PluginEntry { } } +func CountFailedNodes(nodeStatuses map[v1alpha1.NodeID]*v1alpha1.NodeStatus) int { + count := 0 + for _, v := range nodeStatuses { + if v.Phase == v1alpha1.NodePhaseFailed { + count++ + } + } + return count +} + +func CountNodesWithErrors(nodeStatuses map[v1alpha1.NodeID]*v1alpha1.NodeStatus) int { + count := 0 + for _, v := range nodeStatuses { + if v.Error != nil { + count++ + } + } + return count +} + func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { ctx := context.Background() scope := testScope.NewSubScope("12") @@ -496,7 +520,17 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { h := &nodemocks.NodeHandler{} h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) - h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + + // Mock handler marks start-node successfully completed but other nodes as failed + startNodeMatcher := mock.MatchedBy(func(nodeExecContext fakeNodeExecContext) bool { + return nodeExecContext.Node().IsStartNode() + }) + h.OnHandleMatch(mock.Anything, startNodeMatcher).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil) + h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition( + handler.TransitionTypeEphemeral, + handler.PhaseInfoFailureErr(&core.ExecutionError{Code: "code", Message: "message", ErrorUri: "uri"}, nil)), nil, + ) + h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) h.OnFinalizeRequired().Return(false) @@ -504,46 +538,66 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) - assert.NoError(t, err) - executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) - assert.NoError(t, err) - - assert.NoError(t, executor.Initialize(ctx)) + tests := []struct { + name string + onFailurePolicy v1alpha1.WorkflowOnFailurePolicy + enableCRDebugMetadata bool + expectedRoundsToFail int + expectedNodesWithErrorsCount int + expectedFailedNodesCount int + }{ + {"failImidiately", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY), false, 6, 1, 1}, + {"failImidiately enableCRDebugMetadata", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY), true, 6, 1, 1}, + {"failAfterExecutableNodesComplete", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE), false, 12, 1, 2}, + {"failAfterExecutableNodesComplete enableCRDebugMetadata", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE), true, 12, 2, 2}, + } wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") - if assert.NoError(t, err) { - w := &v1alpha1.FlyteWorkflow{ - RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{}}, - } - if assert.NoError(t, json.Unmarshal(wJSON, w)) { - // For benchmark workflow, we will run into the first failure on round 6 - - roundsToFail := 8 - for i := 0; i < roundsToFail; i++ { - err := executor.HandleFlyteWorkflow(ctx, w) - assert.Nil(t, err, "Round [%v]", i) - fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) - walkAndPrint(w.Connections, w.Status.NodeStatus) - for _, v := range w.Status.NodeStatus { - // Reset dirty manually for tests. - v.ResetDirty() - } - fmt.Printf("\n") - - if i == roundsToFail-1 { - assert.Equal(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase) - } else if i == roundsToFail-2 { - assert.Equal(t, v1alpha1.WorkflowPhaseHandlingFailureNode, w.Status.Phase) - } else { - assert.NotEqual(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase, "For Round [%v] got phase [%v]", i, w.Status.Phase.String()) + assert.NoError(t, err) + for _, test := range tests { + + t.Run(test.name, func(t *testing.T) { + nodeConfig := config.GetConfig().NodeConfig + nodeConfig.EnableCRDebugMetadata = test.enableCRDebugMetadata + nodeExec, err := nodes.NewExecutor(ctx, nodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + assert.NoError(t, err) + executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) + assert.NoError(t, err) + assert.NoError(t, executor.Initialize(ctx)) + + w := &v1alpha1.FlyteWorkflow{ + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{}}, + WorkflowSpec: &v1alpha1.WorkflowSpec{OnFailurePolicy: test.onFailurePolicy}, + } + if assert.NoError(t, json.Unmarshal(wJSON, w)) { + + for i := 0; i < test.expectedRoundsToFail; i++ { + t.Run(fmt.Sprintf("Round[%d]", i), func(t *testing.T) { + err := executor.HandleFlyteWorkflow(ctx, w) + assert.Nil(t, err, "Round [%v]", i) + fmt.Printf("Round[%d] Workflow[%v]\n", i, w.Status.Phase.String()) + walkAndPrint(w.Connections, w.Status.NodeStatus) + for _, v := range w.Status.NodeStatus { + // Reset dirty manually for tests. + v.ResetDirty() + } + fmt.Printf("\n") + + if i == test.expectedRoundsToFail-1 { + assert.Equal(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase) + } else if i == test.expectedRoundsToFail-2 { + assert.Equal(t, v1alpha1.WorkflowPhaseHandlingFailureNode, w.Status.Phase) + } else { + assert.NotEqual(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase, "For Round [%v] got phase [%v]", i, w.Status.Phase.String()) + } + }) } - + assert.Equal(t, test.expectedFailedNodesCount, CountFailedNodes(w.Status.NodeStatus)) + assert.Equal(t, test.expectedNodesWithErrorsCount, CountNodesWithErrors(w.Status.NodeStatus)) + assert.Equal(t, v1alpha1.WorkflowPhaseFailed.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) } - - assert.Equal(t, v1alpha1.WorkflowPhaseFailed.String(), w.Status.Phase.String(), "Message: [%v]", w.Status.Message) - } + }) } assert.True(t, recordedRunning) assert.True(t, recordedFailing)