diff --git a/src/imitation/scripts/ingredients/rl.py b/src/imitation/scripts/ingredients/rl.py index d5373c773..9a829aae4 100644 --- a/src/imitation/scripts/ingredients/rl.py +++ b/src/imitation/scripts/ingredients/rl.py @@ -98,6 +98,12 @@ def sac(): locals() # quieten flake8 +@rl_ingredient.named_config +def dqn(): + rl_cls = sb3.DQN + + + def _maybe_add_relabel_buffer( rl_kwargs: Dict[str, Any], relabel_reward_fn: Optional[RewardFn] = None, diff --git a/tuning/hp_search_spaces.py b/tuning/hp_search_spaces.py index 4f5a7ee9d..0dbf18f62 100644 --- a/tuning/hp_search_spaces.py +++ b/tuning/hp_search_spaces.py @@ -125,7 +125,7 @@ def __call__( sqil=RunSacredAsTrial( sacred_ex=imitation.scripts.train_imitation.train_imitation_ex, command_name="sqil", - suggest_named_configs=lambda _: [], + suggest_named_configs=lambda _: ["rl.dqn"], suggest_config_updates=lambda trial: { "seed": trial.number, "demonstrations": { @@ -134,12 +134,8 @@ def __call__( }, "sqil": { "total_timesteps": 1e6, - "train_kwargs": { - - } }, "rl": { - "rl_cls": sb3.DQN, "rl_kwargs": { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-2, log=True), "buffer_size": trial.suggest_int("buffer_size", 1000, 100000),