Skip to content

Commit

Permalink
Make csac's autotune more efficient
Browse files Browse the repository at this point in the history
Summary: Use caching for autotune.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65922397

fbshipit-source-id: c980afcce1219aeed393b512c1ad59f877010d2f
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 7, 2024
1 parent 7f022a9 commit 3ac6732
Showing 1 changed file with 3 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,17 @@ def __init__(
)
else:
self.register_buffer("_entropy_coef", torch.tensor(entropy_coef))
self._action_batch_log_prob_cache: torch.Tensor = torch.tensor(0.0)

def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
actor_critic_loss = super().learn_batch(batch)
state_batch = batch.state # shape: (batch_size x state_dim)

if self._entropy_autotune:
with torch.no_grad():
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
_, action_batch_log_prob = self._actor.sample_action(
state_batch, get_log_prob=True
)

entropy_optimizer_loss = (
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
-torch.exp(self._log_entropy)
* (action_batch_log_prob + self._target_entropy)
* (self._action_batch_log_prob_cache + self._target_entropy).detach()
).mean()

self._entropy_optimizer.zero_grad()
Expand Down Expand Up @@ -214,6 +208,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
) = self._actor.sample_action(state_batch, get_log_prob=True)

self._action_batch_log_prob_cache = action_batch_log_prob
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
q1, q2 = self._critic.get_q_values(
state_batch=state_batch, action_batch=action_batch
Expand Down

0 comments on commit 3ac6732

Please sign in to comment.