diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index f1e2ef64fc..7dcdef4749 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -205,10 +205,16 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // initialize ArrayNode state - maxAttempts := int(config.GetConfig().NodeConfig.DefaultMaxAttempts) - subNodeSpec := *arrayNode.GetSubNodeSpec() - if subNodeSpec.GetRetryStrategy() != nil && subNodeSpec.GetRetryStrategy().MinAttempts != nil { - maxAttempts = *subNodeSpec.GetRetryStrategy().MinAttempts + maxSystemFailuresValue := int(config.GetConfig().NodeConfig.MaxNodeRetriesOnSystemFailures) + maxAttemptsValue := int(config.GetConfig().NodeConfig.DefaultMaxAttempts) + if nCtx.Node().GetRetryStrategy() != nil && nCtx.Node().GetRetryStrategy().MinAttempts != nil && *nCtx.Node().GetRetryStrategy().MinAttempts != 1 { + maxAttemptsValue = *nCtx.Node().GetRetryStrategy().MinAttempts + } + + if config.GetConfig().NodeConfig.IgnoreRetryCause { + maxSystemFailuresValue = maxAttemptsValue + } else { + maxAttemptsValue += maxSystemFailuresValue } for _, item := range []struct { @@ -219,8 +225,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // defined as an `iota` so it is impossible to programmatically get largest value {arrayReference: &arrayNodeState.SubNodePhases, maxValue: int(v1alpha1.NodePhaseRecovered)}, {arrayReference: &arrayNodeState.SubNodeTaskPhases, maxValue: len(core.Phases) - 1}, - {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttempts}, - {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxAttempts}, + {arrayReference: &arrayNodeState.SubNodeRetryAttempts, maxValue: maxAttemptsValue}, + {arrayReference: &arrayNodeState.SubNodeSystemFailures, maxValue: maxSystemFailuresValue}, } { *item.arrayReference, err = bitarray.NewCompactArray(uint(size), bitarray.Item(item.maxValue)) diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index b2e85c0979..fbb5ae875c 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -654,6 +654,126 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { } } +func TestHandleArrayNodePhaseExecutingSubNodeFailures(t *testing.T) { + ctx := context.Background() + + inputValues := map[string][]int64{ + "foo": []int64{1}, + "bar": []int64{2}, + } + literalMap := convertMapToArrayLiterals(inputValues) + + tests := []struct { + name string + defaultMaxAttempts int32 + maxSystemFailures int64 + ignoreRetryCause bool + transition handler.Transition + expectedAttempts int + }{ + { + name: "UserFailure", + defaultMaxAttempts: 3, + maxSystemFailures: 10, + ignoreRetryCause: false, + transition: handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoRetryableFailure(idlcore.ExecutionError_USER, "", "", &handler.ExecutionInfo{})), + expectedAttempts: 3, + }, + { + name: "SystemFailure", + defaultMaxAttempts: 3, + maxSystemFailures: 10, + ignoreRetryCause: false, + transition: handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoRetryableFailure(idlcore.ExecutionError_SYSTEM, "", "", &handler.ExecutionInfo{})), + expectedAttempts: 11, + }, + { + name: "UserFailureIgnoreRetryCause", + defaultMaxAttempts: 3, + maxSystemFailures: 10, + ignoreRetryCause: true, + transition: handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoRetryableFailure(idlcore.ExecutionError_USER, "", "", &handler.ExecutionInfo{})), + expectedAttempts: 3, + }, + { + name: "SystemFailureIgnoreRetryCause", + defaultMaxAttempts: 3, + maxSystemFailures: 10, + ignoreRetryCause: true, + transition: handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoRetryableFailure(idlcore.ExecutionError_SYSTEM, "", "", &handler.ExecutionInfo{})), + expectedAttempts: 3, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + config.GetConfig().NodeConfig.DefaultMaxAttempts = test.defaultMaxAttempts + config.GetConfig().NodeConfig.MaxNodeRetriesOnSystemFailures = test.maxSystemFailures + config.GetConfig().NodeConfig.IgnoreRetryCause = test.ignoreRetryCause + + // create NodeExecutionContext + scope := promutils.NewTestScope() + dataStore, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, scope) + assert.NoError(t, err) + eventRecorder := newBufferedEventRecorder() + arrayNodeState := &handler.ArrayNodeState{ + Phase: v1alpha1.ArrayNodePhaseNone, + } + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + + // initialize ArrayNodeHandler + nodeHandler := &mocks.NodeHandler{} + nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil) + nodeHandler.OnFinalizeRequired().Return(false) + nodeHandler.OnHandleMatch(mock.Anything, mock.Anything).Return(test.transition, nil) + + arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope) + assert.NoError(t, err) + + // evaluate node to transition to Executing + _, err = arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, v1alpha1.ArrayNodePhaseExecuting, arrayNodeState.Phase) + + for i := 0; i < len(arrayNodeState.SubNodePhases.GetItems()); i++ { + arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(v1alpha1.NodePhaseRunning)) + } + + for i := 0; i < len(arrayNodeState.SubNodeTaskPhases.GetItems()); i++ { + arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(core.PhaseRunning)) + } + + // evaluate node until failure + attempts := 1 + for { + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState) + _, err = arrayNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + + if arrayNodeState.Phase == v1alpha1.ArrayNodePhaseFailing { + break + } + + // failing a task requires two calls to Handle, the first to return a + // RetryableFailure and the second to abort. therefore, we only increment the + // number of attempts once in this loop. + if arrayNodeState.SubNodePhases.GetItem(0) == bitarray.Item(v1alpha1.NodePhaseRetryableFailure) { + attempts++ + } + } + + assert.Equal(t, test.expectedAttempts, attempts) + }) + } +} + func TestHandleArrayNodePhaseSucceeding(t *testing.T) { ctx := context.Background() scope := promutils.NewTestScope()