From 3fb61f388115e3a2add975e6aa644f434a05b00e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 15 Nov 2023 19:20:29 -0800 Subject: [PATCH] wip Signed-off-by: Kevin Su --- .../executors/failure_node_lookup.go | 29 +++++++++---------- .../nodes/subworkflow/subworkflow.go | 10 +++++-- .../pkg/controller/workflow/executor.go | 16 ++++------ 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/flytepropeller/pkg/controller/executors/failure_node_lookup.go b/flytepropeller/pkg/controller/executors/failure_node_lookup.go index 15777d5582..a517ba7ead 100644 --- a/flytepropeller/pkg/controller/executors/failure_node_lookup.go +++ b/flytepropeller/pkg/controller/executors/failure_node_lookup.go @@ -6,41 +6,38 @@ import ( ) type FailureNodeLookup struct { - NodeSpec *v1alpha1.NodeSpec - NodeStatus v1alpha1.ExecutableNodeStatus - StartNode v1alpha1.ExecutableNode - StartNodeStatus v1alpha1.ExecutableNodeStatus + NodeLookup + FailureNode v1alpha1.ExecutableNode + FailureNodeStatus v1alpha1.ExecutableNodeStatus } func (f FailureNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { if nodeID == v1alpha1.StartNodeID { - return f.StartNode, true + return f.NodeLookup.GetNode(nodeID) } - return f.NodeSpec, true + return f.FailureNode, true } func (f FailureNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { if id == v1alpha1.StartNodeID { - return f.StartNodeStatus + return f.NodeLookup.GetNodeExecutionStatus(ctx, id) } - return f.NodeStatus + return f.FailureNodeStatus } func (f FailureNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + // The upstream node of the failure node is always the start node return []v1alpha1.NodeID{v1alpha1.StartNodeID}, nil } func (f FailureNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { - return nil, nil + return []v1alpha1.NodeID{v1alpha1.EndNodeID}, nil } -func NewFailureNodeLookup(nodeSpec *v1alpha1.NodeSpec, startNode v1alpha1.ExecutableNode, nodeStatusGetter v1alpha1.NodeStatusGetter) NodeLookup { - startNodeStatus := nodeStatusGetter.GetNodeExecutionStatus(context.TODO(), v1alpha1.StartNodeID) - errNodeStatus := nodeStatusGetter.GetNodeExecutionStatus(context.TODO(), nodeSpec.GetID()) +func NewFailureNodeLookup(nodeLookup NodeLookup, failureNode v1alpha1.ExecutableNode, failureNodeStatus v1alpha1.ExecutableNodeStatus) NodeLookup { return FailureNodeLookup{ - NodeSpec: nodeSpec, - NodeStatus: errNodeStatus, - StartNode: startNode, - StartNodeStatus: startNodeStatus, + NodeLookup: nodeLookup, + FailureNode: failureNode, + FailureNodeStatus: failureNodeStatus, } } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go index 10e48358dd..6ea22b9637 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go @@ -143,12 +143,18 @@ func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx interfaces.No func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { originalError := nCtx.NodeStateReader().GetWorkflowNodeState().Error - if subworkflow.GetOnFailureNode() != nil { + failureNode := subworkflow.GetOnFailureNode() + if failureNode != nil { execContext, err := s.getExecutionContextForDownstream(nCtx) if err != nil { return handler.UnknownTransition, err } - state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode()) + subNodeLookup := nCtx.ContextualNodeLookup() + // TODO: NodeStatus() is deprecated, how do we get the status of the failure node? + failureNodeStatus := nCtx.NodeStatus().GetNodeExecutionStatus(ctx, failureNode.GetID()) + failureNodeLookup := executors.NewFailureNodeLookup(subNodeLookup, failureNode, failureNodeStatus) + + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, failureNodeLookup, failureNodeLookup, failureNode) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } diff --git a/flytepropeller/pkg/controller/workflow/executor.go b/flytepropeller/pkg/controller/workflow/executor.go index f782c30d0a..40fb31e46d 100644 --- a/flytepropeller/pkg/controller/workflow/executor.go +++ b/flytepropeller/pkg/controller/workflow/executor.go @@ -169,22 +169,18 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { execErr := executionErrorOrDefault(w.GetExecutionStatus().GetExecutionError(), w.GetExecutionStatus().GetMessage()) - errorNode := w.GetOnFailureNode() - logger.Infof(ctx, "Handling FailureNode [%v]", errorNode) + failureNode := w.GetOnFailureNode() + logger.Infof(ctx, "Handling FailureNode [%v]", failureNode.GetID()) execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) - // TODO: GetNodeExecutionStatus doesn't work. How do we get the error node status from CRD - failureNodeLookup := executors.NewFailureNodeLookup(errorNode.(*v1alpha1.NodeSpec), w.GetNode(v1alpha1.StartNodeID), w.GetExecutionStatus()) - state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, w, failureNodeLookup, errorNode) - logger.Infof(ctx, "FailureNode [%v] finished with state [%v]", errorNode, state) - logger.Infof(ctx, "FailureNode [%v] finished with error [%v]", errorNode, err) + failureNodeStatus := w.GetExecutionStatus().GetNodeExecutionStatus(ctx, failureNode.GetID()) + failureNodeLookup := executors.NewFailureNodeLookup(w, failureNode, failureNodeStatus) + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, execcontext, failureNodeLookup, failureNodeLookup, failureNode) if err != nil { - logger.Infof(ctx, "test") return StatusFailureNode(execErr), err } if state.HasFailed() { - logger.Infof(ctx, "test1 [%v]", state.Err) return StatusFailed(state.Err), nil } @@ -195,8 +191,6 @@ func (c *workflowExecutor) handleFailureNode(ctx context.Context, w *v1alpha1.Fl Message: "FailureNode Timed-out"}), nil } - logger.Infof(ctx, "test2") - if state.PartiallyComplete() { // Re-enqueue the workflow c.enqueueWorkflow(w.GetK8sWorkflowID().String())