diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 6511bfe14c..c9390c2d8f 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -6,9 +6,15 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 +<<<<<<< HEAD github.com/flyteorg/flyte/flyteplugins v0.0.0-00010101000000-000000000000 github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000 github.com/flyteorg/flyteidl v0.0.0-00010101000000-000000000000 +======= + github.com/flyteorg/flyteidl v1.5.20-0.20231002193413-9bb0dd7669d3 + github.com/flyteorg/flyteplugins v1.1.22-0.20231002220629-5f98be23babc + github.com/flyteorg/flytestdlib v1.0.22 +>>>>>>> flytepropeller/jeev/gpu-type github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible github.com/go-test/deep v1.0.7 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 41cb312151..197a84ed03 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -242,6 +242,19 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +<<<<<<< HEAD +======= +github.com/flyteorg/flyteidl v1.5.18-0.20230913190844-dc07c4922069 h1:8MJ/9HeJ+B+K8lvOuQjUXmpYQkZNLEzyMlhZdmtLtpM= +github.com/flyteorg/flyteidl v1.5.18-0.20230913190844-dc07c4922069/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.20-0.20231002193413-9bb0dd7669d3 h1:FattF/MhnzISxLGB52xEHI12ixA7toJ4ahr4LfMOzJM= +github.com/flyteorg/flyteidl v1.5.20-0.20231002193413-9bb0dd7669d3/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteplugins v1.1.22-0.20230915004356-4d4a681568dd h1:ffcent8Yrn2YW3Kr9S5F/ieGfobPsuDV1+jpgPfsHzk= +github.com/flyteorg/flyteplugins v1.1.22-0.20230915004356-4d4a681568dd/go.mod h1:vA9WtOGV7gLSpJqe26F/VfcN+AuyBnauuwmyDzPTst4= +github.com/flyteorg/flyteplugins v1.1.22-0.20231002220629-5f98be23babc h1:bCxSjssHEoe6QjJJV8lOWJyiyIX/hWS9UYP9OJVVyIs= +github.com/flyteorg/flyteplugins v1.1.22-0.20231002220629-5f98be23babc/go.mod h1:RZuckfOUDIo9GKCESX9vx0ybSw1gRbPJgoGaetJBxV8= +github.com/flyteorg/flytestdlib v1.0.22 h1:8RAc+TmME54FInf4+t6+C7X8Z/dW6i6aTs6W8SEzpI8= +github.com/flyteorg/flytestdlib v1.0.22/go.mod h1:6nXa5g00qFIsgdvQ7jKQMJmDniqO0hG6Z5X5olfduqQ= +>>>>>>> flytepropeller/jeev/gpu-type github.com/flyteorg/stow v0.3.7 h1:Cx7j8/Ux6+toD5hp5fy++927V+yAcAttDeQAlUD/864= github.com/flyteorg/stow v0.3.7/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index e3f19a8edb..08f4fd7c99 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -436,6 +436,7 @@ type ExecutableNode interface { GetOutputAlias() []Alias GetInputBindings() []*Binding GetResources() *v1.ResourceRequirements + GetResourceExtensions() *core.ResourceExtensions GetConfig() *v1.ConfigMap GetRetryStrategy() *RetryStrategy GetExecutionDeadline() *time.Duration diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go index 901148e820..1c03f6be70 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -3,10 +3,11 @@ package mocks import ( - time "time" - + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mock "github.com/stretchr/testify/mock" + time "time" + v1 "k8s.io/api/core/v1" v1alpha1 "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -385,6 +386,40 @@ func (_m *ExecutableNode) GetOutputAlias() []v1alpha1.Alias { return r0 } +type ExecutableNode_GetResourceExtensions struct { + *mock.Call +} + +func (_m ExecutableNode_GetResourceExtensions) Return(_a0 *core.ResourceExtensions) *ExecutableNode_GetResourceExtensions { + return &ExecutableNode_GetResourceExtensions{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNode) OnGetResourceExtensions() *ExecutableNode_GetResourceExtensions { + c_call := _m.On("GetResourceExtensions") + return &ExecutableNode_GetResourceExtensions{Call: c_call} +} + +func (_m *ExecutableNode) OnGetResourceExtensionsMatch(matchers ...interface{}) *ExecutableNode_GetResourceExtensions { + c_call := _m.On("GetResourceExtensions", matchers...) + return &ExecutableNode_GetResourceExtensions{Call: c_call} +} + +// GetResourceExtensions provides a mock function with given fields: +func (_m *ExecutableNode) GetResourceExtensions() *core.ResourceExtensions { + ret := _m.Called() + + var r0 *core.ResourceExtensions + if rf, ok := ret.Get(0).(func() *core.ResourceExtensions); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ResourceExtensions) + } + } + + return r0 +} + type ExecutableNode_GetResources struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go index 3ff76f3d53..436259d185 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -92,6 +92,30 @@ func (in *NodeMetadata) DeepCopyInto(out *NodeMetadata) { // Once we figure out the autogenerate story we can replace this } +type ResourceExtensions struct { + *core.ResourceExtensions +} + +func (in *ResourceExtensions) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.ResourceExtensions); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func (in *ResourceExtensions) UnmarshalJSON(b []byte) error { + in.ResourceExtensions = &core.ResourceExtensions{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.ResourceExtensions) +} + +func (in *ResourceExtensions) DeepCopyInto(out *ResourceExtensions) { + *out = *in + // We do not manipulate the object, so its ok + // Once we figure out the autogenerate story we can replace this +} + type NodeSpec struct { ID NodeID `json:"id"` Name string `json:"name,omitempty"` @@ -134,6 +158,10 @@ type NodeSpec struct { // If not specified, the pod will be dispatched by default scheduler. // +optional SchedulerName string `json:"schedulerName,omitempty" protobuf:"bytes,19,opt,name=schedulerName"` + // If specified, includes overrides for extended resources to allocate to the + // node. + // +optional + ResourceExtensions *ResourceExtensions `json:"resourceExtensions,omitempty" protobuf:"bytes,20,opt,name=resourceExtensions"` // If specified, the pod's tolerations. // +optional Tolerations []typesv1.Toleration `json:"tolerations,omitempty" protobuf:"bytes,22,opt,name=tolerations"` @@ -182,6 +210,13 @@ func (in *NodeSpec) GetResources() *typesv1.ResourceRequirements { return in.Resources } +func (in *NodeSpec) GetResourceExtensions() *core.ResourceExtensions { + if in.ResourceExtensions == nil { + return nil + } + return in.ResourceExtensions.ResourceExtensions +} + func (in *NodeSpec) GetOutputAlias() []Alias { return in.OutputAliases } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node.go b/flytepropeller/pkg/compiler/transformers/k8s/node.go index 2af4b9626f..0e93d64c11 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node.go @@ -48,6 +48,13 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile } } + var resourceExtensions *v1alpha1.ResourceExtensions + if resources != nil && resources.GetExtensions() != nil { + resourceExtensions = &v1alpha1.ResourceExtensions{ + ResourceExtensions: resources.GetExtensions(), + } + } + res, err := flytek8s.ToK8sResourceRequirements(resources) if err != nil { errs.Collect(errors.NewWorkflowBuildError(err)) @@ -76,15 +83,16 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile } nodeSpec := &v1alpha1.NodeSpec{ - ID: n.GetId(), - Name: name, - RetryStrategy: computeRetryStrategy(n, task), - ExecutionDeadline: timeout, - Resources: res, - OutputAliases: toAliasValueArray(n.GetOutputAliases()), - InputBindings: toBindingValueArray(n.GetInputs()), - ActiveDeadline: activeDeadline, - Interruptible: interruptible, + ID: n.GetId(), + Name: name, + RetryStrategy: computeRetryStrategy(n, task), + ExecutionDeadline: timeout, + Resources: res, + ResourceExtensions: resourceExtensions, + OutputAliases: toAliasValueArray(n.GetOutputAliases()), + InputBindings: toBindingValueArray(n.GetInputs()), + ActiveDeadline: activeDeadline, + Interruptible: interruptible, } switch v := n.GetTarget().(type) { diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go index 16f7008c16..77b40af4c2 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go @@ -117,6 +117,32 @@ func TestBuildNodeSpec(t *testing.T) { assert.NotNil(t, spec.Resources) assert.NotNil(t, spec.Resources.Requests.Cpu()) assert.Equal(t, expectedCPU.Value(), spec.Resources.Requests.Cpu().Value()) + assert.Nil(t, spec.GetResourceExtensions()) + }) + + t.Run("node with resource extensions overrides", func(t *testing.T) { + expectedGpuDevice := "nvidia-tesla-t4" + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_2"}, + }, + Overrides: &core.TaskNodeOverrides{ + Resources: &core.Resources{ + Extensions: &core.ResourceExtensions{ + GpuAccelerator: &core.GPUAccelerator{ + Device: "nvidia-tesla-t4", + }, + }, + }, + }, + }, + } + + spec := mustBuild(t, n, 1, errs.NewScope()) + assert.NotNil(t, spec.GetResourceExtensions()) + assert.NotNil(t, spec.GetResourceExtensions().GetGpuAccelerator()) + assert.Equal(t, expectedGpuDevice, spec.GetResourceExtensions().GetGpuAccelerator().GetDevice()) }) t.Run("LaunchPlanRef", func(t *testing.T) {