From f9e165b6e2c0e2d2efecaff3c9423ce8a4982265 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Mon, 11 Apr 2022 16:30:38 -0500 Subject: [PATCH] Supporting interruptible for map tasks (#415) * added GetInterruptibleFailureThreshold function to nodeExecMetadata Signed-off-by: Daniel Rammer * fixed unit tests Signed-off-by: Daniel Rammer * fixed lint issue Signed-off-by: Daniel Rammer * Update to released flyteplugins Signed-off-by: Haytham Abuelfutuh Co-authored-by: Haytham Abuelfutuh --- flytepropeller/go.mod | 4 +-- flytepropeller/go.sum | 9 +++--- .../handler/mocks/node_execution_metadata.go | 32 +++++++++++++++++++ .../nodes/handler/node_exec_context.go | 1 + .../pkg/controller/nodes/node_exec_context.go | 17 +++++++--- .../nodes/node_exec_context_test.go | 2 +- 6 files changed, 52 insertions(+), 13 deletions(-) diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 17332f0169..4443dbabca 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -6,8 +6,8 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 - github.com/flyteorg/flyteidl v0.24.17 - github.com/flyteorg/flyteplugins v0.10.19 + github.com/flyteorg/flyteidl v0.24.19 + github.com/flyteorg/flyteplugins v0.10.23 github.com/flyteorg/flytestdlib v0.4.13 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 5401a8f887..5b62c2c856 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -236,11 +236,10 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.24.7/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteidl v0.24.17 h1:Xx70bJbuQGyvS8uAyU4AN74rot6KnzJ9r/L9gcCdEsU= -github.com/flyteorg/flyteidl v0.24.17/go.mod h1:vHSugApgS3hRITIafzQDU8DZD/W8wFRfFcgaFU35Dww= -github.com/flyteorg/flyteplugins v0.10.19 h1:9fY3aYXfjVR8jyb4omdWu9RW2FwcmAnld9PHnR0BLW8= -github.com/flyteorg/flyteplugins v0.10.19/go.mod h1:C2va2hfD7mBi24dXRhBi0GIKG4dzFhSR27GsCCFDzss= +github.com/flyteorg/flyteidl v0.24.19 h1:9PR0UVe2atWLot0X6dgyiXTMKbut28LJYl4HrcMHl7E= +github.com/flyteorg/flyteidl v0.24.19/go.mod h1:vHSugApgS3hRITIafzQDU8DZD/W8wFRfFcgaFU35Dww= +github.com/flyteorg/flyteplugins v0.10.23 h1:vRTcw+B9bjiCyVsdV6rDuTX4E9JMOy8ZEf9M71fKkeg= +github.com/flyteorg/flyteplugins v0.10.23/go.mod h1:12hTsHaGNKU9BVpTGcxtiL+Zrf5sfDXiDDsPvEO40CQ= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/flyteorg/flytestdlib v0.4.13 h1:TzgqhECRGfOHYH1A7rUwcKEEH2rTtPxGy+oYcif7iBw= github.com/flyteorg/flytestdlib v0.4.13/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs= diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_metadata.go b/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_metadata.go index c6dcc65b7f..a8980e0408 100644 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_metadata.go +++ b/flytepropeller/pkg/controller/nodes/handler/mocks/node_execution_metadata.go @@ -51,6 +51,38 @@ func (_m *NodeExecutionMetadata) GetAnnotations() map[string]string { return r0 } +type NodeExecutionMetadata_GetInterruptibleFailureThreshold struct { + *mock.Call +} + +func (_m NodeExecutionMetadata_GetInterruptibleFailureThreshold) Return(_a0 uint32) *NodeExecutionMetadata_GetInterruptibleFailureThreshold { + return &NodeExecutionMetadata_GetInterruptibleFailureThreshold{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutionMetadata) OnGetInterruptibleFailureThreshold() *NodeExecutionMetadata_GetInterruptibleFailureThreshold { + c := _m.On("GetInterruptibleFailureThreshold") + return &NodeExecutionMetadata_GetInterruptibleFailureThreshold{Call: c} +} + +func (_m *NodeExecutionMetadata) OnGetInterruptibleFailureThresholdMatch(matchers ...interface{}) *NodeExecutionMetadata_GetInterruptibleFailureThreshold { + c := _m.On("GetInterruptibleFailureThreshold", matchers...) + return &NodeExecutionMetadata_GetInterruptibleFailureThreshold{Call: c} +} + +// GetInterruptibleFailureThreshold provides a mock function with given fields: +func (_m *NodeExecutionMetadata) GetInterruptibleFailureThreshold() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + type NodeExecutionMetadata_GetK8sServiceAccount struct { *mock.Call } diff --git a/flytepropeller/pkg/controller/nodes/handler/node_exec_context.go b/flytepropeller/pkg/controller/nodes/handler/node_exec_context.go index 95178c48bd..117358dabc 100644 --- a/flytepropeller/pkg/controller/nodes/handler/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/handler/node_exec_context.go @@ -39,6 +39,7 @@ type NodeExecutionMetadata interface { GetK8sServiceAccount() string GetSecurityContext() core.SecurityContext IsInterruptible() bool + GetInterruptibleFailureThreshold() uint32 } type NodeExecutionContext interface { diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context.go b/flytepropeller/pkg/controller/nodes/node_exec_context.go index fb89eaa433..ebd1b40caa 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context.go @@ -25,9 +25,10 @@ const NodeInterruptibleLabel = "interruptible" type nodeExecMetadata struct { v1alpha1.Meta - nodeExecID *core.NodeExecutionIdentifier - interrutptible bool - nodeLabels map[string]string + nodeExecID *core.NodeExecutionIdentifier + interrutptible bool + interruptibleFailureThreshold uint32 + nodeLabels map[string]string } func (e nodeExecMetadata) GetNodeExecutionID() *core.NodeExecutionIdentifier { @@ -46,6 +47,10 @@ func (e nodeExecMetadata) IsInterruptible() bool { return e.interrutptible } +func (e nodeExecMetadata) GetInterruptibleFailureThreshold() uint32 { + return e.interruptibleFailureThreshold +} + func (e nodeExecMetadata) GetLabels() map[string]string { return e.nodeLabels } @@ -136,7 +141,7 @@ func (e nodeExecContext) MaxDatasetSizeBytes() int64 { } func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, - node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, + node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, interruptibleFailureThreshold uint32, maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { @@ -146,7 +151,8 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext NodeId: node.GetID(), ExecutionId: execContext.GetExecutionID().WorkflowExecutionIdentifier, }, - interrutptible: interruptible, + interrutptible: interruptible, + interruptibleFailureThreshold: interruptibleFailureThreshold, } // Copy the wf labels before adding node specific labels. @@ -235,6 +241,7 @@ func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNod ), ), interruptible, + c.interruptibleFailureThreshold, c.maxDatasetSizeBytes, &taskEventRecorder{TaskEventRecorder: c.taskRecorder}, tr, diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context_test.go b/flytepropeller/pkg/controller/nodes/node_exec_context_test.go index 31e6dfe791..b675b476ab 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context_test.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context_test.go @@ -57,7 +57,7 @@ func Test_NodeContext(t *testing.T) { s, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) p := parentInfo{} execContext := executors.NewExecutionContext(w1, nil, nil, p, nil) - nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, n, nil, nil, false, 0, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) + nCtx := newNodeExecContext(context.TODO(), s, execContext, w1, n, nil, nil, false, 0, 2, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) assert.Equal(t, "id", nCtx.NodeExecutionMetadata().GetLabels()["node-id"]) assert.Equal(t, "false", nCtx.NodeExecutionMetadata().GetLabels()["interruptible"]) assert.Equal(t, "task-name", nCtx.NodeExecutionMetadata().GetLabels()["task-name"])