Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for cleaning working dir in case of same uri #49313

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 100 additions & 103 deletions python/ray/_private/runtime_env/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,116 +723,113 @@ async def download_and_unpack_package(
local_dir = get_local_dir_from_uri(pkg_uri, base_directory)
assert local_dir != pkg_file, "Invalid pkg_file!"
if local_dir.exists():
assert local_dir.is_dir(), f"{local_dir} is not a directory"
else:
protocol, pkg_name = parse_uri(pkg_uri)
if protocol == Protocol.GCS:
if gcs_aio_client is None:
raise ValueError(
"GCS client must be provided to download from GCS."
)
shutil.rmtree(local_dir)

# Download package from the GCS.
code = await gcs_aio_client.internal_kv_get(
pkg_uri.encode(), namespace=None, timeout=None
protocol, pkg_name = parse_uri(pkg_uri)
if protocol == Protocol.GCS:
if gcs_aio_client is None:
raise ValueError("GCS client must be provided to download from GCS.")

# Download package from the GCS.
code = await gcs_aio_client.internal_kv_get(
pkg_uri.encode(), namespace=None, timeout=None
)
if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR):
code = None
if code is None:
raise IOError(
f"Failed to download runtime_env file package {pkg_uri} "
"from the GCS to the Ray worker node. The package may "
"have prematurely been deleted from the GCS due to a "
"long upload time or a problem with Ray. Try setting the "
"environment variable "
f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR} "
" to a value larger than the upload time in seconds "
"(the default is "
f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT}). "
"If this fails, try re-running "
"after making any change to a file in the file package."
)
if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR):
code = None
if code is None:
raise IOError(
f"Failed to download runtime_env file package {pkg_uri} "
"from the GCS to the Ray worker node. The package may "
"have prematurely been deleted from the GCS due to a "
"long upload time or a problem with Ray. Try setting the "
"environment variable "
f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR} "
" to a value larger than the upload time in seconds "
"(the default is "
f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT}). "
"If this fails, try re-running "
"after making any change to a file in the file package."
)
code = code or b""
pkg_file.write_bytes(code)

if is_zip_uri(pkg_uri):
unzip_package(
package_path=pkg_file,
target_dir=local_dir,
remove_top_level_directory=False,
unlink_zip=True,
logger=logger,
)
else:
return str(pkg_file)
elif protocol in Protocol.remote_protocols():
# Download package from remote URI
tp = None
install_warning = (
"Note that these must be preinstalled "
"on all nodes in the Ray cluster; it is not "
"sufficient to install them in the runtime_env."
code = code or b""
pkg_file.write_bytes(code)

if is_zip_uri(pkg_uri):
unzip_package(
package_path=pkg_file,
target_dir=local_dir,
remove_top_level_directory=False,
unlink_zip=True,
logger=logger,
)
else:
return str(pkg_file)
elif protocol in Protocol.remote_protocols():
# Download package from remote URI
tp = None
install_warning = (
"Note that these must be preinstalled "
"on all nodes in the Ray cluster; it is not "
"sufficient to install them in the runtime_env."
)

if protocol == Protocol.S3:
try:
import boto3
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` and "
"`pip install boto3` to fetch URIs in s3 "
"bucket. " + install_warning
)
tp = {"client": boto3.client("s3")}
elif protocol == Protocol.GS:
try:
from google.cloud import storage # noqa: F401
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` and "
"`pip install google-cloud-storage` "
"to fetch URIs in Google Cloud Storage bucket."
+ install_warning
)
elif protocol == Protocol.FILE:
pkg_uri = pkg_uri[len("file://") :]

def open_file(uri, mode, *, transport_params=None):
return open(uri, mode)

else:
try:
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` "
f"to fetch {protocol.value.upper()} URIs. "
+ install_warning
)

with open_file(pkg_uri, "rb", transport_params=tp) as package_zip:
with open_file(pkg_file, "wb") as fin:
fin.write(package_zip.read())

if pkg_file.suffix in [".zip", ".jar"]:
unzip_package(
package_path=pkg_file,
target_dir=local_dir,
remove_top_level_directory=True,
unlink_zip=True,
logger=logger,
if protocol == Protocol.S3:
try:
import boto3
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` and "
"`pip install boto3` to fetch URIs in s3 "
"bucket. " + install_warning
)
elif pkg_file.suffix == ".whl":
return str(pkg_file)
else:
raise NotImplementedError(
f"Package format {pkg_file.suffix} is ",
"not supported for remote protocols",
tp = {"client": boto3.client("s3")}
elif protocol == Protocol.GS:
try:
from google.cloud import storage # noqa: F401
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` and "
"`pip install google-cloud-storage` "
"to fetch URIs in Google Cloud Storage bucket."
+ install_warning
)
elif protocol == Protocol.FILE:
pkg_uri = pkg_uri[len("file://") :]

def open_file(uri, mode, *, transport_params=None):
return open(uri, mode)

else:
try:
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` "
f"to fetch {protocol.value.upper()} URIs. " + install_warning
)

with open_file(pkg_uri, "rb", transport_params=tp) as package_zip:
with open_file(pkg_file, "wb") as fin:
fin.write(package_zip.read())

if pkg_file.suffix in [".zip", ".jar"]:
unzip_package(
package_path=pkg_file,
target_dir=local_dir,
remove_top_level_directory=True,
unlink_zip=True,
logger=logger,
)
elif pkg_file.suffix == ".whl":
return str(pkg_file)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
raise NotImplementedError(
f"Package format {pkg_file.suffix} is ",
"not supported for remote protocols",
)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")

return str(local_dir)

Expand Down
7 changes: 5 additions & 2 deletions python/ray/_private/runtime_env/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ async def create_for_plugin_if_needed(

uris = plugin.get_uris(runtime_env)

logger.info(f"Setting up runtime env {plugin.name} with URIs {uris}.")
if not uris:
logger.debug(
f"No URIs for runtime env plugin {plugin.name}; "
Expand All @@ -252,13 +253,15 @@ async def create_for_plugin_if_needed(
if uri not in uri_cache:
logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await plugin.create(uri, runtime_env, context, logger=logger)
uri_cache.add(uri, size_bytes, logger=logger)
if plugin.name is None or plugin.name != "working_dir":
uri_cache.add(uri, size_bytes, logger=logger)
else:
logger.info(
f"Runtime env {plugin.name} {uri} is already installed "
"and will be reused. Search "
"all runtime_env_setup-*.log to find the corresponding setup log."
)
uri_cache.mark_used(uri, logger=logger)
if plugin.name is None or plugin.name != "working_dir":
uri_cache.mark_used(uri, logger=logger)

plugin.modify_context(uris, runtime_env, context, logger)
24 changes: 24 additions & 0 deletions python/ray/tests/test_runtime_env_working_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ def insert_test_dir_in_pythonpath():
yield


@pytest.mark.asyncio
async def test_working_dir_cleanup(tmpdir, ray_start_regular):
gcs_aio_client = gcs_utils.GcsAioClient(
address=ray.worker.global_worker.gcs_client.address
)

plugin = WorkingDirPlugin(tmpdir, gcs_aio_client)
size = await plugin.create(HTTPS_PACKAGE_URI, {}, RuntimeEnvContext())

files = os.listdir(f"{tmpdir}/working_dir_files")
file_metadata = os.stat(f"{tmpdir}/working_dir_files/{files[0]}")
creation_time = file_metadata.st_ctime

time.sleep(1)

size = await plugin.create(HTTPS_PACKAGE_URI, {}, RuntimeEnvContext())
files = os.listdir(f"{tmpdir}/working_dir_files")

file_metadata = os.stat(f"{tmpdir}/working_dir_files/{files[0]}")
creation_time_after = file_metadata.st_ctime

assert creation_time != creation_time_after


@pytest.mark.asyncio
async def test_create_delete_size_equal(tmpdir, ray_start_regular):
"""Tests that `create` and `delete_uri` return the same size for a URI."""
Expand Down