Skip to content

Commit

Permalink
[Conf] Config validation for experiment
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 11, 2023
1 parent ac6b4a0 commit 1fd9e7b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 44 deletions.
73 changes: 33 additions & 40 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,6 @@ def get_max_n_frames(self, on_policy: bool) -> int:
Args:
on_policy (bool): is the algorithms on_policy
"""

if self.max_n_frames is None and self.max_n_iters is None:
raise ValueError("n_iters and total_frames are both not set")
if self.max_n_frames is not None and self.max_n_iters is not None:
return min(
self.max_n_frames,
Expand Down Expand Up @@ -229,34 +226,6 @@ def get_exploration_anneal_frames(self, on_policy: bool):
else self.exploration_anneal_frames
)

def get_evaluation_interval(self, on_policy: bool):
"""
Get the interval in terms of collected frames for running evaluation
Args:
on_policy (bool): is the algorithms on_policy
"""
if self.evaluation_interval % self.collected_frames_per_batch(on_policy) != 0:
raise ValueError(
f"evaluation_interval ({self.evaluation_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
return self.evaluation_interval

def get_checkpoint_interval(self, on_policy: bool):
"""
Get the interval in terms of collected frames for checkpointing
Args:
on_policy (bool): is the algorithms on_policy
"""
if self.checkpoint_interval % self.collected_frames_per_batch(on_policy) != 0:
raise ValueError(
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
return self.checkpoint_interval

@staticmethod
def get_from_yaml(path: Optional[str] = None):
"""
Expand All @@ -280,6 +249,35 @@ def get_from_yaml(path: Optional[str] = None):
else:
return ExperimentConfig(**read_yaml_config(path))

def validate(self, on_policy: bool):
"""
Validates config.
Args:
on_policy (bool): is the algorithms on_policy
"""
if (
self.evaluation
and self.evaluation_interval % self.collected_frames_per_batch(on_policy)
!= 0
):
raise ValueError(
f"evaluation_interval ({self.evaluation_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if (
self.checkpoint_interval != 0
and self.checkpoint_interval % self.collected_frames_per_batch(on_policy)
!= 0
):
raise ValueError(
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if self.max_n_frames is None and self.max_n_iters is None:
raise ValueError("n_iters and total_frames are both not set")


class Experiment(CallbackNotifier):
"""
Expand Down Expand Up @@ -337,6 +335,7 @@ def on_policy(self) -> bool:
return self.algorithm_config.on_policy()

def _setup(self):
self.config.validate(self.on_policy)
self._set_action_type()
self._setup_task()
self._setup_algorithm()
Expand Down Expand Up @@ -609,11 +608,7 @@ def _collection_loop(self):
# Evaluation
if (
self.config.evaluation
and (
self.total_frames
% self.config.get_evaluation_interval(self.on_policy)
== 0
)
and (self.total_frames % self.config.evaluation_interval == 0)
and (len(self.config.loggers) or self.config.create_json)
):
self._evaluation_loop()
Expand All @@ -622,10 +617,8 @@ def _collection_loop(self):
self.n_iters_performed += 1
self.logger.commit()
if (
self.config.get_checkpoint_interval(self.on_policy) > 0
and self.total_frames
% self.config.get_checkpoint_interval(self.on_policy)
== 0
self.config.checkpoint_interval > 0
and self.total_frames % self.config.checkpoint_interval == 0
):
self._save_experiment()
sampling_start = time.time()
Expand Down
14 changes: 10 additions & 4 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def log_evaluation(
json_metrics["return"] = mean_group_return
if self.json_writer is not None:
self.json_writer.write(
metrics=json_metrics, total_frames=total_frames, step=step
metrics=json_metrics,
total_frames=total_frames,
evaluation_step=total_frames
// self.experiment_config.evaluation_interval,
)
self.log(to_log, step=step)
if video_frames is not None:
Expand Down Expand Up @@ -309,20 +312,23 @@ def __init__(
}
}

def write(self, total_frames: int, metrics: Dict[str, List[Tensor]]):
def write(
self, total_frames: int, metrics: Dict[str, List[Tensor]], evaluation_step: int
):
"""
Writes a step into the json reporting file
Args:
total_frames (int): total frames collected so far in the experiment
metrics (dictionary mapping str to tensor): each value is a 1-dim tensor for the metric in key
of len equal to the number of evaluation episodes for this step.
of len equal to the number of evaluation episodes for this step.
evaluation_step (int): the evaluation step
"""
metrics = {k: val.tolist() for k, val in metrics.items()}
step_metrics = {"step_count": total_frames}
step_metrics.update(metrics)
step_str = f"step_{total_frames}"
step_str = f"step_{evaluation_step}"
if step_str in self.run_data:
self.run_data[step_str].update(step_metrics)
else:
Expand Down

0 comments on commit 1fd9e7b

Please sign in to comment.