Skip to content

Commit

Permalink
add new array plugin abort phase
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt committed Dec 10, 2023
1 parent 7060198 commit 6cae3f8
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 112 deletions.
5 changes: 2 additions & 3 deletions flyteadmin/pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req
}

currentPhase := core.TaskExecution_Phase(core.TaskExecution_Phase_value[taskExecutionModel.Phase])
if common.IsTaskExecutionTerminal(currentPhase) &&
(taskExecutionModel.Phase != request.Event.Phase.String() || taskExecutionModel.PhaseVersion >= request.Event.PhaseVersion) {
// Only update terminate execution if phase matches and it's a newer version
if common.IsTaskExecutionTerminal(currentPhase) {
// Cannot update a terminal execution.
curPhase := request.Event.Phase.String()
errorMsg := fmt.Sprintf("invalid phase change from %v to %v for task execution %v", taskExecutionModel.Phase, request.Event.Phase, taskExecutionID)
logger.Warnf(ctx, errorMsg)
Expand Down
53 changes: 10 additions & 43 deletions flyteadmin/pkg/manager/impl/task_execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,51 +425,18 @@ func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) {
Phase: core.TaskExecution_SUCCEEDED.String(),
}, nil
})
taskEventRequest.Event.Phase = core.TaskExecution_RUNNING
taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil)
resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest)

t.Run("CreateExecutionEvent_NonTerminalPhase", func(t *testing.T) {
taskEventRequest.Event.Phase = core.TaskExecution_RUNNING
taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil)
resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest)
assert.Nil(t, resp)
adminError := err.(flyteAdminErrors.FlyteAdminError)
assert.Equal(t, adminError.Code(), codes.FailedPrecondition)
details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason)
assert.True(t, ok)
_, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState)
assert.True(t, ok)
})

t.Run("CreateExecutionEvent_DifferentTerminalPhase", func(t *testing.T) {
taskEventRequest.Event.Phase = core.TaskExecution_FAILED
taskEventRequest.Event.PhaseVersion = uint32(0)
taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil)
resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest)
assert.Nil(t, resp)
adminError := err.(flyteAdminErrors.FlyteAdminError)
assert.Equal(t, adminError.Code(), codes.FailedPrecondition)
details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason)
assert.True(t, ok)
_, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState)
assert.True(t, ok)
})

t.Run("CreateExecutionEvent_SameTerminalPhase_OldVersion", func(t *testing.T) {
taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED
taskEventRequest.Event.PhaseVersion = uint32(0)
taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil)
resp, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest)
assert.Nil(t, resp)
adminError := err.(flyteAdminErrors.FlyteAdminError)
assert.Equal(t, adminError.Code(), codes.AlreadyExists)
})
assert.Nil(t, resp)

t.Run("CreateExecutionEvent_SameTerminalPhase_NewVersion", func(t *testing.T) {
taskEventRequest.Event.Phase = core.TaskExecution_SUCCEEDED
taskEventRequest.Event.PhaseVersion = uint32(1)
taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, &mockPublisher, &mockPublisher)
_, err := taskExecManager.CreateTaskExecutionEvent(context.Background(), taskEventRequest)
assert.Nil(t, err)
})
adminError := err.(flyteAdminErrors.FlyteAdminError)
assert.Equal(t, adminError.Code(), codes.FailedPrecondition)
details, ok := adminError.GRPCStatus().Details()[0].(*admin.EventFailureReason)
assert.True(t, ok)
_, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState)
assert.True(t, ok)
}

func TestCreateTaskEvent_PhaseVersionChange(t *testing.T) {
Expand Down
8 changes: 6 additions & 2 deletions flyteplugins/go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
PhaseAssembleFinalError
PhaseRetryableFailure
PhasePermanentFailure
PhaseAbortSubTasks
)

type State struct {
Expand Down Expand Up @@ -204,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
case PhaseAssembleFinalError:
fallthrough

case PhaseAbortSubTasks:
fallthrough

case PhaseWriteToDiscoveryThenFail:
fallthrough

Expand Down Expand Up @@ -259,14 +263,14 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus

if totalCount < minSuccesses {
logger.Infof(ctx, "Array failed because totalCount[%v] < minSuccesses[%v]", totalCount, minSuccesses)
return PhaseWriteToDiscoveryThenFail
return PhaseAbortSubTasks
}

// No chance to reach the required success numbers.
if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses {
logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]",
minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures)
return PhaseWriteToDiscoveryThenFail
return PhaseAbortSubTasks
}

if totalWaitingForResources > 0 {
Expand Down
6 changes: 3 additions & 3 deletions flyteplugins/go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func TestSummaryToPhase(t *testing.T) {
}{
{
"FailOnTooFewTasks",
PhaseWriteToDiscoveryThenFail,
PhaseAbortSubTasks,
map[core.Phase]int64{},
},
{
Expand All @@ -304,7 +304,7 @@ func TestSummaryToPhase(t *testing.T) {
},
{
"FailOnTooManyPermanentFailures",
PhaseWriteToDiscoveryThenFail,
PhaseAbortSubTasks,
map[core.Phase]int64{
core.PhasePermanentFailure: 1,
core.PhaseSuccess: 9,
Expand Down Expand Up @@ -335,7 +335,7 @@ func TestSummaryToPhase(t *testing.T) {
},
{
"FailedToRetry",
PhaseWriteToDiscoveryThenFail,
PhaseAbortSubTasks,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhasePermanentFailure: 5,
Expand Down
7 changes: 5 additions & 2 deletions flyteplugins/go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig,
tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState)

case arrayCore.PhaseAbortSubTasks:
nextState, externalResources, err = TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)

case arrayCore.PhaseAssembleFinalOutput:
nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState)

Expand Down Expand Up @@ -156,7 +159,7 @@ func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) err
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)
return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)
}

func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
Expand All @@ -165,7 +168,7 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext)
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState)
return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState)
}

func (e Executor) Start(ctx context.Context) error {
Expand Down
40 changes: 24 additions & 16 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,20 +324,37 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return newState, externalResources, nil
}

func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config,
terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error {

_, externalResources, err := TerminateSubTasks(ctx, tCtx, kubeClient, GetConfig(), terminateFunction, currentState)
if err != nil {
return err
}

taskInfo := &core.TaskInfo{
ExternalResources: externalResources,
}
phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo)
err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo)

return err
}

// TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include
// aborting and finalizing active k8s resources.
func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config,
terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error {
terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) (*arrayCore.State, []*core.ExternalResource, error) {

taskTemplate, err := tCtx.TaskReader().Read(ctx)
externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems()))
if err != nil {
return err
return currentState, externalResources, err
} else if taskTemplate == nil {
return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
return currentState, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
}

messageCollector := errorcollector.NewErrorMessageCollector()
externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems()))
for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() {
existingPhase := core.Phases[existingPhaseIdx]
retryAttempt := uint64(0)
Expand All @@ -348,7 +365,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0)
if err != nil {
return err
return currentState, externalResources, err
}

isAbortedSubtask := false
Expand All @@ -371,18 +388,9 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
})
}

taskInfo := &core.TaskInfo{
ExternalResources: externalResources,
}
phaseInfo := core.PhaseInfoFailureWithCleanup(core.PhasePermanentFailure.String(), "Array subtasks were aborted", taskInfo)
err = tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo)
if err != nil {
return err
}

if messageCollector.Length() > 0 {
return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength))
return currentState, externalResources, fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength))
}

return nil
return currentState.SetPhase(arrayCore.PhaseWriteToDiscoveryThenFail, currentState.PhaseVersion+1), externalResources, nil
}
Loading

0 comments on commit 6cae3f8

Please sign in to comment.