From 12e27bcc0dab76ed8da79d1e95d32076fe52fabb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 19:04:52 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/td3/utils.py | 2 +- torchrl/modules/tensordict_module/exploration.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 7d27f34dea7..13a234e31be 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -229,7 +229,7 @@ def make_td3_agent(cfg, train_env, eval_env, device): net(td) # Exploration wrappers: actor_model_explore = TensorDictSequential( - model[0], + actor, AdditiveGaussianModule( sigma_init=1, sigma_end=1, diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 6e8296a677a..da0c6dc3260 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -397,7 +397,7 @@ class AdditiveGaussianModule(TensorDictModuleBase): default: "action" safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space given the :obj:`TensorSpec.project` heuristic. - default: True + default: False device (torch.device, optional): the device where the buffers have to be stored. .. note:: @@ -420,7 +420,8 @@ def __init__( std: float = 1.0, *, action_key: Optional[NestedKey] = "action", - safe: bool = True, + # safe is already implemented because we project in the noise addition + safe: bool = False, device: torch.device | None = None, ): if not isinstance(sigma_init, float):