Skip to content

Commit

Permalink
TD3BC
Browse files Browse the repository at this point in the history
Summary: Offline RL: TD3BC implementation, based on TD3 algorithms.

Reviewed By: jb3618columbia

Differential Revision: D55509906

fbshipit-source-id: b06dd5654e116ec7fa39c0a2cc53f7ee8fce63d9
  • Loading branch information
Yonathan Efroni authored and facebook-github-bot committed Apr 9, 2024
1 parent 9da5168 commit c3a485d
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,81 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
critic=self._critic,
)
return loss


class TD3BC(TD3):
"""
Implementation of the TD3BC algorithm in which a behaviour cloning term is added to the actor loss.
The actor loss is implemented similarly to https://arxiv.org/pdf/2106.06860.pdf.
"""

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
behavior_policy: torch.nn.Module,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
actor_soft_update_tau: float = 0.005,
critic_soft_update_tau: float = 0.005,
discount_factor: float = 0.99,
training_rounds: int = 1,
batch_size: int = 256,
actor_update_freq: int = 2,
actor_update_noise: float = 0.2,
actor_update_noise_clip: float = 0.5,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None,
alpha_bc: float = 2.5,
) -> None:
super(TD3BC, self).__init__(
state_dim=state_dim,
action_space=action_space,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
exploration_module=exploration_module,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
actor_soft_update_tau=actor_soft_update_tau,
critic_soft_update_tau=critic_soft_update_tau,
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
actor_update_freq=actor_update_freq,
actor_update_noise=actor_update_noise,
actor_update_noise_clip=actor_update_noise_clip,
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
)
self.alpha_bc: float = alpha_bc
self._behavior_policy: torch.nn.Module = behavior_policy

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

# sample a batch of actions from the actor network; shape (batch_size, action_dim)
action_batch = self._actor.sample_action(batch.state)

# samples q values for (batch.state, action_batch) from twin critics
q, _ = self._critic.get_q_values(
state_batch=batch.state, action_batch=action_batch
)

# behvaiour cloning loss terms
with torch.no_grad():
behaviour_action_batch = self._behavior_policy(batch.state)
lmbda = self.alpha_bc / q.abs().mean().detach()
behavior_loss_mse = ((action_batch - behaviour_action_batch).pow(2)).mean()

# optimization objective: optimize actor to maximize Q(s, a)
loss = behavior_loss_mse - lmbda * q.mean()

return loss

0 comments on commit c3a485d

Please sign in to comment.