Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 54d9949 commit bc85b94
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)

# Create optimizer
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ loss:

compile:
compile: False
compile_mode:
compile_mode: default
cudagraphs: False
4 changes: 2 additions & 2 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def make_loss(loss_cfg, model, device):
return loss_module, target_net_updater


def make_discrete_loss(loss_cfg, model):
def make_discrete_loss(loss_cfg, model, device):
loss_module = DiscreteIQLLoss(
model[0],
model[1],
Expand All @@ -390,7 +390,7 @@ def make_discrete_loss(loss_cfg, model):
expectile=loss_cfg.expectile,
action_space="categorical",
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gammam, device=device)
target_net_updater = HardUpdate(
loss_module, value_network_update_interval=loss_cfg.hard_update_interval
)
Expand Down

0 comments on commit bc85b94

Please sign in to comment.