Skip to content

Commit

Permalink
Change branch compilation to avoid adding downstream deps (flyteorg#264)
Browse files Browse the repository at this point in the history
* Change branch compilation to avoid adding downstream deps

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Fix deepcopy generated bug

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Add checks to make sure compiled workflows remain the same

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Additional unit test to test branch spec generation

Signed-off-by: Ketan Umare <[email protected]>

* Fix compiler unit tests

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Fixed unit test

Signed-off-by: Ketan Umare <[email protected]>

* Fix one unit test and shorten json tags for connections

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Use GetConnections() instead of DeprecatedConnections

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Put back branch node upstream check in CanExecute predicate for backward compatibility

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* lint

Signed-off-by: Haytham Abuelfutuh <[email protected]>

Co-authored-by: Ketan Umare <[email protected]>
  • Loading branch information
EngHabu and kumare3 authored May 28, 2021
1 parent f48fdff commit f06e98d
Show file tree
Hide file tree
Showing 45 changed files with 1,304 additions and 256 deletions.
41 changes: 35 additions & 6 deletions pkg/apis/flyteworkflow/v1alpha1/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,13 @@ func (in *Inputs) DeepCopyInto(out *Inputs) {
// Once we figure out the autogenerate story we can replace this
}

type Connections struct {
// Deprecated: Please use Connections instead
type DeprecatedConnections struct {
DownstreamEdges map[NodeID][]NodeID
UpstreamEdges map[NodeID][]NodeID
}

func (in *Connections) UnmarshalJSON(b []byte) error {
func (in *DeprecatedConnections) UnmarshalJSON(b []byte) error {
in.DownstreamEdges = map[NodeID][]NodeID{}
err := json.Unmarshal(b, &in.DownstreamEdges)
if err != nil {
Expand All @@ -204,10 +205,23 @@ func (in *Connections) UnmarshalJSON(b []byte) error {
return nil
}

func (in *Connections) MarshalJSON() ([]byte, error) {
func (in *DeprecatedConnections) MarshalJSON() ([]byte, error) {
return json.Marshal(in.DownstreamEdges)
}

// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DeprecatedConnections) DeepCopyInto(out *DeprecatedConnections) {
*out = *in
// We do not manipulate the object, so its ok
// Once we figure out the autogenerate story we can replace this
}

// Connections keep track of downstream and upstream dependencies (including data and execution dependencies).
type Connections struct {
Downstream map[NodeID][]NodeID `json:"downstream"`
Upstream map[NodeID][]NodeID `json:"upstream"`
}

// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Connections) DeepCopyInto(out *Connections) {
*out = *in
Expand All @@ -223,7 +237,13 @@ type WorkflowSpec struct {
// Defines the set of connections (both data dependencies and execution dependencies) that the graph is
// formed of. The execution engine will respect and follow these connections as it determines which nodes
// can and should be executed.
Connections Connections `json:"connections"`
// Deprecated: Please use Connections
DeprecatedConnections DeprecatedConnections `json:"connections"`

// Defines the set of connections (both data dependencies and execution dependencies) that the graph is
// formed of. The execution engine will respect and follow these connections as it determines which nodes
// can and should be executed.
Connections Connections `json:"edges"`

// Defines a single node to execute in case the system determined the Workflow has failed.
OnFailure *NodeSpec `json:"onFailure,omitempty"`
Expand Down Expand Up @@ -257,7 +277,7 @@ func (in *WorkflowSpec) ToNode(name NodeID) ([]NodeID, error) {
if _, ok := in.Nodes[name]; !ok {
return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID)
}
upstreamNodes := in.Connections.UpstreamEdges[name]
upstreamNodes := in.GetConnections().Upstream[name]
return upstreamNodes, nil
}

Expand All @@ -266,7 +286,7 @@ func (in *WorkflowSpec) FromNode(name NodeID) ([]NodeID, error) {
return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID)
}

downstreamNodes := in.Connections.DownstreamEdges[name]
downstreamNodes := in.GetConnections().Downstream[name]
return downstreamNodes, nil
}

Expand All @@ -284,6 +304,15 @@ func (in *WorkflowSpec) GetNode(nodeID NodeID) (ExecutableNode, bool) {
}

func (in *WorkflowSpec) GetConnections() *Connections {
// For backward compatibility, if the new Connections field is not yet populated then copy the connections from the
// deprecated field. This will happen in one of two cases:
// 1. If an old Admin generated the CRD
// 2. If new propeller is deployed and is unmarshalling an old CRD.
if len(in.Connections.Upstream) == 0 && len(in.Connections.Downstream) == 0 {
in.Connections.Upstream = in.DeprecatedConnections.UpstreamEdges
in.Connections.Downstream = in.DeprecatedConnections.DownstreamEdges
}

return &in.Connections
}

Expand Down
12 changes: 7 additions & 5 deletions pkg/apis/flyteworkflow/v1alpha1/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
func TestMarshalUnmarshal_Connections(t *testing.T) {
r, err := ioutil.ReadFile("testdata/connections.json")
assert.NoError(t, err)
o := v1alpha1.Connections{}
o := v1alpha1.DeprecatedConnections{}
err = json.Unmarshal(r, &o)
assert.NoError(t, err)
assert.Equal(t, map[v1alpha1.NodeID][]v1alpha1.NodeID{
Expand Down Expand Up @@ -42,10 +42,12 @@ func TestWorkflowSpec(t *testing.T) {
assert.NoError(t, err)
w := &v1alpha1.FlyteWorkflow{}
err = json.Unmarshal(j, w)
assert.NoError(t, err)
if !assert.NoError(t, err) {
t.FailNow()
}

assert.NotNil(t, w.WorkflowSpec)
assert.Nil(t, w.GetOnFailureNode())
assert.Equal(t, 7, len(w.Connections.DownstreamEdges))
assert.Equal(t, 8, len(w.Connections.UpstreamEdges))

assert.Equal(t, 7, len(w.GetConnections().Downstream))
assert.Equal(t, 8, len(w.GetConnections().Upstream))
}
13 changes: 13 additions & 0 deletions pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/compiler/common/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type WorkflowBuilder interface {
Workflow
StoreCompiledSubWorkflow(id WorkflowID, compiledWorkflow *core.CompiledWorkflow)
AddExecutionEdge(nodeFrom, nodeTo NodeID)
AddUpstreamEdge(nodeProvider, nodeDependent NodeID)
AddDownstreamEdge(nodeProvider, nodeDependent NodeID)
AddNode(n NodeBuilder, errs errors.CompileErrors) (node NodeBuilder, ok bool)
ValidateWorkflow(fg *core.CompiledWorkflow, errs errors.CompileErrors) (Workflow, bool)
NewNodeBuilder(n *core.Node) NodeBuilder
Expand Down
10 changes: 10 additions & 0 deletions pkg/compiler/common/mocks/workflow_builder.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions pkg/compiler/test/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"testing"

"github.com/go-test/deep"

"github.com/ghodss/yaml"

"github.com/flyteorg/flyteidl/clients/go/coreutils"
Expand Down Expand Up @@ -184,6 +186,10 @@ func TestBranches(t *testing.T) {
errors.SetConfig(errors.Config{IncludeSource: true})
assert.NoError(t, filepath.Walk("testdata/branch", func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
if filepath.Base(info.Name()) != "branch" {
return filepath.SkipDir
}

return nil
}

Expand All @@ -209,6 +215,29 @@ func TestBranches(t *testing.T) {
t.FailNow()
}

marshaler := jsonpb.Marshaler{}
rawStr, err := marshaler.MarshalToString(compiledWfc)
if !assert.NoError(t, err) {
t.Fail()
}

compiledFilePath := filepath.Join(filepath.Dir(path), "compiled", filepath.Base(path))
if *update {
err = ioutil.WriteFile(compiledFilePath, []byte(rawStr), os.ModePerm)
if !assert.NoError(t, err) {
t.Fail()
}
} else {
goldenRaw, err := ioutil.ReadFile(compiledFilePath)
if !assert.NoError(t, err) {
t.Fail()
}

if diff := deep.Equal(rawStr, string(goldenRaw)); diff != nil {
t.Errorf("Compiled() Diff = %v\r\n got = %v\r\n want = %v", diff, rawStr, string(goldenRaw))
}
}

inputs := map[string]interface{}{}
for varName, v := range compiledWfc.Primary.Template.Interface.Inputs.Variables {
inputs[varName] = coreutils.MustMakeDefaultLiteralForType(v.Type)
Expand All @@ -227,6 +256,24 @@ func TestBranches(t *testing.T) {
if assert.NoError(t, err) {
assert.NotEmpty(t, raw)
}

k8sObjectFilepath := filepath.Join(filepath.Dir(path), "k8s", filepath.Base(path))
if *update {
err = ioutil.WriteFile(k8sObjectFilepath, raw, os.ModePerm)
if !assert.NoError(t, err) {
t.Fail()
}
} else {
goldenRaw, err := ioutil.ReadFile(k8sObjectFilepath)
if !assert.NoError(t, err) {
t.Fail()
}

if diff := deep.Equal(string(raw), string(goldenRaw)); diff != nil {
t.Errorf("K8sObject() Diff = %v\r\n got = %v\r\n want = %v", diff, rawStr, string(goldenRaw))
}

}
}
})

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
literals: {}
{}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ tasks:
command:
- pyflyte-execute
config:
- key: testKey1
value: testValue1
- key: testKey2
value: testValue2
- key: testKey3
value: testValue3
- key: testKey1
value: testValue1
- key: testKey2
Expand Down
1 change: 1 addition & 0 deletions pkg/compiler/test/testdata/branch/compiled/success_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"primary":{"template":{"id":{"resourceType":"WORKFLOW","project":"project","domain":"domain","name":"test_serialization.my_wf","version":"version"},"metadata":{},"interface":{"inputs":{"variables":{"a":{"type":{"simple":"INTEGER"},"description":"a"}}},"outputs":{"variables":{"out_0":{"type":{"simple":"STRING"},"description":"out_0"}}}},"nodes":[{"id":"start-node"},{"id":"end-node","inputs":[{"var":"out_0","binding":{"promise":{"nodeId":"node-1","var":"out_0"}}}]},{"id":"node-0","metadata":{"name":"test_serialization.t3","retries":{},"interruptible":false},"inputs":[{"var":"a","binding":{"promise":{"nodeId":"start-node","var":"a"}}}],"taskNode":{"referenceId":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t3","version":"version"}}},{"id":"node-1","metadata":{"name":"test1","retries":{},"interruptible":false},"inputs":[{"var":"out_0","binding":{"promise":{"nodeId":"node-0","var":"out_0"}}}],"upstreamNodeIds":["node-0"],"branchNode":{"ifElse":{"case":{"condition":{"comparison":{"leftValue":{"var":"out_0"},"rightValue":{"primitive":{"integer":"1"}}}},"thenNode":{"id":"node-1-branchnode-0","metadata":{"name":"test_serialization.t2","retries":{},"interruptible":false},"taskNode":{"referenceId":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t2","version":"version"}}}},"error":{"failedNodeId":"test1","message":"Unable to choose branch"}}}}],"outputs":[{"var":"out_0","binding":{"promise":{"nodeId":"node-1","var":"out_0"}}}],"metadataDefaults":{}},"connections":{"downstream":{"node-0":{"ids":["node-1"]},"node-1":{"ids":["end-node"]},"start-node":{"ids":["node-0"]}},"upstream":{"end-node":{"ids":["node-1"]},"node-0":{"ids":["start-node"]},"node-1":{"ids":["node-0"]}}}},"tasks":[{"template":{"id":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t1","version":"version"},"type":"python-task","metadata":{"runtime":{"type":"FLYTE_SDK","version":"1.2.3","flavor":"python"},"retries":{},"interruptible":false},"interface":{"inputs":{},"outputs":{"variables":{"out_0":{"type":{"simple":"STRING"},"description":"out_0"}}}},"container":{"image":"image","args":["pyflyte-execute","--task-module","test_serialization","--task-name","t1","--inputs","{{.input}}","--output-prefix","{{.outputPrefix}}","--raw-output-data-prefix","{{.rawOutputDataPrefix}}"],"resources":{},"config":[{"key":"testKey1","value":"testValue1"},{"key":"testKey2","value":"testValue2"},{"key":"testKey3","value":"testValue3"}]}}},{"template":{"id":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t2","version":"version"},"type":"python-task","metadata":{"runtime":{"type":"FLYTE_SDK","version":"1.2.3","flavor":"python"},"retries":{},"interruptible":false},"interface":{"inputs":{},"outputs":{"variables":{"out_0":{"type":{"simple":"STRING"},"description":"out_0"}}}},"container":{"image":"image","args":["pyflyte-execute","--task-module","test_serialization","--task-name","t2","--inputs","{{.input}}","--output-prefix","{{.outputPrefix}}","--raw-output-data-prefix","{{.rawOutputDataPrefix}}"],"resources":{},"config":[{"key":"testKey1","value":"testValue1"},{"key":"testKey2","value":"testValue2"},{"key":"testKey3","value":"testValue3"}]}}},{"template":{"id":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t3","version":"version"},"type":"python-task","metadata":{"runtime":{"type":"FLYTE_SDK","version":"1.2.3","flavor":"python"},"retries":{},"interruptible":false},"interface":{"inputs":{"variables":{"a":{"type":{"simple":"INTEGER"},"description":"a"}}},"outputs":{"variables":{"out_0":{"type":{"simple":"INTEGER"},"description":"out_0"}}}},"container":{"image":"image","args":["pyflyte-execute","--task-module","test_serialization","--task-name","t3","--inputs","{{.input}}","--output-prefix","{{.outputPrefix}}","--raw-output-data-prefix","{{.rawOutputDataPrefix}}"],"resources":{},"config":[{"key":"testKey1","value":"testValue1"},{"key":"testKey2","value":"testValue2"},{"key":"testKey3","value":"testValue3"}]}}}]}
1 change: 1 addition & 0 deletions pkg/compiler/test/testdata/branch/compiled/success_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"primary":{"template":{"id":{"resourceType":"WORKFLOW","project":"project","domain":"domain","name":"test_serialization.my_wf","version":"version"},"metadata":{},"interface":{"inputs":{"variables":{"a":{"type":{"simple":"INTEGER"},"description":"a"}}},"outputs":{"variables":{"out_0":{"type":{"simple":"STRING"},"description":"out_0"}}}},"nodes":[{"id":"start-node"},{"id":"end-node","inputs":[{"var":"out_0","binding":{"promise":{"nodeId":"node-2","var":"out_0"}}}]},{"id":"node-0","metadata":{"name":"flytekit.annotated.task.test_serialization.t3","retries":{},"interruptible":false},"inputs":[{"var":"a","binding":{"promise":{"nodeId":"start-node","var":"a"}}}],"taskNode":{"referenceId":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t3","version":"version"}}},{"id":"node-1","metadata":{"name":"flytekit.annotated.task.test_serialization.t3","retries":{},"interruptible":false},"inputs":[{"var":"a","binding":{"promise":{"nodeId":"start-node","var":"a"}}}],"taskNode":{"referenceId":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t3","version":"version"}}},{"id":"node-2","metadata":{"name":"test1","retries":{},"interruptible":false},"inputs":[{"var":"node-0.out_0","binding":{"promise":{"nodeId":"node-0","var":"out_0"}}},{"var":"node-1.out_0","binding":{"promise":{"nodeId":"node-1","var":"out_0"}}}],"upstreamNodeIds":["node-0","node-1"],"branchNode":{"ifElse":{"case":{"condition":{"comparison":{"leftValue":{"var":"node-0.out_0"},"rightValue":{"var":"node-1.out_0"}}},"thenNode":{"id":"node-2-branchnode-0","metadata":{"name":"flytekit.annotated.task.test_serialization.t2","retries":{},"interruptible":false},"taskNode":{"referenceId":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t2","version":"version"}}}},"error":{"failedNodeId":"test1","message":"Unable to choose branch"}}}}],"outputs":[{"var":"out_0","binding":{"promise":{"nodeId":"node-2","var":"out_0"}}}],"metadataDefaults":{}},"connections":{"downstream":{"node-0":{"ids":["node-2"]},"node-1":{"ids":["node-2"]},"node-2":{"ids":["end-node"]},"start-node":{"ids":["node-0","node-1"]}},"upstream":{"end-node":{"ids":["node-2"]},"node-0":{"ids":["start-node"]},"node-1":{"ids":["start-node"]},"node-2":{"ids":["node-0","node-1"]}}}},"tasks":[{"template":{"id":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t1","version":"version"},"type":"python-task","metadata":{"runtime":{"type":"FLYTE_SDK","version":"1.2.3","flavor":"python"},"retries":{},"interruptible":false},"interface":{"inputs":{},"outputs":{"variables":{"out_0":{"type":{"simple":"STRING"},"description":"out_0"}}}},"container":{"image":"image","args":["pyflyte-execute","--task-module","test_serialization","--task-name","t1","--inputs","{{.input}}","--output-prefix","{{.outputPrefix}}","--raw-output-data-prefix","{{.rawOutputDataPrefix}}"],"resources":{},"config":[{"key":"testKey1","value":"testValue1"},{"key":"testKey2","value":"testValue2"},{"key":"testKey3","value":"testValue3"}]}}},{"template":{"id":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t2","version":"version"},"type":"python-task","metadata":{"runtime":{"type":"FLYTE_SDK","version":"1.2.3","flavor":"python"},"retries":{},"interruptible":false},"interface":{"inputs":{},"outputs":{"variables":{"out_0":{"type":{"simple":"STRING"},"description":"out_0"}}}},"container":{"image":"image","args":["pyflyte-execute","--task-module","test_serialization","--task-name","t2","--inputs","{{.input}}","--output-prefix","{{.outputPrefix}}","--raw-output-data-prefix","{{.rawOutputDataPrefix}}"],"resources":{},"config":[{"key":"testKey1","value":"testValue1"},{"key":"testKey2","value":"testValue2"},{"key":"testKey3","value":"testValue3"}]}}},{"template":{"id":{"resourceType":"TASK","project":"project","domain":"domain","name":"test_serialization.t3","version":"version"},"type":"python-task","metadata":{"runtime":{"type":"FLYTE_SDK","version":"1.2.3","flavor":"python"},"retries":{},"interruptible":false},"interface":{"inputs":{"variables":{"a":{"type":{"simple":"INTEGER"},"description":"a"}}},"outputs":{"variables":{"out_0":{"type":{"simple":"INTEGER"},"description":"out_0"}}}},"container":{"image":"image","args":["pyflyte-execute","--task-module","test_serialization","--task-name","t3","--inputs","{{.input}}","--output-prefix","{{.outputPrefix}}","--raw-output-data-prefix","{{.rawOutputDataPrefix}}"],"resources":{},"config":[{"key":"testKey1","value":"testValue1"},{"key":"testKey2","value":"testValue2"},{"key":"testKey3","value":"testValue3"}]}}}]}
Loading

0 comments on commit f06e98d

Please sign in to comment.