diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py index 06e36af186..81517e2218 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/models.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -191,12 +191,14 @@ class RayJob(_common.FlyteIdlEntity): def __init__( self, ray_cluster: RayCluster, - runtime_env: typing.Optional[str], + runtime_env: typing.Optional[str] = None, + runtime_env_yaml: typing.Optional[str] = None, ttl_seconds_after_finished: typing.Optional[int] = None, shutdown_after_job_finishes: bool = False, ): self._ray_cluster = ray_cluster self._runtime_env = runtime_env + self._runtime_env_yaml = runtime_env_yaml self._ttl_seconds_after_finished = ttl_seconds_after_finished self._shutdown_after_job_finishes = shutdown_after_job_finishes @@ -208,6 +210,10 @@ def ray_cluster(self) -> RayCluster: def runtime_env(self) -> typing.Optional[str]: return self._runtime_env + @property + def runtime_env_yaml(self) -> typing.Optional[str]: + return self._runtime_env_yaml + @property def ttl_seconds_after_finished(self) -> typing.Optional[int]: # ttl_seconds_after_finished specifies the number of seconds after which the RayCluster will be deleted after the RayJob finishes. @@ -222,6 +228,7 @@ def to_flyte_idl(self) -> _ray_pb2.RayJob: return _ray_pb2.RayJob( ray_cluster=self.ray_cluster.to_flyte_idl(), runtime_env=self.runtime_env, + runtime_env_yaml=self.runtime_env_yaml, ttl_seconds_after_finished=self.ttl_seconds_after_finished, shutdown_after_job_finishes=self.shutdown_after_job_finishes, ) @@ -231,6 +238,7 @@ def from_flyte_idl(cls, proto: _ray_pb2.RayJob): return cls( ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None, runtime_env=proto.runtime_env, + runtime_env_yaml=proto.runtime_env_yaml, ttl_seconds_after_finished=proto.ttl_seconds_after_finished, shutdown_after_job_finishes=proto.shutdown_after_job_finishes, ) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index 76688d74cd..49ded26bfd 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, Optional +import yaml from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec from google.protobuf.json_format import MessageToDict @@ -63,6 +64,11 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: cfg = self._task_config + # Deprecated: runtime_env is removed KubeRay >= 1.1.0. It is replaced by runtime_env_yaml + runtime_env = base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode() if cfg.runtime_env else None + + runtime_env_yaml = yaml.dump(cfg.runtime_env) if cfg.runtime_env else None + ray_job = RayJob( ray_cluster=RayCluster( head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None, @@ -72,8 +78,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ], enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False, ), - # Use base64 to encode runtime_env dict and convert it to byte string - runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(), + runtime_env=runtime_env, + runtime_env_yaml=runtime_env_yaml, ttl_seconds_after_finished=cfg.ttl_seconds_after_finished, shutdown_after_job_finishes=cfg.shutdown_after_job_finishes, ) diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index 0c0ada1944..6fad11dd3e 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -2,6 +2,7 @@ import json import ray +import yaml from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig from google.protobuf.json_format import MessageToDict @@ -42,6 +43,7 @@ def t1(a: int) -> str: ray_job_pb = RayJob( ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)], enable_autoscaling=True), runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), + runtime_env_yaml=yaml.dump({"pip": ["numpy"]}), shutdown_after_job_finishes=True, ttl_seconds_after_finished=20, ).to_flyte_idl()