Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 8476c5e commit 73b5916
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=proof_environment.single_full_action_spec.to(device),
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.single_full_action_spec.to(device),
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def _dreamer_make_actor_real(
default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=proof_environment.single_full_action_spec.to("cpu"),
spec=proof_environment.full_action_spec_unbatched.to("cpu"),
),
),
SafeModule(
Expand Down
1 change: 0 additions & 1 deletion sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=proof_environment.single_full_action_spec,
spec=proof_environment.full_action_spec_unbatched,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=proof_environment.single_full_action_spec,
spec=proof_environment.full_action_spec_unbatched,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.single_full_action_spec,
spec=proof_environment.full_action_spec_unbatched,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,7 +3388,7 @@ def set_provisional_n(self, n: int):
self._provisional_n = n

def rand(self, shape: torch.Size = None) -> torch.Tensor:
if self._undefined_n():
if self._undefined_n:
if self._provisional_n is None:
raise RuntimeError(
"Cannot generate random categorical samples for undefined cardinality (n=-1). "
Expand Down

0 comments on commit 73b5916

Please sign in to comment.