diff --git a/flyteidl/protos/flyteidl/plugins/common.proto b/flyteidl/protos/flyteidl/plugins/common.proto new file mode 100644 index 00000000000..15f31cf2d22 --- /dev/null +++ b/flyteidl/protos/flyteidl/plugins/common.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package flyteidl.plugins; + +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; + +import "flyteidl/core/tasks.proto"; + +enum RestartPolicy { + RESTART_POLICY_NEVER = 0; + RESTART_POLICY_ON_FAILURE = 1; + RESTART_POLICY_ALWAYS = 2; +} + +message CommonReplicaSpec { + // Number of replicas + int32 replicas = 1; + + // Image used for the replica group + string image = 2; + + // Resources required for the replica group + core.Resources resources = 3; + + // RestartPolicy determines whether pods will be restarted when they exit + RestartPolicy restart_policy = 4; +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto index 6795dca11b8..37655caf3d8 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto @@ -2,14 +2,9 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; - -enum RestartPolicy { - RESTART_POLICY_NEVER = 0; - RESTART_POLICY_ON_FAILURE = 1; - RESTART_POLICY_ALWAYS = 2; -} +import public "flyteidl/plugins/common.proto"; enum CleanPodPolicy { CLEANPOD_POLICY_NONE = 0; @@ -30,4 +25,4 @@ message RunPolicy { // Number of retries before marking this job failed. int32 backoff_limit = 4; -} \ No newline at end of file +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto index 6eda161f924..9de48793617 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/mpi.proto @@ -2,9 +2,8 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; -import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; // Proto for plugin that enables distributed training using https://github.com/kubeflow/mpi-operator @@ -26,18 +25,11 @@ message DistributedMPITrainingTask { // Replica specification for distributed MPI training message DistributedMPITrainingReplicaSpec { - // Number of replicas - int32 replicas = 1; - - // Image used for the replica group - string image = 2; - - // Resources required for the replica group - core.Resources resources = 3; - - // Restart policy determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + reserved 1, 2, 3, 4; // MPI sometimes requires different command set for different replica groups repeated string command = 5; -} \ No newline at end of file + + // The common replica spec + CommonReplicaSpec common = 6; +} diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto index bd3ddbdf978..b71339c8739 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/pytorch.proto @@ -2,9 +2,8 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; -import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; // Custom proto for torch elastic config for distributed training using @@ -35,15 +34,8 @@ message DistributedPyTorchTrainingTask { } message DistributedPyTorchTrainingReplicaSpec { - // Number of replicas - int32 replicas = 1; + reserved 1, 2, 3, 4; - // Image used for the replica group - string image = 2; - - // Resources required for the replica group - core.Resources resources = 3; - - // RestartPolicy determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + // The common replica spec + CommonReplicaSpec common = 5; } diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto index 789666b989e..0e9e832f255 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto @@ -2,9 +2,8 @@ syntax = "proto3"; package flyteidl.plugins.kubeflow; -option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"; -import "flyteidl/core/tasks.proto"; import "flyteidl/plugins/kubeflow/common.proto"; // Proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator @@ -28,15 +27,8 @@ message DistributedTensorflowTrainingTask { } message DistributedTensorflowTrainingReplicaSpec { - // Number of replicas - int32 replicas = 1; + reserved 1, 2, 3, 4; - // Image used for the replica group - string image = 2; - - // Resources required for the replica group - core.Resources resources = 3; - - // RestartPolicy Determines whether pods will be restarted when they exit - RestartPolicy restart_policy = 4; + // The common replica spec + CommonReplicaSpec common = 5; }