From a0ee4f0071d0e969424ccec0ffdbafc6ad8bc3f2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 12:06:37 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/iql/discrete_iql.py | 2 +- sota-implementations/iql/discrete_iql.yaml | 2 +- sota-implementations/iql/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 17153a59913..e51bd25a8a8 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create loss - loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index 81b5a88e59f..3f53ab9a68a 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -62,5 +62,5 @@ loss: compile: compile: False - compile_mode: + compile_mode: default cudagraphs: False diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 168416d80da..04cc2b250ab 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -380,7 +380,7 @@ def make_loss(loss_cfg, model, device): return loss_module, target_net_updater -def make_discrete_loss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model, device): loss_module = DiscreteIQLLoss( model[0], model[1], @@ -390,7 +390,7 @@ def make_discrete_loss(loss_cfg, model): expectile=loss_cfg.expectile, action_space="categorical", ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gammam, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=loss_cfg.hard_update_interval )