From 6cae3f82a2177c390dbdd20e4cd2936bc65064ab Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Sat, 9 Dec 2023 18:21:13 -0800 Subject: [PATCH] add new array plugin abort phase Signed-off-by: Paul Dittamo --- .../manager/impl/task_execution_manager.go | 5 +- .../impl/task_execution_manager_test.go | 53 ++----- .../go/tasks/plugins/array/core/state.go | 8 +- .../go/tasks/plugins/array/core/state_test.go | 6 +- .../go/tasks/plugins/array/k8s/executor.go | 7 +- .../go/tasks/plugins/array/k8s/management.go | 40 +++-- .../plugins/array/k8s/management_test.go | 141 ++++++++++++------ 7 files changed, 148 insertions(+), 112 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index 13e2e71c6c4..82b872dbd42 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -174,9 +174,8 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req } currentPhase := core.TaskExecution_Phase(core.TaskExecution_Phase_value[taskExecutionModel.Phase]) - if common.IsTaskExecutionTerminal(currentPhase) && - (taskExecutionModel.Phase != request.Event.Phase.String() || taskExecutionModel.PhaseVersion >= request.Event.PhaseVersion) { - // Only update terminate execution if phase matches and it's a newer version + if common.IsTaskExecutionTerminal(currentPhase) { + // Cannot update a terminal execution. curPhase := request.Event.Phase.String() errorMsg := fmt.Sprintf("invalid phase change from %v to %v for task execution %v", taskExecutionModel.Phase, request.Event.Phase, taskExecutionID) logger.Warnf(ctx, errorMsg) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 9a6dac838a3..8fd80196476 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -425,51 +425,18 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { Phase: core.TaskExecution_SUCCEEDED.String(), }, nil }) + taskEventRequest.Event.Phase = core.TaskExecution_RUNNING + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - t.Run("CreateExecutionEvent_NonTerminalPhase", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_RUNNING - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) - }) - - t.Run("CreateExecutionEvent_DifferentTerminalPhase", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_FAILED - taskEventRequest.Event.PhaseVersion = uint32(0) - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) - }) - - t.Run("CreateExecutionEvent_SameTerminalPhase_OldVersion", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED - taskEventRequest.Event.PhaseVersion = uint32(0) - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.AlreadyExists) - }) + assert.Nil(t, resp) - t.Run("CreateExecutionEvent_SameTerminalPhase_NewVersion", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED - taskEventRequest.Event.PhaseVersion = uint32(1) - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher) - _, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, err) - }) + adminError := err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) + details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) + assert.True(t, ok) } func TestCreateTaskEvent_PhaseVersionChange(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index c03f4a13840..4dac0080060 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -35,6 +35,7 @@ const ( PhaseAssembleFinalError PhaseRetryableFailure PhasePermanentFailure + PhaseAbortSubTasks ) type State struct { @@ -204,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhaseAssembleFinalError: fallthrough + case PhaseAbortSubTasks: + fallthrough + case PhaseWriteToDiscoveryThenFail: fallthrough @@ -259,14 +263,14 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus if totalCount < minSuccesses { logger.Infof(ctx, "Array failed because totalCount[%v] < minSuccesses[%v]", totalCount, minSuccesses) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } // No chance to reach the required success numbers. if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses { logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]", minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } if totalWaitingForResources > 0 { diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index 01b5b415286..969c98df20e 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -291,7 +291,7 @@ func TestSummaryToPhase(t *testing.T) { }{ { "FailOnTooFewTasks", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{}, }, { @@ -304,7 +304,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailOnTooManyPermanentFailures", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhasePermanentFailure: 1, core.PhaseSuccess: 9, @@ -335,7 +335,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailedToRetry", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhaseSuccess: 5, core.PhasePermanentFailure: 5, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 3e324391b85..3407c459055 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -117,6 +117,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) + case arrayCore.PhaseAbortSubTasks: + nextState, externalResources, err = TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + case arrayCore.PhaseAssembleFinalOutput: nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState) @@ -156,7 +159,7 @@ func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) err return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) } func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { @@ -165,7 +168,7 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) } func (e Executor) Start(ctx context.Context) error { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index 2c805a1f05a..3d5eb2ef605 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -324,20 +324,37 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return newState, externalResources, nil } +func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + + _, externalResources, err := TerminateSubTasks(ctx, tCtx, kubeClient, GetConfig(), terminateFunction, currentState) + if err != nil { + return err + } + + taskInfo := &core.TaskInfo{ + ExternalResources: externalResources, + } + phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) + err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) + + return err +} + // TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include // aborting and finalizing active k8s resources. func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, - terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) (*arrayCore.State, []*core.ExternalResource, error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) + externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) if err != nil { - return err + return currentState, externalResources, err } else if taskTemplate == nil { - return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + return currentState, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") } messageCollector := errorcollector.NewErrorMessageCollector() - externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] retryAttempt := uint64(0) @@ -348,7 +365,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { - return err + return currentState, externalResources, err } isAbortedSubtask := false @@ -371,18 +388,9 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube }) } - taskInfo := &core.TaskInfo{ - ExternalResources: externalResources, - } - phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) - err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) - if err != nil { - return err - } - if messageCollector.Length() > 0 { - return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) + return currentState, externalResources, fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) } - return nil + return currentState.SetPhase(arrayCore.PhaseWriteToDiscoveryThenFail, currentState.PhaseVersion+1), externalResources, nil } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 327d87431a5..8d2696d02be 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -531,7 +531,7 @@ func TestCheckSubTasksState(t *testing.T) { // validate results assert.Nil(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail.String(), p.String()) + assert.Equal(t, arrayCore.PhaseAbortSubTasks.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "ReleaseResource", subtaskCount) for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { assert.Equal(t, core.PhasePermanentFailure, core.Phases[subtaskPhaseIndex]) @@ -539,7 +539,7 @@ func TestCheckSubTasksState(t *testing.T) { }) } -func TestTerminateSubtasks(t *testing.T) { +func TestTerminateSubTasksOnAbort(t *testing.T) { ctx := context.Background() subtaskCount := 3 config := Config{ @@ -554,9 +554,9 @@ func TestTerminateSubtasks(t *testing.T) { kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) - compactArray.SetItem(0, 8) // PhasePermanentFailure - compactArray.SetItem(1, 0) // PhaseUndefined - compactArray.SetItem(2, 5) // PhaseRunning + for i := 0; i < subtaskCount; i++ { + compactArray.SetItem(i, 5) + } currentState := &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -566,68 +566,123 @@ func TestTerminateSubtasks(t *testing.T) { ArrayStatus: arraystatus.ArrayStatus{ Detailed: compactArray, }, - IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), } - t.Run("TerminateSubtasks", func(t *testing.T) { - resourceManager := mocks.ResourceManager{} - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + t.Run("SuccessfulTermination", func(t *testing.T) { eventRecorder := mocks.EventsRecorder{} eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) tCtx := getMockTaskExecutionContext(ctx, 0) - tCtx.OnResourceManager().Return(&resourceManager) tCtx.OnEventsRecorder().Return(&eventRecorder) - terminateCounter := 0 mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { - terminateCounter++ return nil } - err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) - assert.Equal(t, 1, terminateCounter) - assert.Nil(t, err) - - args := eventRecorder.Calls[0].Arguments - phaseInfo, ok := args.Get(1).(core.PhaseInfo) - assert.True(t, ok) + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) - externalResources := phaseInfo.Info().ExternalResources - assert.Len(t, externalResources, subtaskCount) - - assert.False(t, externalResources[0].IsAbortedSubtask) - assert.False(t, externalResources[1].IsAbortedSubtask) - assert.True(t, externalResources[2].IsAbortedSubtask) + assert.Nil(t, err) + eventRecorder.AssertCalled(t, "RecordRaw", mock.Anything, mock.Anything) }) - t.Run("TerminateSubtasksWithFailure", func(t *testing.T) { - resourceManager := mocks.ResourceManager{} - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + t.Run("TerminationWithError", func(t *testing.T) { eventRecorder := mocks.EventsRecorder{} eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) tCtx := getMockTaskExecutionContext(ctx, 0) - tCtx.OnResourceManager().Return(&resourceManager) tCtx.OnEventsRecorder().Return(&eventRecorder) - terminateCounter := 0 mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { - terminateCounter++ - return fmt.Errorf("error") + return fmt.Errorf("termination error") } - err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + assert.NotNil(t, err) - assert.Equal(t, 1, terminateCounter) + eventRecorder.AssertNotCalled(t, "RecordRaw", mock.Anything, mock.Anything) + }) +} + +func TestTerminateSubTasks(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + tests := []struct { + name string + initialPhaseIndices []int + expectedAbortCount int + terminateError error + }{ + { + name: "AllSubTasksRunning", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: nil, + }, + { + name: "MixedSubTaskStates", + initialPhaseIndices: []int{8, 0, 5}, + expectedAbortCount: 1, + terminateError: nil, + }, + { + name: "TerminateFunctionFails", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: fmt.Errorf("error"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i, phaseIdx := range test.initialPhaseIndices { + compactArray.SetItem(i, bitarray.Item(phaseIdx)) + } + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + PhaseVersion: 0, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), + } - args := eventRecorder.Calls[0].Arguments - phaseInfo, ok := args.Get(1).(core.PhaseInfo) - assert.True(t, ok) + tCtx := getMockTaskExecutionContext(ctx, 0) + terminateCounter := 0 + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + terminateCounter++ + return test.terminateError + } - externalResources := phaseInfo.Info().ExternalResources - assert.Len(t, externalResources, subtaskCount) + nextState, externalResources, err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) - assert.False(t, externalResources[0].IsAbortedSubtask) - assert.False(t, externalResources[1].IsAbortedSubtask) - assert.False(t, externalResources[2].IsAbortedSubtask) - }) + assert.Equal(t, test.expectedAbortCount, terminateCounter) + + if test.terminateError != nil { + assert.NotNil(t, err) + return + } + + assert.Nil(t, err) + assert.Equal(t, uint32(1), nextState.PhaseVersion) + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, nextState.CurrentPhase) + assert.Len(t, externalResources, subtaskCount) + for _, externalResource := range externalResources { + assert.True(t, externalResource.IsAbortedSubtask || core.Phases[externalResource.Phase].IsTerminal() || core.Phases[externalResource.Phase] == core.PhaseUndefined) + } + }) + } }