From e958040feaa7797822fba0ab949d1f2a79456b32 Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Wed, 18 Dec 2024 11:00:01 -0800 Subject: [PATCH 1/3] Refactor ray._private.runtime_env.packaging.Protocol to make it extensible Signed-off-by: Jiajun Yao --- .../ray/_private/runtime_env/default_impl.py | 6 +++ python/ray/_private/runtime_env/packaging.py | 29 +---------- python/ray/_private/runtime_env/protocol.py | 49 +++++++++++++++++++ 3 files changed, 56 insertions(+), 28 deletions(-) create mode 100644 python/ray/_private/runtime_env/protocol.py diff --git a/python/ray/_private/runtime_env/default_impl.py b/python/ray/_private/runtime_env/default_impl.py index 40fd5485be61..ad2dbc609513 100644 --- a/python/ray/_private/runtime_env/default_impl.py +++ b/python/ray/_private/runtime_env/default_impl.py @@ -3,3 +3,9 @@ def get_image_uri_plugin(ray_tmp_dir: str): return ImageURIPlugin(ray_tmp_dir) + + +def get_protocols_provider(): + from ray._private.runtime_env.protocol import ProtocolsProvider + + return ProtocolsProvider diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 73ada86c4fb0..23292df5967d 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -4,7 +4,6 @@ import logging import os import shutil -from enum import Enum from pathlib import Path from tempfile import TemporaryDirectory from typing import Callable, List, Optional, Tuple @@ -20,6 +19,7 @@ RAY_RUNTIME_ENV_IGNORE_GITIGNORE, ) from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger +from ray._private.runtime_env.protocol import Protocol from ray._private.thirdparty.pathspec import PathSpec from ray.experimental.internal_kv import ( _internal_kv_exists, @@ -73,33 +73,6 @@ async def __aexit__(self, exc_type, exc, tb): self.file.release() -class Protocol(Enum): - """A enum for supported storage backends.""" - - # For docstring - def __new__(cls, value, doc=None): - self = object.__new__(cls) - self._value_ = value - if doc is not None: - self.__doc__ = doc - return self - - GCS = "gcs", "For packages dynamically uploaded and managed by the GCS." - CONDA = "conda", "For conda environments installed locally on each node." - PIP = "pip", "For pip environments installed locally on each node." - UV = "uv", "For uv environments install locally on each node." - HTTPS = "https", "Remote https path, assumes everything packed in one zip file." - S3 = "s3", "Remote s3 path, assumes everything packed in one zip file." - GS = "gs", "Remote google storage path, assumes everything packed in one zip file." - FILE = "file", "File storage path, assumes everything packed in one zip file." - - @classmethod - def remote_protocols(cls): - # Returns a list of protocols that support remote storage - # These protocols should only be used with paths that end in ".zip" or ".whl" - return [cls.HTTPS, cls.S3, cls.GS, cls.FILE] - - def _xor_bytes(left: bytes, right: bytes) -> bytes: if left and right: return bytes(a ^ b for (a, b) in zip(left, right)) diff --git a/python/ray/_private/runtime_env/protocol.py b/python/ray/_private/runtime_env/protocol.py new file mode 100644 index 000000000000..f3c44cd1beb1 --- /dev/null +++ b/python/ray/_private/runtime_env/protocol.py @@ -0,0 +1,49 @@ +import enum +from ray._private.runtime_env.default_impl import get_protocols_provider + + +class ProtocolsProvider: + @classmethod + def get_protocols(cls): + return { + # For packages dynamically uploaded and managed by the GCS. + "gcs", + # For conda environments installed locally on each node. + "conda", + # For pip environments installed locally on each node. + "pip", + # For uv environments install locally on each node. + "uv", + # Remote https path, assumes everything packed in one zip file. + "https", + # Remote s3 path, assumes everything packed in one zip file. + "s3", + # Remote google storage path, assumes everything packed in one zip file. + "gs", + # File storage path, assumes everything packed in one zip file. + "file", + } + + @classmethod + def get_remote_protocols(cls): + return {"https", "s3", "gs", "file"} + + +_protocols_provider = get_protocols_provider() + +Protocol = enum.Enum( + "Protocol", + {protocol.upper(): protocol for protocol in _protocols_provider.get_protocols()}, +) + + +@classmethod +def remote_protocols(cls): + # Returns a list of protocols that support remote storage + # These protocols should only be used with paths that end in ".zip" or ".whl" + return [ + cls[protocol.upper()] for protocol in _protocols_provider.get_remote_protocols() + ] + + +Protocol.remote_protocols = remote_protocols From 422592975e551ba6a71f37274662a3e175c7f43a Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Wed, 18 Dec 2024 16:01:26 -0800 Subject: [PATCH 2/3] up Signed-off-by: Jiajun Yao --- python/ray/_private/runtime_env/packaging.py | 40 ++------------- python/ray/_private/runtime_env/protocol.py | 51 +++++++++++++++++++- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 23292df5967d..51f978ea7c33 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -741,49 +741,15 @@ async def download_and_unpack_package( elif protocol in Protocol.remote_protocols(): # Download package from remote URI tp = None - install_warning = ( - "Note that these must be preinstalled " - "on all nodes in the Ray cluster; it is not " - "sufficient to install them in the runtime_env." - ) - - if protocol == Protocol.S3: - try: - import boto3 - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` and " - "`pip install boto3` to fetch URIs in s3 " - "bucket. " + install_warning - ) - tp = {"client": boto3.client("s3")} - elif protocol == Protocol.GS: - try: - from google.cloud import storage # noqa: F401 - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` and " - "`pip install google-cloud-storage` " - "to fetch URIs in Google Cloud Storage bucket." - + install_warning - ) - elif protocol == Protocol.FILE: + if protocol == Protocol.FILE: pkg_uri = pkg_uri[len("file://") :] def open_file(uri, mode, *, transport_params=None): return open(uri, mode) else: - try: - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` " - f"to fetch {protocol.value.upper()} URIs. " - + install_warning - ) + tp = protocol.get_smart_open_transport_params() + from smart_open import open as open_file with open_file(pkg_uri, "rb", transport_params=tp) as package_zip: with open_file(pkg_file, "wb") as fin: diff --git a/python/ray/_private/runtime_env/protocol.py b/python/ray/_private/runtime_env/protocol.py index f3c44cd1beb1..b9e255f487ee 100644 --- a/python/ray/_private/runtime_env/protocol.py +++ b/python/ray/_private/runtime_env/protocol.py @@ -28,6 +28,46 @@ def get_protocols(cls): def get_remote_protocols(cls): return {"https", "s3", "gs", "file"} + @classmethod + def get_smart_open_transport_params(cls, protocol): + install_warning = ( + "Note that these must be preinstalled " + "on all nodes in the Ray cluster; it is not " + "sufficient to install them in the runtime_env." + ) + + if protocol == "s3": + try: + import boto3 + from smart_open import open as open_file # noqa: F401 + except ImportError: + raise ImportError( + "You must `pip install smart_open` and " + "`pip install boto3` to fetch URIs in s3 " + "bucket. " + install_warning + ) + return {"client": boto3.client("s3")} + elif protocol == "gs": + try: + from google.cloud import storage # noqa: F401 + from smart_open import open as open_file # noqa: F401 + except ImportError: + raise ImportError( + "You must `pip install smart_open` and " + "`pip install google-cloud-storage` " + "to fetch URIs in Google Cloud Storage bucket." + install_warning + ) + return None + else: + try: + from smart_open import open as open_file # noqa: F401 + except ImportError: + raise ImportError( + "You must `pip install smart_open` " + f"to fetch {protocol.upper()} URIs. " + install_warning + ) + return None + _protocols_provider = get_protocols_provider() @@ -38,7 +78,7 @@ def get_remote_protocols(cls): @classmethod -def remote_protocols(cls): +def _remote_protocols(cls): # Returns a list of protocols that support remote storage # These protocols should only be used with paths that end in ".zip" or ".whl" return [ @@ -46,4 +86,11 @@ def remote_protocols(cls): ] -Protocol.remote_protocols = remote_protocols +Protocol.remote_protocols = _remote_protocols + + +def _get_smart_open_transport_params(self): + return _protocols_provider.get_smart_open_transport_params(self.value) + + +Protocol.get_smart_open_transport_params = _get_smart_open_transport_params From c28af3c69428ad222118366f5e24da5fb4c30ecc Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Wed, 18 Dec 2024 22:36:48 -0800 Subject: [PATCH 3/3] comments Signed-off-by: Jiajun Yao --- python/ray/_private/runtime_env/packaging.py | 18 ++--------- python/ray/_private/runtime_env/protocol.py | 34 +++++++++++++------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 51f978ea7c33..83efeb8953e1 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -698,7 +698,7 @@ async def download_and_unpack_package( if local_dir.exists(): assert local_dir.is_dir(), f"{local_dir} is not a directory" else: - protocol, pkg_name = parse_uri(pkg_uri) + protocol, _ = parse_uri(pkg_uri) if protocol == Protocol.GCS: if gcs_aio_client is None: raise ValueError( @@ -739,21 +739,7 @@ async def download_and_unpack_package( else: return str(pkg_file) elif protocol in Protocol.remote_protocols(): - # Download package from remote URI - tp = None - if protocol == Protocol.FILE: - pkg_uri = pkg_uri[len("file://") :] - - def open_file(uri, mode, *, transport_params=None): - return open(uri, mode) - - else: - tp = protocol.get_smart_open_transport_params() - from smart_open import open as open_file - - with open_file(pkg_uri, "rb", transport_params=tp) as package_zip: - with open_file(pkg_file, "wb") as fin: - fin.write(package_zip.read()) + protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file) if pkg_file.suffix in [".zip", ".jar"]: unzip_package( diff --git a/python/ray/_private/runtime_env/protocol.py b/python/ray/_private/runtime_env/protocol.py index b9e255f487ee..7fc2beb92666 100644 --- a/python/ray/_private/runtime_env/protocol.py +++ b/python/ray/_private/runtime_env/protocol.py @@ -29,44 +29,56 @@ def get_remote_protocols(cls): return {"https", "s3", "gs", "file"} @classmethod - def get_smart_open_transport_params(cls, protocol): + def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str): + """Download file from remote URI to dest file""" + assert protocol in cls.get_remote_protocols() + + tp = None install_warning = ( "Note that these must be preinstalled " "on all nodes in the Ray cluster; it is not " "sufficient to install them in the runtime_env." ) - if protocol == "s3": + if protocol == "file": + source_uri = source_uri[len("file://") :] + + def open_file(uri, mode, *, transport_params=None): + return open(uri, mode) + + elif protocol == "s3": try: import boto3 - from smart_open import open as open_file # noqa: F401 + from smart_open import open as open_file except ImportError: raise ImportError( "You must `pip install smart_open` and " "`pip install boto3` to fetch URIs in s3 " "bucket. " + install_warning ) - return {"client": boto3.client("s3")} + tp = {"client": boto3.client("s3")} elif protocol == "gs": try: from google.cloud import storage # noqa: F401 - from smart_open import open as open_file # noqa: F401 + from smart_open import open as open_file except ImportError: raise ImportError( "You must `pip install smart_open` and " "`pip install google-cloud-storage` " "to fetch URIs in Google Cloud Storage bucket." + install_warning ) - return None else: try: - from smart_open import open as open_file # noqa: F401 + from smart_open import open as open_file except ImportError: raise ImportError( "You must `pip install smart_open` " f"to fetch {protocol.upper()} URIs. " + install_warning ) - return None + + with open_file(source_uri, "rb", transport_params=tp) as fin: + with open_file(dest_file, "wb") as fout: + fout.write(fin.read()) _protocols_provider = get_protocols_provider() @@ -89,8 +101,8 @@ def _remote_protocols(cls): Protocol.remote_protocols = _remote_protocols -def _get_smart_open_transport_params(self): - return _protocols_provider.get_smart_open_transport_params(self.value) +def _download_remote_uri(self, source_uri, dest_file): + return _protocols_provider.download_remote_uri(self.value, source_uri, dest_file) -Protocol.get_smart_open_transport_params = _get_smart_open_transport_params +Protocol.download_remote_uri = _download_remote_uri