Skip to content

Commit

Permalink
Fixes for APS integration
Browse files Browse the repository at this point in the history
Summary:
Minor modifications for APS training
1) Log actor loss each step (otherwise get tensorboard error)
2) Removed assertion that actor type is
 GaussianActorNetwork, VanillaContinuousActorNetwork. This allow us to use customized actors which may not be of this type (generally, the user would like to specify the actors rather us forcing the user). Notice: we don't have such assertion in other AC methods, so better to remove it as of now.

Reviewed By: danielrjiang

Differential Revision: D58140074

fbshipit-source-id: 02cf789b7581275328c9c8a5fd3a1da033308c20
  • Loading branch information
Yonathan Efroni authored and facebook-github-bot committed Jun 5, 2024
1 parent 75324b2 commit 88a41f3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,6 @@ def __init__(
self._expectile = expectile
self._is_action_continuous: bool = action_space.is_continuous

# TODO: create actor network interfaces for discrete and continuous actor networks
# and use the continuous one in this test.
if self._is_action_continuous:
torch._assert(
actor_network_type == GaussianActorNetwork
or actor_network_type == VanillaContinuousActorNetwork,
"continuous action space requires a deterministic or a stochastic actor which works"
"with continuous action spaces",
)

self._temperature_advantage_weighted_regression = (
temperature_advantage_weighted_regression
)
Expand Down
4 changes: 3 additions & 1 deletion pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
self._actor_update_noise = actor_update_noise
self._actor_update_noise_clip = actor_update_noise_clip
self._critic_update_count = 0
self._last_actor_loss: float = 0.0

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
# The actor and the critic updates are arranged in the following way
Expand All @@ -108,7 +109,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
# see ddpg base class for actor update details
actor_loss = self._actor_loss(batch)
actor_loss.backward(retain_graph=True)
report["actor_loss"] = actor_loss.item()
self._last_actor_loss = actor_loss.item()
report["actor_loss"] = self._last_actor_loss

self._critic_optimizer.zero_grad()
critic_loss = self._critic_loss(batch) # critic update
Expand Down

0 comments on commit 88a41f3

Please sign in to comment.