Skip to content

Commit

Permalink
feat: Update files with repect to common ReplicaSpec refactor
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#4408
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed May 25, 2024
1 parent 70332db commit 7b5ea22
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
11 changes: 7 additions & 4 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -170,11 +171,13 @@ def _convert_replica_spec(
) -> mpi_task.DistributedMPITrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return mpi_task.DistributedMPITrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
),
command=replica_config.command,
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down
15 changes: 9 additions & 6 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -174,10 +175,12 @@ def _convert_replica_spec(
if not isinstance(replica_config, Master):
replicas = replica_config.replicas
return pytorch_task.DistributedPyTorchTrainingReplicaSpec(
replicas=replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
common=plugins_common.CommonReplicaSpec(
replicas=replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down Expand Up @@ -279,7 +282,7 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
task_type=task_type,
task_function=task_function,
# task_type_version controls the version of the task template, do not change
task_type_version=1,
task_type_version=2,
**kwargs,
)
try:
Expand Down Expand Up @@ -446,7 +449,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
)
job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec(
replicas=self.max_nodes,
common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes),
),
elastic_config=elastic_config,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -173,10 +174,12 @@ def _convert_replica_spec(
) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
common=plugins_common.CommonReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down

0 comments on commit 7b5ea22

Please sign in to comment.