From 3cf93080a11579d5bf4cf7adcf080e47b1f4618b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:28:11 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/iql/iql_online.py | 4 +++- sota-implementations/iql/utils.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 6e9c8a0d8ea..f3537ef0d47 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -103,7 +103,9 @@ def main(cfg: "DictConfig"): # noqa: F821 compile_mode = "reduce-overhead" # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss loss_module, target_net_updater = make_loss(cfg.loss, model) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 42ca848318a..bc643eb6149 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -10,6 +10,7 @@ import torch.optim from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Categorical from torchrl.collectors import SyncDataCollector from torchrl.data import ( @@ -36,7 +37,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - OneHotCategorical, ProbabilisticActor, SafeModule, TanhNormal, @@ -44,7 +44,6 @@ ) from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate from torchrl.record import VideoRecorder - from torchrl.trainers.helpers.models import ACTIVATIONS @@ -58,7 +57,11 @@ def env_maker(cfg, device="cpu", from_pixels=False): if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + categorical_action_encoding=True, ) elif lib == "dm_control": env = DMControlEnv( @@ -221,8 +224,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": action_spec.space.low.to(device), + "high": action_spec.space.high.to(device), "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, @@ -318,7 +321,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): module=actor_module, in_keys=["logits"], out_keys=["action"], - distribution_class=OneHotCategorical, + distribution_class=Categorical, distribution_kwargs={}, default_interaction_type=InteractionType.RANDOM, return_log_prob=False,