diff --git a/olmo/train.py b/olmo/train.py index b55c7c284..b491e4ecf 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -672,7 +672,9 @@ def restore_checkpoint( ) elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None: try: - legacy_mode = resource_path(load_path, f"rank{get_global_rank()}.pt").is_file() + legacy_mode = resource_path( + load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache + ).is_file() except FileNotFoundError: legacy_mode = False if legacy_mode: