diff --git a/test/test_cost.py b/test/test_cost.py index c48b4a28b99..1f191e41db6 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4493,6 +4493,7 @@ def test_sac_terminating( actor_network=actor, qvalue_network=qvalue, value_network=value, + skip_done_states=True, ) loss.set_keys( action=action_key, @@ -5204,6 +5205,7 @@ def test_discrete_sac_terminating( qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, action_space="one-hot", + skip_done_states=True, ) loss.set_keys( action=action_key, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index cd7039c323d..dafff17011e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -126,6 +126,10 @@ class SACLoss(LossModule): ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + skip_done_states (bool, optional): whether the actor network used for value computation should only be run on + valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the + shape of the data and that masking the data results in a valid data structure. Among other things, this may + not be true in MARL settings or when using RNNs. Defaults to ``False``. Examples: >>> import torch @@ -320,6 +324,7 @@ def __init__( priority_key: str = None, separate_losses: bool = False, reduction: str = None, + skip_done_states: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -418,6 +423,7 @@ def __init__( raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._make_vmap() self.reduction = reduction + self.skip_done_states = skip_done_states def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -712,36 +718,44 @@ def _compute_target_v2(self, tensordict) -> Tensor: ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): next_tensordict = tensordict.get("next").copy() - # Check done state and avoid passing these to the actor - done = next_tensordict.get(self.tensor_keys.done) - if done is not None and done.any(): - next_tensordict_select = next_tensordict[~done.squeeze(-1)] - else: - next_tensordict_select = next_tensordict - next_dist = self.actor_network.get_dist(next_tensordict_select) - next_action = next_dist.rsample() - next_sample_log_prob = compute_log_prob( - next_dist, next_action, self.tensor_keys.log_prob - ) - if next_tensordict_select is not next_tensordict: - mask = ~done.squeeze(-1) - if mask.ndim < next_action.ndim: - mask = expand_right( - mask, (*mask.shape, *next_action.shape[mask.ndim :]) - ) - next_action = next_action.new_zeros(mask.shape).masked_scatter_( - mask, next_action + if self.skip_done_states: + # Check done state and avoid passing these to the actor + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_action = next_dist.rsample() + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob ) - mask = ~done.squeeze(-1) - if mask.ndim < next_sample_log_prob.ndim: - mask = expand_right( - mask, - (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + if mask.ndim < next_action.ndim: + mask = expand_right( + mask, (*mask.shape, *next_action.shape[mask.ndim :]) + ) + next_action = next_action.new_zeros(mask.shape).masked_scatter_( + mask, next_action ) - next_sample_log_prob = next_sample_log_prob.new_zeros( - mask.shape - ).masked_scatter_(mask, next_sample_log_prob) - next_tensordict.set(self.tensor_keys.action, next_action) + mask = ~done.squeeze(-1) + if mask.ndim < next_sample_log_prob.ndim: + mask = expand_right( + mask, + (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + ) + next_sample_log_prob = next_sample_log_prob.new_zeros( + mask.shape + ).masked_scatter_(mask, next_sample_log_prob) + next_tensordict.set(self.tensor_keys.action, next_action) + else: + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.rsample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob + ) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -877,6 +891,10 @@ class DiscreteSACLoss(LossModule): ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + skip_done_states (bool, optional): whether the actor network used for value computation should only be run on + valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the + shape of the data and that masking the data results in a valid data structure. Among other things, this may + not be true in MARL settings or when using RNNs. Defaults to ``False``. Examples: >>> import torch @@ -1051,6 +1069,7 @@ def __init__( priority_key: str = None, separate_losses: bool = False, reduction: str = None, + skip_done_states: bool = False, ): if reduction is None: reduction = "mean" @@ -1133,6 +1152,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + self.skip_done_states = skip_done_states def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -1218,35 +1238,58 @@ def _compute_target(self, tensordict) -> Tensor: with torch.no_grad(): next_tensordict = tensordict.get("next").clone(False) - done = next_tensordict.get(self.tensor_keys.done) - if done is not None and done.any(): - next_tensordict_select = next_tensordict[~done.squeeze(-1)] - else: - next_tensordict_select = next_tensordict + if self.skip_done_states: + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict - # get probs and log probs for actions computed from "next" - with self.actor_network_params.to_module(self.actor_network): - next_dist = self.actor_network.get_dist(next_tensordict_select) - next_log_prob = next_dist.logits - next_prob = next_log_prob.exp() + # get probs and log probs for actions computed from "next" + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_log_prob = next_dist.logits + next_prob = next_log_prob.exp() - # get q-values for all actions - next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict_select, self.target_qvalue_network_params - ) - next_action_value = next_tensordict_expand.get( - self.tensor_keys.action_value - ) + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict_select, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) - # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term - next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob - # unlike in continuous SAC, we can compute the exact expectation over all discrete actions - next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) - if next_tensordict_select is not next_tensordict: - mask = ~done - next_state_value = next_state_value.new_zeros( - mask.shape - ).masked_scatter_(mask, next_state_value) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = ( + next_action_value.min(0)[0] - self._alpha * next_log_prob + ) + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) + if next_tensordict_select is not next_tensordict: + mask = ~done + next_state_value = next_state_value.new_zeros( + mask.shape + ).masked_scatter_(mask, next_state_value) + else: + # get probs and log probs for actions computed from "next" + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict) + next_prob = next_dist.probs + next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) + + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = ( + next_action_value.min(0)[0] - self._alpha * next_log_prob + ) + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value