Skip to content

Commit

Permalink
fix a problem in some actor-critic methods
Browse files Browse the repository at this point in the history
Summary: In my previous diff that fixes the incompatibility of actor-critic methods and LSTM (D52946247), I resolved the double backward issue by summing actor and critic losses. I then computed the gradient w.r.t. this sum by calling backward(), However, for DDPG and TD3, this approach is problematic. For these two algorithms, the partial derivative of actor loss w.r.t. critic parameters should not be used to update the critic. This diff addresses this issue.

Reviewed By: rodrigodesalvobraz

Differential Revision: D54048016

fbshipit-source-id: 1a7e0a11fd6215263b13f5ca469125e96ab70f89
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Feb 24, 2024
1 parent af330a6 commit c189690
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._actor_optimizer.add_param_group({"params": value.parameters()})
if self._use_critic:
self._critic_optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value

def act(
Expand Down Expand Up @@ -272,18 +270,31 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
actor_loss = self._actor_loss(batch)
self._actor_optimizer.zero_grad()
"""
If the history summarization module is a neural network,
the computation graph of this neural network is used
to obtain both actor and critic losses.
Without retain_graph=True, after actor_loss.backward(), the computation graph is cleared.
After the graph is cleared, critic_loss.backward() fails.
"""
actor_loss.backward(retain_graph=True)
if self._use_critic:
critic_loss = self._critic_loss(batch)
self._critic_optimizer.zero_grad()
(actor_loss + critic_loss).backward()
critic_loss = self._critic_loss(batch)
"""
This backward operation needs to happen before the actor_optimizer.step().
This is because actor_optimizer.step() updates the history summarization neural network
and critic_loss.backward() fails
once parameters involved in critic_loss's computational graph change.
"""
critic_loss.backward()
self._actor_optimizer.step()
self._critic_optimizer.step()
report = {
"actor_loss": actor_loss.item(),
"critic_loss": critic_loss.item(),
}
else:
actor_loss.backward()
self._actor_optimizer.step()
report = {"actor_loss": actor_loss.item()}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._actor_optimizer.add_param_group({"params": value.parameters()})
self._critic_optimizer.add_param_group({"params": value.parameters()})
self._value_network_optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
Expand Down
31 changes: 14 additions & 17 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,29 +89,26 @@ def __init__(
self._critic_update_count = 0

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
# TODO: this method is very similar to that of ActorCriticBase.
# Can we refactor?
# The actor and the critic updates are arranged in the following way
# for the same reason as in the comment "If the history summarization module ..."
# in the learn_batch function in actor_critic_base.py.

critic_loss = self._critic_loss(batch) # critic update
self._critic_update_count += 1

report = {}
# delayed actor update
self._critic_optimizer.zero_grad()
self._actor_optimizer.zero_grad()
if self._critic_update_count % self._actor_update_freq == 0:
# see ddpg base class for actor update details
actor_loss = self._actor_loss(batch)
self._actor_optimizer.zero_grad()
(actor_loss + critic_loss).backward()
self._actor_optimizer.step()
self._critic_optimizer.step()
report = {
"actor_loss": actor_loss.item(),
"critic_loss": critic_loss.item(),
}
else:
critic_loss.backward()
self._critic_optimizer.step()
report = {"critic_loss": critic_loss.item()}
actor_loss.backward(retain_graph=True)
report["actor_loss"] = actor_loss.item()

self._critic_optimizer.zero_grad()
critic_loss = self._critic_loss(batch) # critic update
critic_loss.backward()
self._actor_optimizer.step()
self._critic_optimizer.step()
report["critic_loss"] = critic_loss.item()

if self._critic_update_count % self._actor_update_freq == 0:
# update targets of critics using soft updates
Expand Down

0 comments on commit c189690

Please sign in to comment.