From 4ebf74f654f60bb247b0f65e5885d102d1b40d6a Mon Sep 17 00:00:00 2001 From: "Joshua A. Anderson" Date: Sun, 11 Feb 2024 20:31:39 -0500 Subject: [PATCH] Validate cached_statpoing when read from disk. This has the added benefit of validating all statepoints that are added to the cache. I needed to add a validate argument to update_cache because one of the unit tests relies on adding invalid statepoints to the cache. --- signac/project.py | 42 ++++++++++++++++++++++++++++-------------- tests/test_project.py | 2 +- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/signac/project.py b/signac/project.py index cd3da31d1..2b739bb81 100644 --- a/signac/project.py +++ b/signac/project.py @@ -856,13 +856,20 @@ def _register(self, id_, statepoint): """ self._sp_cache[id_] = statepoint - def _get_statepoint_from_workspace(self, job_id): + def _get_statepoint_from_workspace(self, job_id, validate=True): """Attempt to read the state point from the workspace. Parameters ---------- job_id : str Identifier of the job. + validate : bool + When True, validate that any statepoint read from disk matches the job_id. + + Raises + ------ + :class:`signac.errors.JobsCorruptedError` + When one or more jobs are identified as corrupted. """ # Performance-critical path. We can rely on the project workspace, job @@ -871,7 +878,11 @@ def _get_statepoint_from_workspace(self, job_id): fn_statepoint = os.sep.join((self.workspace, job_id, Job.FN_STATE_POINT)) try: with open(fn_statepoint, "rb") as statepoint_file: - return json.loads(statepoint_file.read().decode()) + statepoint = json.loads(statepoint_file.read().decode()) + if validate and calc_id(statepoint) != job_id: + raise JobsCorruptedError([job_id]) + + return statepoint except (OSError, ValueError) as error: if os.path.isdir(os.sep.join((self.workspace, job_id))): logger.error( @@ -882,7 +893,7 @@ def _get_statepoint_from_workspace(self, job_id): raise JobsCorruptedError([job_id]) raise KeyError(job_id) - def _get_statepoint(self, job_id): + def _get_statepoint(self, job_id, validate=True): """Get the state point associated with a job id. The state point is retrieved from the internal cache, from @@ -892,6 +903,9 @@ def _get_statepoint(self, job_id): ---------- job_id : str A job id to get the state point for. + validate : bool + When True, validate that any statepoint read from disk matches the job_id. + Returns ------- @@ -926,7 +940,7 @@ def _get_statepoint(self, job_id): "updating the cache by running `signac update-cache`." ) self._sp_cache_warned = True - statepoint = self._get_statepoint_from_workspace(job_id) + statepoint = self._get_statepoint_from_workspace(job_id, validate) # Update the project's state point cache from this cache miss self._sp_cache[job_id] = statepoint return statepoint @@ -1258,11 +1272,7 @@ def check(self): logger.info("Checking workspace for corruption...") for job_id in self._find_job_ids(): try: - statepoint = self._get_statepoint(job_id) - if calc_id(statepoint) != job_id: - corrupted.append(job_id) - else: - self.open_job(statepoint).init() + self._get_statepoint_from_workspace(job_id) except JobsCorruptedError as error: corrupted.extend(error.job_ids) if corrupted: @@ -1298,7 +1308,7 @@ def repair(self, job_ids=None): for job_id in job_ids: try: # First, check if we can look up the state point. - statepoint = self._get_statepoint(job_id) + statepoint = self._get_statepoint(job_id, validate=False) # Check if state point and id correspond. correct_id = calc_id(statepoint) if correct_id != job_id: @@ -1379,7 +1389,7 @@ def _build_index(self, include_job_document=False): raise yield job_id, doc - def _update_in_memory_cache(self): + def _update_in_memory_cache(self, validate=False): """Update the in-memory state point cache to reflect the workspace.""" logger.debug("Updating in-memory cache...") start = time.time() @@ -1392,7 +1402,7 @@ def _update_in_memory_cache(self): del self._sp_cache[id_] def _add(id_): - self._sp_cache[id_] = self._get_statepoint_from_workspace(id_) + self._sp_cache[id_] = self._get_statepoint_from_workspace(id_, validate) to_add_chunks = _split_and_print_progress( iterable=list(to_add), @@ -1419,7 +1429,7 @@ def _remove_persistent_cache_file(self): if error.errno != errno.ENOENT: raise error - def update_cache(self): + def update_cache(self, validate=True): """Update the persistent state point cache. This function updates a persistent state point cache, which @@ -1428,12 +1438,16 @@ def update_cache(self): to be significantly faster after calling this function, especially for large data spaces. + Parameters + ---------- + validate : bool + When True, validate that any statepoint read from disk matches the job_id. """ logger.info("Update cache...") start = time.time() cache = self._read_cache() cached_ids = set(self._sp_cache) - self._update_in_memory_cache() + self._update_in_memory_cache(validate) if cache is None or set(cache) != cached_ids: fn_cache = self.fn(self.FN_CACHE) fn_cache_tmp = fn_cache + "~" diff --git a/tests/test_project.py b/tests/test_project.py index da546ede0..19f13418e 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -2102,7 +2102,7 @@ class UpdateCacheAfterInitJob(signac.job.Job): def init(self, *args, **kwargs): job = super().init(*args, **kwargs) - self._project.update_cache() + self._project.update_cache(validate=False) return job