diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index af8d3334950..e36d87deaa8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -298,7 +298,9 @@ def compile_rssms(module): log_metrics(logger, eval_metrics, collected_frames) # Simulated env if model_based_env_eval is not None: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy,