From e866a0a6f7c6c876b10335b0eec8047e8f60335f Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Wed, 12 Jul 2023 16:18:14 -0500 Subject: [PATCH] parallelized node executions working Signed-off-by: Daniel Rammer --- pkg/controller/executors/execution_context.go | 85 ++++++++++++ pkg/controller/executors/node.go | 2 +- pkg/controller/nodes/branch/handler.go | 5 +- .../nodes/dynamic/dynamic_workflow.go | 5 +- pkg/controller/nodes/executor.go | 126 +++++++++++++----- .../nodes/subworkflow/subworkflow.go | 10 +- pkg/controller/workflow/executor.go | 11 +- 7 files changed, 207 insertions(+), 37 deletions(-) diff --git a/pkg/controller/executors/execution_context.go b/pkg/controller/executors/execution_context.go index 53c0bbbd1..dd4d70919 100644 --- a/pkg/controller/executors/execution_context.go +++ b/pkg/controller/executors/execution_context.go @@ -1,6 +1,7 @@ package executors import ( + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) @@ -33,12 +34,18 @@ type ControlFlow interface { IncrementParallelism() uint32 } +type NodeExecutor interface { + AddNodeFuture(func(chan<- NodeExecutionResult)) + Wait() (NodeStatus, error) +} + type ExecutionContext interface { ImmutableExecutionContext TaskDetailsGetter SubWorkflowGetter ParentInfoGetter ControlFlow + NodeExecutor } type execContext struct { @@ -47,6 +54,7 @@ type execContext struct { TaskDetailsGetter SubWorkflowGetter parentInfo ImmutableParentInfo + NodeExecutor } func (e execContext) GetParentInfo() ImmutableParentInfo { @@ -81,6 +89,82 @@ func (c *controlFlow) IncrementParallelism() uint32 { return c.v } +type NodeExecutionResult struct { + Err error + NodeStatus NodeStatus +} + +type nodeExecutor struct { + nodeFutures []<-chan NodeExecutionResult +} + +func (n *nodeExecutor) AddNodeFuture(f func(chan<- NodeExecutionResult)) { + nodeFuture := make(chan NodeExecutionResult, 1) + go f(nodeFuture) + n.nodeFutures = append(n.nodeFutures, nodeFuture) +} + +func (n *nodeExecutor) Wait() (NodeStatus, error) { + if len(n.nodeFutures) == 0 { + return NodeStatusComplete, nil + } + + // If any downstream node is failed, fail, all + // Else if all are success then success + // Else if any one is running then Downstream is still running + allCompleted := true + partialNodeCompletion := false + onFailurePolicy := v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY) + //onFailurePolicy := execContext.GetOnFailurePolicy() // TODO @hamersaw - need access to this + stateOnComplete := NodeStatusComplete + for _, nodeFuture := range n.nodeFutures { + nodeExecutionResult := <-nodeFuture + state := nodeExecutionResult.NodeStatus + err := nodeExecutionResult.Err + if err != nil { // TODO @hamersaw - do we want to fail right away? or wait until all nodes are done? + return NodeStatusUndefined, err + } + + if state.HasFailed() || state.HasTimedOut() { + // TODO @hamersaw - Debug? + //logger.Debugf(ctx, "Some downstream node has failed. Failed: [%v]. TimedOut: [%v]. Error: [%s]", state.HasFailed(), state.HasTimedOut(), state.Err) + if onFailurePolicy == v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) { + // If the failure policy allows other nodes to continue running, do not exit the loop, + // Keep track of the last failed state in the loop since it'll be the one to return. + // TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one. + stateOnComplete = state + } else { + return state, nil + } + } else if !state.IsComplete() { + // A Failed/Timedout node is implicitly considered "complete" this means none of the downstream nodes from + // that node will ever be allowed to run. + // This else block, therefore, deals with all other states. IsComplete will return true if and only if this + // node as well as all of its downstream nodes have finished executing with success statuses. Otherwise we + // mark this node's state as not completed to ensure we will visit it again later. + allCompleted = false + } + + if state.PartiallyComplete() { + // This implies that one of the downstream nodes has just succeeded and workflow is ready for propagation + // We do not propagate in current cycle to make it possible to store the state between transitions + partialNodeCompletion = true + } + } + + if allCompleted { + // TODO @hamersaw - Debug? + //logger.Debugf(ctx, "All downstream nodes completed") + return stateOnComplete, nil + } + + if partialNodeCompletion { + return NodeStatusSuccess, nil + } + + return NodeStatusPending, nil +} + func NewExecutionContextWithTasksGetter(prevExecContext ExecutionContext, taskGetter TaskDetailsGetter) ExecutionContext { return NewExecutionContext(prevExecContext, taskGetter, prevExecContext, prevExecContext.GetParentInfo(), prevExecContext) } @@ -100,6 +184,7 @@ func NewExecutionContext(immExecContext ImmutableExecutionContext, tasksGetter T SubWorkflowGetter: workflowGetter, parentInfo: parentInfo, ControlFlow: flow, + NodeExecutor: &nodeExecutor{}, } } diff --git a/pkg/controller/executors/node.go b/pkg/controller/executors/node.go index a8f738c3e..e172172c2 100644 --- a/pkg/controller/executors/node.go +++ b/pkg/controller/executors/node.go @@ -74,7 +74,7 @@ type Node interface { // - 1. It finds a blocking node (not ready, or running) // - 2. A node fails and hence the workflow will fail // - 3. The final/end node has completed and the workflow should be stopped - RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) error // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them AbortHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 109290b90..53b3dcca4 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -145,7 +145,10 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node if err != nil { return handler.UnknownTransition, err } - downstreamStatus, err := b.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode) + if err := b.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode); err != nil { + return handler.UnknownTransition, err + } + downstreamStatus, err := execContext.Wait() if err != nil { return handler.UnknownTransition, err } diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go index 6166d8722..50ae53ef0 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -267,7 +267,10 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nC func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, execContext executors.ExecutionContext, dynamicWorkflow v1alpha1.ExecutableWorkflow, nl executors.NodeLookup, nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { - state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()) + if err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()); err != nil { + return handler.UnknownTransition, prevState, err + } + state, err := execContext.Wait() if err != nil { return handler.UnknownTransition, prevState, err } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index c447b779c..ecda540f1 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -822,6 +822,41 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur // The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from // the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. +func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { + logger.Debugf(ctx, "Handling downstream Nodes") + // This node is success. Handle all downstream nodes + downstreamNodes, err := dag.FromNode(currentNode.GetID()) + if err != nil { + logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) + /*return executors.NodeStatusFailed(&core.ExecutionError{ + Code: errors.BadSpecificationError, + Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()), + Kind: core.ExecutionError_SYSTEM, + }), nil*/ + return errors2.Errorf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()) + } + for _, downstreamNodeName := range downstreamNodes { + downstreamNode, ok := nl.GetNode(downstreamNodeName) + if !ok { + /*return executors.NodeStatusFailed(&core.ExecutionError{ + Code: errors.BadSpecificationError, + Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()), + Kind: core.ExecutionError_SYSTEM, + }), nil*/ + return errors2.Errorf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()) + } + + if err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil { + return err + } + } + + return nil +} + +// TODO @hamersaw - move to nodeExecutor.wait +/*// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from +// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { logger.Debugf(ctx, "Handling downstream Nodes") // This node is success. Handle all downstream nodes @@ -896,7 +931,7 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executo } return executors.NodeStatusPending, nil -} +}*/ func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { startNode := dag.StartNode() @@ -968,8 +1003,7 @@ func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.Executab // nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes // The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes. func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, - dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) ( - executors.NodeStatus, error) { + dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) @@ -991,31 +1025,48 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created if nodeStatus.IsDirty() { - return executors.NodeStatusRunning, nil + return nil } if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) { - return executors.NodeStatusRunning, nil + return nil } - nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) - if err != nil { - // NodeExecution creation failure is a permanent fail / system error. - // Should a system failure always return an err? - return executors.NodeStatusFailed(&core.ExecutionError{ - Code: "InternalError", - Message: err.Error(), - Kind: core.ExecutionError_SYSTEM, - }), nil - } + execContext.AddNodeFuture(func(nodeFuture chan<- executors.NodeExecutionResult) { + nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) + if err != nil { + // NodeExecution creation failure is a permanent fail / system error. + // Should a system failure always return an err? + nodeFuture <- executors.NodeExecutionResult{ + NodeStatus: executors.NodeStatusFailed(&core.ExecutionError{ + Code: "InternalError", + Message: err.Error(), + Kind: core.ExecutionError_SYSTEM, + }), + Err: err, + } + return + } - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) - if err != nil { - return executors.NodeStatusUndefined, err - } + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) + if err != nil { + nodeFuture <- executors.NodeExecutionResult{ + NodeStatus: executors.NodeStatusUndefined, + Err: err, + } + return + } - return c.handleNode(currentNodeCtx, dag, nCtx, h) + // TODO @hamersaw - remove + //return c.handleNode(currentNodeCtx, dag, nCtx, h) + nodeStatus, err := c.handleNode(currentNodeCtx, dag, nCtx, h) + nodeFuture <- executors.NodeExecutionResult{ + NodeStatus: nodeStatus, + Err: err, + } + }) + return nil // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped @@ -1025,23 +1076,38 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe return c.handleDownstream(ctx, execContext, dag, nl, currentNode) } else if nodePhase == v1alpha1.NodePhaseFailed { logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.") - _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) - if err != nil { - return executors.NodeStatusUndefined, err + if err := c.handleDownstream(ctx, execContext, dag, nl, currentNode); err != nil { + return err } - return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + // TODO @hamersaw - remove + //return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil + execContext.AddNodeFuture(func(nodeFuture chan<- executors.NodeExecutionResult) { + nodeFuture <- executors.NodeExecutionResult{ + NodeStatus: executors.NodeStatusFailed(nodeStatus.GetExecutionError()), + } + }) + return nil } else if nodePhase == v1alpha1.NodePhaseTimedOut { logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.") - _, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode) - if err != nil { - return executors.NodeStatusUndefined, err + if err := c.handleDownstream(ctx, execContext, dag, nl, currentNode); err != nil { + return err } - return executors.NodeStatusTimedOut, nil + // TODO @hamersaw - remove + //return executors.NodeStatusTimedOut, nil + execContext.AddNodeFuture(func(nodeFuture chan<- executors.NodeExecutionResult) { + nodeFuture <- executors.NodeExecutionResult{ + NodeStatus: executors.NodeStatusTimedOut, + } + }) + return nil } - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), + // TODO @hamersaw - remove + //return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), + //"Should never reach here. Current Phase: %v", nodePhase) + return errors.Errorf(errors.IllegalStateError, currentNode.GetID(), "Should never reach here. Current Phase: %v", nodePhase) } diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index 74beeaf79..5982ebba6 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -71,7 +71,10 @@ func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler return handler.UnknownTransition, err } execContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) - state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.StartNode()) + if err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.StartNode()); err != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err + } + state, err := execContext.Wait() if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } @@ -150,7 +153,10 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, if err != nil { return handler.UnknownTransition, err } - state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode()) + if err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode()); err != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err + } + state, err := execContext.Wait() if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index e3eac1e37..d5d4b2cb0 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -144,7 +144,11 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha Message: "Start node not found"}), nil } execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) - state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, startNode) + err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, startNode) + if err != nil { + return StatusRunning, err + } + state, err := execcontext.Wait() if err != nil { return StatusRunning, err } @@ -171,7 +175,10 @@ func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.Fl execErr := executionErrorOrDefault(w.GetExecutionStatus().GetExecutionError(), w.GetExecutionStatus().GetMessage()) errorNode := w.GetOnFailureNode() execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) - state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, errorNode) + if err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, errorNode); err != nil { + return StatusFailureNode(execErr), err + } + state, err := execcontext.Wait() if err != nil { return StatusFailureNode(execErr), err }