From e2d3a4ba1589e80b58b4f0682bd9c4c7ecc377f0 Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Thu, 16 May 2024 23:35:05 +0800 Subject: [PATCH] feat: Update files with repect to common ReplicaSpec refactor Resolves: flyteorg/flyte#4408 Signed-off-by: Chi-Sheng Liu --- .../flytekitplugins/kfmpi/task.py | 15 +- .../flytekit-kf-mpi/tests/test_mpi_task.py | 86 +++++++----- .../flytekitplugins/kfpytorch/task.py | 15 +- .../tests/test_pytorch_task.py | 60 +++++--- .../flytekitplugins/kftensorflow/task.py | 19 +-- .../tests/test_tensorflow_task.py | 132 +++++++++++------- 6 files changed, 204 insertions(+), 123 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 7c8416d0077..b911506beaf 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict @@ -171,11 +172,13 @@ def _convert_replica_spec( ) -> mpi_task.DistributedMPITrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return mpi_task.DistributedMPITrainingReplicaSpec( + common=plugins_common.CommonReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ), command=replica_config.command, - replicas=replica_config.replicas, - image=replica_config.image, - resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -203,11 +206,11 @@ def get_command(self, settings: SerializationSettings) -> List[str]: def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: - worker.replicas = self.task_config.num_workers + worker.common.replicas = self.task_config.num_workers launcher = self._convert_replica_spec(self.task_config.launcher) if self.task_config.num_launcher_replicas: - launcher.replicas = self.task_config.num_launcher_replicas + launcher.common.replicas = self.task_config.num_launcher_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None mpi_job = mpi_task.DistributedMPITrainingTask( diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index deec3ff3852..0b8b8353c14 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -34,8 +34,18 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_config is not None assert my_mpi_task.get_custom(serialization_settings) == { - "launcherReplicas": {"replicas": 10, "resources": {}}, - "workerReplicas": {"replicas": 10, "resources": {}}, + "launcherReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, + }, + "workerReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, + }, "slots": 1, } assert my_mpi_task.task_type == "mpi" @@ -69,12 +79,16 @@ def my_mpi_task(x: int, y: str) -> int: expected_dict = { "launcherReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "workerReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "slots": 1, } @@ -124,25 +138,29 @@ def my_mpi_task(x: int, y: str) -> int: expected_custom_dict = { "launcherReplicas": { - "replicas": 1, - "image": "launcher:latest", - "resources": { - "requests": [{"name": "CPU", "value": "1"}], - "limits": [{"name": "CPU", "value": "2"}], + "common": { + "replicas": 1, + "image": "launcher:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, }, }, "workerReplicas": { - "replicas": 5, - "image": "worker:latest", - "resources": { - "requests": [ - {"name": "CPU", "value": "2"}, - {"name": "MEMORY", "value": "2Gi"}, - ], - "limits": [ - {"name": "CPU", "value": "4"}, - {"name": "MEMORY", "value": "2Gi"}, - ], + "common": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, }, }, "slots": 2, @@ -185,19 +203,23 @@ def my_horovod_task(): ... # CleanPodPolicy.NONE is the default, so it should not be in the output dictionary expected_dict = { "launcherReplicas": { - "replicas": 1, - "resources": { - "requests": [ - {"name": "CPU", "value": "1"}, - ], - "limits": [ - {"name": "CPU", "value": "2"}, - ], + "common": { + "replicas": 1, + "resources": { + "requests": [ + {"name": "CPU", "value": "1"}, + ], + "limits": [ + {"name": "CPU", "value": "2"}, + ], + }, }, }, "workerReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, "command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], }, "slots": 2, diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 94b575e2a9d..4e272705ee6 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -8,6 +8,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict @@ -175,10 +176,12 @@ def _convert_replica_spec( if not isinstance(replica_config, Master): replicas = replica_config.replicas return pytorch_task.DistributedPyTorchTrainingReplicaSpec( - replicas=replicas, - image=replica_config.image, - resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + common=plugins_common.CommonReplicaSpec( + replicas=replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -193,7 +196,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: worker = self._convert_replica_spec(self.task_config.worker) # support v0 config for backwards compatibility if self.task_config.num_workers: - worker.replicas = self.task_config.num_workers + worker.common.replicas = self.task_config.num_workers run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None pytorch_job = pytorch_task.DistributedPyTorchTrainingTask( @@ -447,7 +450,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ) job = pytorch_task.DistributedPyTorchTrainingTask( worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( - replicas=self.max_nodes, + common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes), ), elastic_config=elastic_config, ) diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 0c15886c3a7..1bf9408e57b 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -33,8 +33,18 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.task_config is not None assert my_pytorch_task.get_custom(serialization_settings) == { - "workerReplicas": {"replicas": 10, "resources": {}}, - "masterReplicas": {"replicas": 1, "resources": {}}, + "workerReplicas": { + "common": { + "replicas": 10, + "resources": {}, + }, + }, + "masterReplicas": { + "common": { + "replicas": 1, + "resources": {}, + }, + }, } assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") @@ -64,12 +74,16 @@ def my_pytorch_task(x: int, y: str) -> int: expected_dict = { "masterReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "workerReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, } assert my_pytorch_task.get_custom(serialization_settings) == expected_dict @@ -114,24 +128,28 @@ def my_pytorch_task(x: int, y: str) -> int: expected_custom_dict = { "workerReplicas": { - "replicas": 5, - "image": "worker:latest", - "resources": { - "requests": [ - {"name": "CPU", "value": "2"}, - {"name": "MEMORY", "value": "2Gi"}, - ], - "limits": [ - {"name": "CPU", "value": "4"}, - {"name": "MEMORY", "value": "2Gi"}, - ], + "common": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, - "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, "masterReplicas": { - "resources": {}, - "replicas": 1, - "restartPolicy": "RESTART_POLICY_ALWAYS", + "common": { + "resources": {}, + "replicas": 1, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, }, "runPolicy": { "cleanPodPolicy": "CLEANPOD_POLICY_ALL", diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 62cd4824165..8bab5bb5ff5 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task from google.protobuf.json_format import MessageToDict @@ -174,10 +175,12 @@ def _convert_replica_spec( ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( - replicas=replica_config.replicas, - image=replica_config.image, - resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + common=plugins_common.CommonReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -191,19 +194,19 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: chief = self._convert_replica_spec(self.task_config.chief) if self.task_config.num_chief_replicas: - chief.replicas = self.task_config.num_chief_replicas + chief.common.replicas = self.task_config.num_chief_replicas worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: - worker.replicas = self.task_config.num_workers + worker.common.replicas = self.task_config.num_workers ps = self._convert_replica_spec(self.task_config.ps) if self.task_config.num_ps_replicas: - ps.replicas = self.task_config.num_ps_replicas + ps.common.replicas = self.task_config.num_ps_replicas evaluator = self._convert_replica_spec(self.task_config.evaluator) if self.task_config.num_evaluator_replicas: - evaluator.replicas = self.task_config.num_evaluator_replicas + evaluator.common.replicas = self.task_config.num_evaluator_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None training_task = tensorflow_task.DistributedTensorflowTrainingTask( diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index 0ae32439d7a..fe7fa22466f 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -44,17 +44,25 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_dict = { "chiefReplicas": { - "resources": {}, + "common": { + "resources": {}, + }, }, "workerReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "psReplicas": { - "resources": {}, + "common": { + "resources": {}, + }, }, "evaluatorReplicas": { - "resources": {}, + "common": { + "resources": {}, + }, }, } assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict @@ -106,47 +114,55 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_custom_dict = { "chiefReplicas": { - "replicas": 1, - "image": "chief:latest", - "resources": { - "requests": [{"name": "CPU", "value": "1"}], - "limits": [{"name": "CPU", "value": "2"}], + "common": { + "replicas": 1, + "image": "chief:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, }, }, "workerReplicas": { - "replicas": 5, - "image": "worker:latest", - "resources": { - "requests": [ - {"name": "CPU", "value": "2"}, - {"name": "MEMORY", "value": "2Gi"}, - ], - "limits": [ - {"name": "CPU", "value": "4"}, - {"name": "MEMORY", "value": "2Gi"}, - ], + "common": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, - "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, "psReplicas": { - "resources": {}, - "replicas": 2, - "restartPolicy": "RESTART_POLICY_ALWAYS", + "common": { + "resources": {}, + "replicas": 2, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, }, "evaluatorReplicas": { - "replicas": 5, - "image": "evaluator:latest", - "resources": { - "requests": [ - {"name": "CPU", "value": "2"}, - {"name": "MEMORY", "value": "2Gi"}, - ], - "limits": [ - {"name": "CPU", "value": "4"}, - {"name": "MEMORY", "value": "2Gi"}, - ], + "common": { + "replicas": 5, + "image": "evaluator:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, - "restartPolicy": "RESTART_POLICY_ON_FAILURE", }, } @@ -185,17 +201,25 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_dict = { "chiefReplicas": { - "resources": {}, + "common": { + "resources": {}, + }, }, "workerReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "psReplicas": { - "resources": {}, + "common": { + "resources": {}, + }, }, "evaluatorReplicas": { - "resources": {}, + "common": { + "resources": {}, + }, }, "runPolicy": { "cleanPodPolicy": "CLEANPOD_POLICY_RUNNING", @@ -233,20 +257,28 @@ def my_tensorflow_task(x: int, y: str) -> int: expected_dict = { "chiefReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "workerReplicas": { - "replicas": 10, - "resources": {}, + "common": { + "replicas": 10, + "resources": {}, + }, }, "psReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, "evaluatorReplicas": { - "replicas": 1, - "resources": {}, + "common": { + "replicas": 1, + "resources": {}, + }, }, }