From 7e25f87c184dc1f22d7ef7a2219bd73851229d82 Mon Sep 17 00:00:00 2001 From: ujjawal-khare Date: Wed, 18 Dec 2024 01:32:19 +0530 Subject: [PATCH] fix for cleaning working dir in case of same uri Signed-off-by: ujjawal-khare --- python/ray/_private/runtime_env/packaging.py | 203 +++++++++--------- python/ray/_private/runtime_env/plugin.py | 7 +- .../ray/tests/test_runtime_env_working_dir.py | 24 +++ 3 files changed, 129 insertions(+), 105 deletions(-) diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 73ada86c4fb0..675c7125d32e 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -723,116 +723,113 @@ async def download_and_unpack_package( local_dir = get_local_dir_from_uri(pkg_uri, base_directory) assert local_dir != pkg_file, "Invalid pkg_file!" if local_dir.exists(): - assert local_dir.is_dir(), f"{local_dir} is not a directory" - else: - protocol, pkg_name = parse_uri(pkg_uri) - if protocol == Protocol.GCS: - if gcs_aio_client is None: - raise ValueError( - "GCS client must be provided to download from GCS." - ) + shutil.rmtree(local_dir) - # Download package from the GCS. - code = await gcs_aio_client.internal_kv_get( - pkg_uri.encode(), namespace=None, timeout=None + protocol, pkg_name = parse_uri(pkg_uri) + if protocol == Protocol.GCS: + if gcs_aio_client is None: + raise ValueError("GCS client must be provided to download from GCS.") + + # Download package from the GCS. + code = await gcs_aio_client.internal_kv_get( + pkg_uri.encode(), namespace=None, timeout=None + ) + if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR): + code = None + if code is None: + raise IOError( + f"Failed to download runtime_env file package {pkg_uri} " + "from the GCS to the Ray worker node. The package may " + "have prematurely been deleted from the GCS due to a " + "long upload time or a problem with Ray. Try setting the " + "environment variable " + f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR} " + " to a value larger than the upload time in seconds " + "(the default is " + f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT}). " + "If this fails, try re-running " + "after making any change to a file in the file package." ) - if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR): - code = None - if code is None: - raise IOError( - f"Failed to download runtime_env file package {pkg_uri} " - "from the GCS to the Ray worker node. The package may " - "have prematurely been deleted from the GCS due to a " - "long upload time or a problem with Ray. Try setting the " - "environment variable " - f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR} " - " to a value larger than the upload time in seconds " - "(the default is " - f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT}). " - "If this fails, try re-running " - "after making any change to a file in the file package." - ) - code = code or b"" - pkg_file.write_bytes(code) - - if is_zip_uri(pkg_uri): - unzip_package( - package_path=pkg_file, - target_dir=local_dir, - remove_top_level_directory=False, - unlink_zip=True, - logger=logger, - ) - else: - return str(pkg_file) - elif protocol in Protocol.remote_protocols(): - # Download package from remote URI - tp = None - install_warning = ( - "Note that these must be preinstalled " - "on all nodes in the Ray cluster; it is not " - "sufficient to install them in the runtime_env." + code = code or b"" + pkg_file.write_bytes(code) + + if is_zip_uri(pkg_uri): + unzip_package( + package_path=pkg_file, + target_dir=local_dir, + remove_top_level_directory=False, + unlink_zip=True, + logger=logger, ) + else: + return str(pkg_file) + elif protocol in Protocol.remote_protocols(): + # Download package from remote URI + tp = None + install_warning = ( + "Note that these must be preinstalled " + "on all nodes in the Ray cluster; it is not " + "sufficient to install them in the runtime_env." + ) - if protocol == Protocol.S3: - try: - import boto3 - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` and " - "`pip install boto3` to fetch URIs in s3 " - "bucket. " + install_warning - ) - tp = {"client": boto3.client("s3")} - elif protocol == Protocol.GS: - try: - from google.cloud import storage # noqa: F401 - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` and " - "`pip install google-cloud-storage` " - "to fetch URIs in Google Cloud Storage bucket." - + install_warning - ) - elif protocol == Protocol.FILE: - pkg_uri = pkg_uri[len("file://") :] - - def open_file(uri, mode, *, transport_params=None): - return open(uri, mode) - - else: - try: - from smart_open import open as open_file - except ImportError: - raise ImportError( - "You must `pip install smart_open` " - f"to fetch {protocol.value.upper()} URIs. " - + install_warning - ) - - with open_file(pkg_uri, "rb", transport_params=tp) as package_zip: - with open_file(pkg_file, "wb") as fin: - fin.write(package_zip.read()) - - if pkg_file.suffix in [".zip", ".jar"]: - unzip_package( - package_path=pkg_file, - target_dir=local_dir, - remove_top_level_directory=True, - unlink_zip=True, - logger=logger, + if protocol == Protocol.S3: + try: + import boto3 + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` and " + "`pip install boto3` to fetch URIs in s3 " + "bucket. " + install_warning ) - elif pkg_file.suffix == ".whl": - return str(pkg_file) - else: - raise NotImplementedError( - f"Package format {pkg_file.suffix} is ", - "not supported for remote protocols", + tp = {"client": boto3.client("s3")} + elif protocol == Protocol.GS: + try: + from google.cloud import storage # noqa: F401 + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` and " + "`pip install google-cloud-storage` " + "to fetch URIs in Google Cloud Storage bucket." + + install_warning ) + elif protocol == Protocol.FILE: + pkg_uri = pkg_uri[len("file://") :] + + def open_file(uri, mode, *, transport_params=None): + return open(uri, mode) + + else: + try: + from smart_open import open as open_file + except ImportError: + raise ImportError( + "You must `pip install smart_open` " + f"to fetch {protocol.value.upper()} URIs. " + install_warning + ) + + with open_file(pkg_uri, "rb", transport_params=tp) as package_zip: + with open_file(pkg_file, "wb") as fin: + fin.write(package_zip.read()) + + if pkg_file.suffix in [".zip", ".jar"]: + unzip_package( + package_path=pkg_file, + target_dir=local_dir, + remove_top_level_directory=True, + unlink_zip=True, + logger=logger, + ) + elif pkg_file.suffix == ".whl": + return str(pkg_file) else: - raise NotImplementedError(f"Protocol {protocol} is not supported") + raise NotImplementedError( + f"Package format {pkg_file.suffix} is ", + "not supported for remote protocols", + ) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") return str(local_dir) diff --git a/python/ray/_private/runtime_env/plugin.py b/python/ray/_private/runtime_env/plugin.py index a1e03a507b59..b7e1a15db24c 100644 --- a/python/ray/_private/runtime_env/plugin.py +++ b/python/ray/_private/runtime_env/plugin.py @@ -241,6 +241,7 @@ async def create_for_plugin_if_needed( uris = plugin.get_uris(runtime_env) + logger.info(f"Setting up runtime env {plugin.name} with URIs {uris}.") if not uris: logger.debug( f"No URIs for runtime env plugin {plugin.name}; " @@ -252,13 +253,15 @@ async def create_for_plugin_if_needed( if uri not in uri_cache: logger.debug(f"Cache miss for URI {uri}.") size_bytes = await plugin.create(uri, runtime_env, context, logger=logger) - uri_cache.add(uri, size_bytes, logger=logger) + if plugin.name is None or plugin.name != "working_dir": + uri_cache.add(uri, size_bytes, logger=logger) else: logger.info( f"Runtime env {plugin.name} {uri} is already installed " "and will be reused. Search " "all runtime_env_setup-*.log to find the corresponding setup log." ) - uri_cache.mark_used(uri, logger=logger) + if plugin.name is None or plugin.name != "working_dir": + uri_cache.mark_used(uri, logger=logger) plugin.modify_context(uris, runtime_env, context, logger) diff --git a/python/ray/tests/test_runtime_env_working_dir.py b/python/ray/tests/test_runtime_env_working_dir.py index e667b0c712b1..6dd35970397c 100644 --- a/python/ray/tests/test_runtime_env_working_dir.py +++ b/python/ray/tests/test_runtime_env_working_dir.py @@ -45,6 +45,30 @@ def insert_test_dir_in_pythonpath(): yield +@pytest.mark.asyncio +async def test_working_dir_cleanup(tmpdir, ray_start_regular): + gcs_aio_client = gcs_utils.GcsAioClient( + address=ray.worker.global_worker.gcs_client.address + ) + + plugin = WorkingDirPlugin(tmpdir, gcs_aio_client) + size = await plugin.create(HTTPS_PACKAGE_URI, {}, RuntimeEnvContext()) + + files = os.listdir(f"{tmpdir}/working_dir_files") + file_metadata = os.stat(f"{tmpdir}/working_dir_files/{files[0]}") + creation_time = file_metadata.st_ctime + + time.sleep(1) + + size = await plugin.create(HTTPS_PACKAGE_URI, {}, RuntimeEnvContext()) + files = os.listdir(f"{tmpdir}/working_dir_files") + + file_metadata = os.stat(f"{tmpdir}/working_dir_files/{files[0]}") + creation_time_after = file_metadata.st_ctime + + assert creation_time != creation_time_after + + @pytest.mark.asyncio async def test_create_delete_size_equal(tmpdir, ray_start_regular): """Tests that `create` and `delete_uri` return the same size for a URI."""