diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index aafde363d3b..aa81d6b989e 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -9,7 +9,7 @@ from contextlib import nullcontext import torch -from tensordict.nn import TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn, optim from torchrl.collectors import SyncDataCollector @@ -32,8 +32,6 @@ from torchrl.modules import ( AdditiveGaussianModule, MLP, - SafeModule, - SafeSequential, TanhModule, ValueOperator, ) @@ -199,14 +197,12 @@ def make_td3_agent(cfg, train_env, eval_env, device): ) in_keys_actor = in_keys - actor_module = SafeModule( + actor_module = TensorDictModule( actor_net, in_keys=in_keys_actor, - out_keys=[ - "param", - ], + out_keys=["param"], ) - actor = SafeSequential( + actor = TensorDictSequential( actor_module, TanhModule( in_keys=["param"],