diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index eb9a916dfc1..c35bd8e818d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import List, Tuple import torch from tensordict import ( @@ -16,6 +16,7 @@ TensorDict, TensorDictBase, TensorDictParams, + unravel_key, ) from tensordict.nn import ( CompositeDistribution, @@ -33,6 +34,8 @@ _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_add_or_extend_key, + _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, @@ -67,7 +70,10 @@ class PPOLoss(LossModule): Args: actor_network (ProbabilisticTensorDictSequential): policy operator. - critic_network (ValueOperator): value operator. + Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations + as input and outputting an action (or actions) as well as its log-probability value. + critic_network (ValueOperator): value operator. The critic will usually take the observations as input + and return a scalar value (``state_value`` by default) in the output keys. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -267,16 +273,16 @@ class _AcceptedKeys: Will be used for the underlying value estimator Defaults to ``"value_target"``. value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - sample_log_prob (NestedKey): The input tensordict key where the + sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the sample log probability is expected. Defaults to ``"sample_log_prob"``. - action (NestedKey): The input tensordict key where the action is expected. + action (NestedKey or list of nested keys): The input tensordict key where the action is expected. Defaults to ``"action"``. - reward (NestedKey): The input tensordict key where the reward is expected. + reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. - done (NestedKey): The key in the input TensorDict that indicates + done (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. - terminated (NestedKey): The key in the input TensorDict that indicates + terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ @@ -284,11 +290,11 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" - action: NestedKey = "action" - reward: NestedKey = "reward" - done: NestedKey = "done" - terminated: NestedKey = "terminated" + sample_log_prob: NestedKey | List[NestedKey] = "sample_log_prob" + action: NestedKey | List[NestedKey] = "action" + reward: NestedKey | List[NestedKey] = "reward" + done: NestedKey | List[NestedKey] = "done" + terminated: NestedKey | List[NestedKey] = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -369,7 +375,7 @@ def __init__( try: device = next(self.parameters()).device - except AttributeError: + except (AttributeError, StopIteration): device = torch.device("cpu") self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) @@ -409,15 +415,36 @@ def functional(self): def _set_in_keys(self): keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.critic_network.in_keys, ] + + if isinstance(self.tensor_keys.action, NestedKey): + keys.append(self.tensor_keys.action) + else: + keys.extend(self.tensor_keys.action) + + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + keys.append(self.tensor_keys.sample_log_prob) + else: + keys.extend(self.tensor_keys.sample_log_prob) + + if isinstance(self.tensor_keys.reward, NestedKey): + keys.append(unravel_key(("next", self.tensor_keys.reward))) + else: + keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.reward]) + + if isinstance(self.tensor_keys.done, NestedKey): + keys.append(unravel_key(("next", self.tensor_keys.done))) + else: + keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.done]) + + if isinstance(self.tensor_keys.terminated, NestedKey): + keys.append(unravel_key(("next", self.tensor_keys.terminated))) + else: + keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.terminated]) + self._in_keys = list(set(keys)) @property @@ -472,25 +499,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: if is_tensor_collection(entropy): entropy = _sum_td_features(entropy) except NotImplementedError: - x = dist.rsample((self.samples_mc_entropy,)) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) - if is_tensor_collection(log_prob): + + if is_tensor_collection(log_prob) and isinstance( + self.tensor_keys.sample_log_prob, NestedKey + ): log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: + # current log_prob of actions - action = tensordict.get(self.tensor_keys.action) + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) + if prev_log_prob.requires_grad: raise RuntimeError( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." @@ -513,8 +553,8 @@ def _log_weight( else: is_composite = False kwargs = {} - log_prob = dist.log_prob(tensordict, **kwargs) - if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) + if is_composite and not is_tensor_collection(prev_log_prob): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) @@ -1088,15 +1128,16 @@ def __init__( def _set_in_keys(self): keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.critic_network.in_keys, ] + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + # Get the parameter keys from the actor dist actor_dist_module = None for module in self.actor_network.modules(): diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 9c46fc98262..3e0b97de710 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,10 +8,10 @@ import re import warnings from enum import Enum -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule from torch import nn, Tensor from torch.nn import functional as F @@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize def _sum_td_features(data: TensorDictBase) -> torch.Tensor: # Sum all features and return a tensor return data.sum(dim="feature", reduce=True) + + +def _maybe_get_or_select(td, key_or_keys): + if isinstance(key_or_keys, (str, tuple)): + return td.get(key_or_keys) + return td.select(*key_or_keys) + + +def _maybe_add_or_extend_key( + tensor_keys: List[NestedKey], + key_or_list_of_keys: NestedKey | List[NestedKey], + prefix: NestedKey = None, +): + if prefix is not None: + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(unravel_key((prefix, key_or_list_of_keys))) + else: + tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys]) + return + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(key_or_list_of_keys) + else: + tensor_keys.extend(key_or_list_of_keys)