diff --git a/howto/eval_your_agent.md b/howto/eval_your_agent.md index 59820ad1..c2d984d6 100644 --- a/howto/eval_your_agent.md +++ b/howto/eval_your_agent.md @@ -10,6 +10,7 @@ The agent and the configs used during the training are loaded automatically. The 1. `fabric` related ones: you can use the accelerator you want for evaluating the agent, you just need to specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the GPU. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. By default, the number of devices and nodes is set to 1, while the precision and the plugins are set to the ones set in the checkpoint config. 2. `env.capture_video`: you can decide whether to capture the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. +3. `seed`: the user can specify the seed used for evaluation with `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt seed=42`. By default the seed is set to `None`. All the other parameters are loaded from the checkpoint config file used during the training. Moreover, the following parameters are automatically set during the evaluation: diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 4861d827..9360c578 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -178,6 +178,9 @@ def eval_algorithm(cfg: DictConfig): cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1, _convert_="all" ) + # Seed everything + fabric.seed_everything(cfg.seed) + # Load the checkpoint state = fabric.load(cfg.checkpoint_path) @@ -303,6 +306,7 @@ def evaluation(cfg: DictConfig): # Load the checkpoint configuration checkpoint_path = Path(cfg.checkpoint_path) ckpt_cfg = OmegaConf.load(checkpoint_path.parent.parent / "config.yaml") + ckpt_cfg.pop("seed", None) # Merge the two configs with open_dict(cfg): diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml index 6487d7a3..bc1b4df6 100644 --- a/sheeprl/configs/eval_config.yaml +++ b/sheeprl/configs/eval_config.yaml @@ -17,5 +17,6 @@ fabric: env: capture_video: True +seed: null disable_grads: True checkpoint_path: ??? diff --git a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml index 3b6967c3..8781beba 100644 --- a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml @@ -12,7 +12,7 @@ total_steps: 1000000 # Environment env: num_envs: 1 - id: reward + id: crafter_reward # Checkpoint checkpoint: diff --git a/sheeprl/envs/crafter.py b/sheeprl/envs/crafter.py index ceb37e1c..ae5a94cc 100644 --- a/sheeprl/envs/crafter.py +++ b/sheeprl/envs/crafter.py @@ -55,6 +55,7 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: + self.env._seed = seed obs = self.env.reset() return self._convert_obs(obs), {} diff --git a/sheeprl/envs/dmc.py b/sheeprl/envs/dmc.py index c32a7657..bf0bc3cc 100644 --- a/sheeprl/envs/dmc.py +++ b/sheeprl/envs/dmc.py @@ -120,7 +120,12 @@ def __init__( self._camera_id = camera_id self._channels_first = channels_first - # create task + # Remove random seed from task_kwargs since it will be set by the wrapper + # upon calling the reset method + task_kwargs = task_kwargs or {} + task_kwargs.pop("random", None) + + # Create task env = suite.load( domain_name=domain_name, task_name=task_name, @@ -225,6 +230,9 @@ def step( def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Union[Dict[str, np.ndarray], np.ndarray], Dict[str, Any]]: + if not isinstance(seed, np.random.RandomState): + seed = np.random.RandomState(seed) + self.env.task._random = seed time_step = self.env.reset() self.current_state = _flatten_obs(time_step.observation) obs = self._get_obs(time_step) diff --git a/sheeprl/envs/wrappers.py b/sheeprl/envs/wrappers.py index 9648470b..a5fa5904 100644 --- a/sheeprl/envs/wrappers.py +++ b/sheeprl/envs/wrappers.py @@ -116,7 +116,7 @@ def reset( gym.logger.warn(f"RESET - Restarting env after crash with {type(e).__name__}: {e}") time.sleep(self._wait) self.env = self._env_fn() - new_obs, info = self.env.reset() + new_obs, info = self.env.reset(seed=seed, options=options) info.update({"restart_on_exception": True}) return new_obs, info