Skip to content

Commit

Permalink
Bug/abort map task subtasks (#4506)
Browse files Browse the repository at this point in the history
* add cleanup on failure for array task permanent failures

Signed-off-by: Paul Dittamo <[email protected]>

* persist subtask aborts to admin

Signed-off-by: Paul Dittamo <[email protected]>

* add isAbortedSubtask field to externalResources

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* add tests for terminal updates

Signed-off-by: Paul Dittamo <[email protected]>

* add unit tests for terminatesubtasks

Signed-off-by: Paul Dittamo <[email protected]>

* lint + update unit test

Signed-off-by: Paul Dittamo <[email protected]>

* delete comment

Signed-off-by: Paul Dittamo <[email protected]>

* clean up unit test

Signed-off-by: Paul Dittamo <[email protected]>

* delete comments

Signed-off-by: Paul Dittamo <[email protected]>

* set correct error message

Signed-off-by: Paul Dittamo <[email protected]>

* add new array plugin abort phase

Signed-off-by: Paul Dittamo <[email protected]>

* update awsbatch array unit test

Signed-off-by: Paul Dittamo <[email protected]>

* using Abort phase and cleaning up best-effort (#4600)

* using Abort phase and cleaning up best-effort

Signed-off-by: Daniel Rammer <[email protected]>

* removed phase updates on Finalize

Signed-off-by: Daniel Rammer <[email protected]>

---------

Signed-off-by: Daniel Rammer <[email protected]>

* update unit test

Signed-off-by: Paul Dittamo <[email protected]>

* handling PhaseAbortSubTasks in awsbatch plugin

Signed-off-by: Daniel Rammer <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Daniel Rammer <[email protected]>
Co-authored-by: Paul Dittamo <[email protected]>
Co-authored-by: Dan Rammer <[email protected]>
  • Loading branch information
3 people authored Dec 15, 2023
1 parent 1bed29a commit 1699094
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 19 deletions.
17 changes: 16 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/core/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (
PhasePermanentFailure
// Indicates the task is waiting for the cache to be populated so it can reuse results
PhaseWaitingForCache
// Indicate the task has been aborted
PhaseAborted
)

var Phases = []Phase{
Expand All @@ -49,11 +51,12 @@ var Phases = []Phase{
PhaseRetryableFailure,
PhasePermanentFailure,
PhaseWaitingForCache,
PhaseAborted,
}

// Returns true if the given phase is failure, retryable failure or success
func (p Phase) IsTerminal() bool {
return p.IsFailure() || p.IsSuccess()
return p.IsFailure() || p.IsSuccess() || p.IsAborted()
}

func (p Phase) IsFailure() bool {
Expand All @@ -64,6 +67,10 @@ func (p Phase) IsSuccess() bool {
return p == PhaseSuccess
}

func (p Phase) IsAborted() bool {
return p == PhaseAborted
}

func (p Phase) IsWaitingForResources() bool {
return p == PhaseWaitingForResources
}
Expand Down Expand Up @@ -257,10 +264,18 @@ func PhaseInfoSystemFailure(code, reason string, info *TaskInfo) PhaseInfo {
return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_SYSTEM}, info)
}

func PhaseInfoSystemFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo {
return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true)
}

func PhaseInfoFailure(code, reason string, info *TaskInfo) PhaseInfo {
return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info)
}

func PhaseInfoFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo {
return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true)
}

func PhaseInfoRetryableFailure(code, reason string, info *TaskInfo) PhaseInfo {
return PhaseInfoFailed(PhaseRetryableFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info)
}
Expand Down
3 changes: 3 additions & 0 deletions flyteplugins/go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState.State)

case arrayCore.PhaseAbortSubTasks:
fallthrough

case arrayCore.PhaseWriteToDiscoveryThenFail:
pluginState.State, externalResources, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version+1)
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func TestCheckSubTasksState(t *testing.T) {

newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail,
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 2,
OriginalArraySize: 2,
OriginalMinSuccesses: 2,
Expand All @@ -264,6 +264,6 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p)
assert.Equal(t, arrayCore.PhaseAbortSubTasks, p)
})
}
8 changes: 6 additions & 2 deletions flyteplugins/go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
PhaseAssembleFinalError
PhaseRetryableFailure
PhasePermanentFailure
PhaseAbortSubTasks
)

type State struct {
Expand Down Expand Up @@ -204,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
case PhaseAssembleFinalError:
fallthrough

case PhaseAbortSubTasks:
fallthrough

case PhaseWriteToDiscoveryThenFail:
fallthrough

Expand Down Expand Up @@ -259,14 +263,14 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus

if totalCount < minSuccesses {
logger.Infof(ctx, "Array failed because totalCount[%v] < minSuccesses[%v]", totalCount, minSuccesses)
return PhaseWriteToDiscoveryThenFail
return PhaseAbortSubTasks
}

// No chance to reach the required success numbers.
if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses {
logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]",
minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures)
return PhaseWriteToDiscoveryThenFail
return PhaseAbortSubTasks
}

if totalWaitingForResources > 0 {
Expand Down
6 changes: 3 additions & 3 deletions flyteplugins/go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func TestSummaryToPhase(t *testing.T) {
}{
{
"FailOnTooFewTasks",
PhaseWriteToDiscoveryThenFail,
PhaseAbortSubTasks,
map[core.Phase]int64{},
},
{
Expand All @@ -304,7 +304,7 @@ func TestSummaryToPhase(t *testing.T) {
},
{
"FailOnTooManyPermanentFailures",
PhaseWriteToDiscoveryThenFail,
PhaseAbortSubTasks,
map[core.Phase]int64{
core.PhasePermanentFailure: 1,
core.PhaseSuccess: 9,
Expand Down Expand Up @@ -335,7 +335,7 @@ func TestSummaryToPhase(t *testing.T) {
},
{
"FailedToRetry",
PhaseWriteToDiscoveryThenFail,
PhaseAbortSubTasks,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhasePermanentFailure: 5,
Expand Down
8 changes: 6 additions & 2 deletions flyteplugins/go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig,
tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState)

case arrayCore.PhaseAbortSubTasks:
nextState, externalResources, err = TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)

case arrayCore.PhaseAssembleFinalOutput:
nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState)

Expand Down Expand Up @@ -156,7 +159,7 @@ func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) err
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)
return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)
}

func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
Expand All @@ -165,7 +168,8 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext)
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState)
_, _, err := TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState)
return err
}

func (e Executor) Start(ctx context.Context) error {
Expand Down
40 changes: 34 additions & 6 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,37 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return newState, externalResources, nil
}

func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config,
terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error {

_, externalResources, err := TerminateSubTasks(ctx, tCtx, kubeClient, GetConfig(), terminateFunction, currentState)
if err != nil {
return err
}

taskInfo := &core.TaskInfo{
ExternalResources: externalResources,
}
executionErr := &idlCore.ExecutionError{
Code: "ArraySubtasksAborted",
Message: "Array subtasks were aborted",
}
phaseInfo := core.PhaseInfoFailed(core.PhaseAborted, executionErr, taskInfo)

return tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo)
}

// TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include
// aborting and finalizing active k8s resources.
func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config,
terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error {
terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) (*arrayCore.State, []*core.ExternalResource, error) {

taskTemplate, err := tCtx.TaskReader().Read(ctx)
externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems()))
if err != nil {
return err
return currentState, externalResources, err
} else if taskTemplate == nil {
return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
return currentState, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
}

messageCollector := errorcollector.NewErrorMessageCollector()
Expand All @@ -353,18 +374,25 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0)
if err != nil {
return err
return currentState, externalResources, err
}

err = terminateFunction(ctx, stCtx, config, kubeClient)
if err != nil {
messageCollector.Collect(childIdx, err.Error())
} else {
externalResources = append(externalResources, &core.ExternalResource{
ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(),
Index: uint32(originalIdx),
RetryAttempt: uint32(retryAttempt),
Phase: core.PhaseAborted,
})
}
}

if messageCollector.Length() > 0 {
return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength))
return currentState, externalResources, fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength))
}

return nil
return currentState.SetPhase(arrayCore.PhaseWriteToDiscoveryThenFail, currentState.PhaseVersion+1), externalResources, nil
}
Loading

0 comments on commit 1699094

Please sign in to comment.