diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 3f8c2a0914..7847b161fc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -36,6 +36,8 @@ const ( PhasePermanentFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache + // Indicate the task has been aborted + PhaseAborted ) var Phases = []Phase{ @@ -49,11 +51,12 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, + PhaseAborted, } // Returns true if the given phase is failure, retryable failure or success func (p Phase) IsTerminal() bool { - return p.IsFailure() || p.IsSuccess() + return p.IsFailure() || p.IsSuccess() || p.IsAborted() } func (p Phase) IsFailure() bool { @@ -64,6 +67,10 @@ func (p Phase) IsSuccess() bool { return p == PhaseSuccess } +func (p Phase) IsAborted() bool { + return p == PhaseAborted +} + func (p Phase) IsWaitingForResources() bool { return p == PhaseWaitingForResources } @@ -257,10 +264,18 @@ func PhaseInfoSystemFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_SYSTEM}, info) } +func PhaseInfoSystemFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo { + return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true) +} + func PhaseInfoFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info) } +func PhaseInfoFailureWithCleanup(code, reason string, info *TaskInfo) PhaseInfo { + return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info, true) +} + func PhaseInfoRetryableFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhaseRetryableFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_USER}, info) } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go index 99688fe36d..6c7a858be8 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go @@ -78,6 +78,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c case arrayCore.PhaseAssembleFinalOutput: pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState.State) + + case arrayCore.PhaseAbortSubTasks: + fallthrough case arrayCore.PhaseWriteToDiscoveryThenFail: pluginState.State, externalResources, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version+1) diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go index d40edd41de..9de13c6413 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -248,7 +248,7 @@ func TestCheckSubTasksState(t *testing.T) { newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ - CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 2, OriginalArraySize: 2, OriginalMinSuccesses: 2, @@ -264,6 +264,6 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p) + assert.Equal(t, arrayCore.PhaseAbortSubTasks, p) }) } diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index f601ec375b..a540359b0a 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -35,6 +35,7 @@ const ( PhaseAssembleFinalError PhaseRetryableFailure PhasePermanentFailure + PhaseAbortSubTasks ) type State struct { @@ -204,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhaseAssembleFinalError: fallthrough + case PhaseAbortSubTasks: + fallthrough + case PhaseWriteToDiscoveryThenFail: fallthrough @@ -259,14 +263,14 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus if totalCount < minSuccesses { logger.Infof(ctx, "Array failed because totalCount[%v] < minSuccesses[%v]", totalCount, minSuccesses) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } // No chance to reach the required success numbers. if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses { logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]", minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures) - return PhaseWriteToDiscoveryThenFail + return PhaseAbortSubTasks } if totalWaitingForResources > 0 { diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index 01b5b41528..969c98df20 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -291,7 +291,7 @@ func TestSummaryToPhase(t *testing.T) { }{ { "FailOnTooFewTasks", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{}, }, { @@ -304,7 +304,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailOnTooManyPermanentFailures", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhasePermanentFailure: 1, core.PhaseSuccess: 9, @@ -335,7 +335,7 @@ func TestSummaryToPhase(t *testing.T) { }, { "FailedToRetry", - PhaseWriteToDiscoveryThenFail, + PhaseAbortSubTasks, map[core.Phase]int64{ core.PhaseSuccess: 5, core.PhasePermanentFailure: 5, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 3e324391b8..f664392fd9 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -117,6 +117,9 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c nextState, externalResources, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) + case arrayCore.PhaseAbortSubTasks: + nextState, externalResources, err = TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + case arrayCore.PhaseAssembleFinalOutput: nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState) @@ -156,7 +159,7 @@ func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) err return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) + return TerminateSubTasksOnAbort(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) } func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { @@ -165,7 +168,8 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + _, _, err := TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) + return err } func (e Executor) Start(ctx context.Context) error { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index fdf9d4a182..510f202e1a 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -324,16 +324,37 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return newState, externalResources, nil } +func TerminateSubTasksOnAbort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + + _, externalResources, err := TerminateSubTasks(ctx, tCtx, kubeClient, GetConfig(), terminateFunction, currentState) + if err != nil { + return err + } + + taskInfo := &core.TaskInfo{ + ExternalResources: externalResources, + } + executionErr := &idlCore.ExecutionError{ + Code: "ArraySubtasksAborted", + Message: "Array subtasks were aborted", + } + phaseInfo := core.PhaseInfoFailed(core.PhaseAborted, executionErr, taskInfo) + + return tCtx.EventsRecorder().RecordRaw(ctx, phaseInfo) +} + // TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include // aborting and finalizing active k8s resources. func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, - terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) (*arrayCore.State, []*core.ExternalResource, error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) + externalResources := make([]*core.ExternalResource, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) if err != nil { - return err + return currentState, externalResources, err } else if taskTemplate == nil { - return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + return currentState, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") } messageCollector := errorcollector.NewErrorMessageCollector() @@ -353,18 +374,25 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) stCtx, err := NewSubTaskExecutionContext(ctx, tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { - return err + return currentState, externalResources, err } err = terminateFunction(ctx, stCtx, config, kubeClient) if err != nil { messageCollector.Collect(childIdx, err.Error()) + } else { + externalResources = append(externalResources, &core.ExternalResource{ + ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), + Index: uint32(originalIdx), + RetryAttempt: uint32(retryAttempt), + Phase: core.PhaseAborted, + }) } } if messageCollector.Length() > 0 { - return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) + return currentState, externalResources, fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) } - return nil + return currentState.SetPhase(arrayCore.PhaseWriteToDiscoveryThenFail, currentState.PhaseVersion+1), externalResources, nil } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 5bead4954d..ab26028b09 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -531,10 +531,160 @@ func TestCheckSubTasksState(t *testing.T) { // validate results assert.Nil(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail.String(), p.String()) + assert.Equal(t, arrayCore.PhaseAbortSubTasks.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "ReleaseResource", subtaskCount) for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { assert.Equal(t, core.PhasePermanentFailure, core.Phases[subtaskPhaseIndex]) } }) } + +func TestTerminateSubTasksOnAbort(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i := 0; i < subtaskCount; i++ { + compactArray.SetItem(i, 5) + } + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), + } + + t.Run("SuccessfulTermination", func(t *testing.T) { + eventRecorder := mocks.EventsRecorder{} + eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnEventsRecorder().Return(&eventRecorder) + + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + return nil + } + + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + + assert.Nil(t, err) + eventRecorder.AssertCalled(t, "RecordRaw", mock.Anything, mock.Anything) + }) + + t.Run("TerminationWithError", func(t *testing.T) { + eventRecorder := mocks.EventsRecorder{} + eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil) + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnEventsRecorder().Return(&eventRecorder) + + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + return fmt.Errorf("termination error") + } + + err := TerminateSubTasksOnAbort(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + + assert.NotNil(t, err) + eventRecorder.AssertNotCalled(t, "RecordRaw", mock.Anything, mock.Anything) + }) +} + +func TestTerminateSubTasks(t *testing.T) { + ctx := context.Background() + subtaskCount := 3 + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + tests := []struct { + name string + initialPhaseIndices []int + expectedAbortCount int + terminateError error + }{ + { + name: "AllSubTasksRunning", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: nil, + }, + { + name: "MixedSubTaskStates", + initialPhaseIndices: []int{8, 0, 5}, + expectedAbortCount: 1, + terminateError: nil, + }, + { + name: "TerminateFunctionFails", + initialPhaseIndices: []int{5, 5, 5}, + expectedAbortCount: 3, + terminateError: fmt.Errorf("error"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i, phaseIdx := range test.initialPhaseIndices { + compactArray.SetItem(i, bitarray.Item(phaseIdx)) + } + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + PhaseVersion: 0, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: compactArray, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), + } + + tCtx := getMockTaskExecutionContext(ctx, 0) + terminateCounter := 0 + mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error { + terminateCounter++ + return test.terminateError + } + + nextState, externalResources, err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState) + + assert.Equal(t, test.expectedAbortCount, terminateCounter) + + if test.terminateError != nil { + assert.NotNil(t, err) + return + } + + assert.Nil(t, err) + assert.Equal(t, uint32(1), nextState.PhaseVersion) + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, nextState.CurrentPhase) + assert.Len(t, externalResources, terminateCounter) + + for _, externalResource := range externalResources { + phase := core.Phases[externalResource.Phase] + assert.True(t, phase.IsAborted()) + } + }) + } +} diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 10a83cd2fc..1afc986287 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -112,9 +112,10 @@ var ( EventConfig: EventConfig{ RawOutputPolicy: RawOutputPolicyReference, }, - ClusterID: "propeller", - CreateFlyteWorkflowCRD: false, - ArrayNodeEventVersion: 0, + ClusterID: "propeller", + CreateFlyteWorkflowCRD: false, + ArrayNodeEventVersion: 0, + NodeExecutionWorkerCount: 8, } ) @@ -155,6 +156,7 @@ type Config struct { ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` ArrayNodeEventVersion int `json:"array-node-event-version" pflag:",ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new"` + NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index 8e9c71bcdb..07a4fba742 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -108,5 +108,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "cluster-id"), defaultConfig.ClusterID, "Unique cluster id running this flytepropeller instance with which to annotate execution events") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "create-flyteworkflow-crd"), defaultConfig.CreateFlyteWorkflowCRD, "Enable creation of the FlyteWorkflow CRD on startup") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-event-version"), defaultConfig.ArrayNodeEventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index f48d01ebea..54da9e9fe1 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -911,4 +911,18 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_node-execution-worker-count", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("node-execution-worker-count", testValue) + if vInt, err := cmdFlags.GetInt("node-execution-worker-count"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.NodeExecutionWorkerCount) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/controller/nodes/array/execution_context.go b/flytepropeller/pkg/controller/nodes/array/execution_context.go index 2191b9c7d2..4fb5a8a214 100644 --- a/flytepropeller/pkg/controller/nodes/array/execution_context.go +++ b/flytepropeller/pkg/controller/nodes/array/execution_context.go @@ -14,35 +14,24 @@ const ( type arrayExecutionContext struct { executors.ExecutionContext - executionConfig v1alpha1.ExecutionConfig - currentParallelism *uint32 + executionConfig v1alpha1.ExecutionConfig } func (a *arrayExecutionContext) GetExecutionConfig() v1alpha1.ExecutionConfig { return a.executionConfig } -func (a *arrayExecutionContext) CurrentParallelism() uint32 { - return *a.currentParallelism -} - -func (a *arrayExecutionContext) IncrementParallelism() uint32 { - *a.currentParallelism = *a.currentParallelism + 1 - return *a.currentParallelism -} - -func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int, currentParallelism *uint32, maxParallelism uint32) *arrayExecutionContext { +func newArrayExecutionContext(executionContext executors.ExecutionContext, subNodeIndex int) *arrayExecutionContext { executionConfig := executionContext.GetExecutionConfig() if executionConfig.EnvironmentVariables == nil { executionConfig.EnvironmentVariables = make(map[string]string) } executionConfig.EnvironmentVariables[JobIndexVarName] = FlyteK8sArrayIndexVarName executionConfig.EnvironmentVariables[FlyteK8sArrayIndexVarName] = strconv.Itoa(subNodeIndex) - executionConfig.MaxParallelism = maxParallelism + executionConfig.MaxParallelism = 0 // hardcoded to 0 because parallelism is handled by the array node return &arrayExecutionContext{ - ExecutionContext: executionContext, - executionConfig: executionConfig, - currentParallelism: currentParallelism, + ExecutionContext: executionContext, + executionConfig: executionConfig, } } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 00a9fc747e..f1e2ef64fc 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -42,11 +42,13 @@ var ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { - eventConfig *config.EventConfig - metrics metrics - nodeExecutor interfaces.Node - pluginStateBytesNotStarted []byte - pluginStateBytesStarted []byte + eventConfig *config.EventConfig + gatherOutputsRequestChannel chan *gatherOutputsRequest + metrics metrics + nodeExecutionRequestChannel chan *nodeExecutionRequest + nodeExecutor interfaces.Node + pluginStateBytesNotStarted []byte + pluginStateBytesStarted []byte } // metrics encapsulates the prometheus metrics for this handler @@ -70,7 +72,6 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing: - currentParallelism := uint32(0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) @@ -81,7 +82,7 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := - a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism, eventRecorder) + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, eventRecorder) if err != nil { return err } @@ -124,7 +125,6 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe messageCollector := errorcollector.NewErrorMessageCollector() switch arrayNodeState.Phase { case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing, v1alpha1.ArrayNodePhaseSucceeding: - currentParallelism := uint32(0) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) @@ -135,7 +135,7 @@ func (a *arrayNodeHandler) Finalize(ctx context.Context, nCtx interfaces.NodeExe // create array contexts arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, _, err := - a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism, eventRecorder) + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, eventRecorder) if err != nil { return err } @@ -242,8 +242,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu arrayNodeState.Phase = v1alpha1.ArrayNodePhaseExecuting case v1alpha1.ArrayNodePhaseExecuting: // process array node subNodes - currentParallelism := uint32(0) - messageCollector := errorcollector.NewErrorMessageCollector() + currentParallelism := int(arrayNode.GetParallelism()) + if currentParallelism == 0 { + currentParallelism = len(arrayNodeState.SubNodePhases.GetItems()) + } + + nodeExecutionRequests := make([]*nodeExecutionRequest, 0, currentParallelism) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(i)) @@ -254,43 +258,97 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu } // create array contexts + subNodeEventRecorder := newArrayEventRecorder(nCtx.EventsRecorder()) arrayNodeExecutor, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec, subNodeStatus, err := - a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, ¤tParallelism, eventRecorder) + a.buildArrayNodeContext(ctx, nCtx, &arrayNodeState, arrayNode, i, subNodeEventRecorder) if err != nil { return handler.UnknownTransition, err } - // execute subNode - _, err = arrayNodeExecutor.RecursiveNodeHandler(ctx, arrayExecutionContext, arrayDAGStructure, arrayNodeLookup, subNodeSpec) - if err != nil { - return handler.UnknownTransition, err + nodeExecutionRequest := &nodeExecutionRequest{ + ctx: ctx, + index: i, + nodePhase: nodePhase, + taskPhase: taskPhase, + nodeExecutor: arrayNodeExecutor, + executionContext: arrayExecutionContext, + dagStructure: arrayDAGStructure, + nodeLookup: arrayNodeLookup, + subNodeSpec: subNodeSpec, + subNodeStatus: subNodeStatus, + arrayEventRecorder: subNodeEventRecorder, + responseChannel: make(chan struct { + interfaces.NodeStatus + error + }, 1), + } + + nodeExecutionRequests = append(nodeExecutionRequests, nodeExecutionRequest) + a.nodeExecutionRequestChannel <- nodeExecutionRequest + + // TODO - this is a naive implementation of parallelism, if we want to support more + // complex subNodes (ie. dynamics / subworkflows) we need to revisit this so that + // parallelism is handled during subNode evaluations. + currentParallelism-- + if currentParallelism == 0 { + break + } + } + + workerErrorCollector := errorcollector.NewErrorMessageCollector() + subNodeFailureCollector := errorcollector.NewErrorMessageCollector() + for i, nodeExecutionRequest := range nodeExecutionRequests { + nodeExecutionResponse := <-nodeExecutionRequest.responseChannel + if nodeExecutionResponse.error != nil { + workerErrorCollector.Collect(i, nodeExecutionResponse.error.Error()) + continue } + index := nodeExecutionRequest.index + subNodeStatus := nodeExecutionRequest.subNodeStatus + // capture subNode error if exists - if subNodeStatus.Error != nil { - messageCollector.Collect(i, subNodeStatus.Error.Message) + if nodeExecutionRequest.subNodeStatus.Error != nil { + subNodeFailureCollector.Collect(index, subNodeStatus.Error.Message) } - // process events - eventRecorder.process(ctx, nCtx, i, subNodeStatus.GetAttempts()) + // process events by copying from internal event recorder + if arrayEventRecorder, ok := nodeExecutionRequest.arrayEventRecorder.(*externalResourcesEventRecorder); ok { + for _, event := range arrayEventRecorder.taskEvents { + if err := eventRecorder.RecordTaskEvent(ctx, event, a.eventConfig); err != nil { + return handler.UnknownTransition, err + } + } + for _, event := range arrayEventRecorder.nodeEvents { + if err := eventRecorder.RecordNodeEvent(ctx, event, a.eventConfig); err != nil { + return handler.UnknownTransition, err + } + } + } + eventRecorder.process(ctx, nCtx, index, subNodeStatus.GetAttempts()) // update subNode state - arrayNodeState.SubNodePhases.SetItem(i, uint64(subNodeStatus.GetPhase())) + arrayNodeState.SubNodePhases.SetItem(index, uint64(subNodeStatus.GetPhase())) if subNodeStatus.GetTaskNodeStatus() == nil { // resetting task phase because during retries we clear the GetTaskNodeStatus - arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(0)) + arrayNodeState.SubNodeTaskPhases.SetItem(index, uint64(0)) } else { - arrayNodeState.SubNodeTaskPhases.SetItem(i, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) + arrayNodeState.SubNodeTaskPhases.SetItem(index, uint64(subNodeStatus.GetTaskNodeStatus().GetPhase())) } - arrayNodeState.SubNodeRetryAttempts.SetItem(i, uint64(subNodeStatus.GetAttempts())) - arrayNodeState.SubNodeSystemFailures.SetItem(i, uint64(subNodeStatus.GetSystemFailures())) + arrayNodeState.SubNodeRetryAttempts.SetItem(index, uint64(subNodeStatus.GetAttempts())) + arrayNodeState.SubNodeSystemFailures.SetItem(index, uint64(subNodeStatus.GetSystemFailures())) // increment task phase version if subNode phase or task phase changed - if subNodeStatus.GetPhase() != nodePhase || subNodeStatus.GetTaskNodeStatus().GetPhase() != taskPhase { + if subNodeStatus.GetPhase() != nodeExecutionRequest.nodePhase || subNodeStatus.GetTaskNodeStatus().GetPhase() != nodeExecutionRequest.taskPhase { incrementTaskPhaseVersion = true } } + // if any workers failed then return the error + if workerErrorCollector.Length() > 0 { + return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) + } + // process phases of subNodes to determine overall `ArrayNode` phase successCount := 0 failedCount := 0 @@ -321,7 +379,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // if there is a failing node set the error message if it has not been previous set if failingCount > 0 && arrayNodeState.Error == nil { arrayNodeState.Error = &idlcore.ExecutionError{ - Message: messageCollector.Summary(events.MaxErrorMessageLength), + Message: subNodeFailureCollector.Summary(events.MaxErrorMessageLength), } } @@ -349,33 +407,51 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu nil, )), nil case v1alpha1.ArrayNodePhaseSucceeding: - outputLiterals := make(map[string]*idlcore.Literal) + gatherOutputsRequests := make([]*gatherOutputsRequest, 0, len(arrayNodeState.SubNodePhases.GetItems())) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) + gatherOutputsRequest := &gatherOutputsRequest{ + ctx: ctx, + responseChannel: make(chan struct { + literalMap map[string]*idlcore.Literal + error + }, 1), + } if nodePhase != v1alpha1.NodePhaseSucceeded { // retrieve output variables from task template - var outputVariables map[string]*idlcore.Variable + outputLiterals := make(map[string]*idlcore.Literal) task, err := nCtx.ExecutionContext().GetTask(*arrayNode.GetSubNodeSpec().TaskRef) if err != nil { // Should never happen - return handler.UnknownTransition, err + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{nil, err} + continue } if task.CoreTask() != nil && task.CoreTask().Interface != nil && task.CoreTask().Interface.Outputs != nil { - outputVariables = task.CoreTask().Interface.Outputs.Variables + for name := range task.CoreTask().Interface.Outputs.Variables { + outputLiterals[name] = nilLiteral + } } - // append nil literal for all output variables - for name := range outputVariables { - appendLiteral(name, nilLiteral, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) - } + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{outputLiterals, nil} } else { // initialize subNode reader - currentAttempt := uint32(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) - subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, strconv.Itoa(i), strconv.Itoa(int(currentAttempt))) + currentAttempt := int(arrayNodeState.SubNodeRetryAttempts.GetItem(i)) + subDataDir, subOutputDir, err := constructOutputReferences(ctx, nCtx, + strconv.Itoa(i), strconv.Itoa(currentAttempt)) if err != nil { - return handler.UnknownTransition, err + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{nil, err} + continue } // checkpoint paths are not computed here because this function is only called when writing @@ -383,22 +459,33 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu outputPaths := ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), subOutputDir, ioutils.NewRawOutputPaths(ctx, subDataDir), "") reader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes()) - // read outputs - outputs, executionErr, err := reader.Read(ctx) - if err != nil { - return handler.UnknownTransition, err - } else if executionErr != nil { - return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), - "execution error ArrayNode output, bad state: %s", executionErr.String()) - } + gatherOutputsRequest.reader = &reader + a.gatherOutputsRequestChannel <- gatherOutputsRequest + } - // copy individual subNode output literals into a collection of output literals - for name, literal := range outputs.GetLiterals() { - appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) - } + gatherOutputsRequests = append(gatherOutputsRequests, gatherOutputsRequest) + } + + outputLiterals := make(map[string]*idlcore.Literal) + workerErrorCollector := errorcollector.NewErrorMessageCollector() + for i, gatherOutputsRequest := range gatherOutputsRequests { + outputResponse := <-gatherOutputsRequest.responseChannel + if outputResponse.error != nil { + workerErrorCollector.Collect(i, outputResponse.error.Error()) + continue + } + + // append literal for all output variables + for name, literal := range outputResponse.literalMap { + appendLiteral(name, literal, outputLiterals, len(arrayNodeState.SubNodePhases.GetItems())) } } + // if any workers failed then return the error + if workerErrorCollector.Length() > 0 { + return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) + } + outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } @@ -460,6 +547,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // Setup handles any initialization requirements for this handler func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) error { + // start workers + for i := 0; i < config.GetConfig().NodeExecutionWorkerCount; i++ { + worker := worker{ + gatherOutputsRequestChannel: a.gatherOutputsRequestChannel, + nodeExecutionRequestChannel: a.nodeExecutionRequestChannel, + } + + go func() { + worker.run() + }() + } + return nil } @@ -478,11 +577,13 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ - eventConfig: eventConfig, - metrics: newMetrics(arrayScope), - nodeExecutor: nodeExecutor, - pluginStateBytesNotStarted: pluginStateBytesNotStarted, - pluginStateBytesStarted: pluginStateBytesStarted, + eventConfig: eventConfig, + gatherOutputsRequestChannel: make(chan *gatherOutputsRequest), + metrics: newMetrics(arrayScope), + nodeExecutionRequestChannel: make(chan *nodeExecutionRequest), + nodeExecutor: nodeExecutor, + pluginStateBytesNotStarted: pluginStateBytesNotStarted, + pluginStateBytesStarted: pluginStateBytesStarted, }, nil } @@ -491,7 +592,7 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr // but need many different execution details, for example setting input values as a singular item rather than a collection, // injecting environment variables for flytekit maptask execution, aggregating eventing so that rather than tracking state for // each subnode individually it sends a single event for the whole ArrayNode, and many more. -func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, currentParallelism *uint32, eventRecorder arrayEventRecorder) ( +func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx interfaces.NodeExecutionContext, arrayNodeState *handler.ArrayNodeState, arrayNode v1alpha1.ExecutableArrayNode, subNodeIndex int, eventRecorder arrayEventRecorder) ( interfaces.Node, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, *v1alpha1.NodeSpec, *v1alpha1.NodeStatus, error) { nodePhase := v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(subNodeIndex)) @@ -556,12 +657,10 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter if err != nil { return nil, nil, nil, nil, nil, nil, err } - arrayExecutionContext := newArrayExecutionContext( - executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), - subNodeIndex, currentParallelism, arrayNode.GetParallelism()) + arrayExecutionContext := newArrayExecutionContext(executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), subNodeIndex) arrayNodeExecutionContextBuilder := newArrayNodeExecutionContextBuilder(a.nodeExecutor.GetNodeExecutionContextBuilder(), - subNodeID, subNodeIndex, subNodeStatus, inputReader, currentParallelism, arrayNode.GetParallelism(), eventRecorder) + subNodeID, subNodeIndex, subNodeStatus, inputReader, eventRecorder) arrayNodeExecutor := a.nodeExecutor.WithNodeExecutionContextBuilder(arrayNodeExecutionContextBuilder) return arrayNodeExecutor, arrayExecutionContext, &arrayNodeLookup, &arrayNodeLookup, &subNodeSpec, subNodeStatus, nil diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index f9086218c2..b2e85c0979 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -62,7 +62,13 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter assert.NoError(t, err) // return ArrayNodeHandler - return New(nodeExecutor, eventConfig, scope) + arrayNodeHandler, err := New(nodeExecutor, eventConfig, scope) + if err != nil { + return nil, err + } + + err = arrayNodeHandler.Setup(ctx, nil) + return arrayNodeHandler, err } func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder interfaces.EventRecorder, outputVariables []string, @@ -496,11 +502,10 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { }, subNodeTransitions: []handler.Transition{ handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), - handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, expectedTransitionPhase: handler.EPhaseRunning, - expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_QUEUED}, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, }, { name: "AllSubNodesSuccedeed", diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go index 17d46d2944..6ef7bb01c1 100644 --- a/flytepropeller/pkg/controller/nodes/array/node_execution_context.go +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context.go @@ -104,8 +104,10 @@ func (a *arrayNodeExecutionContext) TaskReader() interfaces.TaskReader { return a.taskReader } -func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, eventRecorder arrayEventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus, currentParallelism *uint32, maxParallelism uint32) *arrayNodeExecutionContext { - arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex, currentParallelism, maxParallelism) +func newArrayNodeExecutionContext(nodeExecutionContext interfaces.NodeExecutionContext, inputReader io.InputReader, + eventRecorder arrayEventRecorder, subNodeIndex int, nodeStatus *v1alpha1.NodeStatus) *arrayNodeExecutionContext { + + arrayExecutionContext := newArrayExecutionContext(nodeExecutionContext.ExecutionContext(), subNodeIndex) return &arrayNodeExecutionContext{ NodeExecutionContext: nodeExecutionContext, eventRecorder: eventRecorder, diff --git a/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go b/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go index 6d0cfd3bfb..b66ae4a54d 100644 --- a/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go +++ b/flytepropeller/pkg/controller/nodes/array/node_execution_context_builder.go @@ -10,14 +10,12 @@ import ( ) type arrayNodeExecutionContextBuilder struct { - nCtxBuilder interfaces.NodeExecutionContextBuilder - subNodeID v1alpha1.NodeID - subNodeIndex int - subNodeStatus *v1alpha1.NodeStatus - inputReader io.InputReader - currentParallelism *uint32 - maxParallelism uint32 - eventRecorder arrayEventRecorder + nCtxBuilder interfaces.NodeExecutionContextBuilder + subNodeID v1alpha1.NodeID + subNodeIndex int + subNodeStatus *v1alpha1.NodeStatus + inputReader io.InputReader + eventRecorder arrayEventRecorder } func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context.Context, executionContext executors.ExecutionContext, @@ -31,23 +29,21 @@ func (a *arrayNodeExecutionContextBuilder) BuildNodeExecutionContext(ctx context if currentNodeID == a.subNodeID { // overwrite NodeExecutionContext for ArrayNode execution - nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus, a.currentParallelism, a.maxParallelism) + nCtx = newArrayNodeExecutionContext(nCtx, a.inputReader, a.eventRecorder, a.subNodeIndex, a.subNodeStatus) } return nCtx, nil } -func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, subNodeIndex int, subNodeStatus *v1alpha1.NodeStatus, - inputReader io.InputReader, currentParallelism *uint32, maxParallelism uint32, eventRecorder arrayEventRecorder) interfaces.NodeExecutionContextBuilder { +func newArrayNodeExecutionContextBuilder(nCtxBuilder interfaces.NodeExecutionContextBuilder, subNodeID v1alpha1.NodeID, subNodeIndex int, + subNodeStatus *v1alpha1.NodeStatus, inputReader io.InputReader, eventRecorder arrayEventRecorder) interfaces.NodeExecutionContextBuilder { return &arrayNodeExecutionContextBuilder{ - nCtxBuilder: nCtxBuilder, - subNodeID: subNodeID, - subNodeIndex: subNodeIndex, - subNodeStatus: subNodeStatus, - inputReader: inputReader, - currentParallelism: currentParallelism, - maxParallelism: maxParallelism, - eventRecorder: eventRecorder, + nCtxBuilder: nCtxBuilder, + subNodeID: subNodeID, + subNodeIndex: subNodeIndex, + subNodeStatus: subNodeStatus, + inputReader: inputReader, + eventRecorder: eventRecorder, } } diff --git a/flytepropeller/pkg/controller/nodes/array/worker.go b/flytepropeller/pkg/controller/nodes/array/worker.go new file mode 100644 index 0000000000..b5b5db49da --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/array/worker.go @@ -0,0 +1,105 @@ +package array + +import ( + "context" + "fmt" + "runtime/debug" + + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flyte/flytestdlib/logger" +) + +// nodeExecutionRequest is a request to execute an ArrayNode subNode +type nodeExecutionRequest struct { + ctx context.Context + index int + nodePhase v1alpha1.NodePhase + taskPhase int + nodeExecutor interfaces.Node + executionContext executors.ExecutionContext + dagStructure executors.DAGStructure + nodeLookup executors.NodeLookup + subNodeSpec *v1alpha1.NodeSpec + subNodeStatus *v1alpha1.NodeStatus + arrayEventRecorder arrayEventRecorder + responseChannel chan struct { + interfaces.NodeStatus + error + } +} + +// gatherOutputsRequest is a request to read outputs from an ArrayNode subNode +type gatherOutputsRequest struct { + ctx context.Context + reader *ioutils.RemoteFileOutputReader + responseChannel chan struct { + literalMap map[string]*idlcore.Literal + error + } +} + +// worker is an entity that is used to parallelize I/O bound operations for ArrayNode execution +type worker struct { + gatherOutputsRequestChannel chan *gatherOutputsRequest + nodeExecutionRequestChannel chan *nodeExecutionRequest +} + +// run starts the main handle loop for the worker +func (w *worker) run() { + for { + select { + case nodeExecutionRequest := <-w.nodeExecutionRequestChannel: + var nodeStatus interfaces.NodeStatus + var err error + func() { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when executing ArrayNode subNode, Stack: [%s]", string(stack)) + logger.Errorf(nodeExecutionRequest.ctx, err.Error()) + } + }() + + // execute RecurseNodeHandler on node + nodeStatus, err = nodeExecutionRequest.nodeExecutor.RecursiveNodeHandler(nodeExecutionRequest.ctx, nodeExecutionRequest.executionContext, + nodeExecutionRequest.dagStructure, nodeExecutionRequest.nodeLookup, nodeExecutionRequest.subNodeSpec) + }() + + nodeExecutionRequest.responseChannel <- struct { + interfaces.NodeStatus + error + }{nodeStatus, err} + case gatherOutputsRequest := <-w.gatherOutputsRequestChannel: + var literalMap map[string]*idlcore.Literal + var err error + func() { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + err = fmt.Errorf("panic when executing ArrayNode subNode, Stack: [%s]", string(stack)) + logger.Errorf(gatherOutputsRequest.ctx, err.Error()) + } + }() + + // read outputs + outputs, executionErr, gatherErr := gatherOutputsRequest.reader.Read(gatherOutputsRequest.ctx) + if gatherErr != nil { + err = gatherErr + } else if executionErr != nil { + err = fmt.Errorf("%s", executionErr.String()) + } else { + literalMap = outputs.GetLiterals() + } + }() + + gatherOutputsRequest.responseChannel <- struct { + literalMap map[string]*idlcore.Literal + error + }{literalMap, nil} + } + } +} diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 72dba11a50..5e4139296f 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -753,6 +753,7 @@ func (t *Handler) ValidateOutput(ctx context.Context, nodeID v1alpha1.NodeID, i func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { taskNodeState := nCtx.NodeStateReader().GetTaskNodeState() currentPhase := taskNodeState.PluginPhase + currentPhaseVersion := taskNodeState.PluginPhaseVersion logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase) if currentPhase.IsTerminal() && !(currentPhase.IsFailure() && taskNodeState.CleanupOnFailure) { @@ -790,8 +791,40 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext logger.Errorf(ctx, "Abort failed when calling plugin abort.") return err } - taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() + evRecorder := nCtx.EventsRecorder() + logger.Debugf(ctx, "Sending buffered Task events.") + for _, ev := range tCtx.ber.GetAll(ctx) { + evInfo, err := ToTaskExecutionEvent(ToTaskExecutionEventInputs{ + TaskExecContext: tCtx, + InputReader: nCtx.InputReader(), + EventConfig: t.eventConfig, + OutputWriter: tCtx.ow, + Info: ev.WithVersion(currentPhaseVersion + 1), + NodeExecutionMetadata: nCtx.NodeExecutionMetadata(), + ExecContext: nCtx.ExecutionContext(), + TaskType: ttype, + PluginID: p.GetID(), + ResourcePoolInfo: tCtx.rm.GetResourcePoolInfo(), + ClusterID: t.clusterID, + }) + if err != nil { + return err + } + if currentPhase.IsFailure() { + evInfo.Phase = core.TaskExecution_FAILED + } else { + evInfo.Phase = core.TaskExecution_ABORTED + } + if err := evRecorder.RecordTaskEvent(ctx, evInfo, t.eventConfig); err != nil { + logger.Errorf(ctx, "Event recording failed for Plugin [%s], eventPhase [%s], error :%s", p.GetID(), evInfo.Phase.String(), err.Error()) + // Check for idempotency + // Check for terminate state error + return err + } + } + + taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() nodeExecutionID, err := getParentNodeExecIDForTask(&taskExecID, nCtx.ExecutionContext()) if err != nil { return err diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index e026419dc0..242c1334ce 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -44,6 +44,8 @@ func ToTaskEventPhase(p pluginCore.Phase) core.TaskExecution_Phase { return core.TaskExecution_FAILED case pluginCore.PhaseRetryableFailure: return core.TaskExecution_FAILED + case pluginCore.PhaseAborted: + return core.TaskExecution_ABORTED case pluginCore.PhaseNotReady: fallthrough case pluginCore.PhaseUndefined: @@ -117,13 +119,14 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio metadata.ExternalResources = make([]*event.ExternalResourceInfo, len(externalResources)) for idx, e := range input.Info.Info().ExternalResources { + phase := ToTaskEventPhase(e.Phase) metadata.ExternalResources[idx] = &event.ExternalResourceInfo{ ExternalId: e.ExternalID, CacheStatus: e.CacheStatus, Index: e.Index, Logs: e.Logs, RetryAttempt: e.RetryAttempt, - Phase: ToTaskEventPhase(e.Phase), + Phase: phase, } } }