diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 665e195b4b3..bbab74e16af 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict @@ -170,11 +171,13 @@ def _convert_replica_spec( ) -> mpi_task.DistributedMPITrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return mpi_task.DistributedMPITrainingReplicaSpec( + common=plugins_common.CommonReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ), command=replica_config.command, - replicas=replica_config.replicas, - image=replica_config.image, - resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 46eb086ad0d..0758d13a686 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict @@ -174,10 +175,12 @@ def _convert_replica_spec( if not isinstance(replica_config, Master): replicas = replica_config.replicas return pytorch_task.DistributedPyTorchTrainingReplicaSpec( - replicas=replicas, - image=replica_config.image, - resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + common=plugins_common.CommonReplicaSpec( + replicas=replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -279,7 +282,7 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): task_type=task_type, task_function=task_function, # task_type_version controls the version of the task template, do not change - task_type_version=1, + task_type_version=2, **kwargs, ) try: @@ -446,7 +449,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ) job = pytorch_task.DistributedPyTorchTrainingTask( worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( - replicas=self.max_nodes, + common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes), ), elastic_config=elastic_config, ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 7be1f7d0301..6f8a1eb8e0d 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Union +from flyteidl.plugins import common_pb2 as plugins_common from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task from google.protobuf.json_format import MessageToDict @@ -173,10 +174,12 @@ def _convert_replica_spec( ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( - replicas=replica_config.replicas, - image=replica_config.image, - resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + common=plugins_common.CommonReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: