From 73b5916974a8d3e3dc97ae800e0cd54c9da086cf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 09:32:48 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/a2c/utils_atari.py | 2 +- sota-implementations/a2c/utils_mujoco.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 2 +- sota-implementations/gail/gail.py | 1 - sota-implementations/impala/utils.py | 2 +- sota-implementations/ppo/utils_atari.py | 2 +- sota-implementations/ppo/utils_mujoco.py | 2 +- torchrl/data/tensor_specs.py | 2 +- 8 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 0397f7dc5f3..6ff62bbe520 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec.to(device), + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index b78f52b7eb4..5ce5ed1902d 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec.to(device), + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 532fe4e1fe9..7d8b9d6d618 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -546,7 +546,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=proof_environment.single_full_action_spec.to("cpu"), + spec=proof_environment.full_action_spec_unbatched.to("cpu"), ), ), SafeModule( diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 1075a78eba6..a02845cfe4d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -20,7 +20,6 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models -from tensordict import TensorDict from tensordict.nn import CudaGraphModule from torchrl._utils import compile_with_warmup diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 738bb83bf55..e174bc2e71c 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 755c6311729..040259377ad 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index e7eb4534c45..f2e08ffb129 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 6f214fba6de..ad29b63db04 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3388,7 +3388,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self._undefined_n(): + if self._undefined_n: if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). "