Skip to content

Commit

Permalink
Fix CheckpointTracker behaviour when max_checkpoints is None.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Aug 7, 2023
1 parent 9a6d977 commit b33d971
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/invoke_training/training/shared/checkpoint_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def prune(self, buffer_num: int = 1) -> int:
Returns:
int: The number of checkpoints deleted.
"""
if self._max_checkpoints is None:
return 0

checkpoints = os.listdir(self._base_dir)
checkpoints = [p for p in checkpoints if p.startswith(self._prefix)]
checkpoints = sorted(
Expand Down
20 changes: 20 additions & 0 deletions tests/invoke_training/training/shared/test_checkpoint_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,23 @@ def test_checkpoint_tracker_prune_directories():
# Verify that the correct checkpoints were pruned.
assert all([not os.path.exists(checkpoint_tracker.get_path(i)) for i in range(3)])
assert all([os.path.exists(checkpoint_tracker.get_path(i)) for i in range(3, 6)])


def test_checkpoint_tracker_prune_no_max():
"""Test that CheckpointTracker.prune() is a no-op when max_checkpoints is None."""
with tempfile.TemporaryDirectory() as dir_name:
checkpoint_tracker = CheckpointTracker(
base_dir=dir_name, prefix="prefix", extension=".ckpt", max_checkpoints=None
)
# Create 6 checkpoints.
for i in range(6):
path = checkpoint_tracker.get_path(i)
with open(path, "w") as f:
f.write("hi")

# Call prune, which should have no effect.
num_pruned = checkpoint_tracker.prune(2)
assert num_pruned == 0

# Verify that no checkpoints were deleted.
assert all([os.path.exists(checkpoint_tracker.get_path(i)) for i in range(6)])

0 comments on commit b33d971

Please sign in to comment.