diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py index 75e8b06..40481ef 100644 --- a/tests/test_snapshot.py +++ b/tests/test_snapshot.py @@ -10,6 +10,7 @@ import copy from pathlib import Path from typing import Any, Dict, List +from unittest.mock import MagicMock import pytest @@ -226,3 +227,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: snapshot = Snapshot.take(app_state={"state": src}, path=str(tmp_path)) snapshot.restore(app_state={"state": dst}) assert check_state_dict_eq(src.state_dict(), dst.state_dict()) + + +@pytest.mark.usefixtures("toggle_batching") +def test_snapshot_metadata_error(tmp_path: Path) -> None: + mock_storage_plugin = MagicMock() + mock_event_loop = MagicMock() + mock_storage_plugin.sync_read.side_effect = Exception( + "Mock error reading from storage" + ) + with pytest.raises( + expected_exception=RuntimeError, + match=( + "Failed to read .snapshot_metadata. " + "Ensure path to snapshot is correct, " + "otherwise snapshot is likely incomplete or corrupted." + ), + ): + Snapshot._read_snapshot_metadata(mock_storage_plugin, mock_event_loop) diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index e675daf..3e37c85 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -840,7 +840,14 @@ def _read_snapshot_metadata( storage: StoragePlugin, event_loop: asyncio.AbstractEventLoop ) -> SnapshotMetadata: read_io = ReadIO(path=SNAPSHOT_METADATA_FNAME) - storage.sync_read(read_io=read_io, event_loop=event_loop) + try: + storage.sync_read(read_io=read_io, event_loop=event_loop) + except Exception as e: + raise RuntimeError( + f"Failed to read {SNAPSHOT_METADATA_FNAME}. " + "Ensure path to snapshot is correct, " + "otherwise snapshot is likely incomplete or corrupted." + ) from e yaml_str = read_io.buf.getvalue().decode("utf-8") return SnapshotMetadata.from_yaml(yaml_str)