Skip to content

Commit

Permalink
Adapt flytekit ray plugin to kuberay 1.1.0 (#2274)
Browse files Browse the repository at this point in the history
* runtime env yaml

Signed-off-by: Pin-Lun Hsu <[email protected]>

* fix test

Signed-off-by: Pin-Lun Hsu <[email protected]>

* fix fmt

Signed-off-by: Pin-Lun Hsu <[email protected]>

* Update setup.py

Signed-off-by: Pin-Lun Hsu <[email protected]>

---------

Signed-off-by: Pin-Lun Hsu <[email protected]>
Co-authored-by: Pin-Lun Hsu <[email protected]>
  • Loading branch information
ByronHsu and ByronHsu authored Mar 29, 2024
1 parent 6c917ed commit d074786
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
10 changes: 9 additions & 1 deletion plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ class RayJob(_common.FlyteIdlEntity):
def __init__(
self,
ray_cluster: RayCluster,
runtime_env: typing.Optional[str],
runtime_env: typing.Optional[str] = None,
runtime_env_yaml: typing.Optional[str] = None,
ttl_seconds_after_finished: typing.Optional[int] = None,
shutdown_after_job_finishes: bool = False,
):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env
self._runtime_env_yaml = runtime_env_yaml
self._ttl_seconds_after_finished = ttl_seconds_after_finished
self._shutdown_after_job_finishes = shutdown_after_job_finishes

Expand All @@ -208,6 +210,10 @@ def ray_cluster(self) -> RayCluster:
def runtime_env(self) -> typing.Optional[str]:
return self._runtime_env

@property
def runtime_env_yaml(self) -> typing.Optional[str]:
return self._runtime_env_yaml

@property
def ttl_seconds_after_finished(self) -> typing.Optional[int]:
# ttl_seconds_after_finished specifies the number of seconds after which the RayCluster will be deleted after the RayJob finishes.
Expand All @@ -222,6 +228,7 @@ def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
runtime_env_yaml=self.runtime_env_yaml,
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
shutdown_after_job_finishes=self.shutdown_after_job_finishes,
)
Expand All @@ -231,6 +238,7 @@ def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
return cls(
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None,
runtime_env=proto.runtime_env,
runtime_env_yaml=proto.runtime_env_yaml,
ttl_seconds_after_finished=proto.ttl_seconds_after_finished,
shutdown_after_job_finishes=proto.shutdown_after_job_finishes,
)
10 changes: 8 additions & 2 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import yaml
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from google.protobuf.json_format import MessageToDict

Expand Down Expand Up @@ -63,6 +64,11 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
cfg = self._task_config

# Deprecated: runtime_env is removed KubeRay >= 1.1.0. It is replaced by runtime_env_yaml
runtime_env = base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode() if cfg.runtime_env else None

runtime_env_yaml = yaml.dump(cfg.runtime_env) if cfg.runtime_env else None

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
Expand All @@ -72,8 +78,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
],
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
),
# Use base64 to encode runtime_env dict and convert it to byte string
runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(),
runtime_env=runtime_env,
runtime_env_yaml=runtime_env_yaml,
ttl_seconds_after_finished=cfg.ttl_seconds_after_finished,
shutdown_after_job_finishes=cfg.shutdown_after_job_finishes,
)
Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

import ray
import yaml
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -42,6 +43,7 @@ def t1(a: int) -> str:
ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)], enable_autoscaling=True),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
runtime_env_yaml=yaml.dump({"pip": ["numpy"]}),
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=20,
).to_flyte_idl()
Expand Down

0 comments on commit d074786

Please sign in to comment.