Skip to content

Commit

Permalink
Fix some more type errors from D52890934 and D65753120
Browse files Browse the repository at this point in the history
Reviewed By: aakhundov

Differential Revision: D66344734

fbshipit-source-id: 8546045798d0bac87a17c9c179745fac399de193
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 22, 2024
1 parent 44eafb6 commit 47394cf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,15 @@ def __init__(
# actor network takes state as input and outputs an action vector
self._actor: nn.Module = actor_network_type(
input_dim=(
# pyre-fixme[58]: `+` is not supported for operand types `int`
# and `Union[Module, Tensor]`.
state_dim + self.action_representation_module.representation_dim
if issubclass(actor_network_type, DynamicActionActorNetwork)
else state_dim
),
hidden_dims=actor_hidden_dims,
# pyre-fixme[6]: For 3rd argument expected `int` but got `Union[int,
# Module, Tensor]`.
output_dim=(
1
if issubclass(actor_network_type, DynamicActionActorNetwork)
Expand Down Expand Up @@ -175,7 +179,11 @@ def __init__(
parameter critic_network_instance has not been provided."

self._critic: nn.Module = make_critic(
# pyre-fixme[6]: For 1st argument expected `int` but got
# `Optional[int]`.
state_dim=self._state_dim,
# pyre-fixme[6]: For 2nd argument expected `Optional[int]` but
# got `Union[Module, Tensor]`.
action_dim=self.action_representation_module.representation_dim,
hidden_dims=critic_hidden_dims,
use_twin_critic=use_twin_critic,
Expand Down Expand Up @@ -203,6 +211,8 @@ def set_history_summarization_module(
"""
The history summarization module uses its own optimizer.
"""
# pyre-fixme[16]: `ActorCriticBase` has no attribute
# `_history_summarization_optimizer`.
self._history_summarization_optimizer: optim.Optimizer = optim.AdamW(
[
{
Expand Down Expand Up @@ -302,6 +312,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
Dict[str, Any]: A dictionary containing the loss reports from the critic
and actor updates. These can be useful to track for debugging purposes.
"""
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `zero_grad`.
self._history_summarization_optimizer.zero_grad()
actor_loss = self._actor_loss(batch)
self._actor_optimizer.zero_grad()
Expand All @@ -321,6 +333,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
critic_loss.backward()
self._critic_optimizer.step()
report["critic_loss"] = critic_loss.item()
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `step`.
self._history_summarization_optimizer.step()

if self._use_critic_target:
Expand Down Expand Up @@ -348,6 +361,8 @@ def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
# attribute `lambda_constraint`.
batch.reward
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
# attribute `lambda_constraint`.
- self.safety_module.lambda_constraint * batch.cost
)
batch = super().preprocess_batch(batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
value_loss = self._value_loss(batch)
critic_loss = self._critic_loss(batch)
actor_loss = self._actor_loss(batch)
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `zero_grad`.
self._history_summarization_optimizer.zero_grad()
self._value_network_optimizer.zero_grad()
self._actor_optimizer.zero_grad()
Expand All @@ -166,6 +168,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
self._value_network_optimizer.step()
self._actor_optimizer.step()
self._critic_optimizer.step()
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `step`.
self._history_summarization_optimizer.step()
# update critic and target Twin networks;
update_target_networks(
Expand Down
3 changes: 3 additions & 0 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
self._critic_update_count += 1
report = {}
# delayed actor update
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `zero_grad`.
self._history_summarization_optimizer.zero_grad()
if self._critic_update_count % self._actor_update_freq == 0:
self._actor_optimizer.zero_grad()
Expand All @@ -118,6 +120,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
critic_loss.backward()
self._critic_optimizer.step()
report["critic_loss"] = critic_loss.item()
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `step`.
self._history_summarization_optimizer.step()

if self._critic_update_count % self._actor_update_freq == 0:
Expand Down

0 comments on commit 47394cf

Please sign in to comment.