From ab27020fe19fc6d1b6214aa7697ca0ba78063f35 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Wed, 29 Nov 2023 14:08:42 -0800 Subject: [PATCH 01/16] add cleanup on failure for array task permanent failures Signed-off-by: Paul Dittamo --- flyteplugins/go/tasks/pluginmachinery/core/phase.go | 8 ++++++++ flyteplugins/go/tasks/plugins/array/core/state.go | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 3f8c2a0914..f87be03154 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -257,10 +257,18 @@ func PhaseInfoSystemFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_SYSTEM}, info) } +func PhaseInfoSystemFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo { + return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true) +} + func PhaseInfoFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info) } +func PhaseInfoFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo { + return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true) +} + func PhaseInfoRetryableFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhaseRetryableFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info) } diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index f601ec375b..858e1f0cef 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -222,9 +222,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhasePermanentFailure: if state.GetExecutionErr() != nil { - phaseInfo = core.PhaseInfoFailed(core.PhasePermanentFailure, state.GetExecutionErr(), nowTaskInfo) + phaseInfo = core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), state.GetReason(), nowTaskInfo) } else { - phaseInfo = core.PhaseInfoSystemFailure(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) + phaseInfo = core.PhaseInfoSystemFailureWithCleanup(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) } default: return phaseInfo, fmt.Errorf("failed to map custom state phase to core phase. State Phase [%v]", p) From c3114b46cc130a53ffb07bc42198fc584fe84a24 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Thu, 30 Nov 2023 10:59:46 -0800 Subject: [PATCH 02/16] persist subtask aborts to admin Signed-off-by: Paul Dittamo --- .../manager/impl/task_execution_manager.go | 5 ++- .../go/tasks/pluginmachinery/core/phase.go | 3 ++ .../go/tasks/plugins/array/k8s/management.go | 32 ++++++++++++--- .../pkg/controller/nodes/task/handler.go | 39 +++++++++++++++++-- .../pkg/controller/nodes/task/transformer.go | 2 + 5 files changed, 70 insertions(+), 11 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index 82b872dbd4..13e2e71c6c 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -174,8 +174,9 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req } currentPhase := core.TaskExecution_Phase(core.TaskExecution_Phase_value[taskExecutionModel.Phase]) - if common.IsTaskExecutionTerminal(currentPhase) { - // Cannot update a terminal execution. + if common.IsTaskExecutionTerminal(currentPhase) && + (taskExecutionModel.Phase != request.Event.Phase.String() || taskExecutionModel.PhaseVersion >= request.Event.PhaseVersion) { + // Only update terminate execution if phase matches and it's a newer version curPhase := request.Event.Phase.String() errorMsg := fmt.Sprintf("invalid phase change from %v to %v for task execution %v", taskExecutionModel.Phase, request.Event.Phase, taskExecutionID) logger.Warnf(ctx, errorMsg) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index f87be03154..cf48cd2af6 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -36,6 +36,8 @@ const ( PhasePermanentFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache + // Inidicates subtasks are needing to be aborted + PhaseSubTasksAborted ) var Phases = []Phase{ @@ -49,6 +51,7 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, + PhaseSubTasksAborted, } // Returns true if the given phase is failure, retryable failure or success diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index fdf9d4a182..0ae10925c5 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -337,6 +337,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube } messageCollector := errorcollector.NewErrorMessageCollector() + externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] retryAttempt := uint64(0) @@ -344,23 +345,42 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube // we can use RetryAttempts if it has been initialized, otherwise stay with default 0 retryAttempt = currentState.RetryAttempts.GetItem(childIdx) } - - // return immediately if subtask has completed or not yet started - if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined { - continue - } - originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { return err } + // return immediately if subtask has completed or not yet started + if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined { + // still write subtask to buffer to persist to admin + externalResources = append(externalResources, &core.ExternalResource{ + ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), + Index: uint32(originalIdx), + RetryAttempt: uint32(retryAttempt), + Phase: existingPhase, + }) + continue + } + err = terminateFunction(ctx, stCtx, config, kubeClient) if err != nil { messageCollector.Collect(childIdx, err.Error()) } + + externalResources = append(externalResources, &core.ExternalResource{ + ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), + Index: uint32(originalIdx), + RetryAttempt: uint32(retryAttempt), + Phase: core.PhaseSubTasksAborted, + }) + } + + taskInfo := &core.TaskInfo{ + ExternalResources: externalResources, } + phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhaseSubTasksAborted.String(), "Array subtasks were aborted", taskInfo) + err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) if messageCollector.Length() > 0 { return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 64bf012b80..538b728fab 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -3,10 +3,11 @@ package task import ( "context" "fmt" + eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" + "github.com/golang/protobuf/ptypes" "runtime/debug" "time" - "github.com/golang/protobuf/ptypes" regErrors "github.com/pkg/errors" "k8s.io/client-go/kubernetes" @@ -19,7 +20,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" pluginK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" controllerConfig "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" @@ -749,6 +749,7 @@ func (t *Handler) ValidateOutput(ctx context.Context, nodeID v1alpha1.NodeID, i func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { taskNodeState := nCtx.NodeStateReader().GetTaskNodeState() currentPhase := taskNodeState.PluginPhase + currentPhaseVersion := taskNodeState.PluginPhaseVersion logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase) if currentPhase.IsTerminal() && !(currentPhase.IsFailure() && taskNodeState.CleanupOnFailure) { @@ -786,12 +787,44 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext logger.Errorf(ctx, "Abort failed when calling plugin abort.") return err } - taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() evRecorder := nCtx.EventsRecorder() + logger.Debugf(ctx, "Sending buffered Task events.") + for _, ev := range tCtx.ber.GetAll(ctx) { + evInfo, err := ToTaskExecutionEvent(ToTaskExecutionEventInputs{ + TaskExecContext: tCtx, + InputReader: nCtx.InputReader(), + EventConfig: t.eventConfig, + OutputWriter: tCtx.ow, + Info: ev.WithVersion(currentPhaseVersion + 1), + NodeExecutionMetadata: nCtx.NodeExecutionMetadata(), + ExecContext: nCtx.ExecutionContext(), + TaskType: ttype, + PluginID: p.GetID(), + ResourcePoolInfo: tCtx.rm.GetResourcePoolInfo(), + ClusterID: t.clusterID, + }) + if err != nil { + return err + } + if currentPhase.IsFailure() { + evInfo.Phase = core.TaskExecution_FAILED + } else { + evInfo.Phase = core.TaskExecution_ABORTED + } + if err := evRecorder.RecordTaskEvent(ctx, evInfo, t.eventConfig); err != nil { + logger.Errorf(ctx, "Event recording failed for Plugin [%s], eventPhase [%s], error :%s", p.GetID(), evInfo.Phase.String(), err.Error()) + // Check for idempotency + // Check for terminate state error + return err + } + } + + taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() nodeExecutionID, err := getParentNodeExecIDForTask(&taskExecID, nCtx.ExecutionContext()) if err != nil { return err } + // TODO handle this call failing if phase is set to Failure - probably doesn't matter if err := evRecorder.RecordTaskEvent(ctx, &event.TaskExecutionEvent{ TaskId: taskExecID.TaskId, ParentNodeExecutionId: nodeExecutionID, diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index e026419dc0..daa9bb5d9c 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -44,6 +44,8 @@ func ToTaskEventPhase(p pluginCore.Phase) core.TaskExecution_Phase { return core.TaskExecution_FAILED case pluginCore.PhaseRetryableFailure: return core.TaskExecution_FAILED + case pluginCore.PhaseSubTasksAborted: + return core.TaskExecution_ABORTED case pluginCore.PhaseNotReady: fallthrough case pluginCore.PhaseUndefined: From 7c0449e74f67a85daab640bfa111753de9182f3f Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Thu, 30 Nov 2023 13:09:56 -0800 Subject: [PATCH 03/16] add isAbortedSubtask field to externalResources Signed-off-by: Paul Dittamo --- flyteplugins/go/tasks/pluginmachinery/core/phase.go | 5 ++--- flyteplugins/go/tasks/plugins/array/k8s/management.go | 11 ++++++----- .../pkg/controller/nodes/task/transformer.go | 8 +++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index cf48cd2af6..7e9ad08c8e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -36,8 +36,6 @@ const ( PhasePermanentFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache - // Inidicates subtasks are needing to be aborted - PhaseSubTasksAborted ) var Phases = []Phase{ @@ -51,7 +49,6 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, - PhaseSubTasksAborted, } // Returns true if the given phase is failure, retryable failure or success @@ -85,6 +82,8 @@ type ExternalResource struct { RetryAttempt uint32 // Phase (if exists) associated with the external resource Phase Phase + // Indicates if external resource is a subtask getting aborted + IsAbortedSubtask bool } type ReasonInfo struct { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index 0ae10925c5..2ea5fe45d7 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -369,17 +369,18 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube } externalResources = append(externalResources, &core.ExternalResource{ - ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), - Index: uint32(originalIdx), - RetryAttempt: uint32(retryAttempt), - Phase: core.PhaseSubTasksAborted, + ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), + Index: uint32(originalIdx), + RetryAttempt: uint32(retryAttempt), + Phase: existingPhase, + IsAbortedSubtask: true, }) } taskInfo := &core.TaskInfo{ ExternalResources: externalResources, } - phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhaseSubTasksAborted.String(), "Array subtasks were aborted", taskInfo) + phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) if messageCollector.Length() > 0 { diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index daa9bb5d9c..135c3aa004 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -44,8 +44,6 @@ func ToTaskEventPhase(p pluginCore.Phase) core.TaskExecution_Phase { return core.TaskExecution_FAILED case pluginCore.PhaseRetryableFailure: return core.TaskExecution_FAILED - case pluginCore.PhaseSubTasksAborted: - return core.TaskExecution_ABORTED case pluginCore.PhaseNotReady: fallthrough case pluginCore.PhaseUndefined: @@ -119,13 +117,17 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio metadata.ExternalResources = make([]*event.ExternalResourceInfo, len(externalResources)) for idx, e := range input.Info.Info().ExternalResources { + phase := ToTaskEventPhase(e.Phase) + if e.IsAbortedSubtask { + phase = core.TaskExecution_ABORTED + } metadata.ExternalResources[idx] = &event.ExternalResourceInfo{ ExternalId: e.ExternalID, CacheStatus: e.CacheStatus, Index: e.Index, Logs: e.Logs, RetryAttempt: e.RetryAttempt, - Phase: ToTaskEventPhase(e.Phase), + Phase: phase, } } } From 24b7a596c937cfe5caee4e17495ae0130904c9fe Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Thu, 30 Nov 2023 13:45:52 -0800 Subject: [PATCH 04/16] lint Signed-off-by: Paul Dittamo --- flyteplugins/go/tasks/plugins/array/k8s/management.go | 3 +++ flytepropeller/pkg/controller/nodes/task/handler.go | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index 2ea5fe45d7..d1ef18b2cf 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -382,6 +382,9 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube } phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) + if err != nil { + return err + } if messageCollector.Length() > 0 { return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 538b728fab..38e7ad03bb 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -3,11 +3,10 @@ package task import ( "context" "fmt" - eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" - "github.com/golang/protobuf/ptypes" "runtime/debug" "time" + "github.com/golang/protobuf/ptypes" regErrors "github.com/pkg/errors" "k8s.io/client-go/kubernetes" @@ -20,6 +19,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" pluginK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" controllerConfig "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" From dd6dd5b4c846a4102cfd79d6adbf3c7a5d05b881 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Sun, 3 Dec 2023 21:36:32 -0800 Subject: [PATCH 05/16] add tests for terminal updates Signed-off-by: Paul Dittamo --- .../impl/task_execution_manager_test.go | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 8fd8019647..28ea1a4ae7 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -425,18 +425,49 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { Phase: core.TaskExecution_SUCCEEDED.String(), }, nil }) + + // request w/ non terminal phase taskEventRequest.Event.Phase = core.TaskExecution_RUNNING taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) assert.Equal(t, adminError.Code(), codes.FailedPrecondition) details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) assert.True(t, ok) _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) assert.True(t, ok) + + // request w/ different terminal phase + taskEventRequest.Event.Phase = core.TaskExecution_FAILED + taskEventRequest.Event.PhaseVersion = uint32(0) + taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + assert.Nil(t, resp) + adminError = err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) + details, ok = adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) + assert.True(t, ok) + + // request w/ same terminal phase, not a later version + taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED + taskEventRequest.Event.PhaseVersion = uint32(0) + taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + assert.Nil(t, resp) + adminError = err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) + details, ok = adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) + assert.True(t, ok) + + // request w/ same terminal phase, later version + taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED + taskEventRequest.Event.PhaseVersion = uint32(1) + taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher) + _, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) + assert.Nil(t, err) } func TestCreateTaskEvent_PhaseVersionChange(t *testing.T) { From bff847893e84f79412664607849f30c519eb9f60 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Mon, 4 Dec 2023 14:36:13 -0800 Subject: [PATCH 06/16] add unit tests for terminatesubtasks Signed-off-by: Paul Dittamo --- .../go/tasks/plugins/array/k8s/management.go | 26 ++---- .../plugins/array/k8s/management_test.go | 93 +++++++++++++++++++ .../pkg/controller/nodes/task/handler.go | 1 + 3 files changed, 104 insertions(+), 16 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index d1ef18b2cf..2c805a1f05 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -351,21 +351,15 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube return err } - // return immediately if subtask has completed or not yet started - if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined { - // still write subtask to buffer to persist to admin - externalResources = append(externalResources, &core.ExternalResource{ - ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), - Index: uint32(originalIdx), - RetryAttempt: uint32(retryAttempt), - Phase: existingPhase, - }) - continue - } - - err = terminateFunction(ctx, stCtx, config, kubeClient) - if err != nil { - messageCollector.Collect(childIdx, err.Error()) + isAbortedSubtask := false + if !existingPhase.IsTerminal() && existingPhase != core.PhaseUndefined { + // only terminate subtask if it has completed or has not yet started + err = terminateFunction(ctx, stCtx, config, kubeClient) + if err != nil { + messageCollector.Collect(childIdx, err.Error()) + } else { + isAbortedSubtask = true + } } externalResources = append(externalResources, &core.ExternalResource{ @@ -373,7 +367,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube Index: uint32(originalIdx), RetryAttempt: uint32(retryAttempt), Phase: existingPhase, - IsAbortedSubtask: true, + IsAbortedSubtask: isAbortedSubtask, }) } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 5bead4954d..327d87431a 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -538,3 +538,96 @@ func TestCheckSubTasksState(t *testing.T) { } }) } + +func TestTerminateSubtasks(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + compactArray.SetItem(0, 8) // PhasePermanentFailure + compactArray.SetItem(1, 0) // PhaseUndefined + compactArray.SetItem(2, 5) // PhaseRunning + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + } + + t.Run("TerminateSubtasks", func(t *testing.T) { + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + eventRecorder := mocks.EventsRecorder{} + eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + tCtx.OnEventsRecorder().Return(&eventRecorder) + + terminateCounter := 0 + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + terminateCounter++ + return nil + } + + err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + assert.Equal(t, 1, terminateCounter) + assert.Nil(t, err) + + args := eventRecorder.Calls[0].Arguments + phaseInfo, ok := args.Get(1).(core.PhaseInfo) + assert.True(t, ok) + + externalResources := phaseInfo.Info().ExternalResources + assert.Len(t, externalResources, subtaskCount) + + assert.False(t, externalResources[0].IsAbortedSubtask) + assert.False(t, externalResources[1].IsAbortedSubtask) + assert.True(t, externalResources[2].IsAbortedSubtask) + }) + + t.Run("TerminateSubtasksWithFailure", func(t *testing.T) { + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + eventRecorder := mocks.EventsRecorder{} + eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + tCtx.OnEventsRecorder().Return(&eventRecorder) + + terminateCounter := 0 + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + terminateCounter++ + return fmt.Errorf("error") + } + + err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + assert.NotNil(t, err) + assert.Equal(t, 1, terminateCounter) + + args := eventRecorder.Calls[0].Arguments + phaseInfo, ok := args.Get(1).(core.PhaseInfo) + assert.True(t, ok) + + externalResources := phaseInfo.Info().ExternalResources + assert.Len(t, externalResources, subtaskCount) + + assert.False(t, externalResources[0].IsAbortedSubtask) + assert.False(t, externalResources[1].IsAbortedSubtask) + assert.False(t, externalResources[2].IsAbortedSubtask) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 38e7ad03bb..8d88f092bc 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -787,6 +787,7 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext logger.Errorf(ctx, "Abort failed when calling plugin abort.") return err } + evRecorder := nCtx.EventsRecorder() logger.Debugf(ctx, "Sending buffered Task events.") for _, ev := range tCtx.ber.GetAll(ctx) { From d7897f1d07f34b86a1b073b7cba07e98c8c7748f Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Mon, 4 Dec 2023 15:03:08 -0800 Subject: [PATCH 07/16] lint + update unit test Signed-off-by: Paul Dittamo --- .../pkg/manager/impl/task_execution_manager_test.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 28ea1a4ae7..74b286b968 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -442,6 +442,7 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { taskEventRequest.Event.Phase = core.TaskExecution_FAILED taskEventRequest.Event.PhaseVersion = uint32(0) taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) assert.Nil(t, resp) adminError = err.(flyteAdminErrors.FlyteAdminError) assert.Equal(t, adminError.Code(), codes.FailedPrecondition) @@ -454,13 +455,10 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED taskEventRequest.Event.PhaseVersion = uint32(0) taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) assert.Nil(t, resp) adminError = err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok = adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) + assert.Equal(t, adminError.Code(), codes.AlreadyExists) // request w/ same terminal phase, later version taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED From 6dc472c2984b0b8e49fe7228f017648d6322e6e2 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Mon, 4 Dec 2023 17:09:08 -0800 Subject: [PATCH 08/16] delete comment Signed-off-by: Paul Dittamo --- flytepropeller/pkg/controller/nodes/task/handler.go | 1 - 1 file changed, 1 deletion(-) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 8d88f092bc..eb760e9d5b 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -825,7 +825,6 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext if err != nil { return err } - // TODO handle this call failing if phase is set to Failure - probably doesn't matter if err := evRecorder.RecordTaskEvent(ctx, &event.TaskExecutionEvent{ TaskId: taskExecID.TaskId, ParentNodeExecutionId: nodeExecutionID, From 34c422ea3a3248b0a5146f888a56a3da384cd4d9 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Wed, 6 Dec 2023 13:45:03 -0800 Subject: [PATCH 09/16] clean up unit test Signed-off-by: Paul Dittamo --- .../impl/task_execution_manager_test.go | 84 ++++++++++--------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 74b286b968..7e256541e4 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -426,46 +426,52 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { }, nil }) - // request w/ non terminal phase - taskEventRequest.Event.Phase = core.TaskExecution_RUNNING - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) - - // request w/ different terminal phase - taskEventRequest.Event.Phase = core.TaskExecution_FAILED - taskEventRequest.Event.PhaseVersion = uint32(0) - taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError = err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok = adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) - - // request w/ same terminal phase, not a later version - taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED - taskEventRequest.Event.PhaseVersion = uint32(0) - taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError = err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.AlreadyExists) + t.Run("CreateExecutionEvent_NonTerminalPhase", func(t *testing.T) { + taskEventRequest.Event.Phase = core.TaskExecution_RUNNING + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) + assert.Nil(t, resp) + adminError := err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) + details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) + assert.True(t, ok) + }) - // request w/ same terminal phase, later version - taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED - taskEventRequest.Event.PhaseVersion = uint32(1) - taskExecManager = NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher) - _, err = taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, err) + t.Run("CreateExecutionEvent_DifferentTerminalPhase", func(t *testing.T) { + taskEventRequest.Event.Phase = core.TaskExecution_FAILED + taskEventRequest.Event.PhaseVersion = uint32(0) + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) + assert.Nil(t, resp) + adminError := err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) + details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) + assert.True(t, ok) + }) + + t.Run("CreateExecutionEvent_SameTerminalPhase_OldVersion", func(t *testing.T) { + // request w/ same terminal phase, not a later version + taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED + taskEventRequest.Event.PhaseVersion = uint32(0) + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) + assert.Nil(t, resp) + adminError := err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.AlreadyExists) + }) + + t.Run("CreateExecutionEvent_SameTerminalPhase_NewVersion", func(t *testing.T) { + // request w/ same terminal phase, later version + taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED + taskEventRequest.Event.PhaseVersion = uint32(1) + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher) + _, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) + assert.Nil(t, err) + }) } func TestCreateTaskEvent_PhaseVersionChange(t *testing.T) { From ede61f92d7de4cb650f4e5d4e9c4a002a4f68f0f Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Wed, 6 Dec 2023 13:47:31 -0800 Subject: [PATCH 10/16] delete comments Signed-off-by: Paul Dittamo --- flyteadmin/pkg/manager/impl/task_execution_manager_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 7e256541e4..9a6dac838a 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -454,7 +454,6 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { }) t.Run("CreateExecutionEvent_SameTerminalPhase_OldVersion", func(t *testing.T) { - // request w/ same terminal phase, not a later version taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED taskEventRequest.Event.PhaseVersion = uint32(0) taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) @@ -465,7 +464,6 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { }) t.Run("CreateExecutionEvent_SameTerminalPhase_NewVersion", func(t *testing.T) { - // request w/ same terminal phase, later version taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED taskEventRequest.Event.PhaseVersion = uint32(1) taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher) From 48cb8c6c76837c6c53728b2ff2b2e362422df5f9 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Fri, 8 Dec 2023 01:16:44 -0800 Subject: [PATCH 11/16] set correct error message Signed-off-by: Paul Dittamo --- flyteplugins/go/tasks/plugins/array/core/state.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index 858e1f0cef..c03f4a1384 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -222,7 +222,7 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhasePermanentFailure: if state.GetExecutionErr() != nil { - phaseInfo = core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), state.GetReason(), nowTaskInfo) + phaseInfo = core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), state.GetExecutionErr().Message, nowTaskInfo) } else { phaseInfo = core.PhaseInfoSystemFailureWithCleanup(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) } From fadec719a2b42635a0622b8bcc93a0583fac10f1 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Sat, 9 Dec 2023 18:21:13 -0800 Subject: [PATCH 12/16] add new array plugin abort phase Signed-off-by: Paul Dittamo --- .../manager/impl/task_execution_manager.go | 5 +- .../impl/task_execution_manager_test.go | 53 ++----- .../go/tasks/plugins/array/core/state.go | 8 +- .../go/tasks/plugins/array/core/state_test.go | 6 +- .../go/tasks/plugins/array/k8s/executor.go | 7 +- .../go/tasks/plugins/array/k8s/management.go | 40 +++-- .../plugins/array/k8s/management_test.go | 141 ++++++++++++------ 7 files changed, 148 insertions(+), 112 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index 13e2e71c6c..82b872dbd4 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -174,9 +174,8 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req } currentPhase := core.TaskExecution_Phase(core.TaskExecution_Phase_value[taskExecutionModel.Phase]) - if common.IsTaskExecutionTerminal(currentPhase) && - (taskExecutionModel.Phase != request.Event.Phase.String() || taskExecutionModel.PhaseVersion >= request.Event.PhaseVersion) { - // Only update terminate execution if phase matches and it's a newer version + if common.IsTaskExecutionTerminal(currentPhase) { + // Cannot update a terminal execution. curPhase := request.Event.Phase.String() errorMsg := fmt.Sprintf("invalid phase change from %v to %v for task execution %v", taskExecutionModel.Phase, request.Event.Phase, taskExecutionID) logger.Warnf(ctx, errorMsg) diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index 9a6dac838a..8fd8019647 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -425,51 +425,18 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) { Phase: core.TaskExecution_SUCCEEDED.String(), }, nil }) + taskEventRequest.Event.Phase = core.TaskExecution_RUNNING + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - t.Run("CreateExecutionEvent_NonTerminalPhase", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_RUNNING - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) - }) - - t.Run("CreateExecutionEvent_DifferentTerminalPhase", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_FAILED - taskEventRequest.Event.PhaseVersion = uint32(0) - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.FailedPrecondition) - details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) - assert.True(t, ok) - _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) - assert.True(t, ok) - }) - - t.Run("CreateExecutionEvent_SameTerminalPhase_OldVersion", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED - taskEventRequest.Event.PhaseVersion = uint32(0) - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) - resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, resp) - adminError := err.(flyteAdminErrors.FlyteAdminError) - assert.Equal(t, adminError.Code(), codes.AlreadyExists) - }) + assert.Nil(t, resp) - t.Run("CreateExecutionEvent_SameTerminalPhase_NewVersion", func(t *testing.T) { - taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED - taskEventRequest.Event.PhaseVersion = uint32(1) - taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher) - _, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest) - assert.Nil(t, err) - }) + adminError := err.(flyteAdminErrors.FlyteAdminError) + assert.Equal(t, adminError.Code(), codes.FailedPrecondition) + details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason) + assert.True(t, ok) + _, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState) + assert.True(t, ok) } func TestCreateTaskEvent_PhaseVersionChange(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index c03f4a1384..4dac008006 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -35,6 +35,7 @@ const ( PhaseAssembleFinalError PhaseRetryableFailure PhasePermanentFailure + PhaseAbortSubTasks ) type State struct { @@ -204,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhaseAssembleFinalError: fallthrough + case PhaseAbortSubTasks: + fallthrough + case PhaseWriteToDiscoveryThenFail: fallthrough @@ -259,14 +263,14 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus if totalCount < minSuccesses { logger.Infof(ctx, "Array failed because totalCount[%v] < minSuccesses[%v]", totalCount, minSuccesses) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } // No chance to reach the required success numbers. if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses { logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]", minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } if totalWaitingForResources > 0 { diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index 01b5b41528..969c98df20 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -291,7 +291,7 @@ func TestSummaryToPhase(t *testing.T) { }{ { "FailOnTooFewTasks", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{}, }, { @@ -304,7 +304,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailOnTooManyPermanentFailures", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhasePermanentFailure: 1, core.PhaseSuccess: 9, @@ -335,7 +335,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailedToRetry", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhaseSuccess: 5, core.PhasePermanentFailure: 5, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 3e324391b8..3407c45905 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -117,6 +117,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) + case arrayCore.PhaseAbortSubTasks: + nextState, externalResources, err = TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + case arrayCore.PhaseAssembleFinalOutput: nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState) @@ -156,7 +159,7 @@ func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) err return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) } func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { @@ -165,7 +168,7 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) } func (e Executor) Start(ctx context.Context) error { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index 2c805a1f05..3d5eb2ef60 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -324,20 +324,37 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return newState, externalResources, nil } +func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + + _, externalResources, err := TerminateSubTasks(ctx, tCtx, kubeClient, GetConfig(), terminateFunction, currentState) + if err != nil { + return err + } + + taskInfo := &core.TaskInfo{ + ExternalResources: externalResources, + } + phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) + err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) + + return err +} + // TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include // aborting and finalizing active k8s resources. func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, - terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) (*arrayCore.State, []*core.ExternalResource, error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) + externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) if err != nil { - return err + return currentState, externalResources, err } else if taskTemplate == nil { - return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + return currentState, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") } messageCollector := errorcollector.NewErrorMessageCollector() - externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] retryAttempt := uint64(0) @@ -348,7 +365,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { - return err + return currentState, externalResources, err } isAbortedSubtask := false @@ -371,18 +388,9 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube }) } - taskInfo := &core.TaskInfo{ - ExternalResources: externalResources, - } - phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) - err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) - if err != nil { - return err - } - if messageCollector.Length() > 0 { - return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) + return currentState, externalResources, fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) } - return nil + return currentState.SetPhase(arrayCore.PhaseWriteToDiscoveryThenFail, currentState.PhaseVersion+1), externalResources, nil } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 327d87431a..8d2696d02b 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -531,7 +531,7 @@ func TestCheckSubTasksState(t *testing.T) { // validate results assert.Nil(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail.String(), p.String()) + assert.Equal(t, arrayCore.PhaseAbortSubTasks.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "ReleaseResource", subtaskCount) for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { assert.Equal(t, core.PhasePermanentFailure, core.Phases[subtaskPhaseIndex]) @@ -539,7 +539,7 @@ func TestCheckSubTasksState(t *testing.T) { }) } -func TestTerminateSubtasks(t *testing.T) { +func TestTerminateSubTasksOnAbort(t *testing.T) { ctx := context.Background() subtaskCount := 3 config := Config{ @@ -554,9 +554,9 @@ func TestTerminateSubtasks(t *testing.T) { kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) - compactArray.SetItem(0, 8) // PhasePermanentFailure - compactArray.SetItem(1, 0) // PhaseUndefined - compactArray.SetItem(2, 5) // PhaseRunning + for i := 0; i < subtaskCount; i++ { + compactArray.SetItem(i, 5) + } currentState := &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -566,68 +566,123 @@ func TestTerminateSubtasks(t *testing.T) { ArrayStatus: arraystatus.ArrayStatus{ Detailed: compactArray, }, - IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), } - t.Run("TerminateSubtasks", func(t *testing.T) { - resourceManager := mocks.ResourceManager{} - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + t.Run("SuccessfulTermination", func(t *testing.T) { eventRecorder := mocks.EventsRecorder{} eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) tCtx := getMockTaskExecutionContext(ctx, 0) - tCtx.OnResourceManager().Return(&resourceManager) tCtx.OnEventsRecorder().Return(&eventRecorder) - terminateCounter := 0 mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { - terminateCounter++ return nil } - err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) - assert.Equal(t, 1, terminateCounter) - assert.Nil(t, err) - - args := eventRecorder.Calls[0].Arguments - phaseInfo, ok := args.Get(1).(core.PhaseInfo) - assert.True(t, ok) + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) - externalResources := phaseInfo.Info().ExternalResources - assert.Len(t, externalResources, subtaskCount) - - assert.False(t, externalResources[0].IsAbortedSubtask) - assert.False(t, externalResources[1].IsAbortedSubtask) - assert.True(t, externalResources[2].IsAbortedSubtask) + assert.Nil(t, err) + eventRecorder.AssertCalled(t, "RecordRaw", mock.Anything, mock.Anything) }) - t.Run("TerminateSubtasksWithFailure", func(t *testing.T) { - resourceManager := mocks.ResourceManager{} - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + t.Run("TerminationWithError", func(t *testing.T) { eventRecorder := mocks.EventsRecorder{} eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) tCtx := getMockTaskExecutionContext(ctx, 0) - tCtx.OnResourceManager().Return(&resourceManager) tCtx.OnEventsRecorder().Return(&eventRecorder) - terminateCounter := 0 mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { - terminateCounter++ - return fmt.Errorf("error") + return fmt.Errorf("termination error") } - err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + assert.NotNil(t, err) - assert.Equal(t, 1, terminateCounter) + eventRecorder.AssertNotCalled(t, "RecordRaw", mock.Anything, mock.Anything) + }) +} + +func TestTerminateSubTasks(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + tests := []struct { + name string + initialPhaseIndices []int + expectedAbortCount int + terminateError error + }{ + { + name: "AllSubTasksRunning", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: nil, + }, + { + name: "MixedSubTaskStates", + initialPhaseIndices: []int{8, 0, 5}, + expectedAbortCount: 1, + terminateError: nil, + }, + { + name: "TerminateFunctionFails", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: fmt.Errorf("error"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i, phaseIdx := range test.initialPhaseIndices { + compactArray.SetItem(i, bitarray.Item(phaseIdx)) + } + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + PhaseVersion: 0, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), + } - args := eventRecorder.Calls[0].Arguments - phaseInfo, ok := args.Get(1).(core.PhaseInfo) - assert.True(t, ok) + tCtx := getMockTaskExecutionContext(ctx, 0) + terminateCounter := 0 + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + terminateCounter++ + return test.terminateError + } - externalResources := phaseInfo.Info().ExternalResources - assert.Len(t, externalResources, subtaskCount) + nextState, externalResources, err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) - assert.False(t, externalResources[0].IsAbortedSubtask) - assert.False(t, externalResources[1].IsAbortedSubtask) - assert.False(t, externalResources[2].IsAbortedSubtask) - }) + assert.Equal(t, test.expectedAbortCount, terminateCounter) + + if test.terminateError != nil { + assert.NotNil(t, err) + return + } + + assert.Nil(t, err) + assert.Equal(t, uint32(1), nextState.PhaseVersion) + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, nextState.CurrentPhase) + assert.Len(t, externalResources, subtaskCount) + for _, externalResource := range externalResources { + assert.True(t, externalResource.IsAbortedSubtask || core.Phases[externalResource.Phase].IsTerminal() || core.Phases[externalResource.Phase] == core.PhaseUndefined) + } + }) + } } From ccd31ce8522325761d6dff08c540f9666e949787 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Sat, 9 Dec 2023 18:37:31 -0800 Subject: [PATCH 13/16] update awsbatch array unit test Signed-off-by: Paul Dittamo --- flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go index d40edd41de..9de13c6413 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -248,7 +248,7 @@ func TestCheckSubTasksState(t *testing.T) { newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ - CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 2, OriginalArraySize: 2, OriginalMinSuccesses: 2, @@ -264,6 +264,6 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p) + assert.Equal(t, arrayCore.PhaseAbortSubTasks, p) }) } From 08056f3d99c58ca233f4b42740115189e54cc1ce Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Thu, 14 Dec 2023 11:29:31 -0600 Subject: [PATCH 14/16] using Abort phase and cleaning up best-effort (#4600) * using Abort phase and cleaning up best-effort Signed-off-by: Daniel Rammer * removed phase updates on Finalize Signed-off-by: Daniel Rammer --------- Signed-off-by: Daniel Rammer --- .../go/tasks/pluginmachinery/core/phase.go | 11 +++-- .../go/tasks/plugins/array/core/state.go | 4 +- .../go/tasks/plugins/array/k8s/executor.go | 3 +- .../go/tasks/plugins/array/k8s/management.go | 42 ++++++++++--------- .../pkg/controller/nodes/task/transformer.go | 5 +-- 5 files changed, 36 insertions(+), 29 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 7e9ad08c8e..7847b161fc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -36,6 +36,8 @@ const ( PhasePermanentFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache + // Indicate the task has been aborted + PhaseAborted ) var Phases = []Phase{ @@ -49,11 +51,12 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, + PhaseAborted, } // Returns true if the given phase is failure, retryable failure or success func (p Phase) IsTerminal() bool { - return p.IsFailure() || p.IsSuccess() + return p.IsFailure() || p.IsSuccess() || p.IsAborted() } func (p Phase) IsFailure() bool { @@ -64,6 +67,10 @@ func (p Phase) IsSuccess() bool { return p == PhaseSuccess } +func (p Phase) IsAborted() bool { + return p == PhaseAborted +} + func (p Phase) IsWaitingForResources() bool { return p == PhaseWaitingForResources } @@ -82,8 +89,6 @@ type ExternalResource struct { RetryAttempt uint32 // Phase (if exists) associated with the external resource Phase Phase - // Indicates if external resource is a subtask getting aborted - IsAbortedSubtask bool } type ReasonInfo struct { diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index 4dac008006..a540359b0a 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -226,9 +226,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhasePermanentFailure: if state.GetExecutionErr() != nil { - phaseInfo = core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), state.GetExecutionErr().Message, nowTaskInfo) + phaseInfo = core.PhaseInfoFailed(core.PhasePermanentFailure, state.GetExecutionErr(), nowTaskInfo) } else { - phaseInfo = core.PhaseInfoSystemFailureWithCleanup(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) + phaseInfo = core.PhaseInfoSystemFailure(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) } default: return phaseInfo, fmt.Errorf("failed to map custom state phase to core phase. State Phase [%v]", p) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 3407c45905..f664392fd9 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -168,7 +168,8 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + _, _, err := TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + return err } func (e Executor) Start(ctx context.Context) error { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index 3d5eb2ef60..510f202e1a 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -335,10 +335,13 @@ func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContex taskInfo := &core.TaskInfo{ ExternalResources: externalResources, } - phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo) - err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) + executionErr := &idlCore.ExecutionError{ + Code: "ArraySubtasksAborted", + Message: "Array subtasks were aborted", + } + phaseInfo := core.PhaseInfoFailed(core.PhaseAborted, executionErr, taskInfo) - return err + return tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) } // TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include @@ -362,30 +365,29 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube // we can use RetryAttempts if it has been initialized, otherwise stay with default 0 retryAttempt = currentState.RetryAttempts.GetItem(childIdx) } + + // return immediately if subtask has completed or not yet started + if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined { + continue + } + originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { return currentState, externalResources, err } - isAbortedSubtask := false - if !existingPhase.IsTerminal() && existingPhase != core.PhaseUndefined { - // only terminate subtask if it has completed or has not yet started - err = terminateFunction(ctx, stCtx, config, kubeClient) - if err != nil { - messageCollector.Collect(childIdx, err.Error()) - } else { - isAbortedSubtask = true - } + err = terminateFunction(ctx, stCtx, config, kubeClient) + if err != nil { + messageCollector.Collect(childIdx, err.Error()) + } else { + externalResources = append(externalResources, &core.ExternalResource{ + ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), + Index: uint32(originalIdx), + RetryAttempt: uint32(retryAttempt), + Phase: core.PhaseAborted, + }) } - - externalResources = append(externalResources, &core.ExternalResource{ - ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), - Index: uint32(originalIdx), - RetryAttempt: uint32(retryAttempt), - Phase: existingPhase, - IsAbortedSubtask: isAbortedSubtask, - }) } if messageCollector.Length() > 0 { diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index 135c3aa004..242c1334ce 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -44,6 +44,8 @@ func ToTaskEventPhase(p pluginCore.Phase) core.TaskExecution_Phase { return core.TaskExecution_FAILED case pluginCore.PhaseRetryableFailure: return core.TaskExecution_FAILED + case pluginCore.PhaseAborted: + return core.TaskExecution_ABORTED case pluginCore.PhaseNotReady: fallthrough case pluginCore.PhaseUndefined: @@ -118,9 +120,6 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio metadata.ExternalResources = make([]*event.ExternalResourceInfo, len(externalResources)) for idx, e := range input.Info.Info().ExternalResources { phase := ToTaskEventPhase(e.Phase) - if e.IsAbortedSubtask { - phase = core.TaskExecution_ABORTED - } metadata.ExternalResources[idx] = &event.ExternalResourceInfo{ ExternalId: e.ExternalID, CacheStatus: e.CacheStatus, From cc4c99b68c029772ab7735c6913accf1c00adb76 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Thu, 14 Dec 2023 09:36:16 -0800 Subject: [PATCH 15/16] update unit test Signed-off-by: Paul Dittamo --- flyteplugins/go/tasks/plugins/array/k8s/management_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 8d2696d02b..ab26028b09 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -679,9 +679,11 @@ func TestTerminateSubTasks(t *testing.T) { assert.Nil(t, err) assert.Equal(t, uint32(1), nextState.PhaseVersion) assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, nextState.CurrentPhase) - assert.Len(t, externalResources, subtaskCount) + assert.Len(t, externalResources, terminateCounter) + for _, externalResource := range externalResources { - assert.True(t, externalResource.IsAbortedSubtask || core.Phases[externalResource.Phase].IsTerminal() || core.Phases[externalResource.Phase] == core.PhaseUndefined) + phase := core.Phases[externalResource.Phase] + assert.True(t, phase.IsAborted()) } }) } From b34f6d35348841e261a24a929c509f5c5eeac92e Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Fri, 15 Dec 2023 08:58:36 -0600 Subject: [PATCH 16/16] handling PhaseAbortSubTasks in awsbatch plugin Signed-off-by: Daniel Rammer --- flyteplugins/go/tasks/plugins/array/awsbatch/executor.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go index 99688fe36d..6c7a858be8 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go @@ -78,6 +78,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c case arrayCore.PhaseAssembleFinalOutput: pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState.State) + + case arrayCore.PhaseAbortSubTasks: + fallthrough case arrayCore.PhaseWriteToDiscoveryThenFail: pluginState.State, externalResources, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version+1)