diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index a6a6ef3647..fc3feae9bc 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict @@ -171,7 +172,15 @@ def _convert_replica_spec( ) -> mpi_task.DistributedMPITrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return mpi_task.DistributedMPITrainingReplicaSpec( + common=plugins_common.CommonReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ), command=replica_config.command, + # The forllowing fields are deprecated. They are kept for backwards compatibility. + # The following fields are deprecated and will be removed in the future replicas=replica_config.replicas, image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, @@ -203,10 +212,14 @@ def get_command(self, settings: SerializationSettings) -> List[str]: def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: + worker.common.replicas = self.task_config.num_workers + # Deprecated. Only kept for backwards compatibility. worker.replicas = self.task_config.num_workers launcher = self._convert_replica_spec(self.task_config.launcher) if self.task_config.num_launcher_replicas: + launcher.common.replicas = self.task_config.num_launcher_replicas + # Deprecated. Only kept for backwards compatibility. launcher.replicas = self.task_config.num_launcher_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index 05efff84b0..6b1cb0e762 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.6.1,<2.0.0"] +plugin_requires = ["flyteidl>1.12.2", "flytekit>=1.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 36758bfb6f..28c686eeb2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -34,8 +34,22 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_config is not None assert my_mpi_task.get_custom(serialization_settings) == { - "launcherReplicas": {"replicas": 10, "resources": {}}, - "workerReplicas": {"replicas": 10, "resources": {}}, + "launcherReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, + "replicas": 10, + "resources": {}, + }, + "workerReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, + "replicas": 10, + "resources": {}, + }, "slots": 1, } assert my_mpi_task.task_type == "mpi" @@ -69,10 +83,18 @@ def my_mpi_task(x: int, y: str) -> int: expected_dict = { "launcherReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, "workerReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, @@ -124,6 +146,14 @@ def my_mpi_task(x: int, y: str) -> int: expected_custom_dict = { "launcherReplicas": { + "common": { + "replicas": 1, + "image": "launcher:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, "replicas": 1, "image": "launcher:latest", "resources": { @@ -132,6 +162,20 @@ def my_mpi_task(x: int, y: str) -> int: }, }, "workerReplicas": { + "common": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + }, "replicas": 5, "image": "worker:latest", "resources": { @@ -188,6 +232,17 @@ def my_horovod_task(): ... # CleanPodPolicy.NONE is the default, so it should not be in the output dictionary expected_dict = { "launcherReplicas": { + "common": { + "replicas": 1, + "resources": { + "requests": [ + {"name": "CPU", "value": "1"}, + ], + "limits": [ + {"name": "CPU", "value": "2"}, + ], + }, + }, "replicas": 1, "resources": { "requests": [ @@ -199,6 +254,10 @@ def my_horovod_task(): ... }, }, "workerReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, "command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 966425f901..7939cd3195 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -8,6 +8,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict @@ -203,6 +204,13 @@ def _convert_replica_spec( if not isinstance(replica_config, Master): replicas = replica_config.replicas return pytorch_task.DistributedPyTorchTrainingReplicaSpec( + common=plugins_common.CommonReplicaSpec( + replicas=replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ), + # The forllowing fields are deprecated. They are kept for backwards compatibility. replicas=replicas, image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, @@ -213,6 +221,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) # support v0 config for backwards compatibility if self.task_config.num_workers: + worker.common.replicas = self.task_config.num_workers + # Deprecated. Only kept for backwards compatibility. worker.replicas = self.task_config.num_workers run_policy = ( @@ -514,6 +524,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ) job = pytorch_task.DistributedPyTorchTrainingTask( worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( + common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes), + # The following fields are deprecated. They are kept for backwards compatibility. replicas=self.max_nodes, ), elastic_config=elastic_config, diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index 317ca7b8a0..52c5f807a3 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,8 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1", "kubernetes"] + +plugin_requires = ["cloudpickle", "flyteidl>1.12.2", "flytekit>=1.6.1", "kubernetes"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 0c15886c3a..7714cc583f 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -33,8 +33,22 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.task_config is not None assert my_pytorch_task.get_custom(serialization_settings) == { - "workerReplicas": {"replicas": 10, "resources": {}}, - "masterReplicas": {"replicas": 1, "resources": {}}, + "workerReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, + "replicas": 10, + "resources": {}, + }, + "masterReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, + "replicas": 1, + "resources": {}, + }, } assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") @@ -64,10 +78,18 @@ def my_pytorch_task(x: int, y: str) -> int: expected_dict = { "masterReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, "workerReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, @@ -114,6 +136,21 @@ def my_pytorch_task(x: int, y: str) -> int: expected_custom_dict = { "workerReplicas": { + "common": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, "replicas": 5, "image": "worker:latest", "resources": { @@ -129,6 +166,11 @@ def my_pytorch_task(x: int, y: str) -> int: "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, "masterReplicas": { + "common": { + "resources": {}, + "replicas": 1, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, "resources": {}, "replicas": 1, "restartPolicy": "RESTART_POLICY_ALWAYS", diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 62cd482416..bf07f2bab0 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task from google.protobuf.json_format import MessageToDict @@ -174,6 +175,13 @@ def _convert_replica_spec( ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( + common=plugins_common.CommonReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ), + # The following fields are deprecated. They are kept for backwards compatibility. replicas=replica_config.replicas, image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, @@ -191,18 +199,26 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: chief = self._convert_replica_spec(self.task_config.chief) if self.task_config.num_chief_replicas: + chief.common.replicas = self.task_config.num_chief_replicas + # Deprecated. Only kept for backwards compatibility. chief.replicas = self.task_config.num_chief_replicas worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: + worker.common.replicas = self.task_config.num_workers + # Deprecated. Only kept for backwards compatibility. worker.replicas = self.task_config.num_workers ps = self._convert_replica_spec(self.task_config.ps) if self.task_config.num_ps_replicas: + ps.common.replicas = self.task_config.num_ps_replicas + # Deprecated. Only kept for backwards compatibility. ps.replicas = self.task_config.num_ps_replicas evaluator = self._convert_replica_spec(self.task_config.evaluator) if self.task_config.num_evaluator_replicas: + evaluator.common.replicas = self.task_config.num_evaluator_replicas + # Deprecated. Only kept for backwards compatibility. evaluator.replicas = self.task_config.num_evaluator_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index 25ffe19eec..c3983cec50 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flyteidl>=1.10.0", "flytekit>=1.6.1"] +plugin_requires = ["flyteidl>1.12.2", "flytekit>=1.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index 0ae32439d7..f6fe8c4d6b 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -44,16 +44,29 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_dict = { "chiefReplicas": { + "common": { + "resources": {}, + }, "resources": {}, }, "workerReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, "psReplicas": { + "common": { + "resources": {}, + }, "resources": {}, }, "evaluatorReplicas": { + "common": { + "resources": {}, + }, "resources": {}, }, } @@ -106,6 +119,14 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_custom_dict = { "chiefReplicas": { + "common": { + "replicas": 1, + "image": "chief:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, "replicas": 1, "image": "chief:latest", "resources": { @@ -114,6 +135,21 @@ def my_tensorflow_task(x: int, y: str) -> int: }, }, "workerReplicas": { + "common": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, "replicas": 5, "image": "worker:latest", "resources": { @@ -129,11 +165,31 @@ def my_tensorflow_task(x: int, y: str) -> int: "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, "psReplicas": { + "common": { + "resources": {}, + "replicas": 2, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, "resources": {}, "replicas": 2, "restartPolicy": "RESTART_POLICY_ALWAYS", }, "evaluatorReplicas": { + "common": { + "replicas": 5, + "image": "evaluator:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, "replicas": 5, "image": "evaluator:latest", "resources": { @@ -185,16 +241,29 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_dict = { "chiefReplicas": { + "common": { + "resources": {}, + }, "resources": {}, }, "workerReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, "psReplicas": { + "common": { + "resources": {}, + }, "resources": {}, }, "evaluatorReplicas": { + "common": { + "resources": {}, + }, "resources": {}, }, "runPolicy": { @@ -233,18 +302,34 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_dict = { "chiefReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, "workerReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, "replicas": 10, "resources": {}, }, "psReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, }, "evaluatorReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, "replicas": 1, "resources": {}, },