Skip to content

Commit

Permalink
replace _critic_update_count with _training_steps in td3
Browse files Browse the repository at this point in the history
Summary: _critic_update_count is not neccessary given that we already have _training_steps.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65922517

fbshipit-source-id: 944fb8150ae65a230d8bda068cbe8bb3c3ef7786
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 7, 2024
1 parent 3ac6732 commit 3e52c4a
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,19 @@ def __init__(
self._actor_update_freq = actor_update_freq
self._actor_update_noise = actor_update_noise
self._actor_update_noise_clip = actor_update_noise_clip
self._critic_update_count = 0
self._last_actor_loss: float = 0.0

def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
# The actor and the critic updates are arranged in the following way
# for the same reason as in the comment "If the history summarization module ..."
# in the learn_batch function in actor_critic_base.py.

self._critic_update_count += 1
report = {}
# delayed actor update
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `zero_grad`.
self._history_summarization_optimizer.zero_grad()
if self._critic_update_count % self._actor_update_freq == 0:
if self._training_steps % self._actor_update_freq == 0:
self._actor_optimizer.zero_grad()
actor_loss = self._actor_loss(batch)
actor_loss.backward(retain_graph=True)
Expand All @@ -123,7 +121,7 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `step`.
self._history_summarization_optimizer.step()

if self._critic_update_count % self._actor_update_freq == 0:
if self._training_steps % self._actor_update_freq == 0:
# update targets of critics using soft updates
update_critic_target_network(
self._critic_target,
Expand Down

0 comments on commit 3e52c4a

Please sign in to comment.