Skip to content

Commit

Permalink
feat: Update files with repect to common ReplicaSpec refactor
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#4408
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jun 12, 2024
1 parent daeff3f commit e2d3a4b
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 123 deletions.
15 changes: 9 additions & 6 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -171,11 +172,13 @@ def _convert_replica_spec(
) -> mpi_task.DistributedMPITrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return mpi_task.DistributedMPITrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
),
command=replica_config.command,
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down Expand Up @@ -203,11 +206,11 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers
worker.common.replicas = self.task_config.num_workers

launcher = self._convert_replica_spec(self.task_config.launcher)
if self.task_config.num_launcher_replicas:
launcher.replicas = self.task_config.num_launcher_replicas
launcher.common.replicas = self.task_config.num_launcher_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
mpi_job = mpi_task.DistributedMPITrainingTask(
Expand Down
86 changes: 54 additions & 32 deletions plugins/flytekit-kf-mpi/tests/test_mpi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,18 @@ def my_mpi_task(x: int, y: str) -> int:
assert my_mpi_task.task_config is not None

assert my_mpi_task.get_custom(serialization_settings) == {
"launcherReplicas": {"replicas": 10, "resources": {}},
"workerReplicas": {"replicas": 10, "resources": {}},
"launcherReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
},
"workerReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
},
"slots": 1,
}
assert my_mpi_task.task_type == "mpi"
Expand Down Expand Up @@ -69,12 +79,16 @@ def my_mpi_task(x: int, y: str) -> int:

expected_dict = {
"launcherReplicas": {
"replicas": 1,
"resources": {},
"common": {
"replicas": 1,
"resources": {},
},
},
"workerReplicas": {
"replicas": 1,
"resources": {},
"common": {
"replicas": 1,
"resources": {},
},
},
"slots": 1,
}
Expand Down Expand Up @@ -124,25 +138,29 @@ def my_mpi_task(x: int, y: str) -> int:

expected_custom_dict = {
"launcherReplicas": {
"replicas": 1,
"image": "launcher:latest",
"resources": {
"requests": [{"name": "CPU", "value": "1"}],
"limits": [{"name": "CPU", "value": "2"}],
"common": {
"replicas": 1,
"image": "launcher:latest",
"resources": {
"requests": [{"name": "CPU", "value": "1"}],
"limits": [{"name": "CPU", "value": "2"}],
},
},
},
"workerReplicas": {
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
"common": {
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
},
},
"slots": 2,
Expand Down Expand Up @@ -185,19 +203,23 @@ def my_horovod_task(): ...
# CleanPodPolicy.NONE is the default, so it should not be in the output dictionary
expected_dict = {
"launcherReplicas": {
"replicas": 1,
"resources": {
"requests": [
{"name": "CPU", "value": "1"},
],
"limits": [
{"name": "CPU", "value": "2"},
],
"common": {
"replicas": 1,
"resources": {
"requests": [
{"name": "CPU", "value": "1"},
],
"limits": [
{"name": "CPU", "value": "2"},
],
},
},
},
"workerReplicas": {
"replicas": 1,
"resources": {},
"common": {
"replicas": 1,
"resources": {},
},
"command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"],
},
"slots": 2,
Expand Down
15 changes: 9 additions & 6 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -175,10 +176,12 @@ def _convert_replica_spec(
if not isinstance(replica_config, Master):
replicas = replica_config.replicas
return pytorch_task.DistributedPyTorchTrainingReplicaSpec(
replicas=replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
common=plugins_common.CommonReplicaSpec(
replicas=replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand All @@ -193,7 +196,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
# support v0 config for backwards compatibility
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers
worker.common.replicas = self.task_config.num_workers

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
pytorch_job = pytorch_task.DistributedPyTorchTrainingTask(
Expand Down Expand Up @@ -447,7 +450,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
)
job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec(
replicas=self.max_nodes,
common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes),
),
elastic_config=elastic_config,
)
Expand Down
60 changes: 39 additions & 21 deletions plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,18 @@ def my_pytorch_task(x: int, y: str) -> int:
assert my_pytorch_task.task_config is not None

assert my_pytorch_task.get_custom(serialization_settings) == {
"workerReplicas": {"replicas": 10, "resources": {}},
"masterReplicas": {"replicas": 1, "resources": {}},
"workerReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
},
"masterReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
},
}
assert my_pytorch_task.resources.limits == Resources()
assert my_pytorch_task.resources.requests == Resources(cpu="1")
Expand Down Expand Up @@ -64,12 +74,16 @@ def my_pytorch_task(x: int, y: str) -> int:

expected_dict = {
"masterReplicas": {
"replicas": 1,
"resources": {},
"common": {
"replicas": 1,
"resources": {},
},
},
"workerReplicas": {
"replicas": 1,
"resources": {},
"common": {
"replicas": 1,
"resources": {},
},
},
}
assert my_pytorch_task.get_custom(serialization_settings) == expected_dict
Expand Down Expand Up @@ -114,24 +128,28 @@ def my_pytorch_task(x: int, y: str) -> int:

expected_custom_dict = {
"workerReplicas": {
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
"common": {
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
"masterReplicas": {
"resources": {},
"replicas": 1,
"restartPolicy": "RESTART_POLICY_ALWAYS",
"common": {
"resources": {},
"replicas": 1,
"restartPolicy": "RESTART_POLICY_ALWAYS",
},
},
"runPolicy": {
"cleanPodPolicy": "CLEANPOD_POLICY_ALL",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -174,10 +175,12 @@ def _convert_replica_spec(
) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
common=plugins_common.CommonReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand All @@ -191,19 +194,19 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
chief = self._convert_replica_spec(self.task_config.chief)
if self.task_config.num_chief_replicas:
chief.replicas = self.task_config.num_chief_replicas
chief.common.replicas = self.task_config.num_chief_replicas

worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers
worker.common.replicas = self.task_config.num_workers

ps = self._convert_replica_spec(self.task_config.ps)
if self.task_config.num_ps_replicas:
ps.replicas = self.task_config.num_ps_replicas
ps.common.replicas = self.task_config.num_ps_replicas

evaluator = self._convert_replica_spec(self.task_config.evaluator)
if self.task_config.num_evaluator_replicas:
evaluator.replicas = self.task_config.num_evaluator_replicas
evaluator.common.replicas = self.task_config.num_evaluator_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
training_task = tensorflow_task.DistributedTensorflowTrainingTask(
Expand Down
Loading

0 comments on commit e2d3a4b

Please sign in to comment.