diff --git a/cacholote/clean.py b/cacholote/clean.py index 2e11bec..b812c4f 100644 --- a/cacholote/clean.py +++ b/cacholote/clean.py @@ -208,7 +208,7 @@ def delete_cache_files( elif method == "LFU": sorters.extend([database.CacheEntry.counter, database.CacheEntry.timestamp]) else: - raise ValueError("`method` must be 'LRU' or 'LFU'.") + raise ValueError(f"{method=} is invalid. Choose either 'LRU' or 'LFU'.") sorters.append(database.CacheEntry.expiration) # Clean database files @@ -277,10 +277,30 @@ def clean_cache_files( ) -def clean_invalid_cache_entries() -> None: - with config.get().sessionmaker() as session: - for cache_entry in session.scalars(sa.select(database.CacheEntry)): - try: - decode.loads(cache_entry._result_as_string) - except decode.DecodeError: +def clean_invalid_cache_entries(check_expiration=True, try_decode=True) -> None: + """Clean invalid cache entries. + + Parameters + ---------- + check_expiration: bool + Whether or not to delete expired entries + try_decode: bool + Whether or not to delete entries that raise DecodeError (this can be slow!) + """ + filters = [] + if check_expiration: + filters.append(database.CacheEntry.expiration <= utils.utcnow()) + if filters: + with config.get().sessionmaker() as session: + for cache_entry in session.scalars( + sa.select(database.CacheEntry).filter(*filters) + ): _delete_cache_entry(session, cache_entry) + + if try_decode: + with config.get().sessionmaker() as session: + for cache_entry in session.scalars(sa.select(database.CacheEntry)): + try: + decode.loads(cache_entry._result_as_string) + except decode.DecodeError: + _delete_cache_entry(session, cache_entry) diff --git a/tests/test_60_clean.py b/tests/test_60_clean.py index 7a0cc39..569e61d 100644 --- a/tests/test_60_clean.py +++ b/tests/test_60_clean.py @@ -1,5 +1,7 @@ import contextlib +import datetime import pathlib +import time from typing import Any, Literal, Optional, Sequence import fsspec @@ -197,25 +199,45 @@ def test_delete_cache_entry_and_files(tmpdir: pathlib.Path) -> None: assert len(fs.ls(dirname)) == 1 -def test_clean_invalid_cache_entries(tmpdir: pathlib.Path) -> None: +@pytest.mark.parametrize("check_expiration", [True, False]) +@pytest.mark.parametrize("try_decode", [True, False]) +def test_clean_invalid_cache_entries( + tmpdir: pathlib.Path, check_expiration: bool, try_decode: bool +) -> None: fs, dirname = utils.get_cache_files_fs_dirname() con = config.get().engine.raw_connection() cur = con.cursor() - # Create and cache 2 files - for i in range(2): - filepath = tmpdir / f"{i}.txt" - fsspec.filesystem("file").touch(filepath) - open_url(filepath) + # Valid cache file + fsspec.filesystem("file").touch(tmpdir / "valid.txt") + valid = open_url(tmpdir / "valid.txt").path + + # Corrupted cache file + fsspec.filesystem("file").touch(tmpdir / "corrupted.txt") + corrupted = open_url(tmpdir / "corrupted.txt").path + fs.touch(corrupted) - # Invalidate one file - valid, invalid = fs.ls(dirname) - fs.touch(invalid) + # Expired cache file + fsspec.filesystem("file").touch(tmpdir / "expired.txt") + dt = datetime.timedelta(seconds=0.1) + expiration = datetime.datetime.now(tz=datetime.timezone.utc) + dt + with config.set(expiration=expiration): + expired = open_url(tmpdir / "expired.txt").path + time.sleep(0.1) - clean.clean_invalid_cache_entries() + # Clean + clean.clean_invalid_cache_entries( + check_expiration=check_expiration, try_decode=try_decode + ) cur.execute("SELECT * FROM cache_entries", ()) - assert len(cur.fetchall()) == 1 - assert fs.ls(dirname) == [valid] + assert len(cur.fetchall()) == 3 - check_expiration - try_decode + assert valid in fs.ls(dirname) + assert ( + corrupted not in fs.ls(dirname) if try_decode else corrupted in fs.ls(dirname) + ) + assert ( + expired not in fs.ls(dirname) if check_expiration else expired in fs.ls(dirname) + ) def test_cleaner_logging(