Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 12, 2024
1 parent f5d2b83 commit f095c01
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def fin():
request.addfinalizer(fin)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def set_warnings() -> None:
warnings.filterwarnings(
"ignore",
Expand Down
25 changes: 16 additions & 9 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def __init__(

if functional:
self.convert_to_functional(
actor_network, "actor_network",
actor_network,
"actor_network",
)
else:
self.actor_network = actor_network
Expand Down Expand Up @@ -350,7 +351,7 @@ def in_keys(self):
*[("next", key) for key in self.actor_network.in_keys],
]
if self.critic_coef:
keys.extend(self.critic.in_keys)
keys.extend(self.critic_network.in_keys)
return list(set(keys))

@property
Expand Down Expand Up @@ -414,11 +415,11 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
tensordict_select = tensordict.select(*self.critic_network.in_keys)
with self.critic_network_params.to_module(
self.critic
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic(
state_value = self.critic_network(
tensordict_select,
).get(self.tensor_keys.value)
loss_value = distance_loss(
Expand Down Expand Up @@ -477,13 +478,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(value_network=self.critic, **hp)
self._value_estimator = TD1Estimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(value_network=self.critic, **hp)
self._value_estimator = TD0Estimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.GAE:
self._value_estimator = GAE(value_network=self.critic, **hp)
self._value_estimator = GAE(value_network=self.critic_network, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
self._value_estimator = TDLambdaEstimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.VTrace:
# VTrace currently does not support functional call on the actor
if self.functional:
Expand Down
24 changes: 16 additions & 8 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _set_in_keys(self):
("next", self.tensor_keys.terminated),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.critic.in_keys,
*self.critic_network.in_keys,
]
self._in_keys = list(set(keys))

Expand Down Expand Up @@ -398,11 +398,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
try:
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
tensordict_select = tensordict.select(*self.critic_network.in_keys)
with self.critic_network_params.to_module(
self.critic
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic(tensordict_select).get(self.tensor_keys.value)
state_value = self.critic_network(tensordict_select).get(
self.tensor_keys.value
)
loss_value = distance_loss(
target_return,
state_value,
Expand All @@ -427,13 +429,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
hp["gamma"] = self.gamma
hp.update(hyperparams)
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(value_network=self.critic, **hp)
self._value_estimator = TD1Estimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(value_network=self.critic, **hp)
self._value_estimator = TD0Estimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.GAE:
self._value_estimator = GAE(value_network=self.critic, **hp)
self._value_estimator = GAE(value_network=self.critic_network, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
self._value_estimator = TDLambdaEstimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.VTrace:
# VTrace currently does not support functional call on the actor
if self.functional:
Expand Down

0 comments on commit f095c01

Please sign in to comment.