From c24bc7341f115a36c909bf9ca0050ab61367b345 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:28:12 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/iql/utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 42ca848318a..bc643eb6149 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -10,6 +10,7 @@ import torch.optim from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Categorical from torchrl.collectors import SyncDataCollector from torchrl.data import ( @@ -36,7 +37,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - OneHotCategorical, ProbabilisticActor, SafeModule, TanhNormal, @@ -44,7 +44,6 @@ ) from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate from torchrl.record import VideoRecorder - from torchrl.trainers.helpers.models import ACTIVATIONS @@ -58,7 +57,11 @@ def env_maker(cfg, device="cpu", from_pixels=False): if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + categorical_action_encoding=True, ) elif lib == "dm_control": env = DMControlEnv( @@ -221,8 +224,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": action_spec.space.low.to(device), + "high": action_spec.space.high.to(device), "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, @@ -318,7 +321,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): module=actor_module, in_keys=["logits"], out_keys=["action"], - distribution_class=OneHotCategorical, + distribution_class=Categorical, distribution_kwargs={}, default_interaction_type=InteractionType.RANDOM, return_log_prob=False,