From c1896909eeccfba8c97c9a0d4f6cde1526acd858 Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Fri, 23 Feb 2024 18:14:13 -0800 Subject: [PATCH] fix a problem in some actor-critic methods Summary: In my previous diff that fixes the incompatibility of actor-critic methods and LSTM (D52946247), I resolved the double backward issue by summing actor and critic losses. I then computed the gradient w.r.t. this sum by calling backward(), However, for DDPG and TD3, this approach is problematic. For these two algorithms, the partial derivative of actor loss w.r.t. critic parameters should not be used to update the critic. This diff addresses this issue. Reviewed By: rodrigodesalvobraz Differential Revision: D54048016 fbshipit-source-id: 1a7e0a11fd6215263b13f5ca469125e96ab70f89 --- .../actor_critic_base.py | 21 ++++++++++--- .../implicit_q_learning.py | 2 -- .../sequential_decision_making/td3.py | 31 +++++++++---------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py index 2154501c..a26d4d04 100644 --- a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py +++ b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py @@ -182,8 +182,6 @@ def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._actor_optimizer.add_param_group({"params": value.parameters()}) - if self._use_critic: - self._critic_optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value def act( @@ -272,10 +270,24 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ actor_loss = self._actor_loss(batch) self._actor_optimizer.zero_grad() + """ + If the history summarization module is a neural network, + the computation graph of this neural network is used + to obtain both actor and critic losses. + Without retain_graph=True, after actor_loss.backward(), the computation graph is cleared. + After the graph is cleared, critic_loss.backward() fails. + """ + actor_loss.backward(retain_graph=True) if self._use_critic: - critic_loss = self._critic_loss(batch) self._critic_optimizer.zero_grad() - (actor_loss + critic_loss).backward() + critic_loss = self._critic_loss(batch) + """ + This backward operation needs to happen before the actor_optimizer.step(). + This is because actor_optimizer.step() updates the history summarization neural network + and critic_loss.backward() fails + once parameters involved in critic_loss's computational graph change. + """ + critic_loss.backward() self._actor_optimizer.step() self._critic_optimizer.step() report = { @@ -283,7 +295,6 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: "critic_loss": critic_loss.item(), } else: - actor_loss.backward() self._actor_optimizer.step() report = {"actor_loss": actor_loss.item()} diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index 30d732e7..6288dda9 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -149,8 +149,6 @@ def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._actor_optimizer.add_param_group({"params": value.parameters()}) - self._critic_optimizer.add_param_group({"params": value.parameters()}) - self._value_network_optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index c67d3e13..29a0c6c4 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -89,29 +89,26 @@ def __init__( self._critic_update_count = 0 def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: - # TODO: this method is very similar to that of ActorCriticBase. - # Can we refactor? + # The actor and the critic updates are arranged in the following way + # for the same reason as in the comment "If the history summarization module ..." + # in the learn_batch function in actor_critic_base.py. - critic_loss = self._critic_loss(batch) # critic update self._critic_update_count += 1 - + report = {} # delayed actor update - self._critic_optimizer.zero_grad() + self._actor_optimizer.zero_grad() if self._critic_update_count % self._actor_update_freq == 0: # see ddpg base class for actor update details actor_loss = self._actor_loss(batch) - self._actor_optimizer.zero_grad() - (actor_loss + critic_loss).backward() - self._actor_optimizer.step() - self._critic_optimizer.step() - report = { - "actor_loss": actor_loss.item(), - "critic_loss": critic_loss.item(), - } - else: - critic_loss.backward() - self._critic_optimizer.step() - report = {"critic_loss": critic_loss.item()} + actor_loss.backward(retain_graph=True) + report["actor_loss"] = actor_loss.item() + + self._critic_optimizer.zero_grad() + critic_loss = self._critic_loss(batch) # critic update + critic_loss.backward() + self._actor_optimizer.step() + self._critic_optimizer.step() + report["critic_loss"] = critic_loss.item() if self._critic_update_count % self._actor_update_freq == 0: # update targets of critics using soft updates