Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 5721ce5 commit 2af4122
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import nullcontext

import torch
from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -32,8 +32,6 @@
from torchrl.modules import (
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
Expand Down Expand Up @@ -199,14 +197,12 @@ def make_td3_agent(cfg, train_env, eval_env, device):
)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
Expand Down

0 comments on commit 2af4122

Please sign in to comment.