Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update files with respect to common ReplicaSpec refactor #2424

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -171,7 +172,15 @@ def _convert_replica_spec(
) -> mpi_task.DistributedMPITrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return mpi_task.DistributedMPITrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
),
command=replica_config.command,
# The forllowing fields are deprecated. They are kept for backwards compatibility.
# The following fields are deprecated and will be removed in the future
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
Expand Down Expand Up @@ -203,10 +212,14 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.common.replicas = self.task_config.num_workers
# Deprecated. Only kept for backwards compatibility.
worker.replicas = self.task_config.num_workers

launcher = self._convert_replica_spec(self.task_config.launcher)
if self.task_config.num_launcher_replicas:
launcher.common.replicas = self.task_config.num_launcher_replicas
# Deprecated. Only kept for backwards compatibility.
launcher.replicas = self.task_config.num_launcher_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-mpi/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.6.1,<2.0.0"]
plugin_requires = ["flyteidl>1.12.2", "flytekit>=1.6.1"]

__version__ = "0.0.0+develop"

Expand Down
63 changes: 61 additions & 2 deletions plugins/flytekit-kf-mpi/tests/test_mpi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,22 @@ def my_mpi_task(x: int, y: str) -> int:
assert my_mpi_task.task_config is not None

assert my_mpi_task.get_custom(serialization_settings) == {
"launcherReplicas": {"replicas": 10, "resources": {}},
"workerReplicas": {"replicas": 10, "resources": {}},
"launcherReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
"replicas": 10,
"resources": {},
},
"workerReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
"replicas": 10,
"resources": {},
},
"slots": 1,
}
assert my_mpi_task.task_type == "mpi"
Expand Down Expand Up @@ -69,10 +83,18 @@ def my_mpi_task(x: int, y: str) -> int:

expected_dict = {
"launcherReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
"workerReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
Expand Down Expand Up @@ -124,6 +146,14 @@ def my_mpi_task(x: int, y: str) -> int:

expected_custom_dict = {
"launcherReplicas": {
"common": {
"replicas": 1,
"image": "launcher:latest",
"resources": {
"requests": [{"name": "CPU", "value": "1"}],
"limits": [{"name": "CPU", "value": "2"}],
},
},
"replicas": 1,
"image": "launcher:latest",
"resources": {
Expand All @@ -132,6 +162,20 @@ def my_mpi_task(x: int, y: str) -> int:
},
},
"workerReplicas": {
"common": {
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
},
"replicas": 5,
"image": "worker:latest",
"resources": {
Expand Down Expand Up @@ -188,6 +232,17 @@ def my_horovod_task(): ...
# CleanPodPolicy.NONE is the default, so it should not be in the output dictionary
expected_dict = {
"launcherReplicas": {
"common": {
"replicas": 1,
"resources": {
"requests": [
{"name": "CPU", "value": "1"},
],
"limits": [
{"name": "CPU", "value": "2"},
],
},
},
"replicas": 1,
"resources": {
"requests": [
Expand All @@ -199,6 +254,10 @@ def my_horovod_task(): ...
},
},
"workerReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
"command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"],
Expand Down
12 changes: 12 additions & 0 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -203,6 +204,13 @@ def _convert_replica_spec(
if not isinstance(replica_config, Master):
replicas = replica_config.replicas
return pytorch_task.DistributedPyTorchTrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(
replicas=replicas,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
),
# The forllowing fields are deprecated. They are kept for backwards compatibility.
replicas=replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
Expand All @@ -213,6 +221,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
# support v0 config for backwards compatibility
if self.task_config.num_workers:
worker.common.replicas = self.task_config.num_workers
# Deprecated. Only kept for backwards compatibility.
worker.replicas = self.task_config.num_workers

run_policy = (
Expand Down Expand Up @@ -514,6 +524,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
)
job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes),
# The following fields are deprecated. They are kept for backwards compatibility.
replicas=self.max_nodes,
Comment on lines +527 to 529
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the replica arg stay duplicated here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? To ensure backward-compatibility, both replicas and common.replicas need to be sent to the backend.

Copy link
Member

@fg91 fg91 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the respective backend plugin we already distinguish between taskTemplate.TaskTypeVersion == 0/1, see here. We'll have to add backwards compatibility for the refactoring in this PR there as well, right?

I wonder whether removing the duplication in the proto definitions is worth having to check for backwards compatibility in flytekit and flyteplugins. While it might have been better to share the replica spec from the beginning, maybe it's now better to leave as is?

Happy to be convinced otherwise!! 🙏

@pingsutw fyi, let's maybe discuss in the contrib sync?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fg91 @pingsutw Have you finished the discussion in the contrib sync? What do you think in the end?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fg91 We have duplicated replica here to maintain backward compatibility. If people only upgrade flytekit and don't upgrade flyte backend, it should still work, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test it? If it's duplicated, it should but let's definitely test it.

My personal opinion is that maybe we should have used a shared replica in the first place but duplicating the replica now feels more cluttered than leaving separate replicas.

),
elastic_config=elastic_config,
Expand Down
3 changes: 2 additions & 1 deletion plugins/flytekit-kf-pytorch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1", "kubernetes"]

plugin_requires = ["cloudpickle", "flyteidl>1.12.2", "flytekit>=1.6.1", "kubernetes"]

__version__ = "0.0.0+develop"

Expand Down
46 changes: 44 additions & 2 deletions plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,22 @@ def my_pytorch_task(x: int, y: str) -> int:
assert my_pytorch_task.task_config is not None

assert my_pytorch_task.get_custom(serialization_settings) == {
"workerReplicas": {"replicas": 10, "resources": {}},
"masterReplicas": {"replicas": 1, "resources": {}},
"workerReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
"replicas": 10,
"resources": {},
},
"masterReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
}
assert my_pytorch_task.resources.limits == Resources()
assert my_pytorch_task.resources.requests == Resources(cpu="1")
Expand Down Expand Up @@ -64,10 +78,18 @@ def my_pytorch_task(x: int, y: str) -> int:

expected_dict = {
"masterReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
"workerReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
Expand Down Expand Up @@ -114,6 +136,21 @@ def my_pytorch_task(x: int, y: str) -> int:

expected_custom_dict = {
"workerReplicas": {
"common": {
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
"replicas": 5,
"image": "worker:latest",
"resources": {
Expand All @@ -129,6 +166,11 @@ def my_pytorch_task(x: int, y: str) -> int:
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
"masterReplicas": {
"common": {
"resources": {},
"replicas": 1,
"restartPolicy": "RESTART_POLICY_ALWAYS",
},
"resources": {},
"replicas": 1,
"restartPolicy": "RESTART_POLICY_ALWAYS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union

from flyteidl.plugins import common_pb2 as plugins_common
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -174,6 +175,13 @@ def _convert_replica_spec(
) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
),
# The following fields are deprecated. They are kept for backwards compatibility.
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
Expand All @@ -191,18 +199,26 @@ def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolic
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
chief = self._convert_replica_spec(self.task_config.chief)
if self.task_config.num_chief_replicas:
chief.common.replicas = self.task_config.num_chief_replicas
# Deprecated. Only kept for backwards compatibility.
chief.replicas = self.task_config.num_chief_replicas

worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.common.replicas = self.task_config.num_workers
# Deprecated. Only kept for backwards compatibility.
worker.replicas = self.task_config.num_workers

ps = self._convert_replica_spec(self.task_config.ps)
if self.task_config.num_ps_replicas:
ps.common.replicas = self.task_config.num_ps_replicas
# Deprecated. Only kept for backwards compatibility.
ps.replicas = self.task_config.num_ps_replicas

evaluator = self._convert_replica_spec(self.task_config.evaluator)
if self.task_config.num_evaluator_replicas:
evaluator.common.replicas = self.task_config.num_evaluator_replicas
# Deprecated. Only kept for backwards compatibility.
evaluator.replicas = self.task_config.num_evaluator_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-tensorflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flyteidl>=1.10.0", "flytekit>=1.6.1"]
plugin_requires = ["flyteidl>1.12.2", "flytekit>=1.6.1"]

__version__ = "0.0.0+develop"

Expand Down
Loading
Loading