From 3aa605e78a68595f3a3a6a25efa12ca907072bc5 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 26 Apr 2024 00:45:03 -0700 Subject: [PATCH] Move `get_x_checkpoint` functions to `utils/checkpoint.py` Reviewed By: JKSenthil Differential Revision: D56450720 --- .../callbacks/test_base_checkpointer.py | 2 +- .../callbacks/test_checkpoint_utils.py | 395 ----------------- tests/utils/test_checkpoint.py | 408 +++++++++++++++++- .../framework/callbacks/_checkpoint_utils.py | 254 +---------- .../framework/callbacks/base_checkpointer.py | 16 +- torchtnt/utils/__init__.py | 11 +- torchtnt/utils/checkpoint.py | 249 ++++++++++- 7 files changed, 675 insertions(+), 660 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index fb8fea71d7..105d6052ea 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -249,7 +249,7 @@ def test_restore_from_latest_empty_dir(self) -> None: self.assertEqual( log.output, [ - f"WARNING:torchtnt.framework.callbacks._checkpoint_utils:Input dirpath doesn't contain any subdirectories: {temp_dir}" + f"WARNING:torchtnt.utils.checkpoint:Input dirpath doesn't contain any subdirectories: {temp_dir}" ], ) self.assertFalse(restored) diff --git a/tests/framework/callbacks/test_checkpoint_utils.py b/tests/framework/callbacks/test_checkpoint_utils.py index ac1a019a98..f917fcd942 100644 --- a/tests/framework/callbacks/test_checkpoint_utils.py +++ b/tests/framework/callbacks/test_checkpoint_utils.py @@ -6,411 +6,16 @@ # pyre-strict -import os -import shutil -import tempfile import unittest -import torch -import torch.distributed as dist -from torch import nn -from torchsnapshot import Snapshot -from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state from torchtnt.framework.callbacks._checkpoint_utils import ( - _delete_checkpoint, - _metadata_exists, _prepare_app_state_for_checkpoint, - _retrieve_checkpoint_dirpaths, - _sort_by_metric_value, - _sort_by_recency, - get_best_checkpoint_path, - get_checkpoint_dirpaths, - get_latest_checkpoint_path, - rank_zero_read_and_broadcast, ) -from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process -from torchtnt.utils.env import init_from_env -from torchtnt.utils.fsspec import get_filesystem -from torchtnt.utils.test_utils import skip_if_not_distributed - -METADATA_FNAME: str = ".metadata" class CheckpointUtilsTest(unittest.TestCase): - @staticmethod - def _create_snapshot_metadata(output_dir: str) -> None: - path = os.path.join(output_dir, METADATA_FNAME) - with open(path, "w"): - pass - - def test_latest_checkpoint_path(self) -> None: - with tempfile.TemporaryDirectory() as temp_dir: - self.assertIsNone(get_latest_checkpoint_path(temp_dir)) - - with tempfile.TemporaryDirectory() as temp_dir: - latest_path = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(latest_path) - self.assertEqual( - get_latest_checkpoint_path(temp_dir), - latest_path, - ) - self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), - None, - ) - self._create_snapshot_metadata(latest_path) - self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), - latest_path, - ) - - with tempfile.TemporaryDirectory() as temp_dir: - path_1 = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(path_1) - self._create_snapshot_metadata(path_1) - path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002") - os.mkdir(path_2) - self._create_snapshot_metadata(path_2) - - # Missing metadata file - path_3 = os.path.join(temp_dir, "epoch_1_step_100") - os.mkdir(path_3) - - # Ill-formatted name - path_4 = os.path.join(temp_dir, "epoch_700") - os.mkdir(path_4) - self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2 - ) - - @skip_if_not_distributed - def test_latest_checkpoint_path_distributed(self) -> None: - spawn_multi_process( - 2, - "gloo", - self._latest_checkpoint_path_distributed, - ) - - @staticmethod - def _latest_checkpoint_path_distributed() -> None: - tc = unittest.TestCase() - is_rank0 = get_global_rank() == 0 - - if is_rank0: - temp_dir = tempfile.mkdtemp() - else: - temp_dir = "" - tc.assertIsNone(get_latest_checkpoint_path(temp_dir)) - if is_rank0: - shutil.rmtree(temp_dir) # delete temp directory - - if is_rank0: - temp_dir = tempfile.mkdtemp() - path_1 = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(path_1) - CheckpointUtilsTest._create_snapshot_metadata(path_1) - path_2 = os.path.join(temp_dir, "epoch_0_step_100") - os.mkdir(path_2) - CheckpointUtilsTest._create_snapshot_metadata(path_2) - - # Missing metadata file - path_3 = os.path.join(temp_dir, "epoch_1_step_100") - os.mkdir(path_3) - - # Ill-formatted name - path_4 = os.path.join(temp_dir, "epoch_700") - os.mkdir(path_4) - else: - temp_dir = "" - path_2 = "" - - pg = PGWrapper(dist.group.WORLD) - path_container = [path_2] if is_rank0 else [None] - pg.broadcast_object_list(path_container, 0) - expected_path = path_container[0] - tc.assertIsNotNone(expected_path) - tc.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path - ) - - if is_rank0: - shutil.rmtree(temp_dir) # delete temp directory - - def test_best_checkpoint_path(self) -> None: - with tempfile.TemporaryDirectory() as temp_dir: - self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) - - # no checkpoint w/ metric value - path = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(path) - self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) - - with tempfile.TemporaryDirectory() as temp_dir: - best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01") - os.mkdir(best_path) - self.assertEqual( - get_best_checkpoint_path(temp_dir, "val_loss", "min"), - best_path, - ) - self.assertIsNone( - get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), - None, - ) - self._create_snapshot_metadata(best_path) - self.assertEqual( - get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), - best_path, - ) - - # handle negative values - best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01") - os.mkdir(best_path_2) - self.assertEqual( - get_best_checkpoint_path(temp_dir, "val_loss", "min"), - best_path_2, - ) - - # handle "max" mode correctly - best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1") - os.mkdir(best_path_3) - self.assertEqual( - get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), - best_path_3, - ) - - # handle different metric correctly - best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2") - os.mkdir(best_path_4) - self.assertEqual( - get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), - best_path_3, - ) - self.assertEqual( - get_best_checkpoint_path( - temp_dir, metric_name="train_loss", mode="max" - ), - best_path_4, - ) - - def test_retrieve_checkpoint_dirpaths(self) -> None: - """ - Tests retrieving checkpoint directories from a given root directory - """ - with tempfile.TemporaryDirectory() as temp_dir: - paths = [ - "epoch_0_step_10", - "epoch_1_step_10", - "epoch_2_step_10", - "epoch_0_step_5", - "epoch_0_step_6", - "epoch_0_step_3", - ] - for path in paths[:-1]: - os.mkdir(os.path.join(temp_dir, path)) - # make last path a file instead of a directory - with open(os.path.join(temp_dir, paths[-1]), "w"): - pass - - # compares set equality since order of returned dirpaths is not guaranteed - # in _retrieve_checkpoint_dirpaths - self.assertEqual( - set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), - {os.path.join(temp_dir, path) for path in paths[:-1]}, - ) - self.assertEqual( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), - [], - ) - - # check metadata file is correct filtered for - # by creating metadata for 3rd path in list - with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"): - pass - - self.assertEqual( - set( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata") - ), - {os.path.join(temp_dir, paths[2])}, - ) - - def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: - """ - Tests retrieving checkpoint (w/ metrics) directories from a given root directory - """ - with tempfile.TemporaryDirectory() as temp_dir: - paths = [ - "epoch_0_step_10_val_loss=10", - "epoch_1_step_10_val_loss=5", - "epoch_2_step_10", - "epoch_0_step_5", - "epoch_0_step_6_train_loss=13", - ] - for path in paths: - os.mkdir(os.path.join(temp_dir, path)) - # make last path a file instead of a directory - with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"): - pass - - # compares set equality since order of returned dirpaths is not guaranteed - # in _retrieve_checkpoint_dirpaths - self.assertEqual( - set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), - {os.path.join(temp_dir, path) for path in paths}, - ) - self.assertEqual( - set( - _retrieve_checkpoint_dirpaths( - temp_dir, metadata_fname=None, metric_name="val_loss" - ) - ), - { - os.path.join(temp_dir, path) for path in paths[:2] - }, # since last path is a file - ) - self.assertEqual( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), - [], - ) - - # check metadata file is correct filtered for - # by creating metadata for 3rd path in list - with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"): - pass - - self.assertEqual( - set( - _retrieve_checkpoint_dirpaths( - temp_dir, metadata_fname=".metadata", metric_name="val_loss" - ) - ), - {os.path.join(temp_dir, paths[1])}, - ) - - @skip_if_not_distributed - def test_distributed_get_checkpoint_dirpaths(self) -> None: - spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths) - - @staticmethod - def _distributed_get_checkpoint_dirpaths() -> None: - """ - Tests that existing checkpoint directories are read and - properly registered on all ranks - """ - - @rank_zero_read_and_broadcast - def create_tmp_dir() -> str: - return tempfile.mkdtemp() - - init_from_env() - - temp_dir = create_tmp_dir() - try: - path1 = os.path.join(temp_dir, "epoch_0_step_10") - path2 = os.path.join(temp_dir, "epoch_1_step_20") - if get_global_rank() == 0: - os.mkdir(path1) - os.mkdir(path2) - torch.distributed.barrier() - - ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir) - tc = unittest.TestCase() - tc.assertEqual(set(ckpt_dirpaths), {path1, path2}) - - tc.assertEqual( - get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), [] - ) - finally: - if get_global_rank() == 0: - shutil.rmtree(temp_dir) # delete temp directory - - def test_get_checkpoint_dirpaths(self) -> None: - """ - Tests that `get_checkpoint_dirpaths` returns - the sorted checkpoint directories correctly - """ - with tempfile.TemporaryDirectory() as temp_dir: - path1 = os.path.join(temp_dir, "epoch_1_step_20") - path2 = os.path.join(temp_dir, "epoch_4_step_130") - path3 = os.path.join(temp_dir, "epoch_0_step_10") - os.mkdir(path1) - os.mkdir(path2) - os.mkdir(path3) - - self.assertEqual( - set(get_checkpoint_dirpaths(temp_dir)), - {path1, path2, path3}, - ) - - with tempfile.TemporaryDirectory() as temp_dir: - path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01") - path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2") - path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12") - os.mkdir(path1) - os.mkdir(path2) - os.mkdir(path3) - - self.assertEqual( - set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")), - {path1, path2, path3}, - ) - - with tempfile.TemporaryDirectory() as temp_dir: - self.assertEqual( - get_checkpoint_dirpaths(temp_dir), - [], - ) - - def test_checkpoint_sorting_utils(self) -> None: - """ - Tests the sort utilities - """ - paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"] - self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]]) - - paths = [ - "epoch_1_step_20_val_loss=0.09", - "epoch_4_step_130_val_loss=29", - "epoch_0_step_10_val_loss=10", - ] - self.assertEqual( - _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]] - ) - self.assertEqual( - _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]] - ) - - def test_delete_checkpoint(self) -> None: - """ - Tests removing checkpoint directories - """ - app_state = {"module": nn.Linear(2, 2)} - with tempfile.TemporaryDirectory() as temp_dir: - dirpath = os.path.join(temp_dir, "checkpoint") - Snapshot.take(dirpath, app_state=app_state) - self.assertTrue(os.path.exists(dirpath)) - # check that error is thrown if .snapshot_metadata is not found in the directory when deleting - os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) - with self.assertRaisesRegex( - RuntimeError, f"{temp_dir} does not contain .snapshot_metadata" - ): - _delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME) - _delete_checkpoint(dirpath) - self.assertFalse(os.path.exists(dirpath)) - - def test_metadata_exists(self) -> None: - app_state = {"module": nn.Linear(2, 2)} - with tempfile.TemporaryDirectory() as temp_dir: - dirpath = os.path.join(temp_dir, "checkpoint") - Snapshot.take(dirpath, app_state=app_state) - - fs = get_filesystem(dirpath) - self.assertTrue(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) - - os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) - self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) def test_get_app_state(self) -> None: my_unit = DummyTrainUnit(input_dim=2) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 91778df115..2257e683c2 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -5,9 +5,40 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import os +import shutil +import tempfile import unittest -from torchtnt.utils.checkpoint import CheckpointPath, MetricData +import torch + +import torch.distributed as dist +from torch import nn +from torchsnapshot import Snapshot +from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME +from torchtnt.utils import get_global_rank, init_from_env + +from torchtnt.utils.checkpoint import ( + _delete_checkpoint, + _metadata_exists, + _retrieve_checkpoint_dirpaths, + _sort_by_metric_value, + _sort_by_recency, + CheckpointPath, + get_best_checkpoint_path, + get_checkpoint_dirpaths, + get_latest_checkpoint_path, + MetricData, +) +from torchtnt.utils.distributed import ( + PGWrapper, + rank_zero_read_and_broadcast, + spawn_multi_process, +) +from torchtnt.utils.fsspec import get_filesystem +from torchtnt.utils.test_utils import skip_if_not_distributed + +METADATA_FNAME: str = ".metadata" class CheckpointPathTest(unittest.TestCase): @@ -141,3 +172,378 @@ def test_compare_by_optimality(self) -> None: self.assertFalse(smaller.more_optimal_than(larger, mode="max")) self.assertTrue(smaller.more_optimal_than(larger, mode="min")) self.assertFalse(larger.more_optimal_than(smaller, mode="min")) + + +class CheckpointUtilsTest(unittest.TestCase): + @staticmethod + def _create_snapshot_metadata(output_dir: str) -> None: + path = os.path.join(output_dir, METADATA_FNAME) + with open(path, "w"): + pass + + def test_latest_checkpoint_path(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + self.assertIsNone(get_latest_checkpoint_path(temp_dir)) + + with tempfile.TemporaryDirectory() as temp_dir: + latest_path = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(latest_path) + self.assertEqual( + get_latest_checkpoint_path(temp_dir), + latest_path, + ) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), + None, + ) + self._create_snapshot_metadata(latest_path) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), + latest_path, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + path_1 = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(path_1) + self._create_snapshot_metadata(path_1) + path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002") + os.mkdir(path_2) + self._create_snapshot_metadata(path_2) + + # Missing metadata file + path_3 = os.path.join(temp_dir, "epoch_1_step_100") + os.mkdir(path_3) + + # Ill-formatted name + path_4 = os.path.join(temp_dir, "epoch_700") + os.mkdir(path_4) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2 + ) + + @skip_if_not_distributed + def test_latest_checkpoint_path_distributed(self) -> None: + spawn_multi_process( + 2, + "gloo", + self._latest_checkpoint_path_distributed, + ) + + @staticmethod + def _latest_checkpoint_path_distributed() -> None: + tc = unittest.TestCase() + is_rank0 = get_global_rank() == 0 + + if is_rank0: + temp_dir = tempfile.mkdtemp() + else: + temp_dir = "" + tc.assertIsNone(get_latest_checkpoint_path(temp_dir)) + if is_rank0: + shutil.rmtree(temp_dir) # delete temp directory + + if is_rank0: + temp_dir = tempfile.mkdtemp() + path_1 = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(path_1) + CheckpointUtilsTest._create_snapshot_metadata(path_1) + path_2 = os.path.join(temp_dir, "epoch_0_step_100") + os.mkdir(path_2) + CheckpointUtilsTest._create_snapshot_metadata(path_2) + + # Missing metadata file + path_3 = os.path.join(temp_dir, "epoch_1_step_100") + os.mkdir(path_3) + + # Ill-formatted name + path_4 = os.path.join(temp_dir, "epoch_700") + os.mkdir(path_4) + else: + temp_dir = "" + path_2 = "" + + pg = PGWrapper(dist.group.WORLD) + path_container = [path_2] if is_rank0 else [None] + pg.broadcast_object_list(path_container, 0) + expected_path = path_container[0] + tc.assertIsNotNone(expected_path) + tc.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path + ) + + if is_rank0: + shutil.rmtree(temp_dir) # delete temp directory + + def test_best_checkpoint_path(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) + + # no checkpoint w/ metric value + path = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(path) + self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) + + with tempfile.TemporaryDirectory() as temp_dir: + best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01") + os.mkdir(best_path) + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val_loss", "min"), + best_path, + ) + self.assertIsNone( + get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), + None, + ) + self._create_snapshot_metadata(best_path) + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), + best_path, + ) + + # handle negative values + best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01") + os.mkdir(best_path_2) + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val_loss", "min"), + best_path_2, + ) + + # handle "max" mode correctly + best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1") + os.mkdir(best_path_3) + self.assertEqual( + get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), + best_path_3, + ) + + # handle different metric correctly + best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2") + os.mkdir(best_path_4) + self.assertEqual( + get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), + best_path_3, + ) + self.assertEqual( + get_best_checkpoint_path( + temp_dir, metric_name="train_loss", mode="max" + ), + best_path_4, + ) + + def test_retrieve_checkpoint_dirpaths(self) -> None: + """ + Tests retrieving checkpoint directories from a given root directory + """ + with tempfile.TemporaryDirectory() as temp_dir: + paths = [ + "epoch_0_step_10", + "epoch_1_step_10", + "epoch_2_step_10", + "epoch_0_step_5", + "epoch_0_step_6", + "epoch_0_step_3", + ] + for path in paths[:-1]: + os.mkdir(os.path.join(temp_dir, path)) + # make last path a file instead of a directory + with open(os.path.join(temp_dir, paths[-1]), "w"): + pass + + # compares set equality since order of returned dirpaths is not guaranteed + # in _retrieve_checkpoint_dirpaths + self.assertEqual( + set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), + {os.path.join(temp_dir, path) for path in paths[:-1]}, + ) + self.assertEqual( + _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), + [], + ) + + # check metadata file is correct filtered for + # by creating metadata for 3rd path in list + with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"): + pass + + self.assertEqual( + set( + _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata") + ), + {os.path.join(temp_dir, paths[2])}, + ) + + def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: + """ + Tests retrieving checkpoint (w/ metrics) directories from a given root directory + """ + with tempfile.TemporaryDirectory() as temp_dir: + paths = [ + "epoch_0_step_10_val_loss=10", + "epoch_1_step_10_val_loss=5", + "epoch_2_step_10", + "epoch_0_step_5", + "epoch_0_step_6_train_loss=13", + ] + for path in paths: + os.mkdir(os.path.join(temp_dir, path)) + # make last path a file instead of a directory + with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"): + pass + + # compares set equality since order of returned dirpaths is not guaranteed + # in _retrieve_checkpoint_dirpaths + self.assertEqual( + set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), + {os.path.join(temp_dir, path) for path in paths}, + ) + self.assertEqual( + set( + _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=None, metric_name="val_loss" + ) + ), + { + os.path.join(temp_dir, path) for path in paths[:2] + }, # since last path is a file + ) + self.assertEqual( + _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), + [], + ) + + # check metadata file is correct filtered for + # by creating metadata for 3rd path in list + with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"): + pass + + self.assertEqual( + set( + _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=".metadata", metric_name="val_loss" + ) + ), + {os.path.join(temp_dir, paths[1])}, + ) + + @skip_if_not_distributed + def test_distributed_get_checkpoint_dirpaths(self) -> None: + spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths) + + @staticmethod + def _distributed_get_checkpoint_dirpaths() -> None: + """ + Tests that existing checkpoint directories are read and + properly registered on all ranks + """ + + @rank_zero_read_and_broadcast + def create_tmp_dir() -> str: + return tempfile.mkdtemp() + + init_from_env() + + temp_dir = create_tmp_dir() + try: + path1 = os.path.join(temp_dir, "epoch_0_step_10") + path2 = os.path.join(temp_dir, "epoch_1_step_20") + if get_global_rank() == 0: + os.mkdir(path1) + os.mkdir(path2) + torch.distributed.barrier() + + ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir) + tc = unittest.TestCase() + tc.assertEqual(set(ckpt_dirpaths), {path1, path2}) + + tc.assertEqual( + get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), [] + ) + finally: + if get_global_rank() == 0: + shutil.rmtree(temp_dir) # delete temp directory + + def test_get_checkpoint_dirpaths(self) -> None: + """ + Tests that `get_checkpoint_dirpaths` returns + the sorted checkpoint directories correctly + """ + with tempfile.TemporaryDirectory() as temp_dir: + path1 = os.path.join(temp_dir, "epoch_1_step_20") + path2 = os.path.join(temp_dir, "epoch_4_step_130") + path3 = os.path.join(temp_dir, "epoch_0_step_10") + os.mkdir(path1) + os.mkdir(path2) + os.mkdir(path3) + + self.assertEqual( + set(get_checkpoint_dirpaths(temp_dir)), + {path1, path2, path3}, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01") + path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2") + path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12") + os.mkdir(path1) + os.mkdir(path2) + os.mkdir(path3) + + self.assertEqual( + set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")), + {path1, path2, path3}, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + self.assertEqual( + get_checkpoint_dirpaths(temp_dir), + [], + ) + + def test_checkpoint_sorting_utils(self) -> None: + """ + Tests the sort utilities + """ + paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"] + self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]]) + + paths = [ + "epoch_1_step_20_val_loss=0.09", + "epoch_4_step_130_val_loss=29", + "epoch_0_step_10_val_loss=10", + ] + self.assertEqual( + _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]] + ) + self.assertEqual( + _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]] + ) + + def test_delete_checkpoint(self) -> None: + """ + Tests removing checkpoint directories + """ + app_state = {"module": nn.Linear(2, 2)} + with tempfile.TemporaryDirectory() as temp_dir: + dirpath = os.path.join(temp_dir, "checkpoint") + Snapshot.take(dirpath, app_state=app_state) + self.assertTrue(os.path.exists(dirpath)) + # check that error is thrown if .snapshot_metadata is not found in the directory when deleting + os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) + with self.assertRaisesRegex( + RuntimeError, f"{temp_dir} does not contain .snapshot_metadata" + ): + _delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME) + _delete_checkpoint(dirpath) + self.assertFalse(os.path.exists(dirpath)) + + def test_metadata_exists(self) -> None: + app_state = {"module": nn.Linear(2, 2)} + with tempfile.TemporaryDirectory() as temp_dir: + dirpath = os.path.join(temp_dir, "checkpoint") + Snapshot.take(dirpath, app_state=app_state) + + fs = get_filesystem(dirpath) + self.assertTrue(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) + + os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) + self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) diff --git a/torchtnt/framework/callbacks/_checkpoint_utils.py b/torchtnt/framework/callbacks/_checkpoint_utils.py index 087c15c15b..674eb18fe5 100644 --- a/torchtnt/framework/callbacks/_checkpoint_utils.py +++ b/torchtnt/framework/callbacks/_checkpoint_utils.py @@ -6,268 +6,16 @@ # pyre-strict -import logging -import os -import re -from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, TypeVar - -import fsspec +from typing import Any, Dict from pyre_extensions import none_throws -from torch import distributed as dist from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.state import State from torchtnt.framework.unit import AppStateMixin -from torchtnt.utils.distributed import rank_zero_read_and_broadcast -from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.stateful import Stateful -logger: logging.Logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -@rank_zero_read_and_broadcast -def get_latest_checkpoint_path( - dirpath: str, - metadata_fname: Optional[str] = None, - process_group: Optional[dist.ProcessGroup] = None, -) -> Optional[str]: - """ - Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory. - - Args: - dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - - Raises: - AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}. - """ - - return _latest_checkpoint_path(dirpath, metadata_fname) - - -def _latest_checkpoint_path( - dirpath: str, metadata_fname: Optional[str] -) -> Optional[str]: - candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) - - # Initialize variables to store the largest epoch and step numbers - largest_subdirectory = None - largest_epoch = -1 - largest_step = -1 - - # Iterate through all files and directories in the specified directory - for candidate in candidate_dirpaths: - # Extract the epoch and step numbers from the directory name - dirname = os.path.basename(candidate) - - # dirname will be of the format epoch_N_step_M - # where N is the epoch number and M is the step number as integers - split = dirname.split("_") - if len(split) < 4: - raise AssertionError( - f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})" - ) - - epoch_num, step_num = int(split[1]), int(split[3]) - # Check if the current epoch and step numbers are larger than the largest ones found so far - if epoch_num > largest_epoch: - largest_epoch = epoch_num - largest_step = step_num - largest_subdirectory = dirname - elif largest_epoch == epoch_num and step_num > largest_step: - largest_step = step_num - largest_subdirectory = dirname - - if largest_subdirectory is None: - return None - - # Rejoin with the parent directory path and return the largest subdirectory - return os.path.join(dirpath, none_throws(largest_subdirectory)) - - -@rank_zero_read_and_broadcast -def get_best_checkpoint_path( - dirpath: str, - metric_name: str, - mode: Literal["min", "max"], - metadata_fname: Optional[str] = None, - process_group: Optional[dist.ProcessGroup] = None, -) -> Optional[str]: - """ - Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory. - - Args: - dirpath: parent directory where checkpoints are saved. - metric_name: Name of the metric to use to find the best checkpoint. - mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - """ - - dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) - if len(dirpaths) == 0: - # no checkpoints found - return None - - best_checkpoint_path = None - best_metric_value = float("inf") if mode == "min" else float("-inf") - for dirpath in dirpaths: - dirname = os.path.basename(dirpath) - metric_value = float(dirname.split("=")[-1]) - - if mode == "min": - if metric_value < best_metric_value: - best_metric_value = metric_value - best_checkpoint_path = dirpath - else: - if metric_value > best_metric_value: - best_metric_value = metric_value - best_checkpoint_path = dirpath - - return best_checkpoint_path - - -@rank_zero_read_and_broadcast -def get_checkpoint_dirpaths( - dirpath: str, - metadata_fname: Optional[str] = None, - metric_name: Optional[str] = None, - process_group: Optional[dist.ProcessGroup] = None, -) -> List[str]: - """ - Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories. - The order of the checkpoints is not guarenteed. - - Args: - dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - metric_name: fetches all the checkpoint directories containing the metric name only. - process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - """ - - return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) - - -def _sort_by_recency(dirpaths: List[str]) -> List[str]: - """ - Sorts the given list of directories by oldest to newest. - - Args: - dirpaths: A list of directory paths. - - Returns: - A sorted list of directory paths, sorted by recency. - """ - - def sort_fn(path: str) -> Tuple[int, int]: - x = os.path.basename(path) - return (int(x.split("_")[1]), int(x.split("_")[3])) - - return sorted(dirpaths, key=sort_fn) - - -def _sort_by_metric_value( - dirpaths: List[str], mode: Literal["min", "max"] -) -> List[str]: - """ - Sorts the given list of directories by the metric values. - - Args: - dirpaths: A list of directory paths. - mode: Either 'min' or 'max'. If 'min', sorts in descending order. If 'max', sorts in ascending order - - Returns: - A sorted list of directory paths, sorted by the metric values. - """ - - def sort_metric_fn(path: str) -> float: - x = os.path.basename(path) - metric_val = float(x.split("=")[-1]) - return metric_val - - return sorted( - dirpaths, - key=sort_metric_fn, - # sort descending if min, placing worst metric at top of list - reverse=(mode == "min"), - ) - - -def _retrieve_checkpoint_dirpaths( - dirpath: str, - metadata_fname: Optional[str], - metric_name: Optional[str] = None, -) -> List[str]: - """ - Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories - - Args: - dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - metric_name: Name of the metric that must exist in checkpoint name. - """ - - if dirpath[-1] == "/": - # removes trailing forward slash if present - # required for regex search to work - dirpath = dirpath[:-1] - - fs = get_filesystem(dirpath) - - if not fs.exists(dirpath): - logger.warning(f"Input dirpath doesn't exist: {dirpath}") - return [] - - contents = fs.ls(dirpath, detail=True) - contents = [item["name"] for item in contents if item["type"] == "directory"] - if len(contents) == 0: - logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}") - return [] - - # Define the regex pattern to match the directory names - pattern = rf"^{dirpath}/epoch_\d+_step_\d+" - if metric_name: - # inject metric name in regex search - pattern += rf"_{metric_name}=" - snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern) - candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents)) - - if not metadata_fname: - # return early as we don't need to filter out any paths - return candidate_dirpaths - - # Iterate through all files and directories in the specified directory - # and check if metedata is present or not - valid_ckpt_dirpaths = [] - for candidate in candidate_dirpaths: - if not _metadata_exists(fs, candidate, metadata_fname): - logger.warning( - f"Snapshot metadata is missing from {candidate}! Skipping this path" - ) - continue - - valid_ckpt_dirpaths.append(candidate) - - return valid_ckpt_dirpaths - - -def _delete_checkpoint(dirpath: str, metadata_fname: Optional[str] = None) -> None: - fs = get_filesystem(dirpath) - if metadata_fname and not _metadata_exists(fs, dirpath, metadata_fname): - raise RuntimeError(f"{dirpath} does not contain {metadata_fname}") - fs.rm(dirpath, recursive=True) - - -def _metadata_exists( - fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str -) -> bool: - return fs.exists(os.path.join(dirpath, metadata_fname)) - # keys for use when checkpointing _TRAIN_PROGRESS_STATE_KEY = "train_progress" diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index f5f56fdf25..ded257c412 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -16,7 +16,14 @@ import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks._checkpoint_utils import ( +from torchtnt.framework.callbacks.checkpointer_types import ( + BestCheckpointConfig, + RestoreOptions, +) +from torchtnt.framework.state import EntryPoint, State +from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit +from torchtnt.framework.utils import get_timing_context +from torchtnt.utils.checkpoint import ( _delete_checkpoint, _metadata_exists, _sort_by_metric_value, @@ -25,13 +32,6 @@ get_checkpoint_dirpaths, get_latest_checkpoint_path, ) -from torchtnt.framework.callbacks.checkpointer_types import ( - BestCheckpointConfig, - RestoreOptions, -) -from torchtnt.framework.state import EntryPoint, State -from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit -from torchtnt.framework.utils import get_timing_context from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index c0ad4c3b8d..06cb6d33da 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -6,7 +6,13 @@ # pyre-strict -from .checkpoint import CheckpointPath, MetricData +from .checkpoint import ( + CheckpointPath, + get_best_checkpoint_path, + get_checkpoint_dirpaths, + get_latest_checkpoint_path, + MetricData, +) from .device import ( copy_data_to_device, CPUStats, @@ -151,4 +157,7 @@ "spawn_multi_process", "CheckpointPath", "MetricData", + "get_best_checkpoint_path", + "get_checkpoint_dirpaths", + "get_latest_checkpoint_path", ] diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 234464b6a7..d5ddc4b2f6 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -5,13 +5,20 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import logging import os import re from dataclasses import dataclass from functools import total_ordering -from typing import Literal, Optional, Pattern +from typing import List, Literal, Optional, Pattern, Tuple +import fsspec +import torch.distributed as dist +from fsspec.core import url_to_fs from pyre_extensions import none_throws +from torchtnt.utils.distributed import rank_zero_read_and_broadcast + +logger: logging.Logger = logging.getLogger(__name__) @dataclass @@ -176,3 +183,243 @@ def __eq__(self, other: "CheckpointPath") -> bool: def __gt__(self, other: "CheckpointPath") -> bool: return self.newer_than(other) + + +@rank_zero_read_and_broadcast +def get_latest_checkpoint_path( + dirpath: str, + metadata_fname: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Optional[str]: + """ + Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory. + + Args: + dirpath: parent directory where checkpoints are saved. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + + Raises: + AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}. + """ + + return _latest_checkpoint_path(dirpath, metadata_fname) + + +def _latest_checkpoint_path( + dirpath: str, metadata_fname: Optional[str] +) -> Optional[str]: + candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) + + # Initialize variables to store the largest epoch and step numbers + largest_subdirectory = None + largest_epoch = -1 + largest_step = -1 + + # Iterate through all files and directories in the specified directory + for candidate in candidate_dirpaths: + # Extract the epoch and step numbers from the directory name + dirname = os.path.basename(candidate) + + # dirname will be of the format epoch_N_step_M + # where N is the epoch number and M is the step number as integers + split = dirname.split("_") + if len(split) < 4: + raise AssertionError( + f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})" + ) + + epoch_num, step_num = int(split[1]), int(split[3]) + # Check if the current epoch and step numbers are larger than the largest ones found so far + if epoch_num > largest_epoch: + largest_epoch = epoch_num + largest_step = step_num + largest_subdirectory = dirname + elif largest_epoch == epoch_num and step_num > largest_step: + largest_step = step_num + largest_subdirectory = dirname + + if largest_subdirectory is None: + return None + + # Rejoin with the parent directory path and return the largest subdirectory + return os.path.join(dirpath, none_throws(largest_subdirectory)) + + +@rank_zero_read_and_broadcast +def get_best_checkpoint_path( + dirpath: str, + metric_name: str, + mode: Literal["min", "max"], + metadata_fname: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Optional[str]: + """ + Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory. + + Args: + dirpath: parent directory where checkpoints are saved. + metric_name: Name of the metric to use to find the best checkpoint. + mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + """ + + dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + if len(dirpaths) == 0: + # no checkpoints found + return None + + best_checkpoint_path = None + best_metric_value = float("inf") if mode == "min" else float("-inf") + for dirpath in dirpaths: + dirname = os.path.basename(dirpath) + metric_value = float(dirname.split("=")[-1]) + + if mode == "min": + if metric_value < best_metric_value: + best_metric_value = metric_value + best_checkpoint_path = dirpath + else: + if metric_value > best_metric_value: + best_metric_value = metric_value + best_checkpoint_path = dirpath + + return best_checkpoint_path + + +@rank_zero_read_and_broadcast +def get_checkpoint_dirpaths( + dirpath: str, + metadata_fname: Optional[str] = None, + metric_name: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> List[str]: + """ + Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories. + The order of the checkpoints is not guarenteed. + + Args: + dirpath: parent directory where checkpoints are saved. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + metric_name: fetches all the checkpoint directories containing the metric name only. + process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + """ + + return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + + +def _sort_by_recency(dirpaths: List[str]) -> List[str]: + """ + Sorts the given list of directories by oldest to newest. + + Args: + dirpaths: A list of directory paths. + + Returns: + A sorted list of directory paths, sorted by recency. + """ + + def sort_fn(path: str) -> Tuple[int, int]: + x = os.path.basename(path) + return (int(x.split("_")[1]), int(x.split("_")[3])) + + return sorted(dirpaths, key=sort_fn) + + +def _sort_by_metric_value( + dirpaths: List[str], mode: Literal["min", "max"] +) -> List[str]: + """ + Sorts the given list of directories by the metric values. + + Args: + dirpaths: A list of directory paths. + mode: Either 'min' or 'max'. If 'min', sorts in descending order. If 'max', sorts in ascending order + + Returns: + A sorted list of directory paths, sorted by the metric values. + """ + + def sort_metric_fn(path: str) -> float: + x = os.path.basename(path) + metric_val = float(x.split("=")[-1]) + return metric_val + + return sorted( + dirpaths, + key=sort_metric_fn, + # sort descending if min, placing worst metric at top of list + reverse=(mode == "min"), + ) + + +def _retrieve_checkpoint_dirpaths( + dirpath: str, + metadata_fname: Optional[str], + metric_name: Optional[str] = None, +) -> List[str]: + """ + Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories + + Args: + dirpath: parent directory where checkpoints are saved. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + metric_name: Name of the metric that must exist in checkpoint name. + """ + + if dirpath[-1] == "/": + # removes trailing forward slash if present + # required for regex search to work + dirpath = dirpath[:-1] + + fs, _ = url_to_fs(dirpath) + + if not fs.exists(dirpath): + logger.warning(f"Input dirpath doesn't exist: {dirpath}") + return [] + + contents = fs.ls(dirpath, detail=True) + contents = [item["name"] for item in contents if item["type"] == "directory"] + if len(contents) == 0: + logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}") + return [] + + # Define the regex pattern to match the directory names + pattern = rf"^{dirpath}/epoch_\d+_step_\d+" + if metric_name: + # inject metric name in regex search + pattern += rf"_{metric_name}=" + snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern) + candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents)) + + if not metadata_fname: + # return early as we don't need to filter out any paths + return candidate_dirpaths + + # Iterate through all files and directories in the specified directory + # and check if metedata is present or not + valid_ckpt_dirpaths = [] + for candidate in candidate_dirpaths: + if not _metadata_exists(fs, candidate, metadata_fname): + logger.warning( + f"Snapshot metadata is missing from {candidate}! Skipping this path" + ) + continue + + valid_ckpt_dirpaths.append(candidate) + + return valid_ckpt_dirpaths + + +def _delete_checkpoint(dirpath: str, metadata_fname: Optional[str] = None) -> None: + fs, _ = url_to_fs(dirpath) + if metadata_fname and not _metadata_exists(fs, dirpath, metadata_fname): + raise RuntimeError(f"{dirpath} does not contain {metadata_fname}") + fs.rm(dirpath, recursive=True) + + +def _metadata_exists( + fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str +) -> bool: + return fs.exists(os.path.join(dirpath, metadata_fname))