diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index e891cf01..bb6f6d94 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -98,3 +98,8 @@ restore_file: null # Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch). # Set it to 0 to disable checkpointing checkpoint_interval: 0 +# Wether to checkpoint when the experiment is done +checkpoint_at_end: False +# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of +# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints. +keep_checkpoints_num: 3 diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 58a0e818..6ace35b4 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -11,7 +11,7 @@ import os import time -from collections import OrderedDict +from collections import deque, OrderedDict from dataclasses import dataclass, MISSING from pathlib import Path from typing import Dict, List, Optional @@ -96,7 +96,9 @@ class ExperimentConfig: save_folder: Optional[str] = MISSING restore_file: Optional[str] = MISSING - checkpoint_interval: float = MISSING + checkpoint_interval: int = MISSING + checkpoint_at_end: bool = MISSING + keep_checkpoints_num: Optional[int] = MISSING def train_batch_size(self, on_policy: bool) -> int: """ @@ -280,6 +282,8 @@ def validate(self, on_policy: bool): f"checkpoint_interval ({self.checkpoint_interval}) " f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})" ) + if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0: + raise ValueError("keep_checkpoints_num must be greater than zero or null") if self.max_n_frames is None and self.max_n_iters is None: raise ValueError("n_iters and total_frames are both not set") @@ -483,6 +487,7 @@ def _setup_name(self): self.model_name = self.model_config.associated_class().__name__.lower() self.environment_name = self.task.env_name().lower() self.task_name = self.task.name.lower() + self._checkpointed_files = deque([]) if self.config.restore_file is not None and self.config.save_folder is not None: raise ValueError( @@ -668,6 +673,8 @@ def _collection_loop(self): pbar.update() sampling_start = time.time() + if self.config.checkpoint_at_end: + self._save_experiment() self.close() def close(self): @@ -835,10 +842,16 @@ def load_state_dict(self, state_dict: Dict) -> None: def _save_experiment(self) -> None: """Checkpoint trainer""" + if self.config.keep_checkpoints_num is not None: + while len(self._checkpointed_files) >= self.config.keep_checkpoints_num: + file_to_delete = self._checkpointed_files.popleft() + file_to_delete.unlink(missing_ok=False) + checkpoint_folder = self.folder_name / "checkpoints" checkpoint_folder.mkdir(parents=False, exist_ok=True) checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt" torch.save(self.state_dict(), checkpoint_file) + self._checkpointed_files.append(checkpoint_file) def _load_experiment(self) -> Experiment: """Load trainer from checkpoint"""