Skip to content

Commit

Permalink
Add CheckpointPath abstraction in utils/checkpoint.py
Browse files Browse the repository at this point in the history
Differential Revision: D56260188
  • Loading branch information
Diego Urgell authored and facebook-github-bot committed Apr 26, 2024
1 parent 0159a07 commit 3336b0a
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 0 deletions.
143 changes: 143 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import unittest

from torchtnt.utils.checkpoint import CheckpointPath, MetricData


class CheckpointPathTest(unittest.TestCase):
def test_from_str(self) -> None:
# invalid paths
malformed_paths = [
"foo/step_20",
"foo/epoch_50",
"epoch_30",
"foo/epoch_20_step",
"foo/epoch_20_step_30_val_loss=1a",
"foo/epoch_2_step_15_mean=hello",
"foo/epoch_2.6_step_23",
]
for path in malformed_paths:
with self.assertRaisesRegex(
ValueError, f"Attempted to parse malformed checkpoint path: {path}"
):
CheckpointPath.from_str(path)

# valid paths
valid_paths = [
("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)),
(
"foo/epoch_14_step_3_mean=15.0",
CheckpointPath(
"foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0)
),
),
(
"foo/epoch_14_step_3_loss=-27.35",
CheckpointPath(
"foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
),
),
(
"/foo/epoch_14_step_3_loss=-27.35",
CheckpointPath(
"/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
),
),
(
"foo/bar/epoch_23_step_31_mean_loss_squared=0.0",
CheckpointPath(
"foo/bar/",
epoch=23,
step=31,
metric_data=MetricData("mean_loss_squared", 0.0),
),
),
(
"oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98",
CheckpointPath(
"oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61",
epoch=2,
step=1,
metric_data=MetricData("acc", 0.98),
),
),
]
for path, expected_ckpt in valid_paths:
parsed_ckpt = CheckpointPath.from_str(path)
self.assertEqual(parsed_ckpt, expected_ckpt)
self.assertEqual(parsed_ckpt.path, path)

# with a trailing slash
ckpt = CheckpointPath.from_str("foo/epoch_0_step_1/")
self.assertEqual(ckpt, CheckpointPath("foo", epoch=0, step=1))
self.assertEqual(ckpt.path, "foo/epoch_0_step_1")

def test_compare_by_recency(self) -> None:
old = CheckpointPath("foo", epoch=0, step=1)
new = CheckpointPath("foo", epoch=1, step=1)
self.assertTrue(new.newer_than(old))
self.assertFalse(old.newer_than(new))
self.assertFalse(new == old)

old = CheckpointPath("foo", epoch=3, step=5)
new = CheckpointPath("foo", epoch=3, step=9)
self.assertTrue(new.newer_than(old))
self.assertFalse(old.newer_than(new))
self.assertFalse(new == old)

twin1 = CheckpointPath(
"foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0)
)
almost_twin = CheckpointPath(
"foo", epoch=2, step=5, metric_data=MetricData("bar", 2.0)
)

self.assertFalse(twin1.newer_than(almost_twin))
self.assertFalse(almost_twin.newer_than(twin1))
self.assertFalse(twin1 == almost_twin)

twin2 = CheckpointPath(
"foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0)
)
self.assertTrue(twin1 == twin2)

def test_compare_by_optimality(self) -> None:
# not both metric aware
ckpt1 = CheckpointPath("foo", epoch=0, step=1)
ckpt2 = CheckpointPath("foo", epoch=1, step=1)
ckpt3 = CheckpointPath(
"foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0)
)
for ckpt in [ckpt2, ckpt3]:
with self.assertRaisesRegex(
AssertionError,
"Attempted to compare optimality of non metric-aware checkpoints",
):
ckpt1.more_optimal_than(ckpt, mode="min")

# tracking different metrics
ckpt4 = CheckpointPath(
"foo", epoch=1, step=1, metric_data=MetricData("baz", 1.0)
)
with self.assertRaisesRegex(
AssertionError,
"Attempted to compare optimality of checkpoints tracking different metrics",
):
ckpt3.more_optimal_than(ckpt4, mode="min")

smaller = CheckpointPath(
"foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0)
)
larger = CheckpointPath(
"foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0)
)
self.assertTrue(larger.more_optimal_than(smaller, mode="max"))
self.assertFalse(smaller.more_optimal_than(larger, mode="max"))
self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
self.assertFalse(larger.more_optimal_than(smaller, mode="min"))
3 changes: 3 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from .checkpoint import CheckpointPath, MetricData
from .device import (
copy_data_to_device,
CPUStats,
Expand Down Expand Up @@ -148,4 +149,6 @@
"is_windows",
"get_pet_launch_config",
"spawn_multi_process",
"CheckpointPath",
"MetricData",
]
178 changes: 178 additions & 0 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import os
import re
from dataclasses import dataclass
from functools import total_ordering
from typing import Literal, Optional, Pattern

from pyre_extensions import none_throws


@dataclass
class MetricData:
"""
Representation of a metric instance. Should provide both a metric name and it's value.
"""

name: str
value: float


@total_ordering
class CheckpointPath:
"""
Representation of a checkpoint path. Handles parsing and serialization of the specific path format.
Currently, the basic compliant path format is: <dirpath>/epoch_<epoch>_step_<step>
If a metric is being tracked, it's added to the name: <dirpath>/epoch_<epoch>_step_<step>_<metric_name>=<metric_value>
This class is well-ordered by checkpoint recency, so any comparisons will operate using the epoch + step. Sorting by
metric can be done by extracting the metric value from the metric_data attribute.
"""

PATH_REGEX: Pattern = re.compile(
r"^(.+)epoch_(\d+)_step_(\d+)(?:_(.+)=(-?\d+\.?\d*))?\/?$"
)

def __init__(
self,
dirpath: str,
epoch: int,
step: int,
metric_data: Optional[MetricData] = None,
) -> None:
"""
Args:
dirpath: The base directory path that checkpoints are saved in.
epoch: The epoch number of this checkpoint.
step: The step number of this checkpoint.
metric_data: Optional data about the metric being tracked. Should contain both metric name and value.
"""
self.dirpath: str = dirpath.rstrip("/")
self.epoch = epoch
self.step = step
self.metric_data = metric_data

@classmethod
def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
"""
Given a directory path, try to parse it and extract the checkpoint data.
The expected format is: <dirpath>/epoch_<epoch>_step_<step>_<metric_name>=<metric_value>,
where the metric name and value are optional.
Args:
checkpoint_path: The path to the checkpoint directory.
Returns:
A CheckpointPath instance if the path is valid, otherwise None.
Raises:
ValueError: If the path is malformed and can't be parsed.
"""
path_match = cls.PATH_REGEX.match(checkpoint_path)
if not path_match:
raise ValueError(
f"Attempted to parse malformed checkpoint path: {checkpoint_path}."
)

dirpath, epoch, step, metric_name, metric_value = path_match.groups()
try:
metric_data: Optional[MetricData] = None
if metric_name:
metric_value_f = float(metric_value)
metric_data = MetricData(name=metric_name, value=metric_value_f)

return CheckpointPath(
dirpath=dirpath,
epoch=int(epoch),
step=int(step),
metric_data=metric_data,
)

except ValueError:
# Should never happen since path matches regex
raise ValueError(
f"Invalid data types found in checkpoint path: {checkpoint_path}."
)

@property
def path(self) -> str:
"""
Returns:
The full path to the checkpoint directory.
"""
name = f"epoch_{self.epoch}_step_{self.step}"
if self.metric_data:
name += f"_{self.metric_data.name}={self.metric_data.value}"

return os.path.join(self.dirpath, name)

def newer_than(self, other: "CheckpointPath") -> bool:
"""
Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other.
Returns:
True if this checkpoint is newer than the other, otherwise False.
"""
if self.epoch != other.epoch:
return self.epoch > other.epoch

return self.step > other.step

def more_optimal_than(
self, other: "CheckpointPath", mode: Literal["min", "max"]
) -> bool:
"""
Given another CheckpointPath instance, determine if this checkpoint is strictly more optimal than the other.
Optimality is determined by comparing the metric value of the two checkpoints. The mode indicates if the
metric value should be minimized or maximized. This only works for metric-aware checkpoints.
Args:
other: The other checkpoint path to compare against.
mode: The mode to use for comparison.
Returns:
True if this checkpoint is more optimal than the other, otherwise False.
Note: This expects that both checkpoints are metric-aware, and that they are tracking the same metric.
"""

assert (
self.metric_data and other.metric_data
), f"Attempted to compare optimality of non metric-aware checkpoints: {self} and {other}"

assert (
self.metric_data.name == other.metric_data.name
), f"Attempted to compare optimality of checkpoints tracking different metrics: {self} and {other}"

if mode == "min":
return (
none_throws(self.metric_data).value
< none_throws(other.metric_data).value
)

return (
none_throws(self.metric_data).value > none_throws(other.metric_data).value
)

def __str__(self) -> str:
return self.path

def __repr__(self) -> str:
return f"CheckpointPath(dirpath={self.dirpath}, epoch={self.epoch}, step={self.step}, metric_data={self.metric_data})"

def __eq__(self, other: "CheckpointPath") -> bool:
return (
self.dirpath == other.dirpath
and self.epoch == other.epoch
and self.step == other.step
and self.metric_data == other.metric_data
)

def __gt__(self, other: "CheckpointPath") -> bool:
return self.newer_than(other)

0 comments on commit 3336b0a

Please sign in to comment.