Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Parallelize Node Evaluations #590

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions pkg/controller/executors/execution_context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package executors

import (
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
)

Expand Down Expand Up @@ -33,12 +34,18 @@ type ControlFlow interface {
IncrementParallelism() uint32
}

type NodeExecutor interface {
AddNodeFuture(func(chan<- NodeExecutionResult))
Wait() (NodeStatus, error)
}

type ExecutionContext interface {
ImmutableExecutionContext
TaskDetailsGetter
SubWorkflowGetter
ParentInfoGetter
ControlFlow
NodeExecutor
}

type execContext struct {
Expand All @@ -47,6 +54,7 @@ type execContext struct {
TaskDetailsGetter
SubWorkflowGetter
parentInfo ImmutableParentInfo
NodeExecutor
}

func (e execContext) GetParentInfo() ImmutableParentInfo {
Expand Down Expand Up @@ -81,6 +89,82 @@ func (c *controlFlow) IncrementParallelism() uint32 {
return c.v
}

type NodeExecutionResult struct {
Err error
NodeStatus NodeStatus
}

type nodeExecutor struct {
nodeFutures []<-chan NodeExecutionResult
}

func (n *nodeExecutor) AddNodeFuture(f func(chan<- NodeExecutionResult)) {
nodeFuture := make(chan NodeExecutionResult, 1)
go f(nodeFuture)
n.nodeFutures = append(n.nodeFutures, nodeFuture)
}

func (n *nodeExecutor) Wait() (NodeStatus, error) {
if len(n.nodeFutures) == 0 {
return NodeStatusComplete, nil
}

// If any downstream node is failed, fail, all
// Else if all are success then success
// Else if any one is running then Downstream is still running
allCompleted := true
partialNodeCompletion := false
onFailurePolicy := v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY)
//onFailurePolicy := execContext.GetOnFailurePolicy() // TODO @hamersaw - need access to this
stateOnComplete := NodeStatusComplete
for _, nodeFuture := range n.nodeFutures {
nodeExecutionResult := <-nodeFuture
state := nodeExecutionResult.NodeStatus
err := nodeExecutionResult.Err
if err != nil { // TODO @hamersaw - do we want to fail right away? or wait until all nodes are done?
return NodeStatusUndefined, err
}

if state.HasFailed() || state.HasTimedOut() {
// TODO @hamersaw - Debug?
//logger.Debugf(ctx, "Some downstream node has failed. Failed: [%v]. TimedOut: [%v]. Error: [%s]", state.HasFailed(), state.HasTimedOut(), state.Err)
if onFailurePolicy == v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) {
// If the failure policy allows other nodes to continue running, do not exit the loop,
// Keep track of the last failed state in the loop since it'll be the one to return.
// TODO: If multiple nodes fail (which this mode allows), consolidate/summarize failure states in one.
stateOnComplete = state
} else {
return state, nil
}
} else if !state.IsComplete() {
// A Failed/Timedout node is implicitly considered "complete" this means none of the downstream nodes from
// that node will ever be allowed to run.
// This else block, therefore, deals with all other states. IsComplete will return true if and only if this
// node as well as all of its downstream nodes have finished executing with success statuses. Otherwise we
// mark this node's state as not completed to ensure we will visit it again later.
allCompleted = false
}

if state.PartiallyComplete() {
// This implies that one of the downstream nodes has just succeeded and workflow is ready for propagation
// We do not propagate in current cycle to make it possible to store the state between transitions
partialNodeCompletion = true
}
}

if allCompleted {
// TODO @hamersaw - Debug?
//logger.Debugf(ctx, "All downstream nodes completed")
return stateOnComplete, nil
}

if partialNodeCompletion {
return NodeStatusSuccess, nil
}

return NodeStatusPending, nil
}

func NewExecutionContextWithTasksGetter(prevExecContext ExecutionContext, taskGetter TaskDetailsGetter) ExecutionContext {
return NewExecutionContext(prevExecContext, taskGetter, prevExecContext, prevExecContext.GetParentInfo(), prevExecContext)
}
Expand All @@ -100,6 +184,7 @@ func NewExecutionContext(immExecContext ImmutableExecutionContext, tasksGetter T
SubWorkflowGetter: workflowGetter,
parentInfo: parentInfo,
ControlFlow: flow,
NodeExecutor: &nodeExecutor{},
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/executors/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type Node interface {
// - 1. It finds a blocking node (not ready, or running)
// - 2. A node fails and hence the workflow will fail
// - 3. The final/end node has completed and the workflow should be stopped
RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error)
RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) error

// This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them
AbortHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error
Expand Down
5 changes: 4 additions & 1 deletion pkg/controller/nodes/branch/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node
if err != nil {
return handler.UnknownTransition, err
}
downstreamStatus, err := b.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode)
if err := b.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dag, nCtx.ContextualNodeLookup(), branchTakenNode); err != nil {
return handler.UnknownTransition, err
}
downstreamStatus, err := execContext.Wait()
if err != nil {
return handler.UnknownTransition, err
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/controller/nodes/dynamic/dynamic_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,10 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflow(ctx context.Context, nC
func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, execContext executors.ExecutionContext, dynamicWorkflow v1alpha1.ExecutableWorkflow, nl executors.NodeLookup,
nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) {

state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode())
if err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()); err != nil {
return handler.UnknownTransition, prevState, err
}
state, err := execContext.Wait()
if err != nil {
return handler.UnknownTransition, prevState, err
}
Expand Down
126 changes: 96 additions & 30 deletions pkg/controller/nodes/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,41 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur

// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from
// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure.
func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error {
logger.Debugf(ctx, "Handling downstream Nodes")
// This node is success. Handle all downstream nodes
downstreamNodes, err := dag.FromNode(currentNode.GetID())
if err != nil {
logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err)
/*return executors.NodeStatusFailed(&core.ExecutionError{
Code: errors.BadSpecificationError,
Message: fmt.Sprintf("failed to retrieve downstream nodes for [%s]", currentNode.GetID()),
Kind: core.ExecutionError_SYSTEM,
}), nil*/
return errors2.Errorf("failed to retrieve downstream nodes for [%s]", currentNode.GetID())
}
for _, downstreamNodeName := range downstreamNodes {
downstreamNode, ok := nl.GetNode(downstreamNodeName)
if !ok {
/*return executors.NodeStatusFailed(&core.ExecutionError{
Code: errors.BadSpecificationError,
Message: fmt.Sprintf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID()),
Kind: core.ExecutionError_SYSTEM,
}), nil*/
return errors2.Errorf("failed to retrieve downstream node [%s] for [%s]", downstreamNodeName, currentNode.GetID())
}

if err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil {
return err
}
}

return nil
}

// TODO @hamersaw - move to nodeExecutor.wait
/*// The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from
// the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure.
func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) {
logger.Debugf(ctx, "Handling downstream Nodes")
// This node is success. Handle all downstream nodes
Expand Down Expand Up @@ -896,7 +931,7 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executo
}

return executors.NodeStatusPending, nil
}
}*/

func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) {
startNode := dag.StartNode()
Expand Down Expand Up @@ -968,8 +1003,7 @@ func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.Executab
// nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes
// The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes.
func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext,
dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (
executors.NodeStatus, error) {
dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error {

currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID())
nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID())
Expand All @@ -991,31 +1025,48 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe
// This is an optimization to avoid creating the nodeContext object in case the node has already been looked at.
// If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created
if nodeStatus.IsDirty() {
return executors.NodeStatusRunning, nil
return nil
}

if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) {
return executors.NodeStatusRunning, nil
return nil
}

nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl)
if err != nil {
// NodeExecution creation failure is a permanent fail / system error.
// Should a system failure always return an err?
return executors.NodeStatusFailed(&core.ExecutionError{
Code: "InternalError",
Message: err.Error(),
Kind: core.ExecutionError_SYSTEM,
}), nil
}
execContext.AddNodeFuture(func(nodeFuture chan<- executors.NodeExecutionResult) {
nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl)
if err != nil {
// NodeExecution creation failure is a permanent fail / system error.
// Should a system failure always return an err?
nodeFuture <- executors.NodeExecutionResult{
NodeStatus: executors.NodeStatusFailed(&core.ExecutionError{
Code: "InternalError",
Message: err.Error(),
Kind: core.ExecutionError_SYSTEM,
}),
Err: err,
}
return
}

// Now depending on the node type decide
h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind())
if err != nil {
return executors.NodeStatusUndefined, err
}
// Now depending on the node type decide
h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind())
if err != nil {
nodeFuture <- executors.NodeExecutionResult{
NodeStatus: executors.NodeStatusUndefined,
Err: err,
}
return
}

return c.handleNode(currentNodeCtx, dag, nCtx, h)
// TODO @hamersaw - remove
//return c.handleNode(currentNodeCtx, dag, nCtx, h)
nodeStatus, err := c.handleNode(currentNodeCtx, dag, nCtx, h)
nodeFuture <- executors.NodeExecutionResult{
NodeStatus: nodeStatus,
Err: err,
}
})
return nil

// TODO we can optimize skip state handling by iterating down the graph and marking all as skipped
// Currently we treat either Skip or Success the same way. In this approach only one node will be skipped
Expand All @@ -1025,23 +1076,38 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe
return c.handleDownstream(ctx, execContext, dag, nl, currentNode)
} else if nodePhase == v1alpha1.NodePhaseFailed {
logger.Debugf(currentNodeCtx, "Node has failed, traversing downstream.")
_, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode)
if err != nil {
return executors.NodeStatusUndefined, err
if err := c.handleDownstream(ctx, execContext, dag, nl, currentNode); err != nil {
return err
}

return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil
// TODO @hamersaw - remove
//return executors.NodeStatusFailed(nodeStatus.GetExecutionError()), nil
execContext.AddNodeFuture(func(nodeFuture chan<- executors.NodeExecutionResult) {
nodeFuture <- executors.NodeExecutionResult{
NodeStatus: executors.NodeStatusFailed(nodeStatus.GetExecutionError()),
}
})
return nil
} else if nodePhase == v1alpha1.NodePhaseTimedOut {
logger.Debugf(currentNodeCtx, "Node has timed out, traversing downstream.")
_, err := c.handleDownstream(ctx, execContext, dag, nl, currentNode)
if err != nil {
return executors.NodeStatusUndefined, err
if err := c.handleDownstream(ctx, execContext, dag, nl, currentNode); err != nil {
return err
}

return executors.NodeStatusTimedOut, nil
// TODO @hamersaw - remove
//return executors.NodeStatusTimedOut, nil
execContext.AddNodeFuture(func(nodeFuture chan<- executors.NodeExecutionResult) {
nodeFuture <- executors.NodeExecutionResult{
NodeStatus: executors.NodeStatusTimedOut,
}
})
return nil
}

return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(),
// TODO @hamersaw - remove
//return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(),
//"Should never reach here. Current Phase: %v", nodePhase)
return errors.Errorf(errors.IllegalStateError, currentNode.GetID(),
"Should never reach here. Current Phase: %v", nodePhase)
}

Expand Down
10 changes: 8 additions & 2 deletions pkg/controller/nodes/subworkflow/subworkflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler
return handler.UnknownTransition, err
}
execContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo)
state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.StartNode())
if err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.StartNode()); err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err
}
state, err := execContext.Wait()
if err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err
}
Expand Down Expand Up @@ -150,7 +153,10 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context,
if err != nil {
return handler.UnknownTransition, err
}
state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode())
if err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode()); err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err
}
state, err := execContext.Wait()
if err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err
}
Expand Down
11 changes: 9 additions & 2 deletions pkg/controller/workflow/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha
Message: "Start node not found"}), nil
}
execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow())
state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, startNode)
err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, startNode)
if err != nil {
return StatusRunning, err
}
state, err := execcontext.Wait()
if err != nil {
return StatusRunning, err
}
Expand All @@ -171,7 +175,10 @@ func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.Fl
execErr := executionErrorOrDefault(w.GetExecutionStatus().GetExecutionError(), w.GetExecutionStatus().GetMessage())
errorNode := w.GetOnFailureNode()
execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow())
state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, errorNode)
if err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, w, errorNode); err != nil {
return StatusFailureNode(execErr), err
}
state, err := execcontext.Wait()
if err != nil {
return StatusFailureNode(execErr), err
}
Expand Down