diff --git a/src/invoke_training/training/shared/checkpoint_tracker.py b/src/invoke_training/training/shared/checkpoint_tracker.py index 46ebe28b..e07280b8 100644 --- a/src/invoke_training/training/shared/checkpoint_tracker.py +++ b/src/invoke_training/training/shared/checkpoint_tracker.py @@ -54,6 +54,9 @@ def prune(self, buffer_num: int = 1) -> int: Returns: int: The number of checkpoints deleted. """ + if self._max_checkpoints is None: + return 0 + checkpoints = os.listdir(self._base_dir) checkpoints = [p for p in checkpoints if p.startswith(self._prefix)] checkpoints = sorted( diff --git a/tests/invoke_training/training/shared/test_checkpoint_tracker.py b/tests/invoke_training/training/shared/test_checkpoint_tracker.py index f54223af..9e0212b8 100644 --- a/tests/invoke_training/training/shared/test_checkpoint_tracker.py +++ b/tests/invoke_training/training/shared/test_checkpoint_tracker.py @@ -81,3 +81,23 @@ def test_checkpoint_tracker_prune_directories(): # Verify that the correct checkpoints were pruned. assert all([not os.path.exists(checkpoint_tracker.get_path(i)) for i in range(3)]) assert all([os.path.exists(checkpoint_tracker.get_path(i)) for i in range(3, 6)]) + + +def test_checkpoint_tracker_prune_no_max(): + """Test that CheckpointTracker.prune() is a no-op when max_checkpoints is None.""" + with tempfile.TemporaryDirectory() as dir_name: + checkpoint_tracker = CheckpointTracker( + base_dir=dir_name, prefix="prefix", extension=".ckpt", max_checkpoints=None + ) + # Create 6 checkpoints. + for i in range(6): + path = checkpoint_tracker.get_path(i) + with open(path, "w") as f: + f.write("hi") + + # Call prune, which should have no effect. + num_pruned = checkpoint_tracker.prune(2) + assert num_pruned == 0 + + # Verify that no checkpoints were deleted. + assert all([os.path.exists(checkpoint_tracker.get_path(i)) for i in range(6)])