diff --git a/flyteidl/protos/flyteidl/plugins/common.proto b/flyteidl/protos/flyteidl/plugins/common.proto new file mode 100644 index 00000000000..15f31cf2d22 --- /dev/null +++ b/flyteidl/protos/flyteidl/plugins/common.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package flyteidl.plugins; + +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; + +import "flyteidl/core/tasks.proto"; + +enum RestartPolicy { + RESTART_POLICY_NEVER = 0; + RESTART_POLICY_ON_FAILURE = 1; + RESTART_POLICY_ALWAYS = 2; +} + +message CommonReplicaSpec { + // Number of replicas + int32 replicas = 1; + + // Image used for the replica group + string image = 2; + + // Resources required for the replica group + core.Resources resources = 3; + + // RestartPolicy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4; +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto index 6795dca11b8..37655caf3d8 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto @@ -2,14 +2,9 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; - -enum RestartPolicy { - RESTART_POLICY_NEVER = 0; - RESTART_POLICY_ON_FAILURE = 1; - RESTART_POLICY_ALWAYS = 2; -} +import public "flyteidl/plugins/common.proto"; enum CleanPodPolicy { CLEANPOD_POLICY_NONE = 0; @@ -30,4 +25,4 @@ message RunPolicy { // Number of retries before marking this job failed. int32 backoff_limit = 4; -} \ No newline at end of file +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto index 6eda161f924..b98e8aad992 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; @@ -26,18 +26,22 @@ message DistributedMPITrainingTask { // Replica specification for distributed MPI training message DistributedMPITrainingReplicaSpec { + // 1~4 deprecated. Use common instead. // Number of replicas - int32 replicas = 1; + int32 replicas = 1 [deprecated = true]; // Image used for the replica group - string image = 2; + string image = 2 [deprecated = true]; // Resources required for the replica group - core.Resources resources = 3; - + core.Resources resources = 3 [deprecated = true]; + // Restart policy determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + RestartPolicy restart_policy = 4 [deprecated = true]; // MPI sometimes requires different command set for different replica groups repeated string command = 5; -} \ No newline at end of file + + // The common replica spec + CommonReplicaSpec common = 6; +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto index bd3ddbdf978..0433384e751 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; @@ -35,15 +35,19 @@ message DistributedPyTorchTrainingTask { } message DistributedPyTorchTrainingReplicaSpec { + // 1~4 deprecated. Use common instead. // Number of replicas - int32 replicas = 1; + int32 replicas = 1 [deprecated = true]; // Image used for the replica group - string image = 2; + string image = 2 [deprecated = true]; // Resources required for the replica group - core.Resources resources = 3; - - // RestartPolicy determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + core.Resources resources = 3 [deprecated = true]; + + // Restart policy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4 [deprecated = true]; + + // The common replica spec + CommonReplicaSpec common = 5; } diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto index 789666b989e..251526f7e08 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; @@ -28,15 +28,19 @@ message DistributedTensorflowTrainingTask { } message DistributedTensorflowTrainingReplicaSpec { + // 1~4 deprecated. Use common instead. // Number of replicas - int32 replicas = 1; + int32 replicas = 1 [deprecated = true]; // Image used for the replica group - string image = 2; + string image = 2 [deprecated = true]; // Resources required for the replica group - core.Resources resources = 3; + core.Resources resources = 3 [deprecated = true]; - // RestartPolicy Determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + // Restart policy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4 [deprecated = true]; + + // The common replica spec + CommonReplicaSpec common = 5; }