Skip to content

Commit

Permalink
feat(replicaspec): Keeps old fields for backward-compatibility
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#4408
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jul 8, 2024
1 parent 9100245 commit a972abb
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 2 deletions.
10 changes: 10 additions & 0 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def _convert_replica_spec(
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
),
command=replica_config.command,
# The forllowing fields are deprecated. They are kept for backwards compatibility.
# The following fields are deprecated and will be removed in the future
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down Expand Up @@ -207,10 +213,14 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.common.replicas = self.task_config.num_workers
# Deprecated. Only kept for backwards compatibility.
worker.replicas = self.task_config.num_workers

launcher = self._convert_replica_spec(self.task_config.launcher)
if self.task_config.num_launcher_replicas:
launcher.common.replicas = self.task_config.num_launcher_replicas
# Deprecated. Only kept for backwards compatibility.
launcher.replicas = self.task_config.num_launcher_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
mpi_job = mpi_task.DistributedMPITrainingTask(
Expand Down
37 changes: 37 additions & 0 deletions plugins/flytekit-kf-mpi/tests/test_mpi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ def my_mpi_task(x: int, y: str) -> int:
"replicas": 10,
"resources": {},
},
"replicas": 10,
"resources": {},
},
"workerReplicas": {
"common": {
"replicas": 10,
"resources": {},
},
"replicas": 10,
"resources": {},
},
"slots": 1,
}
Expand Down Expand Up @@ -83,12 +87,16 @@ def my_mpi_task(x: int, y: str) -> int:
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
"workerReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
"slots": 1,
}
Expand Down Expand Up @@ -146,6 +154,12 @@ def my_mpi_task(x: int, y: str) -> int:
"limits": [{"name": "CPU", "value": "2"}],
},
},
"replicas": 1,
"image": "launcher:latest",
"resources": {
"requests": [{"name": "CPU", "value": "1"}],
"limits": [{"name": "CPU", "value": "2"}],
},
},
"workerReplicas": {
"common": {
Expand All @@ -162,6 +176,18 @@ def my_mpi_task(x: int, y: str) -> int:
],
},
},
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
},
"slots": 2,
"runPolicy": {"cleanPodPolicy": "CLEANPOD_POLICY_ALL"},
Expand Down Expand Up @@ -214,12 +240,23 @@ def my_horovod_task(): ...
],
},
},
"replicas": 1,
"resources": {
"requests": [
{"name": "CPU", "value": "1"},
],
"limits": [
{"name": "CPU", "value": "2"},
],
},
},
"workerReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
"command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"],
},
"slots": 2,
Expand Down
11 changes: 10 additions & 1 deletion plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,21 @@ def _convert_replica_spec(
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
),
# The forllowing fields are deprecated. They are kept for backwards compatibility.
replicas=replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
# support v0 config for backwards compatibility
if self.task_config.num_workers:
worker.common.replicas = self.task_config.num_workers
# Deprecated. Only kept for backwards compatibility.
worker.replicas = self.task_config.num_workers

run_policy = (
_convert_run_policy_to_flyte_idl(self.task_config.run_policy) if self.task_config.run_policy else None
Expand Down Expand Up @@ -459,6 +466,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec(
common=plugins_common.CommonReplicaSpec(replicas=self.max_nodes),
# The following fields are deprecated. They are kept for backwards compatibility.
replicas=self.max_nodes,
),
elastic_config=elastic_config,
run_policy=run_policy,
Expand Down
24 changes: 24 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ def my_pytorch_task(x: int, y: str) -> int:
"replicas": 10,
"resources": {},
},
"replicas": 10,
"resources": {},
},
"masterReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
}
assert my_pytorch_task.resources.limits == Resources()
Expand Down Expand Up @@ -78,12 +82,16 @@ def my_pytorch_task(x: int, y: str) -> int:
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
"workerReplicas": {
"common": {
"replicas": 1,
"resources": {},
},
"replicas": 1,
"resources": {},
},
}
assert my_pytorch_task.get_custom(serialization_settings) == expected_dict
Expand Down Expand Up @@ -143,13 +151,29 @@ def my_pytorch_task(x: int, y: str) -> int:
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
"replicas": 5,
"image": "worker:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
"masterReplicas": {
"common": {
"resources": {},
"replicas": 1,
"restartPolicy": "RESTART_POLICY_ALWAYS",
},
"resources": {},
"replicas": 1,
"restartPolicy": "RESTART_POLICY_ALWAYS",
},
"runPolicy": {
"cleanPodPolicy": "CLEANPOD_POLICY_ALL",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ def _convert_replica_spec(
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
),
# The following fields are deprecated. They are kept for backwards compatibility.
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand All @@ -195,18 +200,26 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
chief = self._convert_replica_spec(self.task_config.chief)
if self.task_config.num_chief_replicas:
chief.common.replicas = self.task_config.num_chief_replicas
# Deprecated. Only kept for backwards compatibility.
chief.replicas = self.task_config.num_chief_replicas

worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.common.replicas = self.task_config.num_workers
# Deprecated. Only kept for backwards compatibility.
worker.replicas = self.task_config.num_workers

ps = self._convert_replica_spec(self.task_config.ps)
if self.task_config.num_ps_replicas:
ps.common.replicas = self.task_config.num_ps_replicas
# Deprecated. Only kept for backwards compatibility.
ps.replicas = self.task_config.num_ps_replicas

evaluator = self._convert_replica_spec(self.task_config.evaluator)
if self.task_config.num_evaluator_replicas:
evaluator.common.replicas = self.task_config.num_evaluator_replicas
# Deprecated. Only kept for backwards compatibility.
evaluator.replicas = self.task_config.num_evaluator_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
training_task = tensorflow_task.DistributedTensorflowTrainingTask(
Expand Down
Loading

0 comments on commit a972abb

Please sign in to comment.