diff --git a/python/ray/_private/runtime_env/default_impl.py b/python/ray/_private/runtime_env/default_impl.py index 40fd5485be61..ad2dbc609513 100644 --- a/python/ray/_private/runtime_env/default_impl.py +++ b/python/ray/_private/runtime_env/default_impl.py @@ -3,3 +3,9 @@ def get_image_uri_plugin(ray_tmp_dir: str): return ImageURIPlugin(ray_tmp_dir) + + +def get_protocols_provider(): + from ray._private.runtime_env.protocol import ProtocolsProvider + + return ProtocolsProvider diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 73ada86c4fb0..83efeb8953e1 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -4,7 +4,6 @@ import logging import os import shutil -from enum import Enum from pathlib import Path from tempfile import TemporaryDirectory from typing import Callable, List, Optional, Tuple @@ -20,6 +19,7 @@ RAY_RUNTIME_ENV_IGNORE_GITIGNORE, ) from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger +from ray._private.runtime_env.protocol import Protocol from ray._private.thirdparty.pathspec import PathSpec from ray.experimental.internal_kv import ( _internal_kv_exists, @@ -73,33 +73,6 @@ async def __aexit__(self, exc_type, exc, tb): self.file.release() -class Protocol(Enum): - """A enum for supported storage backends.""" - - # For docstring - def __new__(cls, value, doc=None): - self = object.__new__(cls) - self._value_ = value - if doc is not None: - self.__doc__ = doc - return self - - GCS = "gcs", "For packages dynamically uploaded and managed by the GCS." - CONDA = "conda", "For conda environments installed locally on each node." - PIP = "pip", "For pip environments installed locally on each node." - UV = "uv", "For uv environments install locally on each node." - HTTPS = "https", "Remote https path, assumes everything packed in one zip file." - S3 = "s3", "Remote s3 path, assumes everything packed in one zip file." - GS = "gs", "Remote google storage path, assumes everything packed in one zip file." - FILE = "file", "File storage path, assumes everything packed in one zip file." - - @classmethod - def remote_protocols(cls): - # Returns a list of protocols that support remote storage - # These protocols should only be used with paths that end in ".zip" or ".whl" - return [cls.HTTPS, cls.S3, cls.GS, cls.FILE] - - def _xor_bytes(left: bytes, right: bytes) -> bytes: if left and right: return bytes(a ^ b for (a, b) in zip(left, right)) @@ -725,7 +698,7 @@ async def download_and_unpack_package( if local_dir.exists(): assert local_dir.is_dir(), f"{local_dir} is not a directory" else: - protocol, pkg_name = parse_uri(pkg_uri) + protocol, _ = parse_uri(pkg_uri) if protocol == Protocol.GCS: if gcs_aio_client is None: raise ValueError( @@ -766,55 +739,7 @@ async def download_and_unpack_package( else: return str(pkg_file) elif protocol in Protocol.remote_protocols(): - # Download package from remote URI - tp = None - install_warning = ( - "Note that these must be preinstalled " - "on all nodes in the Ray cluster; it is not " - "sufficient to install them in the runtime_env." - ) - - if protocol == Protocol.S3: - try: - import boto3 - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` and " - "`pip install boto3` to fetch URIs in s3 " - "bucket. " + install_warning - ) - tp = {"client": boto3.client("s3")} - elif protocol == Protocol.GS: - try: - from google.cloud import storage # noqa: F401 - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` and " - "`pip install google-cloud-storage` " - "to fetch URIs in Google Cloud Storage bucket." - + install_warning - ) - elif protocol == Protocol.FILE: - pkg_uri = pkg_uri[len("file://") :] - - def open_file(uri, mode, *, transport_params=None): - return open(uri, mode) - - else: - try: - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` " - f"to fetch {protocol.value.upper()} URIs. " - + install_warning - ) - - with open_file(pkg_uri, "rb", transport_params=tp) as package_zip: - with open_file(pkg_file, "wb") as fin: - fin.write(package_zip.read()) + protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file) if pkg_file.suffix in [".zip", ".jar"]: unzip_package( diff --git a/python/ray/_private/runtime_env/protocol.py b/python/ray/_private/runtime_env/protocol.py new file mode 100644 index 000000000000..7fc2beb92666 --- /dev/null +++ b/python/ray/_private/runtime_env/protocol.py @@ -0,0 +1,108 @@ +import enum +from ray._private.runtime_env.default_impl import get_protocols_provider + + +class ProtocolsProvider: + @classmethod + def get_protocols(cls): + return { + # For packages dynamically uploaded and managed by the GCS. + "gcs", + # For conda environments installed locally on each node. + "conda", + # For pip environments installed locally on each node. + "pip", + # For uv environments install locally on each node. + "uv", + # Remote https path, assumes everything packed in one zip file. + "https", + # Remote s3 path, assumes everything packed in one zip file. + "s3", + # Remote google storage path, assumes everything packed in one zip file. + "gs", + # File storage path, assumes everything packed in one zip file. + "file", + } + + @classmethod + def get_remote_protocols(cls): + return {"https", "s3", "gs", "file"} + + @classmethod + def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str): + """Download file from remote URI to dest file""" + assert protocol in cls.get_remote_protocols() + + tp = None + install_warning = ( + "Note that these must be preinstalled " + "on all nodes in the Ray cluster; it is not " + "sufficient to install them in the runtime_env." + ) + + if protocol == "file": + source_uri = source_uri[len("file://") :] + + def open_file(uri, mode, *, transport_params=None): + return open(uri, mode) + + elif protocol == "s3": + try: + import boto3 + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` and " + "`pip install boto3` to fetch URIs in s3 " + "bucket. " + install_warning + ) + tp = {"client": boto3.client("s3")} + elif protocol == "gs": + try: + from google.cloud import storage # noqa: F401 + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` and " + "`pip install google-cloud-storage` " + "to fetch URIs in Google Cloud Storage bucket." + install_warning + ) + else: + try: + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` " + f"to fetch {protocol.upper()} URIs. " + install_warning + ) + + with open_file(source_uri, "rb", transport_params=tp) as fin: + with open_file(dest_file, "wb") as fout: + fout.write(fin.read()) + + +_protocols_provider = get_protocols_provider() + +Protocol = enum.Enum( + "Protocol", + {protocol.upper(): protocol for protocol in _protocols_provider.get_protocols()}, +) + + +@classmethod +def _remote_protocols(cls): + # Returns a list of protocols that support remote storage + # These protocols should only be used with paths that end in ".zip" or ".whl" + return [ + cls[protocol.upper()] for protocol in _protocols_provider.get_remote_protocols() + ] + + +Protocol.remote_protocols = _remote_protocols + + +def _download_remote_uri(self, source_uri, dest_file): + return _protocols_provider.download_remote_uri(self.value, source_uri, dest_file) + + +Protocol.download_remote_uri = _download_remote_uri