Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
2 parents 352511f + 12e27bc commit de6ebcd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def make_td3_agent(cfg, train_env, eval_env, device):
net(td)
# Exploration wrappers:
actor_model_explore = TensorDictSequential(
model[0],
actor,
AdditiveGaussianModule(
sigma_init=1,
sigma_end=1,
Expand Down
5 changes: 3 additions & 2 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class AdditiveGaussianModule(TensorDictModuleBase):
default: "action"
safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
given the :obj:`TensorSpec.project` heuristic.
default: True
default: False
device (torch.device, optional): the device where the buffers have to be stored.
.. note::
Expand All @@ -420,7 +420,8 @@ def __init__(
std: float = 1.0,
*,
action_key: Optional[NestedKey] = "action",
safe: bool = True,
# safe is already implemented because we project in the noise addition
safe: bool = False,
device: torch.device | None = None,
):
if not isinstance(sigma_init, float):
Expand Down

0 comments on commit de6ebcd

Please sign in to comment.