From df96edc1fbd5e6cbeab031cb58d70b7359354d0f Mon Sep 17 00:00:00 2001 From: Rodrigo de Salvo Braz Date: Thu, 22 Feb 2024 00:41:29 -0800 Subject: [PATCH] Fix bug (TD3 not returning error) Summary: TD3 was not returning error information. This diff corrects that. Reviewed By: yiwan-rl Differential Revision: D54049188 fbshipit-source-id: 2a20a162ac99c354f0b7dbc0f5fed7ebbb2c7a8d --- pearl/policy_learners/sequential_decision_making/td3.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index 45186679..c67d3e13 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -89,6 +89,8 @@ def __init__( self._critic_update_count = 0 def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: + # TODO: this method is very similar to that of ActorCriticBase. + # Can we refactor? critic_loss = self._critic_loss(batch) # critic update self._critic_update_count += 1 @@ -102,9 +104,14 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: (actor_loss + critic_loss).backward() self._actor_optimizer.step() self._critic_optimizer.step() + report = { + "actor_loss": actor_loss.item(), + "critic_loss": critic_loss.item(), + } else: critic_loss.backward() self._critic_optimizer.step() + report = {"critic_loss": critic_loss.item()} if self._critic_update_count % self._actor_update_freq == 0: # update targets of critics using soft updates @@ -119,7 +126,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: self._actor_target, self._actor, self._actor_soft_update_tau ) - return {} + return report def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor: with torch.no_grad():