diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 3f8c2a0914..7847b161fc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -36,6 +36,8 @@ const ( PhasePermanentFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache + // Indicate the task has been aborted + PhaseAborted ) var Phases = []Phase{ @@ -49,11 +51,12 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, + PhaseAborted, } // Returns true if the given phase is failure, retryable failure or success func (p Phase) IsTerminal() bool { - return p.IsFailure() || p.IsSuccess() + return p.IsFailure() || p.IsSuccess() || p.IsAborted() } func (p Phase) IsFailure() bool { @@ -64,6 +67,10 @@ func (p Phase) IsSuccess() bool { return p == PhaseSuccess } +func (p Phase) IsAborted() bool { + return p == PhaseAborted +} + func (p Phase) IsWaitingForResources() bool { return p == PhaseWaitingForResources } @@ -257,10 +264,18 @@ func PhaseInfoSystemFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_SYSTEM}, info) } +func PhaseInfoSystemFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo { + return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true) +} + func PhaseInfoFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info) } +func PhaseInfoFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo { + return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true) +} + func PhaseInfoRetryableFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhaseRetryableFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info) } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go index 99688fe36d..6c7a858be8 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go @@ -78,6 +78,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c case arrayCore.PhaseAssembleFinalOutput: pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState.State) + + case arrayCore.PhaseAbortSubTasks: + fallthrough case arrayCore.PhaseWriteToDiscoveryThenFail: pluginState.State, externalResources, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version+1) diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go index d40edd41de..9de13c6413 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -248,7 +248,7 @@ func TestCheckSubTasksState(t *testing.T) { newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ - CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 2, OriginalArraySize: 2, OriginalMinSuccesses: 2, @@ -264,6 +264,6 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p) + assert.Equal(t, arrayCore.PhaseAbortSubTasks, p) }) } diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index f601ec375b..a540359b0a 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -35,6 +35,7 @@ const ( PhaseAssembleFinalError PhaseRetryableFailure PhasePermanentFailure + PhaseAbortSubTasks ) type State struct { @@ -204,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhaseAssembleFinalError: fallthrough + case PhaseAbortSubTasks: + fallthrough + case PhaseWriteToDiscoveryThenFail: fallthrough @@ -259,14 +263,14 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus if totalCount < minSuccesses { logger.Infof(ctx, "Array failed because totalCount[%v] < minSuccesses[%v]", totalCount, minSuccesses) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } // No chance to reach the required success numbers. if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses { logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]", minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } if totalWaitingForResources > 0 { diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index 01b5b41528..969c98df20 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -291,7 +291,7 @@ func TestSummaryToPhase(t *testing.T) { }{ { "FailOnTooFewTasks", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{}, }, { @@ -304,7 +304,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailOnTooManyPermanentFailures", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhasePermanentFailure: 1, core.PhaseSuccess: 9, @@ -335,7 +335,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailedToRetry", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhaseSuccess: 5, core.PhasePermanentFailure: 5, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 3e324391b8..f664392fd9 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -117,6 +117,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) + case arrayCore.PhaseAbortSubTasks: + nextState, externalResources, err = TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + case arrayCore.PhaseAssembleFinalOutput: nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState) @@ -156,7 +159,7 @@ func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) err return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) } func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { @@ -165,7 +168,8 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + _, _, err := TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + return err } func (e Executor) Start(ctx context.Context) error { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index fdf9d4a182..510f202e1a 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -324,16 +324,37 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return newState, externalResources, nil } +func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + + _, externalResources, err := TerminateSubTasks(ctx, tCtx, kubeClient, GetConfig(), terminateFunction, currentState) + if err != nil { + return err + } + + taskInfo := &core.TaskInfo{ + ExternalResources: externalResources, + } + executionErr := &idlCore.ExecutionError{ + Code: "ArraySubtasksAborted", + Message: "Array subtasks were aborted", + } + phaseInfo := core.PhaseInfoFailed(core.PhaseAborted, executionErr, taskInfo) + + return tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) +} + // TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include // aborting and finalizing active k8s resources. func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, - terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) (*arrayCore.State, []*core.ExternalResource, error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) + externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) if err != nil { - return err + return currentState, externalResources, err } else if taskTemplate == nil { - return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + return currentState, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") } messageCollector := errorcollector.NewErrorMessageCollector() @@ -353,18 +374,25 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { - return err + return currentState, externalResources, err } err = terminateFunction(ctx, stCtx, config, kubeClient) if err != nil { messageCollector.Collect(childIdx, err.Error()) + } else { + externalResources = append(externalResources, &core.ExternalResource{ + ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), + Index: uint32(originalIdx), + RetryAttempt: uint32(retryAttempt), + Phase: core.PhaseAborted, + }) } } if messageCollector.Length() > 0 { - return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) + return currentState, externalResources, fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) } - return nil + return currentState.SetPhase(arrayCore.PhaseWriteToDiscoveryThenFail, currentState.PhaseVersion+1), externalResources, nil } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 5bead4954d..ab26028b09 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -531,10 +531,160 @@ func TestCheckSubTasksState(t *testing.T) { // validate results assert.Nil(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail.String(), p.String()) + assert.Equal(t, arrayCore.PhaseAbortSubTasks.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "ReleaseResource", subtaskCount) for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { assert.Equal(t, core.PhasePermanentFailure, core.Phases[subtaskPhaseIndex]) } }) } + +func TestTerminateSubTasksOnAbort(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i := 0; i < subtaskCount; i++ { + compactArray.SetItem(i, 5) + } + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), + } + + t.Run("SuccessfulTermination", func(t *testing.T) { + eventRecorder := mocks.EventsRecorder{} + eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnEventsRecorder().Return(&eventRecorder) + + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + return nil + } + + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + + assert.Nil(t, err) + eventRecorder.AssertCalled(t, "RecordRaw", mock.Anything, mock.Anything) + }) + + t.Run("TerminationWithError", func(t *testing.T) { + eventRecorder := mocks.EventsRecorder{} + eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnEventsRecorder().Return(&eventRecorder) + + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + return fmt.Errorf("termination error") + } + + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + + assert.NotNil(t, err) + eventRecorder.AssertNotCalled(t, "RecordRaw", mock.Anything, mock.Anything) + }) +} + +func TestTerminateSubTasks(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + tests := []struct { + name string + initialPhaseIndices []int + expectedAbortCount int + terminateError error + }{ + { + name: "AllSubTasksRunning", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: nil, + }, + { + name: "MixedSubTaskStates", + initialPhaseIndices: []int{8, 0, 5}, + expectedAbortCount: 1, + terminateError: nil, + }, + { + name: "TerminateFunctionFails", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: fmt.Errorf("error"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i, phaseIdx := range test.initialPhaseIndices { + compactArray.SetItem(i, bitarray.Item(phaseIdx)) + } + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + PhaseVersion: 0, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), + } + + tCtx := getMockTaskExecutionContext(ctx, 0) + terminateCounter := 0 + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + terminateCounter++ + return test.terminateError + } + + nextState, externalResources, err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + + assert.Equal(t, test.expectedAbortCount, terminateCounter) + + if test.terminateError != nil { + assert.NotNil(t, err) + return + } + + assert.Nil(t, err) + assert.Equal(t, uint32(1), nextState.PhaseVersion) + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, nextState.CurrentPhase) + assert.Len(t, externalResources, terminateCounter) + + for _, externalResource := range externalResources { + phase := core.Phases[externalResource.Phase] + assert.True(t, phase.IsAborted()) + } + }) + } +} diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 72dba11a50..5e4139296f 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -753,6 +753,7 @@ func (t *Handler) ValidateOutput(ctx context.Context, nodeID v1alpha1.NodeID, i func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { taskNodeState := nCtx.NodeStateReader().GetTaskNodeState() currentPhase := taskNodeState.PluginPhase + currentPhaseVersion := taskNodeState.PluginPhaseVersion logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase) if currentPhase.IsTerminal() && !(currentPhase.IsFailure() && taskNodeState.CleanupOnFailure) { @@ -790,8 +791,40 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext logger.Errorf(ctx, "Abort failed when calling plugin abort.") return err } - taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() + evRecorder := nCtx.EventsRecorder() + logger.Debugf(ctx, "Sending buffered Task events.") + for _, ev := range tCtx.ber.GetAll(ctx) { + evInfo, err := ToTaskExecutionEvent(ToTaskExecutionEventInputs{ + TaskExecContext: tCtx, + InputReader: nCtx.InputReader(), + EventConfig: t.eventConfig, + OutputWriter: tCtx.ow, + Info: ev.WithVersion(currentPhaseVersion + 1), + NodeExecutionMetadata: nCtx.NodeExecutionMetadata(), + ExecContext: nCtx.ExecutionContext(), + TaskType: ttype, + PluginID: p.GetID(), + ResourcePoolInfo: tCtx.rm.GetResourcePoolInfo(), + ClusterID: t.clusterID, + }) + if err != nil { + return err + } + if currentPhase.IsFailure() { + evInfo.Phase = core.TaskExecution_FAILED + } else { + evInfo.Phase = core.TaskExecution_ABORTED + } + if err := evRecorder.RecordTaskEvent(ctx, evInfo, t.eventConfig); err != nil { + logger.Errorf(ctx, "Event recording failed for Plugin [%s], eventPhase [%s], error :%s", p.GetID(), evInfo.Phase.String(), err.Error()) + // Check for idempotency + // Check for terminate state error + return err + } + } + + taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() nodeExecutionID, err := getParentNodeExecIDForTask(&taskExecID, nCtx.ExecutionContext()) if err != nil { return err diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index e026419dc0..242c1334ce 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -44,6 +44,8 @@ func ToTaskEventPhase(p pluginCore.Phase) core.TaskExecution_Phase { return core.TaskExecution_FAILED case pluginCore.PhaseRetryableFailure: return core.TaskExecution_FAILED + case pluginCore.PhaseAborted: + return core.TaskExecution_ABORTED case pluginCore.PhaseNotReady: fallthrough case pluginCore.PhaseUndefined: @@ -117,13 +119,14 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio metadata.ExternalResources = make([]*event.ExternalResourceInfo, len(externalResources)) for idx, e := range input.Info.Info().ExternalResources { + phase := ToTaskEventPhase(e.Phase) metadata.ExternalResources[idx] = &event.ExternalResourceInfo{ ExternalId: e.ExternalID, CacheStatus: e.CacheStatus, Index: e.Index, Logs: e.Logs, RetryAttempt: e.RetryAttempt, - Phase: ToTaskEventPhase(e.Phase), + Phase: phase, } } }