Skip to content

Commit

Permalink
Add support failure node (#4308)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Ketan Umare <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
3 people authored and pvditt committed Dec 13, 2023
1 parent f99e8e9 commit e9f2cee
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 30 deletions.
2 changes: 1 addition & 1 deletion charts/flyte-core/values-eks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ configmap:
propeller:
resourcemanager:
type: noop
# Note: By default resource manager is disable for propeller, Please use `type: redis` to enaable
# Note: By default resource manager is disabled for propeller, Please use `type: redis` to enable
# type: redis
# resourceMaxQuota: 10000
# redis:
Expand Down
3 changes: 3 additions & 0 deletions flytepropeller/pkg/compiler/requirements.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ func updateWorkflowRequirements(workflow *core.WorkflowTemplate, subWfs common.W
for _, node := range workflow.Nodes {
updateNodeRequirements(node, subWfs, taskIds, workflowIds, followSubworkflows, errs)
}
if workflow.FailureNode != nil {
updateNodeRequirements(workflow.FailureNode, subWfs, taskIds, workflowIds, followSubworkflows, errs)
}
}

func updateNodeRequirements(node *flyteNode, subWfs common.WorkflowIndex, taskIds, workflowIds common.IdentifierSet,
Expand Down
6 changes: 6 additions & 0 deletions flytepropeller/pkg/compiler/workflow_compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ func (w workflowBuilder) ValidateWorkflow(fg *flyteWorkflow, errs errors.Compile
wf.AddEdges(n, c.EdgeDirectionBidirectional, errs.NewScope())
}

if fg.Template.FailureNode != nil {
failureNode := fg.Template.FailureNode
v.ValidateNode(&wf, wf.GetOrCreateNodeBuilder(failureNode), false, errs.NewScope())
wf.AddEdges(wf.GetOrCreateNodeBuilder(failureNode), c.EdgeDirectionUpstream, errs.NewScope())
}

// Add execution edges for orphan nodes that don't have any inward/outward edges.
for nodeID := range wf.Nodes {
if nodeID == c.StartNodeID || nodeID == c.EndNodeID {
Expand Down
84 changes: 84 additions & 0 deletions flytepropeller/pkg/compiler/workflow_compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,90 @@ func ExampleCompileWorkflow_basic() {
// Compile Errors: <nil>
}

func TestCompileWorkflowWithFailureNode(t *testing.T) {
inputWorkflow := &core.WorkflowTemplate{
Id: &core.Identifier{Name: "repo"},
Interface: &core.TypedInterface{
Inputs: createEmptyVariableMap(),
Outputs: createEmptyVariableMap(),
},
Nodes: []*core.Node{
{
Id: "FirstNode",
Target: &core.Node_TaskNode{
TaskNode: &core.TaskNode{
Reference: &core.TaskNode_ReferenceId{
ReferenceId: &core.Identifier{Name: "task_123"},
},
},
},
},
},
FailureNode: &core.Node{
Id: "FailureNode",
Target: &core.Node_TaskNode{
TaskNode: &core.TaskNode{
Reference: &core.TaskNode_ReferenceId{
ReferenceId: &core.Identifier{Name: "cleanup"},
},
},
},
},
}

// Detect what other workflows/tasks does this coreWorkflow reference
subWorkflows := make([]*core.WorkflowTemplate, 0)
reqs, err := GetRequirements(inputWorkflow, subWorkflows)
assert.Nil(t, err)
assert.Equal(t, reqs.taskIds, []common.Identifier{{Name: "cleanup"}, {Name: "task_123"}})

// Replace with logic to satisfy the requirements
workflows := make([]common.InterfaceProvider, 0)
tasks := []*core.TaskTemplate{
{
Id: &core.Identifier{Name: "task_123"},
Interface: &core.TypedInterface{
Inputs: createEmptyVariableMap(),
Outputs: createEmptyVariableMap(),
},
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: "image://",
Command: []string{"cmd"},
Args: []string{"args"},
},
},
},
{
Id: &core.Identifier{Name: "cleanup"},
Interface: &core.TypedInterface{
Inputs: createEmptyVariableMap(),
Outputs: createEmptyVariableMap(),
},
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: "image://",
Command: []string{"cmd"},
Args: []string{"args"},
},
},
},
}

compiledTasks := make([]*core.CompiledTask, 0, len(tasks))
for _, task := range tasks {
compiledTask, err := CompileTask(task)
assert.Nil(t, err)

compiledTasks = append(compiledTasks, compiledTask)
}

output, errs := CompileWorkflow(inputWorkflow, subWorkflows, compiledTasks, workflows)
assert.Equal(t, output.Primary.Template.FailureNode.Id, "FailureNode")
assert.NotNil(t, output.Primary.Template.FailureNode.GetTaskNode())
assert.Nil(t, errs)
}

func ExampleCompileWorkflow_inputsOutputsBinding() {
inputWorkflow := &core.WorkflowTemplate{
Id: &core.Identifier{Name: "repo"},
Expand Down
44 changes: 44 additions & 0 deletions flytepropeller/pkg/controller/executors/failure_node_lookup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package executors

import (
"context"

"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
)

type FailureNodeLookup struct {
NodeLookup
FailureNode v1alpha1.ExecutableNode
FailureNodeStatus v1alpha1.ExecutableNodeStatus
}

func (f FailureNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) {
if nodeID == v1alpha1.StartNodeID {
return f.NodeLookup.GetNode(nodeID)
}
return f.FailureNode, true
}

func (f FailureNodeLookup) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus {
if id == v1alpha1.StartNodeID {
return f.NodeLookup.GetNodeExecutionStatus(ctx, id)
}
return f.FailureNodeStatus
}

func (f FailureNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
// The upstream node of the failure node is always the start node
return []v1alpha1.NodeID{v1alpha1.StartNodeID}, nil
}

func (f FailureNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
return nil, nil
}

func NewFailureNodeLookup(nodeLookup NodeLookup, failureNode v1alpha1.ExecutableNode, failureNodeStatus v1alpha1.ExecutableNodeStatus) NodeLookup {
return FailureNodeLookup{
NodeLookup: nodeLookup,
FailureNode: failureNode,
FailureNodeStatus: failureNodeStatus,
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package executors

import (
"context"
"testing"

"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks"
"github.com/stretchr/testify/assert"
)

type nl struct {
NodeLookup
}

type en struct {
v1alpha1.ExecutableNode
}

type ns struct {
v1alpha1.ExecutableNodeStatus
}

func TestNewFailureNodeLookup(t *testing.T) {
nl := nl{}
en := en{}
ns := ns{}
nodeLoopUp := NewFailureNodeLookup(nl, en, ns)
assert.NotNil(t, nl)
typed := nodeLoopUp.(FailureNodeLookup)
assert.Equal(t, nl, typed.NodeLookup)
assert.Equal(t, en, typed.FailureNode)
assert.Equal(t, ns, typed.FailureNodeStatus)
}

func TestNewTestFailureNodeLookup(t *testing.T) {
n := &mocks.ExecutableNode{}
ns := &mocks.ExecutableNodeStatus{}
failureNodeID := "fn1"
nl := NewTestNodeLookup(
map[string]v1alpha1.ExecutableNode{v1alpha1.StartNodeID: n, failureNodeID: n},
map[string]v1alpha1.ExecutableNodeStatus{v1alpha1.StartNodeID: ns, failureNodeID: ns},
)

assert.NotNil(t, nl)

failureNodeLookup := NewFailureNodeLookup(nl, n, ns)
r, ok := failureNodeLookup.GetNode(v1alpha1.StartNodeID)
assert.True(t, ok)
assert.Equal(t, n, r)
assert.Equal(t, ns, failureNodeLookup.GetNodeExecutionStatus(context.TODO(), v1alpha1.StartNodeID))

r, ok = failureNodeLookup.GetNode(failureNodeID)
assert.True(t, ok)
assert.Equal(t, n, r)
assert.Equal(t, ns, failureNodeLookup.GetNodeExecutionStatus(context.TODO(), failureNodeID))

nodeIDs, err := failureNodeLookup.ToNode(failureNodeID)
assert.Equal(t, len(nodeIDs), 1)
assert.Equal(t, nodeIDs[0], v1alpha1.StartNodeID)
assert.Nil(t, err)

nodeIDs, err = failureNodeLookup.FromNode(failureNodeID)
assert.Nil(t, nodeIDs)
assert.Nil(t, err)
}
2 changes: 1 addition & 1 deletion flytepropeller/pkg/controller/nodes/predicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func CanExecute(ctx context.Context, dag executors.DAGStructure, nl executors.No

upstreamNodes, err := dag.ToNode(nodeID)
if err != nil {
return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node")
return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node {%v}", nodeID)
}

skipped := false
Expand Down
10 changes: 5 additions & 5 deletions flytepropeller/pkg/controller/nodes/subworkflow/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeEx
errors.BadSpecificationError, errMsg, nil)), nil
}

updateNodeStateFn := func(transition handler.Transition, newPhase v1alpha1.WorkflowNodePhase, err error) (handler.Transition, error) {
updateNodeStateFn := func(transition handler.Transition, workflowNodeState handler.WorkflowNodeState, err error) (handler.Transition, error) {
if err != nil {
return transition, err
}

workflowNodeState := handler.WorkflowNodeState{Phase: newPhase}
err = nCtx.NodeStateWriter().PutWorkflowNodeState(workflowNodeState)
if err != nil {
logger.Errorf(ctx, "Failed to store WorkflowNodeState, err :%s", err.Error())
Expand All @@ -75,10 +74,10 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeEx

if wfNode.GetSubWorkflowRef() != nil {
trns, err := w.subWfHandler.StartSubWorkflow(ctx, nCtx)
return updateNodeStateFn(trns, v1alpha1.WorkflowNodePhaseExecuting, err)
return updateNodeStateFn(trns, handler.WorkflowNodeState{Phase: v1alpha1.WorkflowNodePhaseExecuting}, err)
} else if wfNode.GetLaunchPlanRefID() != nil {
trns, err := w.lpHandler.StartLaunchPlan(ctx, nCtx)
return updateNodeStateFn(trns, v1alpha1.WorkflowNodePhaseExecuting, err)
return updateNodeStateFn(trns, handler.WorkflowNodeState{Phase: v1alpha1.WorkflowNodePhaseExecuting}, err)
}

return invalidWFNodeError()
Expand All @@ -95,8 +94,9 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeEx
}

if wfNode.GetSubWorkflowRef() != nil {
originalError := nCtx.NodeStateReader().GetWorkflowNodeState().Error
trns, err := w.subWfHandler.HandleFailingSubWorkflow(ctx, nCtx)
return updateNodeStateFn(trns, workflowPhase, err)
return updateNodeStateFn(trns, handler.WorkflowNodeState{Phase: workflowPhase, Error: originalError}, err)
} else if wfNode.GetLaunchPlanRefID() != nil {
// There is no failure node for launch plans, terminate immediately.
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailureErr(wfNodeState.Error, nil)), nil
Expand Down
16 changes: 10 additions & 6 deletions flytepropeller/pkg/controller/nodes/subworkflow/subworkflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx
func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) {
// The current node would end up becoming the parent for the sub workflow nodes.
// This is done to track the lineage. For level zero, the CreateParentInfo will return nil
newParentInfo, err := common.CreateParentInfo(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID(), nCtx.CurrentAttempt())
execContext, err := s.getExecutionContextForDownstream(nCtx)
if err != nil {
return handler.UnknownTransition, err
}
execContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo)
state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.StartNode())
if err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err
Expand Down Expand Up @@ -143,17 +142,22 @@ func (s *subworkflowHandler) getExecutionContextForDownstream(nCtx interfaces.No

func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx interfaces.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) {
originalError := nCtx.NodeStateReader().GetWorkflowNodeState().Error
if subworkflow.GetOnFailureNode() != nil {
if failureNode := subworkflow.GetOnFailureNode(); failureNode != nil {
execContext, err := s.getExecutionContextForDownstream(nCtx)
if err != nil {
return handler.UnknownTransition, err
}
state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, subworkflow, nl, subworkflow.GetOnFailureNode())
status := nCtx.NodeStatus()
subworkflowNodeLookup := executors.NewNodeLookup(subworkflow, status, subworkflow)
failureNodeStatus := status.GetNodeExecutionStatus(ctx, failureNode.GetID())
failureNodeLookup := executors.NewFailureNodeLookup(subworkflowNodeLookup, failureNode, failureNodeStatus)

state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, execContext, failureNodeLookup, failureNodeLookup, failureNode)
if err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err
}

if state.NodePhase == interfaces.NodePhaseRunning {
if state.NodePhase == interfaces.NodePhaseQueued || state.NodePhase == interfaces.NodePhaseRunning {
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil
}

Expand All @@ -168,7 +172,7 @@ func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context,
return handler.UnknownTransition, err
}

return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailingErr(originalError, nil)), nil
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil
}

// When handling the failure node succeeds, the final status will still be failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ func TestGetSubWorkflow(t *testing.T) {
assert.Equal(t, swf, w)
})

t.Run("subworkflow with failure node", func(t *testing.T) {

wfNode := &coreMocks.ExecutableWorkflowNode{}
x := "x"
wfNode.OnGetSubWorkflowRef().Return(&x)

node := &coreMocks.ExecutableNode{}
node.OnGetWorkflowNode().Return(wfNode)

ectx := &execMocks.ExecutionContext{}

wfFailureNode := &coreMocks.ExecutableWorkflowNode{}
y := "y"
wfFailureNode.OnGetSubWorkflowRef().Return(&y)
failureNode := &coreMocks.ExecutableNode{}
failureNode.OnGetWorkflowNode().Return(wfFailureNode)

swf := &coreMocks.ExecutableSubWorkflow{}
swf.OnGetOnFailureNode().Return(failureNode)
ectx.OnFindSubWorkflow("x").Return(swf)

nCtx := &mocks.NodeExecutionContext{}
nCtx.OnNode().Return(node)
nCtx.OnExecutionContext().Return(ectx)

w, err := GetSubWorkflow(ctx, nCtx)
assert.NoError(t, err)
assert.Equal(t, swf, w)
})

t.Run("missing-subworkflow", func(t *testing.T) {

wfNode := &coreMocks.ExecutableWorkflowNode{}
Expand Down
Loading

0 comments on commit e9f2cee

Please sign in to comment.