Skip to content

Commit

Permalink
Fix/deterministic evaluate (#202)
Browse files Browse the repository at this point in the history
* Evaluate functions run with torch.no_grad() by defeault

* Import user-defined evaluation file

* Set seed before the evaluation starts

* Rollback unwanted modification

* Fix setting seed for DMC envs

* Fix setting seed in Crafter env

* Fix RestartOnException call to restart on exception

* Make amend
  • Loading branch information
belerico authored Feb 5, 2024
1 parent 062b21b commit 033ad74
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 3 deletions.
1 change: 1 addition & 0 deletions howto/eval_your_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The agent and the configs used during the training are loaded automatically. The

1. `fabric` related ones: you can use the accelerator you want for evaluating the agent, you just need to specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the GPU. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. By default, the number of devices and nodes is set to 1, while the precision and the plugins are set to the ones set in the checkpoint config.
2. `env.capture_video`: you can decide whether to capture the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation.
3. `seed`: the user can specify the seed used for evaluation with `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt seed=42`. By default the seed is set to `None`.

All the other parameters are loaded from the checkpoint config file used during the training. Moreover, the following parameters are automatically set during the evaluation:

Expand Down
4 changes: 4 additions & 0 deletions sheeprl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def eval_algorithm(cfg: DictConfig):
cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1, _convert_="all"
)

# Seed everything
fabric.seed_everything(cfg.seed)

# Load the checkpoint
state = fabric.load(cfg.checkpoint_path)

Expand Down Expand Up @@ -303,6 +306,7 @@ def evaluation(cfg: DictConfig):
# Load the checkpoint configuration
checkpoint_path = Path(cfg.checkpoint_path)
ckpt_cfg = OmegaConf.load(checkpoint_path.parent.parent / "config.yaml")
ckpt_cfg.pop("seed", None)

# Merge the two configs
with open_dict(cfg):
Expand Down
1 change: 1 addition & 0 deletions sheeprl/configs/eval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ fabric:
env:
capture_video: True

seed: null
disable_grads: True
checkpoint_path: ???
2 changes: 1 addition & 1 deletion sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ total_steps: 1000000
# Environment
env:
num_envs: 1
id: reward
id: crafter_reward

# Checkpoint
checkpoint:
Expand Down
1 change: 1 addition & 0 deletions sheeprl/envs/crafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Any, Dict[str, Any]]:
self.env._seed = seed
obs = self.env.reset()
return self._convert_obs(obs), {}

Expand Down
10 changes: 9 additions & 1 deletion sheeprl/envs/dmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def __init__(
self._camera_id = camera_id
self._channels_first = channels_first

# create task
# Remove random seed from task_kwargs since it will be set by the wrapper
# upon calling the reset method
task_kwargs = task_kwargs or {}
task_kwargs.pop("random", None)

# Create task
env = suite.load(
domain_name=domain_name,
task_name=task_name,
Expand Down Expand Up @@ -225,6 +230,9 @@ def step(
def reset(
self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Union[Dict[str, np.ndarray], np.ndarray], Dict[str, Any]]:
if not isinstance(seed, np.random.RandomState):
seed = np.random.RandomState(seed)
self.env.task._random = seed
time_step = self.env.reset()
self.current_state = _flatten_obs(time_step.observation)
obs = self._get_obs(time_step)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def reset(
gym.logger.warn(f"RESET - Restarting env after crash with {type(e).__name__}: {e}")
time.sleep(self._wait)
self.env = self._env_fn()
new_obs, info = self.env.reset()
new_obs, info = self.env.reset(seed=seed, options=options)
info.update({"restart_on_exception": True})
return new_obs, info

Expand Down

0 comments on commit 033ad74

Please sign in to comment.