From a972abbe9a9512a7a757886d9ee86bae79916e3a Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Sun, 16 Jun 2024 11:34:03 +0800 Subject: [PATCH] feat(replicaspec): Keeps old fields for backward-compatibility Resolves: flyteorg/flyte#4408 Signed-off-by: Chi-Sheng Liu --- .../flytekitplugins/kfmpi/task.py | 10 ++++ .../flytekit-kf-mpi/tests/test_mpi_task.py | 37 +++++++++++++ .../flytekitplugins/kfpytorch/task.py | 11 +++- .../tests/test_pytorch_task.py | 24 +++++++++ .../flytekitplugins/kftensorflow/task.py | 15 +++++- .../tests/test_tensorflow_task.py | 53 +++++++++++++++++++ 6 files changed, 148 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index b911506bea..8e577542d7 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -179,6 +179,12 @@ def _convert_replica_spec( restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ), command=replica_config.command, + # The forllowing fields are deprecated. They are kept for backwards compatibility. + # The following fields are deprecated and will be removed in the future + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -207,10 +213,14 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: worker.common.replicas = self.task_config.num_workers + # Deprecated. Only kept for backwards compatibility. + worker.replicas = self.task_config.num_workers launcher = self._convert_replica_spec(self.task_config.launcher) if self.task_config.num_launcher_replicas: launcher.common.replicas = self.task_config.num_launcher_replicas + # Deprecated. Only kept for backwards compatibility. + launcher.replicas = self.task_config.num_launcher_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None mpi_job = mpi_task.DistributedMPITrainingTask( diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 0b8b8353c1..fbf2b4d33d 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -39,12 +39,16 @@ def my_mpi_task(x: int, y: str) -> int: "replicas": 10, "resources": {}, }, + "replicas": 10, + "resources": {}, }, "workerReplicas": { "common": { "replicas": 10, "resources": {}, }, + "replicas": 10, + "resources": {}, }, "slots": 1, } @@ -83,12 +87,16 @@ def my_mpi_task(x: int, y: str) -> int: "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "workerReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "slots": 1, } @@ -146,6 +154,12 @@ def my_mpi_task(x: int, y: str) -> int: "limits": [{"name": "CPU", "value": "2"}], }, }, + "replicas": 1, + "image": "launcher:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, }, "workerReplicas": { "common": { @@ -162,6 +176,18 @@ def my_mpi_task(x: int, y: str) -> int: ], }, }, + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, }, "slots": 2, "runPolicy": {"cleanPodPolicy": "CLEANPOD_POLICY_ALL"}, @@ -214,12 +240,23 @@ def my_horovod_task(): ... ], }, }, + "replicas": 1, + "resources": { + "requests": [ + {"name": "CPU", "value": "1"}, + ], + "limits": [ + {"name": "CPU", "value": "2"}, + ], + }, }, "workerReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, "command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], }, "slots": 2, diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 57fa6d8f3a..68107a8f9a 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -183,7 +183,12 @@ def _convert_replica_spec( image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, - ) + ), + # The forllowing fields are deprecated. They are kept for backwards compatibility. + replicas=replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: @@ -191,6 +196,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # support v0 config for backwards compatibility if self.task_config.num_workers: worker.common.replicas = self.task_config.num_workers + # Deprecated. Only kept for backwards compatibility. + worker.replicas = self.task_config.num_workers run_policy = ( _convert_run_policy_to_flyte_idl(self.task_config.run_policy) if self.task_config.run_policy else None @@ -459,6 +466,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] job = pytorch_task.DistributedPyTorchTrainingTask( worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes), + # The following fields are deprecated. They are kept for backwards compatibility. + replicas=self.max_nodes, ), elastic_config=elastic_config, run_policy=run_policy, diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 1bf9408e57..7714cc583f 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -38,12 +38,16 @@ def my_pytorch_task(x: int, y: str) -> int: "replicas": 10, "resources": {}, }, + "replicas": 10, + "resources": {}, }, "masterReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, } assert my_pytorch_task.resources.limits == Resources() @@ -78,12 +82,16 @@ def my_pytorch_task(x: int, y: str) -> int: "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "workerReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, } assert my_pytorch_task.get_custom(serialization_settings) == expected_dict @@ -143,6 +151,19 @@ def my_pytorch_task(x: int, y: str) -> int: }, "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, "masterReplicas": { "common": { @@ -150,6 +171,9 @@ def my_pytorch_task(x: int, y: str) -> int: "replicas": 1, "restartPolicy": "RESTART_POLICY_ALWAYS", }, + "resources": {}, + "replicas": 1, + "restartPolicy": "RESTART_POLICY_ALWAYS", }, "runPolicy": { "cleanPodPolicy": "CLEANPOD_POLICY_ALL", diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 8bab5bb5ff..bf07f2bab0 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -180,7 +180,12 @@ def _convert_replica_spec( image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, - ) + ), + # The following fields are deprecated. They are kept for backwards compatibility. + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -195,18 +200,26 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: chief = self._convert_replica_spec(self.task_config.chief) if self.task_config.num_chief_replicas: chief.common.replicas = self.task_config.num_chief_replicas + # Deprecated. Only kept for backwards compatibility. + chief.replicas = self.task_config.num_chief_replicas worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: worker.common.replicas = self.task_config.num_workers + # Deprecated. Only kept for backwards compatibility. + worker.replicas = self.task_config.num_workers ps = self._convert_replica_spec(self.task_config.ps) if self.task_config.num_ps_replicas: ps.common.replicas = self.task_config.num_ps_replicas + # Deprecated. Only kept for backwards compatibility. + ps.replicas = self.task_config.num_ps_replicas evaluator = self._convert_replica_spec(self.task_config.evaluator) if self.task_config.num_evaluator_replicas: evaluator.common.replicas = self.task_config.num_evaluator_replicas + # Deprecated. Only kept for backwards compatibility. + evaluator.replicas = self.task_config.num_evaluator_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None training_task = tensorflow_task.DistributedTensorflowTrainingTask( diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index fe7fa22466..f6fe8c4d6b 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -47,22 +47,27 @@ def my_tensorflow_task(x: int, y: str) -> int: "common": { "resources": {}, }, + "resources": {}, }, "workerReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "psReplicas": { "common": { "resources": {}, }, + "resources": {}, }, "evaluatorReplicas": { "common": { "resources": {}, }, + "resources": {}, }, } assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict @@ -122,6 +127,12 @@ def my_tensorflow_task(x: int, y: str) -> int: "limits": [{"name": "CPU", "value": "2"}], }, }, + "replicas": 1, + "image": "chief:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, }, "workerReplicas": { "common": { @@ -139,6 +150,19 @@ def my_tensorflow_task(x: int, y: str) -> int: }, "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, "psReplicas": { "common": { @@ -146,6 +170,9 @@ def my_tensorflow_task(x: int, y: str) -> int: "replicas": 2, "restartPolicy": "RESTART_POLICY_ALWAYS", }, + "resources": {}, + "replicas": 2, + "restartPolicy": "RESTART_POLICY_ALWAYS", }, "evaluatorReplicas": { "common": { @@ -163,6 +190,19 @@ def my_tensorflow_task(x: int, y: str) -> int: }, "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, + "replicas": 5, + "image": "evaluator:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, } @@ -204,22 +244,27 @@ def my_tensorflow_task(x: int, y: str) -> int: "common": { "resources": {}, }, + "resources": {}, }, "workerReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "psReplicas": { "common": { "resources": {}, }, + "resources": {}, }, "evaluatorReplicas": { "common": { "resources": {}, }, + "resources": {}, }, "runPolicy": { "cleanPodPolicy": "CLEANPOD_POLICY_RUNNING", @@ -261,24 +306,32 @@ def my_tensorflow_task(x: int, y: str) -> int: "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "workerReplicas": { "common": { "replicas": 10, "resources": {}, }, + "replicas": 10, + "resources": {}, }, "psReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, "evaluatorReplicas": { "common": { "replicas": 1, "resources": {}, }, + "replicas": 1, + "resources": {}, }, }