From ccf8c2aca2ef8c7348498dec6dae9ed767b3b533 Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Wed, 12 Jun 2024 09:34:17 +0800 Subject: [PATCH] [Feature] Refactor distributed job using common ReplicaSpec (#5355) * feat(proto): Define CommonReplicaSpec in common.proto Resolves: flyteorg/flyte#4408 Signed-off-by: Chi-Sheng Liu * chore(proto): Generate new clients corresponding to proto changes Resolves: flyteorg/flyte#4408 Signed-off-by: Chi-Sheng Liu * feat(replica-spec): Update corresponding golang files based on protobuf changes Resolves: flyteorg/flyte#4408 Signed-off-by: Chi-Sheng Liu --------- Signed-off-by: Chi-Sheng Liu --- .../gen/pb-es/flyteidl/plugins/common_pb.ts | 98 +++ .../flyteidl/plugins/kubeflow/common_pb.ts | 26 - .../pb-es/flyteidl/plugins/kubeflow/mpi_pb.ts | 24 +- .../flyteidl/plugins/kubeflow/pytorch_pb.ts | 26 +- .../plugins/kubeflow/tensorflow_pb.ts | 26 +- .../gen/pb-go/flyteidl/plugins/common.pb.go | 257 +++++++ .../flyteidl/plugins/kubeflow/common.pb.go | 158 ++--- .../pb-go/flyteidl/plugins/kubeflow/mpi.pb.go | 113 ++-- .../flyteidl/plugins/kubeflow/pytorch.pb.go | 112 ++-- .../plugins/kubeflow/tensorflow.pb.go | 117 ++-- .../flyteidl/plugins/common.swagger.json | 46 ++ .../pb_python/flyteidl/plugins/common_pb2.py | 30 + .../pb_python/flyteidl/plugins/common_pb2.pyi | 28 + .../flyteidl/plugins/common_pb2_grpc.py | 4 + .../flyteidl/plugins/kubeflow/common_pb2.py | 16 +- .../flyteidl/plugins/kubeflow/common_pb2.pyi | 15 +- .../flyteidl/plugins/kubeflow/mpi_pb2.py | 18 +- .../flyteidl/plugins/kubeflow/mpi_pb2.pyi | 9 +- .../flyteidl/plugins/kubeflow/pytorch_pb2.py | 18 +- .../flyteidl/plugins/kubeflow/pytorch_pb2.pyi | 9 +- .../plugins/kubeflow/tensorflow_pb2.py | 18 +- .../plugins/kubeflow/tensorflow_pb2.pyi | 9 +- .../gen/pb_rust/flyteidl.plugins.kubeflow.rs | 63 +- flyteidl/gen/pb_rust/flyteidl.plugins.rs | 45 ++ flyteidl/protos/flyteidl/plugins/common.proto | 27 + .../flyteidl/plugins/kubeflow/common.proto | 11 +- .../flyteidl/plugins/kubeflow/mpi.proto | 18 +- .../flyteidl/plugins/kubeflow/pytorch.proto | 18 +- .../plugins/kubeflow/tensorflow.proto | 16 +- .../k8s/kfoperators/common/common_operator.go | 58 +- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 345 ++++++---- .../k8s/kfoperators/pytorch/pytorch_test.go | 534 +++++++++------ .../k8s/kfoperators/tensorflow/tensorflow.go | 11 +- .../kfoperators/tensorflow/tensorflow_test.go | 634 +++++++++++------- 34 files changed, 2015 insertions(+), 942 deletions(-) create mode 100644 flyteidl/gen/pb-es/flyteidl/plugins/common_pb.ts create mode 100644 flyteidl/gen/pb-go/flyteidl/plugins/common.pb.go create mode 100644 flyteidl/gen/pb-go/gateway/flyteidl/plugins/common.swagger.json create mode 100644 flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.py create mode 100644 flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.pyi create mode 100644 flyteidl/gen/pb_python/flyteidl/plugins/common_pb2_grpc.py create mode 100644 flyteidl/protos/flyteidl/plugins/common.proto diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/common_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/common_pb.ts new file mode 100644 index 0000000000..37949692d1 --- /dev/null +++ b/flyteidl/gen/pb-es/flyteidl/plugins/common_pb.ts @@ -0,0 +1,98 @@ +// @generated by protoc-gen-es v1.7.2 with parameter "target=ts" +// @generated from file flyteidl/plugins/common.proto (package flyteidl.plugins, syntax proto3) +/* eslint-disable */ +// @ts-nocheck + +import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; +import { Message, proto3 } from "@bufbuild/protobuf"; +import { Resources } from "../core/tasks_pb.js"; + +/** + * @generated from enum flyteidl.plugins.RestartPolicy + */ +export enum RestartPolicy { + /** + * @generated from enum value: RESTART_POLICY_NEVER = 0; + */ + NEVER = 0, + + /** + * @generated from enum value: RESTART_POLICY_ON_FAILURE = 1; + */ + ON_FAILURE = 1, + + /** + * @generated from enum value: RESTART_POLICY_ALWAYS = 2; + */ + ALWAYS = 2, +} +// Retrieve enum metadata with: proto3.getEnumType(RestartPolicy) +proto3.util.setEnumType(RestartPolicy, "flyteidl.plugins.RestartPolicy", [ + { no: 0, name: "RESTART_POLICY_NEVER" }, + { no: 1, name: "RESTART_POLICY_ON_FAILURE" }, + { no: 2, name: "RESTART_POLICY_ALWAYS" }, +]); + +/** + * @generated from message flyteidl.plugins.CommonReplicaSpec + */ +export class CommonReplicaSpec extends Message { + /** + * Number of replicas + * + * @generated from field: int32 replicas = 1; + */ + replicas = 0; + + /** + * Image used for the replica group + * + * @generated from field: string image = 2; + */ + image = ""; + + /** + * Resources required for the replica group + * + * @generated from field: flyteidl.core.Resources resources = 3; + */ + resources?: Resources; + + /** + * RestartPolicy determines whether pods will be restarted when they exit + * + * @generated from field: flyteidl.plugins.RestartPolicy restart_policy = 4; + */ + restartPolicy = RestartPolicy.NEVER; + + constructor(data?: PartialMessage) { + super(); + proto3.util.initPartial(data, this); + } + + static readonly runtime: typeof proto3 = proto3; + static readonly typeName = "flyteidl.plugins.CommonReplicaSpec"; + static readonly fields: FieldList = proto3.util.newFieldList(() => [ + { no: 1, name: "replicas", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, + { no: 2, name: "image", kind: "scalar", T: 9 /* ScalarType.STRING */ }, + { no: 3, name: "resources", kind: "message", T: Resources }, + { no: 4, name: "restart_policy", kind: "enum", T: proto3.getEnumType(RestartPolicy) }, + ]); + + static fromBinary(bytes: Uint8Array, options?: Partial): CommonReplicaSpec { + return new CommonReplicaSpec().fromBinary(bytes, options); + } + + static fromJson(jsonValue: JsonValue, options?: Partial): CommonReplicaSpec { + return new CommonReplicaSpec().fromJson(jsonValue, options); + } + + static fromJsonString(jsonString: string, options?: Partial): CommonReplicaSpec { + return new CommonReplicaSpec().fromJsonString(jsonString, options); + } + + static equals(a: CommonReplicaSpec | PlainMessage | undefined, b: CommonReplicaSpec | PlainMessage | undefined): boolean { + return proto3.util.equals(CommonReplicaSpec, a, b); + } +} + diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/common_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/common_pb.ts index aec23a4da5..8ab5d03372 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/common_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/common_pb.ts @@ -6,32 +6,6 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3 } from "@bufbuild/protobuf"; -/** - * @generated from enum flyteidl.plugins.kubeflow.RestartPolicy - */ -export enum RestartPolicy { - /** - * @generated from enum value: RESTART_POLICY_NEVER = 0; - */ - NEVER = 0, - - /** - * @generated from enum value: RESTART_POLICY_ON_FAILURE = 1; - */ - ON_FAILURE = 1, - - /** - * @generated from enum value: RESTART_POLICY_ALWAYS = 2; - */ - ALWAYS = 2, -} -// Retrieve enum metadata with: proto3.getEnumType(RestartPolicy) -proto3.util.setEnumType(RestartPolicy, "flyteidl.plugins.kubeflow.RestartPolicy", [ - { no: 0, name: "RESTART_POLICY_NEVER" }, - { no: 1, name: "RESTART_POLICY_ON_FAILURE" }, - { no: 2, name: "RESTART_POLICY_ALWAYS" }, -]); - /** * @generated from enum flyteidl.plugins.kubeflow.CleanPodPolicy */ diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/mpi_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/mpi_pb.ts index 89ff16b82b..9364f1c082 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/mpi_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/mpi_pb.ts @@ -5,8 +5,9 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3 } from "@bufbuild/protobuf"; -import { RestartPolicy, RunPolicy } from "./common_pb.js"; +import { RunPolicy } from "./common_pb.js"; import { Resources } from "../../core/tasks_pb.js"; +import { CommonReplicaSpec, RestartPolicy } from "../common_pb.js"; /** * Proto for plugin that enables distributed training using https://github.com/kubeflow/mpi-operator @@ -82,30 +83,35 @@ export class DistributedMPITrainingTask extends Message { /** + * 1~4 deprecated. Use common instead. * Number of replicas * - * @generated from field: int32 replicas = 1; + * @generated from field: int32 replicas = 1 [deprecated = true]; + * @deprecated */ replicas = 0; /** * Image used for the replica group * - * @generated from field: string image = 2; + * @generated from field: string image = 2 [deprecated = true]; + * @deprecated */ image = ""; /** * Resources required for the replica group * - * @generated from field: flyteidl.core.Resources resources = 3; + * @generated from field: flyteidl.core.Resources resources = 3 [deprecated = true]; + * @deprecated */ resources?: Resources; /** * Restart policy determines whether pods will be restarted when they exit * - * @generated from field: flyteidl.plugins.kubeflow.RestartPolicy restart_policy = 4; + * @generated from field: flyteidl.plugins.RestartPolicy restart_policy = 4 [deprecated = true]; + * @deprecated */ restartPolicy = RestartPolicy.NEVER; @@ -116,6 +122,13 @@ export class DistributedMPITrainingReplicaSpec extends Message) { super(); proto3.util.initPartial(data, this); @@ -129,6 +142,7 @@ export class DistributedMPITrainingReplicaSpec extends Message): DistributedMPITrainingReplicaSpec { diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/pytorch_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/pytorch_pb.ts index 2dd38a56ba..fc5a17a460 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/pytorch_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/pytorch_pb.ts @@ -5,8 +5,9 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3 } from "@bufbuild/protobuf"; -import { RestartPolicy, RunPolicy } from "./common_pb.js"; +import { RunPolicy } from "./common_pb.js"; import { Resources } from "../../core/tasks_pb.js"; +import { CommonReplicaSpec, RestartPolicy } from "../common_pb.js"; /** * Custom proto for torch elastic config for distributed training using @@ -144,33 +145,45 @@ export class DistributedPyTorchTrainingTask extends Message { /** + * 1~4 deprecated. Use common instead. * Number of replicas * - * @generated from field: int32 replicas = 1; + * @generated from field: int32 replicas = 1 [deprecated = true]; + * @deprecated */ replicas = 0; /** * Image used for the replica group * - * @generated from field: string image = 2; + * @generated from field: string image = 2 [deprecated = true]; + * @deprecated */ image = ""; /** * Resources required for the replica group * - * @generated from field: flyteidl.core.Resources resources = 3; + * @generated from field: flyteidl.core.Resources resources = 3 [deprecated = true]; + * @deprecated */ resources?: Resources; /** - * RestartPolicy determines whether pods will be restarted when they exit + * Restart policy determines whether pods will be restarted when they exit * - * @generated from field: flyteidl.plugins.kubeflow.RestartPolicy restart_policy = 4; + * @generated from field: flyteidl.plugins.RestartPolicy restart_policy = 4 [deprecated = true]; + * @deprecated */ restartPolicy = RestartPolicy.NEVER; + /** + * The common replica spec + * + * @generated from field: flyteidl.plugins.CommonReplicaSpec common = 5; + */ + common?: CommonReplicaSpec; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -183,6 +196,7 @@ export class DistributedPyTorchTrainingReplicaSpec extends Message): DistributedPyTorchTrainingReplicaSpec { diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/tensorflow_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/tensorflow_pb.ts index 356385d858..5b9c001e30 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/tensorflow_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/kubeflow/tensorflow_pb.ts @@ -5,8 +5,9 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3 } from "@bufbuild/protobuf"; -import { RestartPolicy, RunPolicy } from "./common_pb.js"; +import { RunPolicy } from "./common_pb.js"; import { Resources } from "../../core/tasks_pb.js"; +import { CommonReplicaSpec, RestartPolicy } from "../common_pb.js"; /** * Proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator @@ -88,33 +89,45 @@ export class DistributedTensorflowTrainingTask extends Message { /** + * 1~4 deprecated. Use common instead. * Number of replicas * - * @generated from field: int32 replicas = 1; + * @generated from field: int32 replicas = 1 [deprecated = true]; + * @deprecated */ replicas = 0; /** * Image used for the replica group * - * @generated from field: string image = 2; + * @generated from field: string image = 2 [deprecated = true]; + * @deprecated */ image = ""; /** * Resources required for the replica group * - * @generated from field: flyteidl.core.Resources resources = 3; + * @generated from field: flyteidl.core.Resources resources = 3 [deprecated = true]; + * @deprecated */ resources?: Resources; /** - * RestartPolicy Determines whether pods will be restarted when they exit + * Restart policy determines whether pods will be restarted when they exit * - * @generated from field: flyteidl.plugins.kubeflow.RestartPolicy restart_policy = 4; + * @generated from field: flyteidl.plugins.RestartPolicy restart_policy = 4 [deprecated = true]; + * @deprecated */ restartPolicy = RestartPolicy.NEVER; + /** + * The common replica spec + * + * @generated from field: flyteidl.plugins.CommonReplicaSpec common = 5; + */ + common?: CommonReplicaSpec; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -127,6 +140,7 @@ export class DistributedTensorflowTrainingReplicaSpec extends Message): DistributedTensorflowTrainingReplicaSpec { diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/common.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/common.pb.go new file mode 100644 index 0000000000..f92acdc2ee --- /dev/null +++ b/flyteidl/gen/pb-go/flyteidl/plugins/common.pb.go @@ -0,0 +1,257 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc (unknown) +// source: flyteidl/plugins/common.proto + +package plugins + +import ( + core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type RestartPolicy int32 + +const ( + RestartPolicy_RESTART_POLICY_NEVER RestartPolicy = 0 + RestartPolicy_RESTART_POLICY_ON_FAILURE RestartPolicy = 1 + RestartPolicy_RESTART_POLICY_ALWAYS RestartPolicy = 2 +) + +// Enum value maps for RestartPolicy. +var ( + RestartPolicy_name = map[int32]string{ + 0: "RESTART_POLICY_NEVER", + 1: "RESTART_POLICY_ON_FAILURE", + 2: "RESTART_POLICY_ALWAYS", + } + RestartPolicy_value = map[string]int32{ + "RESTART_POLICY_NEVER": 0, + "RESTART_POLICY_ON_FAILURE": 1, + "RESTART_POLICY_ALWAYS": 2, + } +) + +func (x RestartPolicy) Enum() *RestartPolicy { + p := new(RestartPolicy) + *p = x + return p +} + +func (x RestartPolicy) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RestartPolicy) Descriptor() protoreflect.EnumDescriptor { + return file_flyteidl_plugins_common_proto_enumTypes[0].Descriptor() +} + +func (RestartPolicy) Type() protoreflect.EnumType { + return &file_flyteidl_plugins_common_proto_enumTypes[0] +} + +func (x RestartPolicy) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RestartPolicy.Descriptor instead. +func (RestartPolicy) EnumDescriptor() ([]byte, []int) { + return file_flyteidl_plugins_common_proto_rawDescGZIP(), []int{0} +} + +type CommonReplicaSpec struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Number of replicas + Replicas int32 `protobuf:"varint,1,opt,name=replicas,proto3" json:"replicas,omitempty"` + // Image used for the replica group + Image string `protobuf:"bytes,2,opt,name=image,proto3" json:"image,omitempty"` + // Resources required for the replica group + Resources *core.Resources `protobuf:"bytes,3,opt,name=resources,proto3" json:"resources,omitempty"` + // RestartPolicy determines whether pods will be restarted when they exit + RestartPolicy RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.RestartPolicy" json:"restart_policy,omitempty"` +} + +func (x *CommonReplicaSpec) Reset() { + *x = CommonReplicaSpec{} + if protoimpl.UnsafeEnabled { + mi := &file_flyteidl_plugins_common_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CommonReplicaSpec) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CommonReplicaSpec) ProtoMessage() {} + +func (x *CommonReplicaSpec) ProtoReflect() protoreflect.Message { + mi := &file_flyteidl_plugins_common_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CommonReplicaSpec.ProtoReflect.Descriptor instead. +func (*CommonReplicaSpec) Descriptor() ([]byte, []int) { + return file_flyteidl_plugins_common_proto_rawDescGZIP(), []int{0} +} + +func (x *CommonReplicaSpec) GetReplicas() int32 { + if x != nil { + return x.Replicas + } + return 0 +} + +func (x *CommonReplicaSpec) GetImage() string { + if x != nil { + return x.Image + } + return "" +} + +func (x *CommonReplicaSpec) GetResources() *core.Resources { + if x != nil { + return x.Resources + } + return nil +} + +func (x *CommonReplicaSpec) GetRestartPolicy() RestartPolicy { + if x != nil { + return x.RestartPolicy + } + return RestartPolicy_RESTART_POLICY_NEVER +} + +var File_flyteidl_plugins_common_proto protoreflect.FileDescriptor + +var file_flyteidl_plugins_common_proto_rawDesc = []byte{ + 0x0a, 0x1d, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, + 0x6e, 0x73, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x10, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, + 0x73, 0x1a, 0x19, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x63, 0x6f, 0x72, 0x65, + 0x2f, 0x74, 0x61, 0x73, 0x6b, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xc5, 0x01, 0x0a, + 0x11, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x53, 0x70, + 0x65, 0x63, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x14, + 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, + 0x6d, 0x61, 0x67, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, + 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x12, 0x46, 0x0a, 0x0e, + 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, + 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, + 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x2a, 0x63, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, + 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x18, 0x0a, 0x14, 0x52, 0x45, 0x53, 0x54, 0x41, 0x52, 0x54, + 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x4e, 0x45, 0x56, 0x45, 0x52, 0x10, 0x00, 0x12, + 0x1d, 0x0a, 0x19, 0x52, 0x45, 0x53, 0x54, 0x41, 0x52, 0x54, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, + 0x59, 0x5f, 0x4f, 0x4e, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x55, 0x52, 0x45, 0x10, 0x01, 0x12, 0x19, + 0x0a, 0x15, 0x52, 0x45, 0x53, 0x54, 0x41, 0x52, 0x54, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, + 0x5f, 0x41, 0x4c, 0x57, 0x41, 0x59, 0x53, 0x10, 0x02, 0x42, 0xc3, 0x01, 0x0a, 0x14, 0x63, 0x6f, + 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, + 0x6e, 0x73, 0x42, 0x0b, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, + 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, + 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, + 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, + 0xa2, 0x02, 0x03, 0x46, 0x50, 0x58, 0xaa, 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, + 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xca, 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, + 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xe2, 0x02, 0x1c, 0x46, + 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, + 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x11, 0x46, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_flyteidl_plugins_common_proto_rawDescOnce sync.Once + file_flyteidl_plugins_common_proto_rawDescData = file_flyteidl_plugins_common_proto_rawDesc +) + +func file_flyteidl_plugins_common_proto_rawDescGZIP() []byte { + file_flyteidl_plugins_common_proto_rawDescOnce.Do(func() { + file_flyteidl_plugins_common_proto_rawDescData = protoimpl.X.CompressGZIP(file_flyteidl_plugins_common_proto_rawDescData) + }) + return file_flyteidl_plugins_common_proto_rawDescData +} + +var file_flyteidl_plugins_common_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_flyteidl_plugins_common_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_flyteidl_plugins_common_proto_goTypes = []interface{}{ + (RestartPolicy)(0), // 0: flyteidl.plugins.RestartPolicy + (*CommonReplicaSpec)(nil), // 1: flyteidl.plugins.CommonReplicaSpec + (*core.Resources)(nil), // 2: flyteidl.core.Resources +} +var file_flyteidl_plugins_common_proto_depIdxs = []int32{ + 2, // 0: flyteidl.plugins.CommonReplicaSpec.resources:type_name -> flyteidl.core.Resources + 0, // 1: flyteidl.plugins.CommonReplicaSpec.restart_policy:type_name -> flyteidl.plugins.RestartPolicy + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_flyteidl_plugins_common_proto_init() } +func file_flyteidl_plugins_common_proto_init() { + if File_flyteidl_plugins_common_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_flyteidl_plugins_common_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CommonReplicaSpec); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_flyteidl_plugins_common_proto_rawDesc, + NumEnums: 1, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_flyteidl_plugins_common_proto_goTypes, + DependencyIndexes: file_flyteidl_plugins_common_proto_depIdxs, + EnumInfos: file_flyteidl_plugins_common_proto_enumTypes, + MessageInfos: file_flyteidl_plugins_common_proto_msgTypes, + }.Build() + File_flyteidl_plugins_common_proto = out.File + file_flyteidl_plugins_common_proto_rawDesc = nil + file_flyteidl_plugins_common_proto_goTypes = nil + file_flyteidl_plugins_common_proto_depIdxs = nil +} diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/common.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/common.pb.go index ba03051e1e..685ffa716b 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/common.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/common.pb.go @@ -4,9 +4,10 @@ // protoc (unknown) // source: flyteidl/plugins/kubeflow/common.proto -package plugins +package kubeflow import ( + plugins "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -20,54 +21,18 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type RestartPolicy int32 +// Symbols defined in public import of flyteidl/plugins/common.proto. -const ( - RestartPolicy_RESTART_POLICY_NEVER RestartPolicy = 0 - RestartPolicy_RESTART_POLICY_ON_FAILURE RestartPolicy = 1 - RestartPolicy_RESTART_POLICY_ALWAYS RestartPolicy = 2 -) - -// Enum value maps for RestartPolicy. -var ( - RestartPolicy_name = map[int32]string{ - 0: "RESTART_POLICY_NEVER", - 1: "RESTART_POLICY_ON_FAILURE", - 2: "RESTART_POLICY_ALWAYS", - } - RestartPolicy_value = map[string]int32{ - "RESTART_POLICY_NEVER": 0, - "RESTART_POLICY_ON_FAILURE": 1, - "RESTART_POLICY_ALWAYS": 2, - } -) +type RestartPolicy = plugins.RestartPolicy -func (x RestartPolicy) Enum() *RestartPolicy { - p := new(RestartPolicy) - *p = x - return p -} +const RestartPolicy_RESTART_POLICY_NEVER = plugins.RestartPolicy_RESTART_POLICY_NEVER +const RestartPolicy_RESTART_POLICY_ON_FAILURE = plugins.RestartPolicy_RESTART_POLICY_ON_FAILURE +const RestartPolicy_RESTART_POLICY_ALWAYS = plugins.RestartPolicy_RESTART_POLICY_ALWAYS -func (x RestartPolicy) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} +var RestartPolicy_name = plugins.RestartPolicy_name +var RestartPolicy_value = plugins.RestartPolicy_value -func (RestartPolicy) Descriptor() protoreflect.EnumDescriptor { - return file_flyteidl_plugins_kubeflow_common_proto_enumTypes[0].Descriptor() -} - -func (RestartPolicy) Type() protoreflect.EnumType { - return &file_flyteidl_plugins_kubeflow_common_proto_enumTypes[0] -} - -func (x RestartPolicy) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use RestartPolicy.Descriptor instead. -func (RestartPolicy) EnumDescriptor() ([]byte, []int) { - return file_flyteidl_plugins_kubeflow_common_proto_rawDescGZIP(), []int{0} -} +type CommonReplicaSpec = plugins.CommonReplicaSpec type CleanPodPolicy int32 @@ -102,11 +67,11 @@ func (x CleanPodPolicy) String() string { } func (CleanPodPolicy) Descriptor() protoreflect.EnumDescriptor { - return file_flyteidl_plugins_kubeflow_common_proto_enumTypes[1].Descriptor() + return file_flyteidl_plugins_kubeflow_common_proto_enumTypes[0].Descriptor() } func (CleanPodPolicy) Type() protoreflect.EnumType { - return &file_flyteidl_plugins_kubeflow_common_proto_enumTypes[1] + return &file_flyteidl_plugins_kubeflow_common_proto_enumTypes[0] } func (x CleanPodPolicy) Number() protoreflect.EnumNumber { @@ -115,7 +80,7 @@ func (x CleanPodPolicy) Number() protoreflect.EnumNumber { // Deprecated: Use CleanPodPolicy.Descriptor instead. func (CleanPodPolicy) EnumDescriptor() ([]byte, []int) { - return file_flyteidl_plugins_kubeflow_common_proto_rawDescGZIP(), []int{1} + return file_flyteidl_plugins_kubeflow_common_proto_rawDescGZIP(), []int{0} } type RunPolicy struct { @@ -201,51 +166,47 @@ var file_flyteidl_plugins_kubeflow_common_proto_rawDesc = []byte{ 0x6e, 0x73, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x19, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, - 0x6c, 0x6f, 0x77, 0x22, 0xfa, 0x01, 0x0a, 0x09, 0x52, 0x75, 0x6e, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x12, 0x53, 0x0a, 0x10, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x5f, 0x70, 0x6f, 0x64, 0x5f, 0x70, - 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, - 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x50, 0x6f, 0x64, - 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x52, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x50, 0x6f, 0x64, - 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x3b, 0x0a, 0x1a, 0x74, 0x74, 0x6c, 0x5f, 0x73, 0x65, - 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x5f, 0x66, 0x69, 0x6e, 0x69, - 0x73, 0x68, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x17, 0x74, 0x74, 0x6c, 0x53, - 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x41, 0x66, 0x74, 0x65, 0x72, 0x46, 0x69, 0x6e, 0x69, 0x73, - 0x68, 0x65, 0x64, 0x12, 0x36, 0x0a, 0x17, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x64, 0x65, - 0x61, 0x64, 0x6c, 0x69, 0x6e, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x05, 0x52, 0x15, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x44, 0x65, 0x61, 0x64, - 0x6c, 0x69, 0x6e, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x62, - 0x61, 0x63, 0x6b, 0x6f, 0x66, 0x66, 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x0c, 0x62, 0x61, 0x63, 0x6b, 0x6f, 0x66, 0x66, 0x4c, 0x69, 0x6d, 0x69, 0x74, - 0x2a, 0x63, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x12, 0x18, 0x0a, 0x14, 0x52, 0x45, 0x53, 0x54, 0x41, 0x52, 0x54, 0x5f, 0x50, 0x4f, 0x4c, - 0x49, 0x43, 0x59, 0x5f, 0x4e, 0x45, 0x56, 0x45, 0x52, 0x10, 0x00, 0x12, 0x1d, 0x0a, 0x19, 0x52, - 0x45, 0x53, 0x54, 0x41, 0x52, 0x54, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x4f, 0x4e, - 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x55, 0x52, 0x45, 0x10, 0x01, 0x12, 0x19, 0x0a, 0x15, 0x52, 0x45, - 0x53, 0x54, 0x41, 0x52, 0x54, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x41, 0x4c, 0x57, - 0x41, 0x59, 0x53, 0x10, 0x02, 0x2a, 0x60, 0x0a, 0x0e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x50, 0x6f, - 0x64, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x18, 0x0a, 0x14, 0x43, 0x4c, 0x45, 0x41, 0x4e, - 0x50, 0x4f, 0x44, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x4e, 0x4f, 0x4e, 0x45, 0x10, - 0x00, 0x12, 0x1b, 0x0a, 0x17, 0x43, 0x4c, 0x45, 0x41, 0x4e, 0x50, 0x4f, 0x44, 0x5f, 0x50, 0x4f, - 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x52, 0x55, 0x4e, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x17, - 0x0a, 0x13, 0x43, 0x4c, 0x45, 0x41, 0x4e, 0x50, 0x4f, 0x44, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, - 0x59, 0x5f, 0x41, 0x4c, 0x4c, 0x10, 0x02, 0x42, 0xf1, 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, - 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, - 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x42, 0x0b, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, - 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, - 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, - 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, - 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, - 0x2e, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, - 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xe2, 0x02, 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, - 0x77, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, - 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, - 0x73, 0x3a, 0x3a, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x6c, 0x6f, 0x77, 0x1a, 0x1d, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x22, 0xfa, 0x01, 0x0a, 0x09, 0x52, 0x75, 0x6e, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x12, 0x53, 0x0a, 0x10, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x5f, 0x70, 0x6f, 0x64, 0x5f, 0x70, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x66, 0x6c, 0x79, + 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, + 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x50, 0x6f, 0x64, 0x50, + 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x52, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x50, 0x6f, 0x64, 0x50, + 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x3b, 0x0a, 0x1a, 0x74, 0x74, 0x6c, 0x5f, 0x73, 0x65, 0x63, + 0x6f, 0x6e, 0x64, 0x73, 0x5f, 0x61, 0x66, 0x74, 0x65, 0x72, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, + 0x68, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x17, 0x74, 0x74, 0x6c, 0x53, 0x65, + 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x41, 0x66, 0x74, 0x65, 0x72, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, + 0x65, 0x64, 0x12, 0x36, 0x0a, 0x17, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x64, 0x65, 0x61, + 0x64, 0x6c, 0x69, 0x6e, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x15, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x44, 0x65, 0x61, 0x64, 0x6c, + 0x69, 0x6e, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x62, 0x61, + 0x63, 0x6b, 0x6f, 0x66, 0x66, 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x62, 0x61, 0x63, 0x6b, 0x6f, 0x66, 0x66, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x2a, + 0x60, 0x0a, 0x0e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x50, 0x6f, 0x64, 0x50, 0x6f, 0x6c, 0x69, 0x63, + 0x79, 0x12, 0x18, 0x0a, 0x14, 0x43, 0x4c, 0x45, 0x41, 0x4e, 0x50, 0x4f, 0x44, 0x5f, 0x50, 0x4f, + 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x4e, 0x4f, 0x4e, 0x45, 0x10, 0x00, 0x12, 0x1b, 0x0a, 0x17, 0x43, + 0x4c, 0x45, 0x41, 0x4e, 0x50, 0x4f, 0x44, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x52, + 0x55, 0x4e, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x43, 0x4c, 0x45, 0x41, + 0x4e, 0x50, 0x4f, 0x44, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x5f, 0x41, 0x4c, 0x4c, 0x10, + 0x02, 0x42, 0xfa, 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, + 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, + 0x6c, 0x6f, 0x77, 0x42, 0x0b, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, + 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, + 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, + 0x73, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, + 0xaa, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, + 0x69, 0x6e, 0x73, 0x2e, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, + 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, + 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xe2, 0x02, 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, + 0x66, 0x6c, 0x6f, 0x77, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0xea, 0x02, 0x1b, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x3a, 0x3a, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x50, 0x00, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -260,15 +221,14 @@ func file_flyteidl_plugins_kubeflow_common_proto_rawDescGZIP() []byte { return file_flyteidl_plugins_kubeflow_common_proto_rawDescData } -var file_flyteidl_plugins_kubeflow_common_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_flyteidl_plugins_kubeflow_common_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_flyteidl_plugins_kubeflow_common_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_flyteidl_plugins_kubeflow_common_proto_goTypes = []interface{}{ - (RestartPolicy)(0), // 0: flyteidl.plugins.kubeflow.RestartPolicy - (CleanPodPolicy)(0), // 1: flyteidl.plugins.kubeflow.CleanPodPolicy - (*RunPolicy)(nil), // 2: flyteidl.plugins.kubeflow.RunPolicy + (CleanPodPolicy)(0), // 0: flyteidl.plugins.kubeflow.CleanPodPolicy + (*RunPolicy)(nil), // 1: flyteidl.plugins.kubeflow.RunPolicy } var file_flyteidl_plugins_kubeflow_common_proto_depIdxs = []int32{ - 1, // 0: flyteidl.plugins.kubeflow.RunPolicy.clean_pod_policy:type_name -> flyteidl.plugins.kubeflow.CleanPodPolicy + 0, // 0: flyteidl.plugins.kubeflow.RunPolicy.clean_pod_policy:type_name -> flyteidl.plugins.kubeflow.CleanPodPolicy 1, // [1:1] is the sub-list for method output_type 1, // [1:1] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name @@ -300,7 +260,7 @@ func file_flyteidl_plugins_kubeflow_common_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_flyteidl_plugins_kubeflow_common_proto_rawDesc, - NumEnums: 2, + NumEnums: 1, NumMessages: 1, NumExtensions: 0, NumServices: 0, diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/mpi.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/mpi.pb.go index 4dcea4912a..ee2bf87273 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/mpi.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/mpi.pb.go @@ -4,10 +4,11 @@ // protoc (unknown) // source: flyteidl/plugins/kubeflow/mpi.proto -package plugins +package kubeflow import ( core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + plugins "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -105,16 +106,27 @@ type DistributedMPITrainingReplicaSpec struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // 1~4 deprecated. Use common instead. // Number of replicas + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. Replicas int32 `protobuf:"varint,1,opt,name=replicas,proto3" json:"replicas,omitempty"` // Image used for the replica group + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. Image string `protobuf:"bytes,2,opt,name=image,proto3" json:"image,omitempty"` // Resources required for the replica group + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. Resources *core.Resources `protobuf:"bytes,3,opt,name=resources,proto3" json:"resources,omitempty"` // Restart policy determines whether pods will be restarted when they exit - RestartPolicy RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.kubeflow.RestartPolicy" json:"restart_policy,omitempty"` + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. + RestartPolicy plugins.RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.RestartPolicy" json:"restart_policy,omitempty"` // MPI sometimes requires different command set for different replica groups Command []string `protobuf:"bytes,5,rep,name=command,proto3" json:"command,omitempty"` + // The common replica spec + Common *plugins.CommonReplicaSpec `protobuf:"bytes,6,opt,name=common,proto3" json:"common,omitempty"` } func (x *DistributedMPITrainingReplicaSpec) Reset() { @@ -149,6 +161,7 @@ func (*DistributedMPITrainingReplicaSpec) Descriptor() ([]byte, []int) { return file_flyteidl_plugins_kubeflow_mpi_proto_rawDescGZIP(), []int{1} } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. func (x *DistributedMPITrainingReplicaSpec) GetReplicas() int32 { if x != nil { return x.Replicas @@ -156,6 +169,7 @@ func (x *DistributedMPITrainingReplicaSpec) GetReplicas() int32 { return 0 } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. func (x *DistributedMPITrainingReplicaSpec) GetImage() string { if x != nil { return x.Image @@ -163,6 +177,7 @@ func (x *DistributedMPITrainingReplicaSpec) GetImage() string { return "" } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. func (x *DistributedMPITrainingReplicaSpec) GetResources() *core.Resources { if x != nil { return x.Resources @@ -170,11 +185,12 @@ func (x *DistributedMPITrainingReplicaSpec) GetResources() *core.Resources { return nil } -func (x *DistributedMPITrainingReplicaSpec) GetRestartPolicy() RestartPolicy { +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/mpi.proto. +func (x *DistributedMPITrainingReplicaSpec) GetRestartPolicy() plugins.RestartPolicy { if x != nil { return x.RestartPolicy } - return RestartPolicy_RESTART_POLICY_NEVER + return plugins.RestartPolicy(0) } func (x *DistributedMPITrainingReplicaSpec) GetCommand() []string { @@ -184,6 +200,13 @@ func (x *DistributedMPITrainingReplicaSpec) GetCommand() []string { return nil } +func (x *DistributedMPITrainingReplicaSpec) GetCommon() *plugins.CommonReplicaSpec { + if x != nil { + return x.Common + } + return nil +} + var File_flyteidl_plugins_kubeflow_mpi_proto protoreflect.FileDescriptor var file_flyteidl_plugins_kubeflow_mpi_proto_rawDesc = []byte{ @@ -216,38 +239,42 @@ var file_flyteidl_plugins_kubeflow_mpi_proto_rawDesc = []byte{ 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x75, 0x6e, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x52, 0x09, 0x72, 0x75, 0x6e, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x6c, 0x6f, 0x74, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x73, 0x6c, 0x6f, 0x74, 0x73, 0x22, - 0xf8, 0x01, 0x0a, 0x21, 0x44, 0x69, 0x73, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x64, 0x4d, + 0xbc, 0x02, 0x0a, 0x21, 0x44, 0x69, 0x73, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x64, 0x4d, 0x50, 0x49, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x53, 0x70, 0x65, 0x63, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, - 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, - 0x72, 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, - 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, - 0x72, 0x63, 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x12, - 0x4f, 0x0a, 0x0e, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x28, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, - 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, - 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x18, 0x05, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x42, 0xee, 0x01, 0x0a, 0x1d, 0x63, - 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, - 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x42, 0x08, 0x4d, 0x70, - 0x69, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, - 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, - 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, - 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, - 0x2e, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, - 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xe2, 0x02, 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, - 0x77, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, - 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, - 0x73, 0x3a, 0x3a, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x61, 0x53, 0x70, 0x65, 0x63, 0x12, 0x1e, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x02, 0x18, 0x01, 0x52, 0x08, 0x72, 0x65, 0x70, + 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x18, 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x12, + 0x3a, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, + 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x42, 0x02, 0x18, 0x01, + 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x12, 0x4a, 0x0a, 0x0e, 0x72, + 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, + 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, + 0x6c, 0x69, 0x63, 0x79, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, + 0x6e, 0x64, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, + 0x64, 0x12, 0x3b, 0x0a, 0x06, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x23, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, + 0x63, 0x61, 0x53, 0x70, 0x65, 0x63, 0x52, 0x06, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x42, 0xf7, + 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, + 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, + 0x42, 0x08, 0x4d, 0x70, 0x69, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, + 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, + 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x6b, 0x75, 0x62, 0x65, + 0x66, 0x6c, 0x6f, 0x77, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, 0x46, 0x6c, 0x79, + 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x4b, 0x75, + 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, + 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, + 0x6f, 0x77, 0xe2, 0x02, 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x5c, 0x47, + 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, 0x46, 0x6c, 0x79, + 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x3a, 0x3a, + 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -268,19 +295,21 @@ var file_flyteidl_plugins_kubeflow_mpi_proto_goTypes = []interface{}{ (*DistributedMPITrainingReplicaSpec)(nil), // 1: flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec (*RunPolicy)(nil), // 2: flyteidl.plugins.kubeflow.RunPolicy (*core.Resources)(nil), // 3: flyteidl.core.Resources - (RestartPolicy)(0), // 4: flyteidl.plugins.kubeflow.RestartPolicy + (plugins.RestartPolicy)(0), // 4: flyteidl.plugins.RestartPolicy + (*plugins.CommonReplicaSpec)(nil), // 5: flyteidl.plugins.CommonReplicaSpec } var file_flyteidl_plugins_kubeflow_mpi_proto_depIdxs = []int32{ 1, // 0: flyteidl.plugins.kubeflow.DistributedMPITrainingTask.worker_replicas:type_name -> flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec 1, // 1: flyteidl.plugins.kubeflow.DistributedMPITrainingTask.launcher_replicas:type_name -> flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec 2, // 2: flyteidl.plugins.kubeflow.DistributedMPITrainingTask.run_policy:type_name -> flyteidl.plugins.kubeflow.RunPolicy 3, // 3: flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec.resources:type_name -> flyteidl.core.Resources - 4, // 4: flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec.restart_policy:type_name -> flyteidl.plugins.kubeflow.RestartPolicy - 5, // [5:5] is the sub-list for method output_type - 5, // [5:5] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name + 4, // 4: flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec.restart_policy:type_name -> flyteidl.plugins.RestartPolicy + 5, // 5: flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpec.common:type_name -> flyteidl.plugins.CommonReplicaSpec + 6, // [6:6] is the sub-list for method output_type + 6, // [6:6] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name } func init() { file_flyteidl_plugins_kubeflow_mpi_proto_init() } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/pytorch.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/pytorch.pb.go index 4dbabf4ae3..7107835725 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/pytorch.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/pytorch.pb.go @@ -4,10 +4,11 @@ // protoc (unknown) // source: flyteidl/plugins/kubeflow/pytorch.proto -package plugins +package kubeflow import ( core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + plugins "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -185,14 +186,25 @@ type DistributedPyTorchTrainingReplicaSpec struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // 1~4 deprecated. Use common instead. // Number of replicas + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. Replicas int32 `protobuf:"varint,1,opt,name=replicas,proto3" json:"replicas,omitempty"` // Image used for the replica group + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. Image string `protobuf:"bytes,2,opt,name=image,proto3" json:"image,omitempty"` // Resources required for the replica group + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. Resources *core.Resources `protobuf:"bytes,3,opt,name=resources,proto3" json:"resources,omitempty"` - // RestartPolicy determines whether pods will be restarted when they exit - RestartPolicy RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.kubeflow.RestartPolicy" json:"restart_policy,omitempty"` + // Restart policy determines whether pods will be restarted when they exit + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. + RestartPolicy plugins.RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.RestartPolicy" json:"restart_policy,omitempty"` + // The common replica spec + Common *plugins.CommonReplicaSpec `protobuf:"bytes,5,opt,name=common,proto3" json:"common,omitempty"` } func (x *DistributedPyTorchTrainingReplicaSpec) Reset() { @@ -227,6 +239,7 @@ func (*DistributedPyTorchTrainingReplicaSpec) Descriptor() ([]byte, []int) { return file_flyteidl_plugins_kubeflow_pytorch_proto_rawDescGZIP(), []int{2} } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. func (x *DistributedPyTorchTrainingReplicaSpec) GetReplicas() int32 { if x != nil { return x.Replicas @@ -234,6 +247,7 @@ func (x *DistributedPyTorchTrainingReplicaSpec) GetReplicas() int32 { return 0 } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. func (x *DistributedPyTorchTrainingReplicaSpec) GetImage() string { if x != nil { return x.Image @@ -241,6 +255,7 @@ func (x *DistributedPyTorchTrainingReplicaSpec) GetImage() string { return "" } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. func (x *DistributedPyTorchTrainingReplicaSpec) GetResources() *core.Resources { if x != nil { return x.Resources @@ -248,11 +263,19 @@ func (x *DistributedPyTorchTrainingReplicaSpec) GetResources() *core.Resources { return nil } -func (x *DistributedPyTorchTrainingReplicaSpec) GetRestartPolicy() RestartPolicy { +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/pytorch.proto. +func (x *DistributedPyTorchTrainingReplicaSpec) GetRestartPolicy() plugins.RestartPolicy { if x != nil { return x.RestartPolicy } - return RestartPolicy_RESTART_POLICY_NEVER + return plugins.RestartPolicy(0) +} + +func (x *DistributedPyTorchTrainingReplicaSpec) GetCommon() *plugins.CommonReplicaSpec { + if x != nil { + return x.Common + } + return nil } var File_flyteidl_plugins_kubeflow_pytorch_proto protoreflect.FileDescriptor @@ -303,37 +326,42 @@ var file_flyteidl_plugins_kubeflow_pytorch_proto_rawDesc = []byte{ 0x0b, 0x32, 0x28, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x45, 0x6c, 0x61, 0x73, 0x74, 0x69, 0x63, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x65, 0x6c, 0x61, - 0x73, 0x74, 0x69, 0x63, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xe2, 0x01, 0x0a, 0x25, 0x44, + 0x73, 0x74, 0x69, 0x63, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xa6, 0x02, 0x0a, 0x25, 0x44, 0x69, 0x73, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x64, 0x50, 0x79, 0x54, 0x6f, 0x72, 0x63, 0x68, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, - 0x53, 0x70, 0x65, 0x63, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, - 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, - 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, - 0x63, 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x12, 0x4f, - 0x0a, 0x0e, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x28, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, - 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x42, - 0xf2, 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, - 0x77, 0x42, 0x0c, 0x50, 0x79, 0x74, 0x6f, 0x72, 0x63, 0x68, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, - 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, - 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, - 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, - 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, - 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, - 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xe2, 0x02, - 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, - 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x3a, 0x3a, 0x4b, 0x75, 0x62, 0x65, - 0x66, 0x6c, 0x6f, 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x53, 0x70, 0x65, 0x63, 0x12, 0x1e, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x02, 0x18, 0x01, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, + 0x69, 0x63, 0x61, 0x73, 0x12, 0x18, 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x3a, + 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, + 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x42, 0x02, 0x18, 0x01, 0x52, + 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x12, 0x4a, 0x0a, 0x0e, 0x72, 0x65, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x3b, 0x0a, 0x06, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, + 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, + 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x53, 0x70, 0x65, 0x63, 0x52, 0x06, 0x63, 0x6f, 0x6d, + 0x6d, 0x6f, 0x6e, 0x42, 0xfb, 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, + 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, + 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x42, 0x0c, 0x50, 0x79, 0x74, 0x6f, 0x72, 0x63, 0x68, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, + 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, + 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xa2, 0x02, 0x03, + 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, + 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xca, + 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, + 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xe2, 0x02, 0x25, 0x46, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, + 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, + 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x3a, 0x3a, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, + 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -355,7 +383,8 @@ var file_flyteidl_plugins_kubeflow_pytorch_proto_goTypes = []interface{}{ (*DistributedPyTorchTrainingReplicaSpec)(nil), // 2: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpec (*RunPolicy)(nil), // 3: flyteidl.plugins.kubeflow.RunPolicy (*core.Resources)(nil), // 4: flyteidl.core.Resources - (RestartPolicy)(0), // 5: flyteidl.plugins.kubeflow.RestartPolicy + (plugins.RestartPolicy)(0), // 5: flyteidl.plugins.RestartPolicy + (*plugins.CommonReplicaSpec)(nil), // 6: flyteidl.plugins.CommonReplicaSpec } var file_flyteidl_plugins_kubeflow_pytorch_proto_depIdxs = []int32{ 2, // 0: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingTask.worker_replicas:type_name -> flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpec @@ -363,12 +392,13 @@ var file_flyteidl_plugins_kubeflow_pytorch_proto_depIdxs = []int32{ 3, // 2: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingTask.run_policy:type_name -> flyteidl.plugins.kubeflow.RunPolicy 0, // 3: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingTask.elastic_config:type_name -> flyteidl.plugins.kubeflow.ElasticConfig 4, // 4: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpec.resources:type_name -> flyteidl.core.Resources - 5, // 5: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpec.restart_policy:type_name -> flyteidl.plugins.kubeflow.RestartPolicy - 6, // [6:6] is the sub-list for method output_type - 6, // [6:6] is the sub-list for method input_type - 6, // [6:6] is the sub-list for extension type_name - 6, // [6:6] is the sub-list for extension extendee - 0, // [0:6] is the sub-list for field type_name + 5, // 5: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpec.restart_policy:type_name -> flyteidl.plugins.RestartPolicy + 6, // 6: flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpec.common:type_name -> flyteidl.plugins.CommonReplicaSpec + 7, // [7:7] is the sub-list for method output_type + 7, // [7:7] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_flyteidl_plugins_kubeflow_pytorch_proto_init() } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go index ef6ec1899b..8fb9878d56 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go @@ -4,10 +4,11 @@ // protoc (unknown) // source: flyteidl/plugins/kubeflow/tensorflow.proto -package plugins +package kubeflow import ( core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + plugins "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -113,14 +114,25 @@ type DistributedTensorflowTrainingReplicaSpec struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // 1~4 deprecated. Use common instead. // Number of replicas + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. Replicas int32 `protobuf:"varint,1,opt,name=replicas,proto3" json:"replicas,omitempty"` // Image used for the replica group + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. Image string `protobuf:"bytes,2,opt,name=image,proto3" json:"image,omitempty"` // Resources required for the replica group + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. Resources *core.Resources `protobuf:"bytes,3,opt,name=resources,proto3" json:"resources,omitempty"` - // RestartPolicy Determines whether pods will be restarted when they exit - RestartPolicy RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.kubeflow.RestartPolicy" json:"restart_policy,omitempty"` + // Restart policy determines whether pods will be restarted when they exit + // + // Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. + RestartPolicy plugins.RestartPolicy `protobuf:"varint,4,opt,name=restart_policy,json=restartPolicy,proto3,enum=flyteidl.plugins.RestartPolicy" json:"restart_policy,omitempty"` + // The common replica spec + Common *plugins.CommonReplicaSpec `protobuf:"bytes,5,opt,name=common,proto3" json:"common,omitempty"` } func (x *DistributedTensorflowTrainingReplicaSpec) Reset() { @@ -155,6 +167,7 @@ func (*DistributedTensorflowTrainingReplicaSpec) Descriptor() ([]byte, []int) { return file_flyteidl_plugins_kubeflow_tensorflow_proto_rawDescGZIP(), []int{1} } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. func (x *DistributedTensorflowTrainingReplicaSpec) GetReplicas() int32 { if x != nil { return x.Replicas @@ -162,6 +175,7 @@ func (x *DistributedTensorflowTrainingReplicaSpec) GetReplicas() int32 { return 0 } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. func (x *DistributedTensorflowTrainingReplicaSpec) GetImage() string { if x != nil { return x.Image @@ -169,6 +183,7 @@ func (x *DistributedTensorflowTrainingReplicaSpec) GetImage() string { return "" } +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. func (x *DistributedTensorflowTrainingReplicaSpec) GetResources() *core.Resources { if x != nil { return x.Resources @@ -176,11 +191,19 @@ func (x *DistributedTensorflowTrainingReplicaSpec) GetResources() *core.Resource return nil } -func (x *DistributedTensorflowTrainingReplicaSpec) GetRestartPolicy() RestartPolicy { +// Deprecated: Marked as deprecated in flyteidl/plugins/kubeflow/tensorflow.proto. +func (x *DistributedTensorflowTrainingReplicaSpec) GetRestartPolicy() plugins.RestartPolicy { if x != nil { return x.RestartPolicy } - return RestartPolicy_RESTART_POLICY_NEVER + return plugins.RestartPolicy(0) +} + +func (x *DistributedTensorflowTrainingReplicaSpec) GetCommon() *plugins.CommonReplicaSpec { + if x != nil { + return x.Common + } + return nil } var File_flyteidl_plugins_kubeflow_tensorflow_proto protoreflect.FileDescriptor @@ -228,38 +251,42 @@ var file_flyteidl_plugins_kubeflow_tensorflow_proto_rawDesc = []byte{ 0x73, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x64, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x53, 0x70, 0x65, 0x63, 0x52, 0x11, 0x65, 0x76, 0x61, 0x6c, 0x75, 0x61, 0x74, 0x6f, - 0x72, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x22, 0xe5, 0x01, 0x0a, 0x28, 0x44, 0x69, + 0x72, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x22, 0xa9, 0x02, 0x0a, 0x28, 0x44, 0x69, 0x73, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x64, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x66, 0x6c, 0x6f, 0x77, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x69, - 0x63, 0x61, 0x53, 0x70, 0x65, 0x63, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, - 0x12, 0x4f, 0x0a, 0x0e, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, - 0x63, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x28, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, - 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, - 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, - 0x63, 0x79, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x42, 0xf5, 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, - 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, 0x75, 0x62, 0x65, 0x66, - 0x6c, 0x6f, 0x77, 0x42, 0x0f, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x66, 0x6c, 0x6f, 0x77, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, - 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, - 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, - 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, - 0x75, 0x67, 0x69, 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, 0x46, 0x6c, - 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x4b, - 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, - 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, - 0x6c, 0x6f, 0x77, 0xe2, 0x02, 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, - 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x5c, - 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, 0x46, 0x6c, - 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x3a, - 0x3a, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x63, 0x61, 0x53, 0x70, 0x65, 0x63, 0x12, 0x1e, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, + 0x61, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x42, 0x02, 0x18, 0x01, 0x52, 0x08, 0x72, 0x65, + 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x18, 0x0a, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x05, 0x69, 0x6d, 0x61, 0x67, 0x65, + 0x12, 0x3a, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, + 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x42, 0x02, 0x18, + 0x01, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x12, 0x4a, 0x0a, 0x0e, + 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, + 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x52, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x50, + 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x3b, 0x0a, 0x06, 0x63, 0x6f, 0x6d, 0x6d, + 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, + 0x6f, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x53, 0x70, 0x65, 0x63, 0x52, 0x06, 0x63, + 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x42, 0xfe, 0x01, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6b, + 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x42, 0x0f, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x66, + 0x6c, 0x6f, 0x77, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, + 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, + 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, + 0x6f, 0x77, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x4b, 0xaa, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x4b, 0x75, 0x62, 0x65, + 0x66, 0x6c, 0x6f, 0x77, 0xca, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, + 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, + 0xe2, 0x02, 0x25, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, + 0x69, 0x6e, 0x73, 0x5c, 0x4b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x5c, 0x47, 0x50, 0x42, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1b, 0x46, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x3a, 0x3a, 0x4b, 0x75, + 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -278,9 +305,10 @@ var file_flyteidl_plugins_kubeflow_tensorflow_proto_msgTypes = make([]protoimpl. var file_flyteidl_plugins_kubeflow_tensorflow_proto_goTypes = []interface{}{ (*DistributedTensorflowTrainingTask)(nil), // 0: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask (*DistributedTensorflowTrainingReplicaSpec)(nil), // 1: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec - (*RunPolicy)(nil), // 2: flyteidl.plugins.kubeflow.RunPolicy - (*core.Resources)(nil), // 3: flyteidl.core.Resources - (RestartPolicy)(0), // 4: flyteidl.plugins.kubeflow.RestartPolicy + (*RunPolicy)(nil), // 2: flyteidl.plugins.kubeflow.RunPolicy + (*core.Resources)(nil), // 3: flyteidl.core.Resources + (plugins.RestartPolicy)(0), // 4: flyteidl.plugins.RestartPolicy + (*plugins.CommonReplicaSpec)(nil), // 5: flyteidl.plugins.CommonReplicaSpec } var file_flyteidl_plugins_kubeflow_tensorflow_proto_depIdxs = []int32{ 1, // 0: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.worker_replicas:type_name -> flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec @@ -289,12 +317,13 @@ var file_flyteidl_plugins_kubeflow_tensorflow_proto_depIdxs = []int32{ 2, // 3: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.run_policy:type_name -> flyteidl.plugins.kubeflow.RunPolicy 1, // 4: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas:type_name -> flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec 3, // 5: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec.resources:type_name -> flyteidl.core.Resources - 4, // 6: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec.restart_policy:type_name -> flyteidl.plugins.kubeflow.RestartPolicy - 7, // [7:7] is the sub-list for method output_type - 7, // [7:7] is the sub-list for method input_type - 7, // [7:7] is the sub-list for extension type_name - 7, // [7:7] is the sub-list for extension extendee - 0, // [0:7] is the sub-list for field type_name + 4, // 6: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec.restart_policy:type_name -> flyteidl.plugins.RestartPolicy + 5, // 7: flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec.common:type_name -> flyteidl.plugins.CommonReplicaSpec + 8, // [8:8] is the sub-list for method output_type + 8, // [8:8] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name } func init() { file_flyteidl_plugins_kubeflow_tensorflow_proto_init() } diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/plugins/common.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/plugins/common.swagger.json new file mode 100644 index 0000000000..4e9dc31dbe --- /dev/null +++ b/flyteidl/gen/pb-go/gateway/flyteidl/plugins/common.swagger.json @@ -0,0 +1,46 @@ +{ + "swagger": "2.0", + "info": { + "title": "flyteidl/plugins/common.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": {}, + "definitions": { + "googlerpcStatus": { + "type": "object", + "properties": { + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/definitions/protobufAny" + } + } + } + }, + "protobufAny": { + "type": "object", + "properties": { + "@type": { + "type": "string", + "description": "A URL/resource name that uniquely identifies the type of the serialized\nprotocol buffer message. This string must contain at least\none \"/\" character. The last segment of the URL's path must represent\nthe fully qualified name of the type (as in\n`path/google.protobuf.Duration`). The name should be in a canonical form\n(e.g., leading \".\" is not accepted).\n\nIn practice, teams usually precompile into the binary all types that they\nexpect it to use in the context of Any. However, for URLs which use the\nscheme `http`, `https`, or no scheme, one can optionally set up a type\nserver that maps type URLs to message definitions as follows:\n\n* If no scheme is provided, `https` is assumed.\n* An HTTP GET on the URL must yield a [google.protobuf.Type][]\n value in binary format, or produce an error.\n* Applications are allowed to cache lookup results based on the\n URL, or have them precompiled into a binary to avoid any\n lookup. Therefore, binary compatibility needs to be preserved\n on changes to types. (Use versioned type names to manage\n breaking changes.)\n\nNote: this functionality is not currently available in the official\nprotobuf release, and it is not used for type URLs beginning with\ntype.googleapis.com. As of May 2023, there are no widely used type server\nimplementations and no plans to implement one.\n\nSchemes other than `http`, `https` (or the empty scheme) might be\nused with implementation specific semantics." + } + }, + "additionalProperties": {}, + "description": "`Any` contains an arbitrary serialized protocol buffer message along with a\nURL that describes the type of the serialized message.\n\nProtobuf library provides support to pack/unpack Any values in the form\nof utility functions or additional generated methods of the Any type.\n\nExample 1: Pack and unpack a message in C++.\n\n Foo foo = ...;\n Any any;\n any.PackFrom(foo);\n ...\n if (any.UnpackTo(\u0026foo)) {\n ...\n }\n\nExample 2: Pack and unpack a message in Java.\n\n Foo foo = ...;\n Any any = Any.pack(foo);\n ...\n if (any.is(Foo.class)) {\n foo = any.unpack(Foo.class);\n }\n // or ...\n if (any.isSameTypeAs(Foo.getDefaultInstance())) {\n foo = any.unpack(Foo.getDefaultInstance());\n }\n\n Example 3: Pack and unpack a message in Python.\n\n foo = Foo(...)\n any = Any()\n any.Pack(foo)\n ...\n if any.Is(Foo.DESCRIPTOR):\n any.Unpack(foo)\n ...\n\n Example 4: Pack and unpack a message in Go\n\n foo := \u0026pb.Foo{...}\n any, err := anypb.New(foo)\n if err != nil {\n ...\n }\n ...\n foo := \u0026pb.Foo{}\n if err := any.UnmarshalTo(foo); err != nil {\n ...\n }\n\nThe pack methods provided by protobuf library will by default use\n'type.googleapis.com/full.type.name' as the type URL and the unpack\nmethods only use the fully qualified type name after the last '/'\nin the type URL, for example \"foo.bar.com/x/y.z\" will yield type\nname \"y.z\".\n\nJSON\n====\nThe JSON representation of an `Any` value uses the regular\nrepresentation of the deserialized, embedded message, with an\nadditional field `@type` which contains the type URL. Example:\n\n package google.profile;\n message Person {\n string first_name = 1;\n string last_name = 2;\n }\n\n {\n \"@type\": \"type.googleapis.com/google.profile.Person\",\n \"firstName\": \u003cstring\u003e,\n \"lastName\": \u003cstring\u003e\n }\n\nIf the embedded message type is well-known and has a custom JSON\nrepresentation, that representation will be embedded adding a field\n`value` which holds the custom JSON in addition to the `@type`\nfield. Example (for message [google.protobuf.Duration][]):\n\n {\n \"@type\": \"type.googleapis.com/google.protobuf.Duration\",\n \"value\": \"1.212s\"\n }" + } + } +} diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.py new file mode 100644 index 0000000000..5abdb07e50 --- /dev/null +++ b/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flyteidl/plugins/common.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lyteidl/plugins/common.proto\x12\x10\x66lyteidl.plugins\x1a\x19\x66lyteidl/core/tasks.proto\"\xc5\x01\n\x11\x43ommonReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12\x46\n\x0erestart_policy\x18\x04 \x01(\x0e\x32\x1f.flyteidl.plugins.RestartPolicyR\rrestartPolicy*c\n\rRestartPolicy\x12\x18\n\x14RESTART_POLICY_NEVER\x10\x00\x12\x1d\n\x19RESTART_POLICY_ON_FAILURE\x10\x01\x12\x19\n\x15RESTART_POLICY_ALWAYS\x10\x02\x42\xc3\x01\n\x14\x63om.flyteidl.pluginsB\x0b\x43ommonProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flyteidl.plugins.common_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\024com.flyteidl.pluginsB\013CommonProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPX\252\002\020Flyteidl.Plugins\312\002\020Flyteidl\\Plugins\342\002\034Flyteidl\\Plugins\\GPBMetadata\352\002\021Flyteidl::Plugins' + _globals['_RESTARTPOLICY']._serialized_start=278 + _globals['_RESTARTPOLICY']._serialized_end=377 + _globals['_COMMONREPLICASPEC']._serialized_start=79 + _globals['_COMMONREPLICASPEC']._serialized_end=276 +# @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.pyi new file mode 100644 index 0000000000..24115b7166 --- /dev/null +++ b/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2.pyi @@ -0,0 +1,28 @@ +from flyteidl.core import tasks_pb2 as _tasks_pb2 +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class RestartPolicy(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + RESTART_POLICY_NEVER: _ClassVar[RestartPolicy] + RESTART_POLICY_ON_FAILURE: _ClassVar[RestartPolicy] + RESTART_POLICY_ALWAYS: _ClassVar[RestartPolicy] +RESTART_POLICY_NEVER: RestartPolicy +RESTART_POLICY_ON_FAILURE: RestartPolicy +RESTART_POLICY_ALWAYS: RestartPolicy + +class CommonReplicaSpec(_message.Message): + __slots__ = ["replicas", "image", "resources", "restart_policy"] + REPLICAS_FIELD_NUMBER: _ClassVar[int] + IMAGE_FIELD_NUMBER: _ClassVar[int] + RESOURCES_FIELD_NUMBER: _ClassVar[int] + RESTART_POLICY_FIELD_NUMBER: _ClassVar[int] + replicas: int + image: str + resources: _tasks_pb2.Resources + restart_policy: RestartPolicy + def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[RestartPolicy, str]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2_grpc.py b/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2_grpc.py new file mode 100644 index 0000000000..2daafffebf --- /dev/null +++ b/flyteidl/gen/pb_python/flyteidl/plugins/common_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.py index d09a4aba41..3cba1b71ec 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.py @@ -11,9 +11,11 @@ _sym_db = _symbol_database.Default() +from flyteidl.plugins import common_pb2 as flyteidl_dot_plugins_dot_common__pb2 +from flyteidl.plugins.common_pb2 import * -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&flyteidl/plugins/kubeflow/common.proto\x12\x19\x66lyteidl.plugins.kubeflow\"\xfa\x01\n\tRunPolicy\x12S\n\x10\x63lean_pod_policy\x18\x01 \x01(\x0e\x32).flyteidl.plugins.kubeflow.CleanPodPolicyR\x0e\x63leanPodPolicy\x12;\n\x1attl_seconds_after_finished\x18\x02 \x01(\x05R\x17ttlSecondsAfterFinished\x12\x36\n\x17\x61\x63tive_deadline_seconds\x18\x03 \x01(\x05R\x15\x61\x63tiveDeadlineSeconds\x12#\n\rbackoff_limit\x18\x04 \x01(\x05R\x0c\x62\x61\x63koffLimit*c\n\rRestartPolicy\x12\x18\n\x14RESTART_POLICY_NEVER\x10\x00\x12\x1d\n\x19RESTART_POLICY_ON_FAILURE\x10\x01\x12\x19\n\x15RESTART_POLICY_ALWAYS\x10\x02*`\n\x0e\x43leanPodPolicy\x12\x18\n\x14\x43LEANPOD_POLICY_NONE\x10\x00\x12\x1b\n\x17\x43LEANPOD_POLICY_RUNNING\x10\x01\x12\x17\n\x13\x43LEANPOD_POLICY_ALL\x10\x02\x42\xf1\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0b\x43ommonProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&flyteidl/plugins/kubeflow/common.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x1d\x66lyteidl/plugins/common.proto\"\xfa\x01\n\tRunPolicy\x12S\n\x10\x63lean_pod_policy\x18\x01 \x01(\x0e\x32).flyteidl.plugins.kubeflow.CleanPodPolicyR\x0e\x63leanPodPolicy\x12;\n\x1attl_seconds_after_finished\x18\x02 \x01(\x05R\x17ttlSecondsAfterFinished\x12\x36\n\x17\x61\x63tive_deadline_seconds\x18\x03 \x01(\x05R\x15\x61\x63tiveDeadlineSeconds\x12#\n\rbackoff_limit\x18\x04 \x01(\x05R\x0c\x62\x61\x63koffLimit*`\n\x0e\x43leanPodPolicy\x12\x18\n\x14\x43LEANPOD_POLICY_NONE\x10\x00\x12\x1b\n\x17\x43LEANPOD_POLICY_RUNNING\x10\x01\x12\x17\n\x13\x43LEANPOD_POLICY_ALL\x10\x02\x42\xfa\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0b\x43ommonProtoP\x01ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::KubeflowP\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -21,11 +23,9 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\013CommonProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' - _globals['_RESTARTPOLICY']._serialized_start=322 - _globals['_RESTARTPOLICY']._serialized_end=421 - _globals['_CLEANPODPOLICY']._serialized_start=423 - _globals['_CLEANPODPOLICY']._serialized_end=519 - _globals['_RUNPOLICY']._serialized_start=70 - _globals['_RUNPOLICY']._serialized_end=320 + DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\013CommonProtoP\001ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + _globals['_CLEANPODPOLICY']._serialized_start=353 + _globals['_CLEANPODPOLICY']._serialized_end=449 + _globals['_RUNPOLICY']._serialized_start=101 + _globals['_RUNPOLICY']._serialized_end=351 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.pyi index 916c3344b2..484c1472b7 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/common_pb2.pyi @@ -1,24 +1,21 @@ +from flyteidl.plugins import common_pb2 as _common_pb2 from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union +from flyteidl.plugins.common_pb2 import CommonReplicaSpec +from flyteidl.plugins.common_pb2 import RestartPolicy DESCRIPTOR: _descriptor.FileDescriptor - -class RestartPolicy(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - RESTART_POLICY_NEVER: _ClassVar[RestartPolicy] - RESTART_POLICY_ON_FAILURE: _ClassVar[RestartPolicy] - RESTART_POLICY_ALWAYS: _ClassVar[RestartPolicy] +RESTART_POLICY_NEVER: _common_pb2.RestartPolicy +RESTART_POLICY_ON_FAILURE: _common_pb2.RestartPolicy +RESTART_POLICY_ALWAYS: _common_pb2.RestartPolicy class CleanPodPolicy(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = [] CLEANPOD_POLICY_NONE: _ClassVar[CleanPodPolicy] CLEANPOD_POLICY_RUNNING: _ClassVar[CleanPodPolicy] CLEANPOD_POLICY_ALL: _ClassVar[CleanPodPolicy] -RESTART_POLICY_NEVER: RestartPolicy -RESTART_POLICY_ON_FAILURE: RestartPolicy -RESTART_POLICY_ALWAYS: RestartPolicy CLEANPOD_POLICY_NONE: CleanPodPolicy CLEANPOD_POLICY_RUNNING: CleanPodPolicy CLEANPOD_POLICY_ALL: CleanPodPolicy diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.py index 4ed3f22be7..49078c7372 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.py @@ -13,9 +13,13 @@ from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 from flyteidl.plugins.kubeflow import common_pb2 as flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2 +try: + flyteidl_dot_plugins_dot_common__pb2 = flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2.flyteidl_dot_plugins_dot_common__pb2 +except AttributeError: + flyteidl_dot_plugins_dot_common__pb2 = flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2.flyteidl.plugins.common_pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#flyteidl/plugins/kubeflow/mpi.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\xc9\x02\n\x1a\x44istributedMPITrainingTask\x12\x65\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32<.flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpecR\x0eworkerReplicas\x12i\n\x11launcher_replicas\x18\x02 \x01(\x0b\x32<.flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpecR\x10launcherReplicas\x12\x43\n\nrun_policy\x18\x03 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12\x14\n\x05slots\x18\x04 \x01(\x05R\x05slots\"\xf8\x01\n!DistributedMPITrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicy\x12\x18\n\x07\x63ommand\x18\x05 \x03(\tR\x07\x63ommandB\xee\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x08MpiProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#flyteidl/plugins/kubeflow/mpi.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\xc9\x02\n\x1a\x44istributedMPITrainingTask\x12\x65\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32<.flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpecR\x0eworkerReplicas\x12i\n\x11launcher_replicas\x18\x02 \x01(\x0b\x32<.flyteidl.plugins.kubeflow.DistributedMPITrainingReplicaSpecR\x10launcherReplicas\x12\x43\n\nrun_policy\x18\x03 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12\x14\n\x05slots\x18\x04 \x01(\x05R\x05slots\"\xbc\x02\n!DistributedMPITrainingReplicaSpec\x12\x1e\n\x08replicas\x18\x01 \x01(\x05\x42\x02\x18\x01R\x08replicas\x12\x18\n\x05image\x18\x02 \x01(\tB\x02\x18\x01R\x05image\x12:\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesB\x02\x18\x01R\tresources\x12J\n\x0erestart_policy\x18\x04 \x01(\x0e\x32\x1f.flyteidl.plugins.RestartPolicyB\x02\x18\x01R\rrestartPolicy\x12\x18\n\x07\x63ommand\x18\x05 \x03(\tR\x07\x63ommand\x12;\n\x06\x63ommon\x18\x06 \x01(\x0b\x32#.flyteidl.plugins.CommonReplicaSpecR\x06\x63ommonB\xf7\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x08MpiProtoP\x01ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,9 +27,17 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\010MpiProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\010MpiProtoP\001ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['replicas']._options = None + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['replicas']._serialized_options = b'\030\001' + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['image']._options = None + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['image']._serialized_options = b'\030\001' + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['resources']._options = None + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['resources']._serialized_options = b'\030\001' + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['restart_policy']._options = None + _DISTRIBUTEDMPITRAININGREPLICASPEC.fields_by_name['restart_policy']._serialized_options = b'\030\001' _globals['_DISTRIBUTEDMPITRAININGTASK']._serialized_start=134 _globals['_DISTRIBUTEDMPITRAININGTASK']._serialized_end=463 _globals['_DISTRIBUTEDMPITRAININGREPLICASPEC']._serialized_start=466 - _globals['_DISTRIBUTEDMPITRAININGREPLICASPEC']._serialized_end=714 + _globals['_DISTRIBUTEDMPITRAININGREPLICASPEC']._serialized_end=782 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.pyi index 6258542993..03fb4a4924 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/mpi_pb2.pyi @@ -1,5 +1,6 @@ from flyteidl.core import tasks_pb2 as _tasks_pb2 from flyteidl.plugins.kubeflow import common_pb2 as _common_pb2 +from flyteidl.plugins import common_pb2 as _common_pb2_1 from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message @@ -20,15 +21,17 @@ class DistributedMPITrainingTask(_message.Message): def __init__(self, worker_replicas: _Optional[_Union[DistributedMPITrainingReplicaSpec, _Mapping]] = ..., launcher_replicas: _Optional[_Union[DistributedMPITrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ..., slots: _Optional[int] = ...) -> None: ... class DistributedMPITrainingReplicaSpec(_message.Message): - __slots__ = ["replicas", "image", "resources", "restart_policy", "command"] + __slots__ = ["replicas", "image", "resources", "restart_policy", "command", "common"] REPLICAS_FIELD_NUMBER: _ClassVar[int] IMAGE_FIELD_NUMBER: _ClassVar[int] RESOURCES_FIELD_NUMBER: _ClassVar[int] RESTART_POLICY_FIELD_NUMBER: _ClassVar[int] COMMAND_FIELD_NUMBER: _ClassVar[int] + COMMON_FIELD_NUMBER: _ClassVar[int] replicas: int image: str resources: _tasks_pb2.Resources - restart_policy: _common_pb2.RestartPolicy + restart_policy: _common_pb2_1.RestartPolicy command: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[_common_pb2.RestartPolicy, str]] = ..., command: _Optional[_Iterable[str]] = ...) -> None: ... + common: _common_pb2_1.CommonReplicaSpec + def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[_common_pb2_1.RestartPolicy, str]] = ..., command: _Optional[_Iterable[str]] = ..., common: _Optional[_Union[_common_pb2_1.CommonReplicaSpec, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.py index 46f574228a..06bf2b934b 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.py @@ -13,9 +13,13 @@ from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 from flyteidl.plugins.kubeflow import common_pb2 as flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2 +try: + flyteidl_dot_plugins_dot_common__pb2 = flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2.flyteidl_dot_plugins_dot_common__pb2 +except AttributeError: + flyteidl_dot_plugins_dot_common__pb2 = flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2.flyteidl.plugins.common_pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'flyteidl/plugins/kubeflow/pytorch.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\xc1\x01\n\rElasticConfig\x12!\n\x0crdzv_backend\x18\x01 \x01(\tR\x0brdzvBackend\x12!\n\x0cmin_replicas\x18\x02 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x03 \x01(\x05R\x0bmaxReplicas\x12$\n\x0enproc_per_node\x18\x04 \x01(\x05R\x0cnprocPerNode\x12!\n\x0cmax_restarts\x18\x05 \x01(\x05R\x0bmaxRestarts\"\x8c\x03\n\x1e\x44istributedPyTorchTrainingTask\x12i\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32@.flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpecR\x0eworkerReplicas\x12i\n\x0fmaster_replicas\x18\x02 \x01(\x0b\x32@.flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpecR\x0emasterReplicas\x12\x43\n\nrun_policy\x18\x03 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12O\n\x0e\x65lastic_config\x18\x04 \x01(\x0b\x32(.flyteidl.plugins.kubeflow.ElasticConfigR\relasticConfig\"\xe2\x01\n%DistributedPyTorchTrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicyB\xf2\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0cPytorchProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'flyteidl/plugins/kubeflow/pytorch.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\xc1\x01\n\rElasticConfig\x12!\n\x0crdzv_backend\x18\x01 \x01(\tR\x0brdzvBackend\x12!\n\x0cmin_replicas\x18\x02 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x03 \x01(\x05R\x0bmaxReplicas\x12$\n\x0enproc_per_node\x18\x04 \x01(\x05R\x0cnprocPerNode\x12!\n\x0cmax_restarts\x18\x05 \x01(\x05R\x0bmaxRestarts\"\x8c\x03\n\x1e\x44istributedPyTorchTrainingTask\x12i\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32@.flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpecR\x0eworkerReplicas\x12i\n\x0fmaster_replicas\x18\x02 \x01(\x0b\x32@.flyteidl.plugins.kubeflow.DistributedPyTorchTrainingReplicaSpecR\x0emasterReplicas\x12\x43\n\nrun_policy\x18\x03 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12O\n\x0e\x65lastic_config\x18\x04 \x01(\x0b\x32(.flyteidl.plugins.kubeflow.ElasticConfigR\relasticConfig\"\xa6\x02\n%DistributedPyTorchTrainingReplicaSpec\x12\x1e\n\x08replicas\x18\x01 \x01(\x05\x42\x02\x18\x01R\x08replicas\x12\x18\n\x05image\x18\x02 \x01(\tB\x02\x18\x01R\x05image\x12:\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesB\x02\x18\x01R\tresources\x12J\n\x0erestart_policy\x18\x04 \x01(\x0e\x32\x1f.flyteidl.plugins.RestartPolicyB\x02\x18\x01R\rrestartPolicy\x12;\n\x06\x63ommon\x18\x05 \x01(\x0b\x32#.flyteidl.plugins.CommonReplicaSpecR\x06\x63ommonB\xfb\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0cPytorchProtoP\x01ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,11 +27,19 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\014PytorchProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\014PytorchProtoP\001ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['replicas']._options = None + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['replicas']._serialized_options = b'\030\001' + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['image']._options = None + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['image']._serialized_options = b'\030\001' + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['resources']._options = None + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['resources']._serialized_options = b'\030\001' + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['restart_policy']._options = None + _DISTRIBUTEDPYTORCHTRAININGREPLICASPEC.fields_by_name['restart_policy']._serialized_options = b'\030\001' _globals['_ELASTICCONFIG']._serialized_start=138 _globals['_ELASTICCONFIG']._serialized_end=331 _globals['_DISTRIBUTEDPYTORCHTRAININGTASK']._serialized_start=334 _globals['_DISTRIBUTEDPYTORCHTRAININGTASK']._serialized_end=730 _globals['_DISTRIBUTEDPYTORCHTRAININGREPLICASPEC']._serialized_start=733 - _globals['_DISTRIBUTEDPYTORCHTRAININGREPLICASPEC']._serialized_end=959 + _globals['_DISTRIBUTEDPYTORCHTRAININGREPLICASPEC']._serialized_end=1027 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.pyi index ee6599ad82..3c06df0964 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/pytorch_pb2.pyi @@ -1,5 +1,6 @@ from flyteidl.core import tasks_pb2 as _tasks_pb2 from flyteidl.plugins.kubeflow import common_pb2 as _common_pb2 +from flyteidl.plugins import common_pb2 as _common_pb2_1 from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union @@ -33,13 +34,15 @@ class DistributedPyTorchTrainingTask(_message.Message): def __init__(self, worker_replicas: _Optional[_Union[DistributedPyTorchTrainingReplicaSpec, _Mapping]] = ..., master_replicas: _Optional[_Union[DistributedPyTorchTrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ..., elastic_config: _Optional[_Union[ElasticConfig, _Mapping]] = ...) -> None: ... class DistributedPyTorchTrainingReplicaSpec(_message.Message): - __slots__ = ["replicas", "image", "resources", "restart_policy"] + __slots__ = ["replicas", "image", "resources", "restart_policy", "common"] REPLICAS_FIELD_NUMBER: _ClassVar[int] IMAGE_FIELD_NUMBER: _ClassVar[int] RESOURCES_FIELD_NUMBER: _ClassVar[int] RESTART_POLICY_FIELD_NUMBER: _ClassVar[int] + COMMON_FIELD_NUMBER: _ClassVar[int] replicas: int image: str resources: _tasks_pb2.Resources - restart_policy: _common_pb2.RestartPolicy - def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[_common_pb2.RestartPolicy, str]] = ...) -> None: ... + restart_policy: _common_pb2_1.RestartPolicy + common: _common_pb2_1.CommonReplicaSpec + def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[_common_pb2_1.RestartPolicy, str]] = ..., common: _Optional[_Union[_common_pb2_1.CommonReplicaSpec, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py index f0c086f9e7..57768cd91d 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py @@ -13,9 +13,13 @@ from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 from flyteidl.plugins.kubeflow import common_pb2 as flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2 +try: + flyteidl_dot_plugins_dot_common__pb2 = flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2.flyteidl_dot_plugins_dot_common__pb2 +except AttributeError: + flyteidl_dot_plugins_dot_common__pb2 = flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2.flyteidl.plugins.common_pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*flyteidl/plugins/kubeflow/tensorflow.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\x9c\x04\n!DistributedTensorflowTrainingTask\x12l\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x0eworkerReplicas\x12\x64\n\x0bps_replicas\x18\x02 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\npsReplicas\x12j\n\x0e\x63hief_replicas\x18\x03 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\rchiefReplicas\x12\x43\n\nrun_policy\x18\x04 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12r\n\x12\x65valuator_replicas\x18\x05 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x11\x65valuatorReplicas\"\xe5\x01\n(DistributedTensorflowTrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicyB\xf5\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*flyteidl/plugins/kubeflow/tensorflow.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\x9c\x04\n!DistributedTensorflowTrainingTask\x12l\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x0eworkerReplicas\x12\x64\n\x0bps_replicas\x18\x02 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\npsReplicas\x12j\n\x0e\x63hief_replicas\x18\x03 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\rchiefReplicas\x12\x43\n\nrun_policy\x18\x04 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12r\n\x12\x65valuator_replicas\x18\x05 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x11\x65valuatorReplicas\"\xa9\x02\n(DistributedTensorflowTrainingReplicaSpec\x12\x1e\n\x08replicas\x18\x01 \x01(\x05\x42\x02\x18\x01R\x08replicas\x12\x18\n\x05image\x18\x02 \x01(\tB\x02\x18\x01R\x05image\x12:\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesB\x02\x18\x01R\tresources\x12J\n\x0erestart_policy\x18\x04 \x01(\x0e\x32\x1f.flyteidl.plugins.RestartPolicyB\x02\x18\x01R\rrestartPolicy\x12;\n\x06\x63ommon\x18\x05 \x01(\x0b\x32#.flyteidl.plugins.CommonReplicaSpecR\x06\x63ommonB\xfe\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0fTensorflowProtoP\x01ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,9 +27,17 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\017TensorflowProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\017TensorflowProtoP\001ZFgithub.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['replicas']._options = None + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['replicas']._serialized_options = b'\030\001' + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['image']._options = None + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['image']._serialized_options = b'\030\001' + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['resources']._options = None + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['resources']._serialized_options = b'\030\001' + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['restart_policy']._options = None + _DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC.fields_by_name['restart_policy']._serialized_options = b'\030\001' _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_start=141 _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=681 _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_start=684 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_end=913 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_end=981 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi index 4a999f70e8..44e492b624 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi @@ -1,5 +1,6 @@ from flyteidl.core import tasks_pb2 as _tasks_pb2 from flyteidl.plugins.kubeflow import common_pb2 as _common_pb2 +from flyteidl.plugins import common_pb2 as _common_pb2_1 from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union @@ -21,13 +22,15 @@ class DistributedTensorflowTrainingTask(_message.Message): def __init__(self, worker_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., ps_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., chief_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ..., evaluator_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ...) -> None: ... class DistributedTensorflowTrainingReplicaSpec(_message.Message): - __slots__ = ["replicas", "image", "resources", "restart_policy"] + __slots__ = ["replicas", "image", "resources", "restart_policy", "common"] REPLICAS_FIELD_NUMBER: _ClassVar[int] IMAGE_FIELD_NUMBER: _ClassVar[int] RESOURCES_FIELD_NUMBER: _ClassVar[int] RESTART_POLICY_FIELD_NUMBER: _ClassVar[int] + COMMON_FIELD_NUMBER: _ClassVar[int] replicas: int image: str resources: _tasks_pb2.Resources - restart_policy: _common_pb2.RestartPolicy - def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[_common_pb2.RestartPolicy, str]] = ...) -> None: ... + restart_policy: _common_pb2_1.RestartPolicy + common: _common_pb2_1.CommonReplicaSpec + def __init__(self, replicas: _Optional[int] = ..., image: _Optional[str] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., restart_policy: _Optional[_Union[_common_pb2_1.RestartPolicy, str]] = ..., common: _Optional[_Union[_common_pb2_1.CommonReplicaSpec, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs index 96d46653da..9eebb7bc9e 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs @@ -18,35 +18,6 @@ pub struct RunPolicy { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] -pub enum RestartPolicy { - Never = 0, - OnFailure = 1, - Always = 2, -} -impl RestartPolicy { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - RestartPolicy::Never => "RESTART_POLICY_NEVER", - RestartPolicy::OnFailure => "RESTART_POLICY_ON_FAILURE", - RestartPolicy::Always => "RESTART_POLICY_ALWAYS", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "RESTART_POLICY_NEVER" => Some(Self::Never), - "RESTART_POLICY_ON_FAILURE" => Some(Self::OnFailure), - "RESTART_POLICY_ALWAYS" => Some(Self::Always), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] pub enum CleanPodPolicy { CleanpodPolicyNone = 0, CleanpodPolicyRunning = 1, @@ -97,21 +68,29 @@ pub struct DistributedMpiTrainingTask { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedMpiTrainingReplicaSpec { + /// 1~4 deprecated. Use common instead. /// Number of replicas + #[deprecated] #[prost(int32, tag="1")] pub replicas: i32, /// Image used for the replica group + #[deprecated] #[prost(string, tag="2")] pub image: ::prost::alloc::string::String, /// Resources required for the replica group + #[deprecated] #[prost(message, optional, tag="3")] pub resources: ::core::option::Option, /// Restart policy determines whether pods will be restarted when they exit - #[prost(enumeration="RestartPolicy", tag="4")] + #[deprecated] + #[prost(enumeration="super::RestartPolicy", tag="4")] pub restart_policy: i32, /// MPI sometimes requires different command set for different replica groups #[prost(string, repeated, tag="5")] pub command: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// The common replica spec + #[prost(message, optional, tag="6")] + pub common: ::core::option::Option, } /// Custom proto for torch elastic config for distributed training using /// @@ -151,18 +130,26 @@ pub struct DistributedPyTorchTrainingTask { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedPyTorchTrainingReplicaSpec { + /// 1~4 deprecated. Use common instead. /// Number of replicas + #[deprecated] #[prost(int32, tag="1")] pub replicas: i32, /// Image used for the replica group + #[deprecated] #[prost(string, tag="2")] pub image: ::prost::alloc::string::String, /// Resources required for the replica group + #[deprecated] #[prost(message, optional, tag="3")] pub resources: ::core::option::Option, - /// RestartPolicy determines whether pods will be restarted when they exit - #[prost(enumeration="RestartPolicy", tag="4")] + /// Restart policy determines whether pods will be restarted when they exit + #[deprecated] + #[prost(enumeration="super::RestartPolicy", tag="4")] pub restart_policy: i32, + /// The common replica spec + #[prost(message, optional, tag="5")] + pub common: ::core::option::Option, } /// Proto for plugin that enables distributed training using #[allow(clippy::derive_partial_eq_without_eq)] @@ -189,17 +176,25 @@ pub struct DistributedTensorflowTrainingTask { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedTensorflowTrainingReplicaSpec { + /// 1~4 deprecated. Use common instead. /// Number of replicas + #[deprecated] #[prost(int32, tag="1")] pub replicas: i32, /// Image used for the replica group + #[deprecated] #[prost(string, tag="2")] pub image: ::prost::alloc::string::String, /// Resources required for the replica group + #[deprecated] #[prost(message, optional, tag="3")] pub resources: ::core::option::Option, - /// RestartPolicy Determines whether pods will be restarted when they exit - #[prost(enumeration="RestartPolicy", tag="4")] + /// Restart policy determines whether pods will be restarted when they exit + #[deprecated] + #[prost(enumeration="super::RestartPolicy", tag="4")] pub restart_policy: i32, + /// The common replica spec + #[prost(message, optional, tag="5")] + pub common: ::core::option::Option, } // @@protoc_insertion_point(module) diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index 0903d1c71a..28c2f77e97 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -33,6 +33,51 @@ pub mod array_job { MinSuccessRatio(f32), } } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommonReplicaSpec { + /// Number of replicas + #[prost(int32, tag="1")] + pub replicas: i32, + /// Image used for the replica group + #[prost(string, tag="2")] + pub image: ::prost::alloc::string::String, + /// Resources required for the replica group + #[prost(message, optional, tag="3")] + pub resources: ::core::option::Option, + /// RestartPolicy determines whether pods will be restarted when they exit + #[prost(enumeration="RestartPolicy", tag="4")] + pub restart_policy: i32, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum RestartPolicy { + Never = 0, + OnFailure = 1, + Always = 2, +} +impl RestartPolicy { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + RestartPolicy::Never => "RESTART_POLICY_NEVER", + RestartPolicy::OnFailure => "RESTART_POLICY_ON_FAILURE", + RestartPolicy::Always => "RESTART_POLICY_ALWAYS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "RESTART_POLICY_NEVER" => Some(Self::Never), + "RESTART_POLICY_ON_FAILURE" => Some(Self::OnFailure), + "RESTART_POLICY_ALWAYS" => Some(Self::Always), + _ => None, + } + } +} /// Custom Proto for Dask Plugin. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/flyteidl/protos/flyteidl/plugins/common.proto b/flyteidl/protos/flyteidl/plugins/common.proto new file mode 100644 index 0000000000..15f31cf2d2 --- /dev/null +++ b/flyteidl/protos/flyteidl/plugins/common.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package flyteidl.plugins; + +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; + +import "flyteidl/core/tasks.proto"; + +enum RestartPolicy { + RESTART_POLICY_NEVER = 0; + RESTART_POLICY_ON_FAILURE = 1; + RESTART_POLICY_ALWAYS = 2; +} + +message CommonReplicaSpec { + // Number of replicas + int32 replicas = 1; + + // Image used for the replica group + string image = 2; + + // Resources required for the replica group + core.Resources resources = 3; + + // RestartPolicy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4; +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto index 6795dca11b..37655caf3d 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto @@ -2,14 +2,9 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; - -enum RestartPolicy { - RESTART_POLICY_NEVER = 0; - RESTART_POLICY_ON_FAILURE = 1; - RESTART_POLICY_ALWAYS = 2; -} +import public "flyteidl/plugins/common.proto"; enum CleanPodPolicy { CLEANPOD_POLICY_NONE = 0; @@ -30,4 +25,4 @@ message RunPolicy { // Number of retries before marking this job failed. int32 backoff_limit = 4; -} \ No newline at end of file +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto index 6eda161f92..b98e8aad99 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; @@ -26,18 +26,22 @@ message DistributedMPITrainingTask { // Replica specification for distributed MPI training message DistributedMPITrainingReplicaSpec { + // 1~4 deprecated. Use common instead. // Number of replicas - int32 replicas = 1; + int32 replicas = 1 [deprecated = true]; // Image used for the replica group - string image = 2; + string image = 2 [deprecated = true]; // Resources required for the replica group - core.Resources resources = 3; - + core.Resources resources = 3 [deprecated = true]; + // Restart policy determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + RestartPolicy restart_policy = 4 [deprecated = true]; // MPI sometimes requires different command set for different replica groups repeated string command = 5; -} \ No newline at end of file + + // The common replica spec + CommonReplicaSpec common = 6; +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto index bd3ddbdf97..0433384e75 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; @@ -35,15 +35,19 @@ message DistributedPyTorchTrainingTask { } message DistributedPyTorchTrainingReplicaSpec { + // 1~4 deprecated. Use common instead. // Number of replicas - int32 replicas = 1; + int32 replicas = 1 [deprecated = true]; // Image used for the replica group - string image = 2; + string image = 2 [deprecated = true]; // Resources required for the replica group - core.Resources resources = 3; - - // RestartPolicy determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + core.Resources resources = 3 [deprecated = true]; + + // Restart policy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4 [deprecated = true]; + + // The common replica spec + CommonReplicaSpec common = 5; } diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto index 789666b989..251526f7e0 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; @@ -28,15 +28,19 @@ message DistributedTensorflowTrainingTask { } message DistributedTensorflowTrainingReplicaSpec { + // 1~4 deprecated. Use common instead. // Number of replicas - int32 replicas = 1; + int32 replicas = 1 [deprecated = true]; // Image used for the replica group - string image = 2; + string image = 2 [deprecated = true]; // Resources required for the replica group - core.Resources resources = 3; + core.Resources resources = 3 [deprecated = true]; - // RestartPolicy Determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + // Restart policy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4 [deprecated = true]; + + // The common replica spec + CommonReplicaSpec common = 5; } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 7e4b32e25e..9d2e4a5aec 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -294,6 +294,7 @@ type kfDistributedReplicaSpec interface { GetImage() string GetResources() *core.Resources GetRestartPolicy() kfplugins.RestartPolicy + GetCommon() *kfplugins.CommonReplicaSpec } type allowsCommandOverride interface { @@ -301,9 +302,29 @@ type allowsCommandOverride interface { } func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) { + var replicas int32 + var image string + var resources *core.Resources + var restartPolicy kfplugins.RestartPolicy + + // replicas, image, resources, restartPolicy are deprecated since the common replica spec is introduced. + // Therefore, if the common replica spec is set, use that to get the common fields + common := rs.GetCommon() + if common != nil { + replicas = common.GetReplicas() + image = common.GetImage() + resources = common.GetResources() + restartPolicy = common.GetRestartPolicy() + } else { + replicas = rs.GetReplicas() + image = rs.GetImage() + resources = rs.GetResources() + restartPolicy = rs.GetRestartPolicy() + } + taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{} - if rs != nil && rs.GetResources() != nil { - resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources()) + if resources != nil { + resources, err := flytek8s.ToK8sResourceRequirements(resources) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) } @@ -321,26 +342,23 @@ func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExe replicaSpec.Replicas = &replicas } - if rs != nil { - var command []string - if v, ok := rs.(allowsCommandOverride); ok { - command = v.GetCommand() - } - if err := OverrideContainerSpec( - &replicaSpec.Template.Spec, - primaryContainerName, - rs.GetImage(), - command, - ); err != nil { - return nil, err - } + var command []string + if v, ok := rs.(allowsCommandOverride); ok { + command = v.GetCommand() + } + if err := OverrideContainerSpec( + &replicaSpec.Template.Spec, + primaryContainerName, + image, + command, + ); err != nil { + return nil, err + } - replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy()) + replicaSpec.RestartPolicy = ParseRestartPolicy(restartPolicy) - if !isMaster { - replicas := rs.GetReplicas() - replicaSpec.Replicas = &replicas - } + if !isMaster { + replicaSpec.Replicas = &replicas } return replicaSpec, nil diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 3c823510ac..900091f78a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -594,129 +594,191 @@ func TestReplicaCounts(t *testing.T) { func TestBuildResourceMPIV1(t *testing.T) { launcherCommand := []string{"python", "launcher.py"} workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} - taskConfig := &kfplugins.DistributedMPITrainingTask{ - LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ - Image: testImage, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "250Mi"}, + taskConfigs := []*kfplugins.DistributedMPITrainingTask{ + { + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "500Mi"}, + Command: launcherCommand, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, + Command: workerCommand, }, - Command: launcherCommand, + Slots: int32(1), }, - WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + { + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + Command: launcherCommand, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, }, + Command: workerCommand, }, - Command: workerCommand, + Slots: int32(1), }, - Slots: int32(1), } - launcherResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), - corev1.ResourceMemory: resource.MustParse("250Mi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), - corev1.ResourceMemory: resource.MustParse("500Mi"), - }, - } + for _, taskConfig := range taskConfigs { + launcherResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("250Mi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("500Mi"), + }, + } - workerResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - } + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + } - mpiResourceHandler := mpiOperatorResourceHandler{} + mpiResourceHandler := mpiOperatorResourceHandler{} - taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) - assert.NoError(t, err) - assert.NotNil(t, resource) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) - mpiJob, ok := resource.(*kubeflowv1.MPIJob) - assert.True(t, ok) - assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) - assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) - assert.Equal(t, *launcherResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Resources) - assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) - assert.Equal(t, launcherCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) - assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *launcherResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, launcherCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) + } } func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} - taskConfig := &kfplugins.DistributedMPITrainingTask{ - WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + taskConfigs := []*kfplugins.DistributedMPITrainingTask{ + { + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, + }, + Slots: int32(1), + }, + { + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, }, + Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, }, - Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, + Slots: int32(1), }, - Slots: int32(1), } - workerResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - } + for _, taskConfig := range taskConfigs { + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + } - mpiResourceHandler := mpiOperatorResourceHandler{} + mpiResourceHandler := mpiOperatorResourceHandler{} - taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) - assert.NoError(t, err) - assert.NotNil(t, resource) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) - mpiJob, ok := resource.(*kubeflowv1.MPIJob) - assert.True(t, ok) - assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) - assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) - assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) - assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) - assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) + } } func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { @@ -733,50 +795,87 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { }, })) - taskConfig := &kfplugins.DistributedMPITrainingTask{ - LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "250Mi"}, + taskConfigs := []*kfplugins.DistributedMPITrainingTask{ + { + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, }, }, }, - WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + { + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, }, }, }, } - mpiResourceHandler := mpiOperatorResourceHandler{} + for _, taskConfig := range taskConfigs { + mpiResourceHandler := mpiOperatorResourceHandler{} - taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) - assert.NoError(t, err) - assert.NotNil(t, resource) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) - mpiJob, ok := resource.(*kubeflowv1.MPIJob) - assert.True(t, ok) + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) - assert.NotContains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Tolerations, gpuToleration) - assert.Contains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) + assert.NotContains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Tolerations, gpuToleration) + assert.Contains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) + } } func TestGetReplicaCount(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index aea096e4ba..f40f80e1f7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -727,187 +727,264 @@ func TestReplicaCounts(t *testing.T) { } func TestBuildResourcePytorchV1(t *testing.T) { - taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ - MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Image: testImageMaster, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "250Mi"}, + taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ + { + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Image: testImageMaster, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "500Mi"}, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, }, - RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, }, - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + { + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Image: testImageMaster, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, }, }, }, } - masterResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), - corev1.ResourceMemory: resource.MustParse("250Mi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), - corev1.ResourceMemory: resource.MustParse("500Mi"), - }, - } + for _, taskConfig := range taskConfigs { + masterResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("250Mi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("500Mi"), + }, + } - workerResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - } + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + } - pytorchResourceHandler := pytorchOperatorResourceHandler{} + pytorchResourceHandler := pytorchOperatorResourceHandler{} - taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) - assert.NoError(t, err) - assert.NotNil(t, res) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + assert.NoError(t, err) + assert.NotNil(t, res) - pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) - assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) - assert.Equal(t, testImageMaster, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) - assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) + assert.Equal(t, testImageMaster, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) - assert.Equal(t, *masterResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) - assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, *masterResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) - assert.Equal(t, commonOp.RestartPolicyAlways, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) - assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyAlways, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) - assert.Nil(t, pytorchJob.Spec.RunPolicy.CleanPodPolicy) - assert.Nil(t, pytorchJob.Spec.RunPolicy.BackoffLimit) - assert.Nil(t, pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) - assert.Nil(t, pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) + assert.Nil(t, pytorchJob.Spec.RunPolicy.CleanPodPolicy) + assert.Nil(t, pytorchJob.Spec.RunPolicy.BackoffLimit) + assert.Nil(t, pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) + assert.Nil(t, pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) - assert.Nil(t, pytorchJob.Spec.ElasticPolicy) + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) + } } func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { - taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 100, + taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + BackoffLimit: 100, + ActiveDeadlineSeconds: 1000, + TtlSecondsAfterFinished: 10000, + }, }, - RunPolicy: &kfplugins.RunPolicy{ - CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, - BackoffLimit: 100, - ActiveDeadlineSeconds: 1000, - TtlSecondsAfterFinished: 10000, + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + }, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + BackoffLimit: 100, + ActiveDeadlineSeconds: 1000, + TtlSecondsAfterFinished: 10000, + }, }, } - pytorchResourceHandler := pytorchOperatorResourceHandler{} - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 + for _, taskConfig := range taskConfigs { + pytorchResourceHandler := pytorchOperatorResourceHandler{} - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) - assert.NoError(t, err) - assert.NotNil(t, res) + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 - pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) - assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) - assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) - assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) - assert.Equal(t, int64(1000), *pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) - assert.Equal(t, int32(10000), *pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) + assert.Equal(t, int64(1000), *pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) + assert.Equal(t, int32(10000), *pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) + } } func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { - taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, }, }, - } - // Master Replica should use resource from task override if not set - taskOverrideResourceRequirements := &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1000m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("100m"), - corev1.ResourceMemory: resource.MustParse("512Mi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + }, + }, }, } - workerResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - } + for _, taskConfig := range taskConfigs { + // Master Replica should use resource from task override if not set + taskOverrideResourceRequirements := &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } - pytorchResourceHandler := pytorchOperatorResourceHandler{} + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + } - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 + pytorchResourceHandler := pytorchOperatorResourceHandler{} - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) - assert.NoError(t, err) - assert.NotNil(t, res) + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 - pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + assert.NoError(t, err) + assert.NotNil(t, res) - assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) - assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) - assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) - assert.Equal(t, *taskOverrideResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) - assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) - assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) - assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + assert.Equal(t, *taskOverrideResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) - assert.Nil(t, pytorchJob.Spec.ElasticPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) + } } func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { @@ -924,101 +1001,164 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { }, })) - taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ - MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "250Mi"}, + taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ + { + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, }, }, }, - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + { + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, + }, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, }, }, }, } - pytorchResourceHandler := pytorchOperatorResourceHandler{} + for _, taskConfig := range taskConfigs { + pytorchResourceHandler := pytorchOperatorResourceHandler{} - taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) - assert.NoError(t, err) - assert.NotNil(t, res) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + assert.NoError(t, err) + assert.NotNil(t, res) - pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) - assert.NotContains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Tolerations, gpuToleration) - assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) + assert.NotContains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Tolerations, gpuToleration) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) + } } func TestBuildResourcePytorchV1WithElastic(t *testing.T) { - taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 2, + taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 2, + }, + ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}, + }, + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 2, + }, + }, + ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}, }, - ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}, } - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 - pytorchResourceHandler := pytorchOperatorResourceHandler{} - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) - assert.NoError(t, err) - assert.NotNil(t, resource) + for _, taskConfig := range taskConfigs { + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 - pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) - assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.NotNil(t, pytorchJob.Spec.ElasticPolicy) - assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas) - assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas) - assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode) - assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend) + pytorchResourceHandler := pytorchOperatorResourceHandler{} + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + assert.NoError(t, err) + assert.NotNil(t, resource) - assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs)) - assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) + pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.NotNil(t, pytorchJob.Spec.ElasticPolicy) + assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas) + assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas) + assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode) + assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend) + + assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs)) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) - var hasContainerWithDefaultPytorchName = false + var hasContainerWithDefaultPytorchName = false - for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers { - if container.Name == kubeflowv1.PytorchJobDefaultContainerName { - hasContainerWithDefaultPytorchName = true + for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers { + if container.Name == kubeflowv1.PytorchJobDefaultContainerName { + hasContainerWithDefaultPytorchName = true + } } - } - assert.True(t, hasContainerWithDefaultPytorchName) + assert.True(t, hasContainerWithDefaultPytorchName) + } } func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { - taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 0, + taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 0, + }, + }, + { + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 0, + }, + }, }, } - pytorchResourceHandler := pytorchOperatorResourceHandler{} - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 - _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) - assert.Error(t, err) + + for _, taskConfig := range taskConfigs { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 + _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + assert.Error(t, err) + } } func TestParseElasticConfig(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index db5fe6a83a..d69fd30b01 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -100,9 +100,18 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task for t, cfg := range replicaSpecCfgMap { // Short circuit if replica set has no replicas to avoid unnecessarily // generating pod specs - if cfg.GetReplicas() <= 0 { + var replicas int32 + // replicas is deprecated since the common replica spec is introduced. + // Therefore, if the common replica spec is set, use that to get the common fields + if cfg.GetCommon() != nil { + replicas = cfg.GetCommon().GetReplicas() + } else { + replicas = cfg.GetReplicas() + } + if replicas <= 0 { continue } + rs, err := common.ToReplicaSpecWithOverrides(ctx, taskCtx, cfg, kubeflowv1.TFJobDefaultContainerName, false) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error()) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 80e95871d1..a85ce8f875 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -460,55 +460,81 @@ func TestBuildResourceTensorFlowExtendedResources(t *testing.T) { } v0TaskTemplate := dummyTensorFlowTaskTemplate("v0", dummyTensorFlowCustomObj(100, 50, 1, 1)) - v1TaskTemplate := dummyTensorFlowTaskTemplate("v1", &kfplugins.DistributedTensorflowTrainingTask{ - ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 1, - }, - WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 100, - }, - PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 50, - }, - EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 1, - }, - }) - v1TaskTemplate.TaskTypeVersion = 1 - testConfigs := []struct { - name string - taskTemplate *core.TaskTemplate - }{ - {"v0", v0TaskTemplate}, - {"v1", v1TaskTemplate}, - } - - for _, tCfg := range testConfigs { - for _, f := range fixtures { - t.Run(tCfg.name+" "+f.name, func(t *testing.T) { - taskTemplate := *tCfg.taskTemplate - taskTemplate.ExtendedResources = f.extendedResourcesBase - tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride) - r, err := tensorflowResourceHandler.BuildResource(context.TODO(), taskContext) - assert.NoError(t, err) - assert.NotNil(t, r) - tensorflowJob, ok := r.(*kubeflowv1.TFJob) - assert.True(t, ok) - - for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { - assert.EqualValues( - t, - f.expectedNsr, - replicaSpec.Template.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, - ) - assert.EqualValues( - t, - f.expectedTol, - replicaSpec.Template.Spec.Tolerations, - ) - } - }) + v1TaskTemplates := []*core.TaskTemplate{ + dummyTensorFlowTaskTemplate("v1", &kfplugins.DistributedTensorflowTrainingTask{ + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + }, + PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 50, + }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + }, + }), + dummyTensorFlowTaskTemplate("v1", &kfplugins.DistributedTensorflowTrainingTask{ + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 1, + }, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + }, + }, + PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 50, + }, + }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 1, + }, + }, + }), + } + for _, v1TaskTemplate := range v1TaskTemplates { + v1TaskTemplate.TaskTypeVersion = 1 + testConfigs := []struct { + name string + taskTemplate *core.TaskTemplate + }{ + {"v0", v0TaskTemplate}, + {"v1", v1TaskTemplate}, + } + + for _, tCfg := range testConfigs { + for _, f := range fixtures { + t.Run(tCfg.name+" "+f.name, func(t *testing.T) { + taskTemplate := *tCfg.taskTemplate + taskTemplate.ExtendedResources = f.extendedResourcesBase + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride) + r, err := tensorflowResourceHandler.BuildResource(context.TODO(), taskContext) + assert.NoError(t, err) + assert.NotNil(t, r) + tensorflowJob, ok := r.(*kubeflowv1.TFJob) + assert.True(t, ok) + + for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + assert.EqualValues( + t, + f.expectedNsr, + replicaSpec.Template.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, + ) + assert.EqualValues( + t, + f.expectedTol, + replicaSpec.Template.Spec.Tolerations, + ) + } + }) + } } } } @@ -638,208 +664,307 @@ func TestReplicaCounts(t *testing.T) { } func TestBuildResourceTensorFlowV1(t *testing.T) { - taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ - ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 1, - Image: testImage, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + taskConfigs := []*kfplugins.DistributedTensorflowTrainingTask{ + { + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, }, }, - RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, - }, - WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 50, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + ActiveDeadlineSeconds: int32(100), }, }, - PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 50, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + { + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, }, }, - }, - EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 1, - Image: testImage, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 50, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, }, }, - RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, - }, - RunPolicy: &kfplugins.RunPolicy{ - CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, - ActiveDeadlineSeconds: int32(100), + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + ActiveDeadlineSeconds: int32(100), + }, }, } + for _, taskConfig := range taskConfigs { - resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ - kubeflowv1.TFJobReplicaTypeChief: { - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - }, - }, - kubeflowv1.TFJobReplicaTypeWorker: { - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - }, - kubeflowv1.TFJobReplicaTypePS: { - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), + resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ + kubeflowv1.TFJobReplicaTypeChief: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), + kubeflowv1.TFJobReplicaTypeWorker: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, }, - }, - kubeflowv1.TFJobReplicaTypeEval: { - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("250m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), + kubeflowv1.TFJobReplicaTypePS: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), + kubeflowv1.TFJobReplicaTypeEval: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, }, - }, - } + } - tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) + taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) - assert.NoError(t, err) - assert.NotNil(t, resource) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) - tensorflowJob, ok := resource.(*kubeflowv1.TFJob) - assert.True(t, ok) - assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) - assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) - assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas) + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas) - for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { - var hasContainerWithDefaultTensorFlowName = false + for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false - for _, container := range replicaSpec.Template.Spec.Containers { - if container.Name == kubeflowv1.TFJobDefaultContainerName { - hasContainerWithDefaultTensorFlowName = true - assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == kubeflowv1.TFJobDefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + } } - } - assert.True(t, hasContainerWithDefaultTensorFlowName) + assert.True(t, hasContainerWithDefaultTensorFlowName) + } + assert.Equal(t, commonOp.CleanPodPolicyAll, *tensorflowJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int64(100), *tensorflowJob.Spec.RunPolicy.ActiveDeadlineSeconds) } - assert.Equal(t, commonOp.CleanPodPolicyAll, *tensorflowJob.Spec.RunPolicy.CleanPodPolicy) - assert.Equal(t, int64(100), *tensorflowJob.Spec.RunPolicy.ActiveDeadlineSeconds) } func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { - taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ - WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + taskConfigs := []*kfplugins.DistributedTensorflowTrainingTask{ + { + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + }, + }, + { + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, }, }, }, } - resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ - kubeflowv1.TFJobReplicaTypeWorker: { - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + for _, taskConfig := range taskConfigs { + resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ + kubeflowv1.TFJobReplicaTypeWorker: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, }, - }, - } + } - tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) - taskTemplate.TaskTypeVersion = 1 + taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) + taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) - assert.NoError(t, err) - assert.NotNil(t, resource) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) - tensorflowJob, ok := resource.(*kubeflowv1.TFJob) - assert.True(t, ok) - assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) - assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief]) - assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS]) + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief]) + assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS]) - for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { - var hasContainerWithDefaultTensorFlowName = false + for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false - for _, container := range replicaSpec.Template.Spec.Containers { - if container.Name == kubeflowv1.TFJobDefaultContainerName { - hasContainerWithDefaultTensorFlowName = true - assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == kubeflowv1.TFJobDefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + } } - } - assert.True(t, hasContainerWithDefaultTensorFlowName) + assert.True(t, hasContainerWithDefaultTensorFlowName) + } } } @@ -857,49 +982,88 @@ func TestBuildResourceTensorFlowV1ResourceTolerations(t *testing.T) { }, })) - taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ - ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 1, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + taskConfigs := []*kfplugins.DistributedTensorflowTrainingTask{ + { + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, }, }, }, - WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + { + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 1, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Common: &kfplugins.CommonReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, }, }, }, } - tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + for _, taskConfig := range taskConfigs { - taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) - taskTemplate.TaskTypeVersion = 1 + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) - assert.NoError(t, err) - assert.NotNil(t, resource) + taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) + taskTemplate.TaskTypeVersion = 1 - tensorflowJob, ok := resource.(*kubeflowv1.TFJob) - assert.True(t, ok) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) - assert.NotContains(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Template.Spec.Tolerations, gpuToleration) - assert.Contains(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + + assert.NotContains(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Template.Spec.Tolerations, gpuToleration) + assert.Contains(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) + } }