From 5b4ef9e6207323950f0bce2e588cf00ea6132149 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Thu, 25 Apr 2024 23:40:02 -0700 Subject: [PATCH] Add CheckpointPath abstraction in utils/checkpoint.py Differential Revision: D56260188 --- tests/utils/test_checkpoint.py | 143 ++++++++++++++++++++++++++ torchtnt/utils/__init__.py | 3 + torchtnt/utils/checkpoint.py | 178 +++++++++++++++++++++++++++++++++ 3 files changed, 324 insertions(+) create mode 100644 tests/utils/test_checkpoint.py create mode 100644 torchtnt/utils/checkpoint.py diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py new file mode 100644 index 0000000000..91778df115 --- /dev/null +++ b/tests/utils/test_checkpoint.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +from torchtnt.utils.checkpoint import CheckpointPath, MetricData + + +class CheckpointPathTest(unittest.TestCase): + def test_from_str(self) -> None: + # invalid paths + malformed_paths = [ + "foo/step_20", + "foo/epoch_50", + "epoch_30", + "foo/epoch_20_step", + "foo/epoch_20_step_30_val_loss=1a", + "foo/epoch_2_step_15_mean=hello", + "foo/epoch_2.6_step_23", + ] + for path in malformed_paths: + with self.assertRaisesRegex( + ValueError, f"Attempted to parse malformed checkpoint path: {path}" + ): + CheckpointPath.from_str(path) + + # valid paths + valid_paths = [ + ("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)), + ( + "foo/epoch_14_step_3_mean=15.0", + CheckpointPath( + "foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0) + ), + ), + ( + "foo/epoch_14_step_3_loss=-27.35", + CheckpointPath( + "foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) + ), + ), + ( + "/foo/epoch_14_step_3_loss=-27.35", + CheckpointPath( + "/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) + ), + ), + ( + "foo/bar/epoch_23_step_31_mean_loss_squared=0.0", + CheckpointPath( + "foo/bar/", + epoch=23, + step=31, + metric_data=MetricData("mean_loss_squared", 0.0), + ), + ), + ( + "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98", + CheckpointPath( + "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61", + epoch=2, + step=1, + metric_data=MetricData("acc", 0.98), + ), + ), + ] + for path, expected_ckpt in valid_paths: + parsed_ckpt = CheckpointPath.from_str(path) + self.assertEqual(parsed_ckpt, expected_ckpt) + self.assertEqual(parsed_ckpt.path, path) + + # with a trailing slash + ckpt = CheckpointPath.from_str("foo/epoch_0_step_1/") + self.assertEqual(ckpt, CheckpointPath("foo", epoch=0, step=1)) + self.assertEqual(ckpt.path, "foo/epoch_0_step_1") + + def test_compare_by_recency(self) -> None: + old = CheckpointPath("foo", epoch=0, step=1) + new = CheckpointPath("foo", epoch=1, step=1) + self.assertTrue(new.newer_than(old)) + self.assertFalse(old.newer_than(new)) + self.assertFalse(new == old) + + old = CheckpointPath("foo", epoch=3, step=5) + new = CheckpointPath("foo", epoch=3, step=9) + self.assertTrue(new.newer_than(old)) + self.assertFalse(old.newer_than(new)) + self.assertFalse(new == old) + + twin1 = CheckpointPath( + "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0) + ) + almost_twin = CheckpointPath( + "foo", epoch=2, step=5, metric_data=MetricData("bar", 2.0) + ) + + self.assertFalse(twin1.newer_than(almost_twin)) + self.assertFalse(almost_twin.newer_than(twin1)) + self.assertFalse(twin1 == almost_twin) + + twin2 = CheckpointPath( + "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0) + ) + self.assertTrue(twin1 == twin2) + + def test_compare_by_optimality(self) -> None: + # not both metric aware + ckpt1 = CheckpointPath("foo", epoch=0, step=1) + ckpt2 = CheckpointPath("foo", epoch=1, step=1) + ckpt3 = CheckpointPath( + "foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0) + ) + for ckpt in [ckpt2, ckpt3]: + with self.assertRaisesRegex( + AssertionError, + "Attempted to compare optimality of non metric-aware checkpoints", + ): + ckpt1.more_optimal_than(ckpt, mode="min") + + # tracking different metrics + ckpt4 = CheckpointPath( + "foo", epoch=1, step=1, metric_data=MetricData("baz", 1.0) + ) + with self.assertRaisesRegex( + AssertionError, + "Attempted to compare optimality of checkpoints tracking different metrics", + ): + ckpt3.more_optimal_than(ckpt4, mode="min") + + smaller = CheckpointPath( + "foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0) + ) + larger = CheckpointPath( + "foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0) + ) + self.assertTrue(larger.more_optimal_than(smaller, mode="max")) + self.assertFalse(smaller.more_optimal_than(larger, mode="max")) + self.assertTrue(smaller.more_optimal_than(larger, mode="min")) + self.assertFalse(larger.more_optimal_than(smaller, mode="min")) diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index cb973c13a6..c0ad4c3b8d 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from .checkpoint import CheckpointPath, MetricData from .device import ( copy_data_to_device, CPUStats, @@ -148,4 +149,6 @@ "is_windows", "get_pet_launch_config", "spawn_multi_process", + "CheckpointPath", + "MetricData", ] diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py new file mode 100644 index 0000000000..234464b6a7 --- /dev/null +++ b/torchtnt/utils/checkpoint.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import os +import re +from dataclasses import dataclass +from functools import total_ordering +from typing import Literal, Optional, Pattern + +from pyre_extensions import none_throws + + +@dataclass +class MetricData: + """ + Representation of a metric instance. Should provide both a metric name and it's value. + """ + + name: str + value: float + + +@total_ordering +class CheckpointPath: + """ + Representation of a checkpoint path. Handles parsing and serialization of the specific path format. + Currently, the basic compliant path format is: /epoch__step_ + If a metric is being tracked, it's added to the name: /epoch__step__= + + This class is well-ordered by checkpoint recency, so any comparisons will operate using the epoch + step. Sorting by + metric can be done by extracting the metric value from the metric_data attribute. + """ + + PATH_REGEX: Pattern = re.compile( + r"^(.+)epoch_(\d+)_step_(\d+)(?:_(.+)=(-?\d+\.?\d*))?\/?$" + ) + + def __init__( + self, + dirpath: str, + epoch: int, + step: int, + metric_data: Optional[MetricData] = None, + ) -> None: + """ + Args: + dirpath: The base directory path that checkpoints are saved in. + epoch: The epoch number of this checkpoint. + step: The step number of this checkpoint. + metric_data: Optional data about the metric being tracked. Should contain both metric name and value. + """ + self.dirpath: str = dirpath.rstrip("/") + self.epoch = epoch + self.step = step + self.metric_data = metric_data + + @classmethod + def from_str(cls, checkpoint_path: str) -> "CheckpointPath": + """ + Given a directory path, try to parse it and extract the checkpoint data. + The expected format is: /epoch__step__=, + where the metric name and value are optional. + + Args: + checkpoint_path: The path to the checkpoint directory. + + Returns: + A CheckpointPath instance if the path is valid, otherwise None. + + Raises: + ValueError: If the path is malformed and can't be parsed. + """ + path_match = cls.PATH_REGEX.match(checkpoint_path) + if not path_match: + raise ValueError( + f"Attempted to parse malformed checkpoint path: {checkpoint_path}." + ) + + dirpath, epoch, step, metric_name, metric_value = path_match.groups() + try: + metric_data: Optional[MetricData] = None + if metric_name: + metric_value_f = float(metric_value) + metric_data = MetricData(name=metric_name, value=metric_value_f) + + return CheckpointPath( + dirpath=dirpath, + epoch=int(epoch), + step=int(step), + metric_data=metric_data, + ) + + except ValueError: + # Should never happen since path matches regex + raise ValueError( + f"Invalid data types found in checkpoint path: {checkpoint_path}." + ) + + @property + def path(self) -> str: + """ + Returns: + The full path to the checkpoint directory. + """ + name = f"epoch_{self.epoch}_step_{self.step}" + if self.metric_data: + name += f"_{self.metric_data.name}={self.metric_data.value}" + + return os.path.join(self.dirpath, name) + + def newer_than(self, other: "CheckpointPath") -> bool: + """ + Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other. + + Returns: + True if this checkpoint is newer than the other, otherwise False. + """ + if self.epoch != other.epoch: + return self.epoch > other.epoch + + return self.step > other.step + + def more_optimal_than( + self, other: "CheckpointPath", mode: Literal["min", "max"] + ) -> bool: + """ + Given another CheckpointPath instance, determine if this checkpoint is strictly more optimal than the other. + Optimality is determined by comparing the metric value of the two checkpoints. The mode indicates if the + metric value should be minimized or maximized. This only works for metric-aware checkpoints. + + Args: + other: The other checkpoint path to compare against. + mode: The mode to use for comparison. + + Returns: + True if this checkpoint is more optimal than the other, otherwise False. + + Note: This expects that both checkpoints are metric-aware, and that they are tracking the same metric. + """ + + assert ( + self.metric_data and other.metric_data + ), f"Attempted to compare optimality of non metric-aware checkpoints: {self} and {other}" + + assert ( + self.metric_data.name == other.metric_data.name + ), f"Attempted to compare optimality of checkpoints tracking different metrics: {self} and {other}" + + if mode == "min": + return ( + none_throws(self.metric_data).value + < none_throws(other.metric_data).value + ) + + return ( + none_throws(self.metric_data).value > none_throws(other.metric_data).value + ) + + def __str__(self) -> str: + return self.path + + def __repr__(self) -> str: + return f"CheckpointPath(dirpath={self.dirpath}, epoch={self.epoch}, step={self.step}, metric_data={self.metric_data})" + + def __eq__(self, other: "CheckpointPath") -> bool: + return ( + self.dirpath == other.dirpath + and self.epoch == other.epoch + and self.step == other.step + and self.metric_data == other.metric_data + ) + + def __gt__(self, other: "CheckpointPath") -> bool: + return self.newer_than(other)