diff --git a/cacholote/config.py b/cacholote/config.py index f677803..22046eb 100644 --- a/cacholote/config.py +++ b/cacholote/config.py @@ -60,6 +60,7 @@ class Settings(pydantic.BaseSettings): logger: Union[ structlog.BoundLogger, structlog._config.BoundLoggerLazyProxy ] = _DEFAULT_LOGGER + lock_timeout: Optional[float] = None @pydantic.validator("create_engine_kwargs", allow_reuse=True) def validate_create_engine_kwargs( diff --git a/cacholote/extra_encoders.py b/cacholote/extra_encoders.py index a82cf36..8514843 100644 --- a/cacholote/extra_encoders.py +++ b/cacholote/extra_encoders.py @@ -216,7 +216,9 @@ def _maybe_store_xr_dataset( obj: "xr.Dataset", fs: fsspec.AbstractFileSystem, urlpath: str, filetype: str ) -> None: if filetype == "application/vnd+zarr": - with utils._Locker(fs, urlpath) as file_exists: + with utils.FileLock( + fs, urlpath, timeout=config.get().lock_timeout + ) as file_exists: if not file_exists: # Write directly on any filesystem mapper = fs.get_mapper(urlpath) @@ -285,7 +287,9 @@ def _maybe_store_file_object( ) -> None: if io_delete_original is None: io_delete_original = config.get().io_delete_original - with utils._Locker(fs_out, urlpath_out) as file_exists: + with utils.FileLock( + fs_out, urlpath_out, timeout=config.get().lock_timeout + ) as file_exists: if not file_exists: kwargs = {} content_type = _guess_type(fs_in, urlpath_in) @@ -321,7 +325,9 @@ def _maybe_store_io_object( fs_out: fsspec.AbstractFileSystem, urlpath_out: str, ) -> None: - with utils._Locker(fs_out, urlpath_out) as file_exists: + with utils.FileLock( + fs_out, urlpath_out, timeout=config.get().lock_timeout + ) as file_exists: if not file_exists: f_out = fs_out.open(urlpath_out, "wb") with _logging_timer("upload", urlpath=fs_out.unstrip_protocol(urlpath_out)): diff --git a/cacholote/utils.py b/cacholote/utils.py index a143894..34653a3 100644 --- a/cacholote/utils.py +++ b/cacholote/utils.py @@ -15,7 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License.import hashlib +import dataclasses import datetime +import functools import hashlib import io import time @@ -67,21 +69,15 @@ def copy_buffered_file( f_out.write(data if isinstance(data, bytes) else data.encode()) -class _Locker: - def __init__( - self, - fs: fsspec.AbstractFileSystem, - urlpath: str, - lock_validity_period: Optional[float] = None, - ) -> None: - self.fs = fs - self.urlpath = urlpath - self.lockfile = urlpath + ".lock" - self.lock_validity_period = lock_validity_period +@dataclasses.dataclass +class FileLock: + fs: fsspec.AbstractFileSystem # fsspec file system + urlpath: str # file to lock + timeout: Optional[float] # lock timeout in seconds - @property - def file_exists(self) -> bool: - return bool(self.fs.exists(self.urlpath)) + @functools.cached_property + def lockfile(self) -> str: + return self.urlpath + ".lock" def acquire(self) -> None: self.fs.touch(self.lockfile) @@ -92,32 +88,24 @@ def release(self) -> None: @property def is_locked(self) -> bool: - if not self.fs.exists(self.lockfile): - return False - - delta = datetime.datetime.now() - self.fs.modified(self.lockfile) - if self.lock_validity_period is None or delta < datetime.timedelta( - seconds=self.lock_validity_period - ): - return True - - return False + return bool(self.fs.exists(self.lockfile)) def wait_until_released(self) -> None: warned = False + message = f"{self.urlpath!r} is locked: {self.lockfile!r}" + start = time.perf_counter() while self.is_locked: + if self.timeout is not None and time.perf_counter() - start > self.timeout: + raise TimeoutError(message) if not warned: - warnings.warn( - f"can NOT proceed until file is released: {self.lockfile!r}.", - UserWarning, - ) + warnings.warn(message, UserWarning) warned = True - time.sleep(1) + time.sleep(min(1, self.timeout or 1)) def __enter__(self) -> bool: self.wait_until_released() self.acquire() - return self.file_exists + return bool(self.fs.exists(self.urlpath)) def __exit__( self, diff --git a/tests/test_50_io_encoder.py b/tests/test_50_io_encoder.py index 995898a..075f209 100644 --- a/tests/test_50_io_encoder.py +++ b/tests/test_50_io_encoder.py @@ -1,9 +1,10 @@ +import contextlib import hashlib import importlib import io import pathlib -import threading -from typing import Any, Dict, Tuple, Union +import subprocess +from typing import Any, Dict, Optional, Tuple, Union import fsspec import pytest @@ -146,30 +147,32 @@ def test_io_corrupted_files( assert fs.exists(f"{dirname}/{cached_basename}") -def test_io_locker_warning(tmpdir: pathlib.Path) -> None: +@pytest.mark.parametrize( + "lock_timeout, raises_or_warns", + ( + [None, pytest.warns(UserWarning, match="is locked")], + [0, pytest.raises(TimeoutError, match="is locked")], + ), +) +def test_io_locker( + tmpdir: pathlib.Path, + lock_timeout: Optional[float], + raises_or_warns: contextlib.nullcontext, # type: ignore[type-arg] +) -> None: + config.set(lock_timeout=lock_timeout, raise_all_encoding_errors=True) # Create tmpfile tmpfile = tmpdir / "test.txt" fsspec.filesystem("file").touch(tmpfile) # Acquire lock fs, dirname = utils.get_cache_files_fs_dirname() - file_hash = f"{fsspec.filesystem('file').checksum(tmpfile):x}" - lock = f"{dirname}/{file_hash}.txt.lock" - fs.touch(lock) - - def release_lock(fs: fsspec.AbstractFileSystem, lock: str) -> None: - fs.rm(lock) - - # Threading - t1 = threading.Timer(0, cached_open, args=(tmpfile,)) - t2 = threading.Timer(0.1, release_lock, args=(fs, lock)) - with pytest.warns( - UserWarning, match=f"can NOT proceed until file is released: {lock!r}." - ): - t1.start() - t2.start() - t1.join() - t2.join() + file_path = f"{dirname}/{fsspec.filesystem('file').checksum(tmpfile):x}.txt" + fs.touch(f"{file_path}.lock") + + process = subprocess.Popen(f"sleep 0.1; rm {file_path}.lock", shell=True) + with raises_or_warns: + cached_open(tmpfile) + assert process.wait() == 0 @pytest.mark.parametrize("set_cache", ["cads"], indirect=True) diff --git a/tests/test_60_clean.py b/tests/test_60_clean.py index 569e61d..6b9d0a3 100644 --- a/tests/test_60_clean.py +++ b/tests/test_60_clean.py @@ -204,9 +204,9 @@ def test_delete_cache_entry_and_files(tmpdir: pathlib.Path) -> None: def test_clean_invalid_cache_entries( tmpdir: pathlib.Path, check_expiration: bool, try_decode: bool ) -> None: - fs, dirname = utils.get_cache_files_fs_dirname() con = config.get().engine.raw_connection() cur = con.cursor() + fs, dirname = utils.get_cache_files_fs_dirname() # Valid cache file fsspec.filesystem("file").touch(tmpdir / "valid.txt") @@ -225,12 +225,13 @@ def test_clean_invalid_cache_entries( expired = open_url(tmpdir / "expired.txt").path time.sleep(0.1) - # Clean + # Clean and check clean.clean_invalid_cache_entries( check_expiration=check_expiration, try_decode=try_decode ) cur.execute("SELECT * FROM cache_entries", ()) - assert len(cur.fetchall()) == 3 - check_expiration - try_decode + nrows = len(cur.fetchall()) + assert nrows == 3 - check_expiration - try_decode assert valid in fs.ls(dirname) assert ( corrupted not in fs.ls(dirname) if try_decode else corrupted in fs.ls(dirname)