Skip to content

Commit

Permalink
Turn DQN into a named config.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 26, 2024
1 parent 5769fc6 commit d3860a3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 6 additions & 0 deletions src/imitation/scripts/ingredients/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def sac():
locals() # quieten flake8


@rl_ingredient.named_config
def dqn():
rl_cls = sb3.DQN



def _maybe_add_relabel_buffer(
rl_kwargs: Dict[str, Any],
relabel_reward_fn: Optional[RewardFn] = None,
Expand Down
6 changes: 1 addition & 5 deletions tuning/hp_search_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __call__(
sqil=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_imitation.train_imitation_ex,
command_name="sqil",
suggest_named_configs=lambda _: [],
suggest_named_configs=lambda _: ["rl.dqn"],
suggest_config_updates=lambda trial: {
"seed": trial.number,
"demonstrations": {
Expand All @@ -134,12 +134,8 @@ def __call__(
},
"sqil": {
"total_timesteps": 1e6,
"train_kwargs": {

}
},
"rl": {
"rl_cls": sb3.DQN,
"rl_kwargs": {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-2, log=True),
"buffer_size": trial.suggest_int("buffer_size", 1000, 100000),
Expand Down

0 comments on commit d3860a3

Please sign in to comment.