diff --git a/examples/agents/composite_actor.py b/examples/agents/composite_actor.py index ae08062e084..c7e83095983 100644 --- a/examples/agents/composite_actor.py +++ b/examples/agents/composite_actor.py @@ -50,3 +50,9 @@ def forward(self, x): data = TensorDict({"x": torch.rand(10)}, []) module(data) print(actor(data)) + + +# TODO: +# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action") +# 2. Must multi-head require an action_key to be a list of keys (I guess so) +# 3. Using maps in the Actor diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py new file mode 100644 index 00000000000..d75ce3218b3 --- /dev/null +++ b/examples/agents/composite_ppo.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multi-head agent and PPO loss +============================= + +This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions +(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. + +The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict. +It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution +object containing the three distributions. + +The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters, +creates a distribution from these parameters, and samples from the distribution to output multiple actions. + +The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss. + +Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a +fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities` +argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False` +if not specified. + +In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in +the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used. + +""" + +import functools + +import torch +from tensordict import TensorDict +from tensordict.nn import ( + CompositeDistribution, + InteractionType, + ProbabilisticTensorDictModule as Prob, + ProbabilisticTensorDictSequential as ProbSeq, + TensorDictModule as Mod, + TensorDictSequential as Seq, + WrapModule as Wrap, +) +from torch import distributions as d +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss + +make_params = Mod( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], +) + + +def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + +# ============================================================================= +# Example 0: aggregate_probabilities=None (default) =========================== + +dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + aggregate_probabilities=None, +) + + +policy = ProbSeq( + make_params, + Prob( + in_keys=["params"], + out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), +) + +td = policy(TensorDict(batch_size=[4])) +print("0. result of policy call", td) + +dist = policy.get_dist(td) +log_prob = dist.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False +) +print("0. non-aggregated log-prob") + +# We can also get the log-prob from the policy directly +log_prob = policy.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False +) +print("0. non-aggregated log-prob (from policy)") + +# Build a dummy value operator +value_operator = Seq( + Wrap( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) +) + +# Create fake data +data = policy(TensorDict(batch_size=[4])) +data.set( + "next", + TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), +) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + + # Get the loss values + loss_vals = ppo(data) + print("0. ", loss_cls, loss_vals) + + +# =================================================================== +# Example 1: aggregate_probabilities=True =========================== + +dist_constructor.keywords["aggregate_probabilities"] = True + +td = policy(TensorDict(batch_size=[4])) +print("1. result of policy call", td) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since + # there is only one. + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")] + ) + + # Get the loss values + loss_vals = ppo(data) + print("1. ", loss_cls, loss_vals) + + +# =================================================================== +# Example 2: aggregate_probabilities=False =========================== + +dist_constructor.keywords["aggregate_probabilities"] = False + +td = policy(TensorDict(batch_size=[4])) +print("2. result of policy call", td) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + + # Get the loss values + loss_vals = ppo(data) + print("2. ", loss_cls, loss_vals) diff --git a/test/test_cost.py b/test/test_cost.py index 1f191e41db6..a538a8d3418 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -34,6 +34,7 @@ TensorDictModule as Mod, TensorDictSequential, TensorDictSequential as Seq, + WrapModule, ) from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key @@ -7907,27 +7908,30 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", - action_key="action", + action_key=None, observation_key="observation", sample_log_prob_key="sample_log_prob", composite_action_dist=False, - aggregate_probabilities=True, + aggregate_probabilities=None, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - if composite_action_dist: - action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) if composite_action_dist: + if action_key is None: + action_key = ("action", "action1") + else: + action_key = (action_key, "action1") + action_spec = Composite({action_key: {"action1": action_spec}}) distribution_class = functools.partial( CompositeDistribution, distribution_map={ "action1": TanhNormal, }, name_map={ - "action1": (action_key, "action1"), + "action1": action_key, }, log_prob_key=sample_log_prob_key, aggregate_probabilities=aggregate_probabilities, @@ -7938,6 +7942,8 @@ def _create_mock_actor( ] actor_in_keys = ["params"] else: + if action_key is None: + action_key = "action" distribution_class = TanhNormal module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( @@ -8148,8 +8154,8 @@ def _create_seq_mock_data_ppo( action_dim=4, atoms=None, device="cpu", - sample_log_prob_key="sample_log_prob", - action_key="action", + sample_log_prob_key=None, + action_key=None, composite_action_dist=False, ): # create a tensordict @@ -8171,6 +8177,17 @@ def _create_seq_mock_data_ppo( params_scale = torch.rand_like(action) / 10 loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) + if sample_log_prob_key is None: + if composite_action_dist: + sample_log_prob_key = ("action", "action1_log_prob") + else: + sample_log_prob_key = "sample_log_prob" + + if action_key is None: + if composite_action_dist: + action_key = ("action", "action1") + else: + action_key = "action" td = TensorDict( batch_size=(batch, T), source={ @@ -8182,7 +8199,7 @@ def _create_seq_mock_data_ppo( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: {"action1": action} if composite_action_dist else action, + action_key: action, sample_log_prob_key: ( torch.randn_like(action[..., 1]) / 10 ).masked_fill_(~mask, 0.0), @@ -8262,6 +8279,13 @@ def test_ppo( loss_critic_type="l2", functional=functional, ) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + if advantage is not None: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) if advantage is not None: advantage(td) else: @@ -8355,7 +8379,12 @@ def test_ppo_composite_no_aggregate( loss_critic_type="l2", functional=functional, ) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) if advantage is not None: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) else: if td_est is not None: @@ -8463,7 +8492,15 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): ) if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8570,7 +8607,20 @@ def test_ppo_shared_seq( ) if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + loss_fn2.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + loss = loss_fn(td).exclude("entropy") sum(val for key, val in loss.items() if key.startswith("loss_")).backward() @@ -8658,7 +8708,14 @@ def zero_param(p): # assert len(list(floss_fn.parameters())) == 0 with params.to_module(loss_fn): if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8748,10 +8805,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - @pytest.mark.parametrize("composite_action_dist", [True, False]) - def test_ppo_tensordict_keys_run( - self, loss_class, advantage, td_est, composite_action_dist - ): + def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8766,11 +8820,9 @@ def test_ppo_tensordict_keys_run( td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], - composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( sample_log_prob_key=tensor_keys["sample_log_prob"], - composite_action_dist=composite_action_dist, action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8864,9 +8916,7 @@ def test_ppo_tensordict_keys_run( @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize( "composite_action_dist", - [ - False, - ], + [False], ) def test_ppo_notensordict( self, @@ -8987,11 +9037,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): reduction=reduction, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) if reduction == "none": for key in loss.keys(): if key.startswith("loss_"): - assert loss[key].shape == td.shape + assert loss[key].shape == td.shape, key else: for key in loss.keys(): if not key.startswith("loss_"): @@ -9039,6 +9094,11 @@ def test_ppo_value_clipping( clip_value=clip_value, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) value = td.pop(loss_fn.tensor_keys.value) @@ -9060,6 +9120,110 @@ def test_ppo_value_clipping( loss = loss_fn(td) assert "loss_critic" in loss.keys() + def test_ppo_composite_dists(self): + d = torch.distributions + + make_params = TensorDictModule( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], + ) + + def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + aggregate_probabilities=False, + include_sum=False, + inplace=True, + ) + policy = ProbSeq( + make_params, + ProbabilisticTensorDictModule( + in_keys=["params"], + out_keys=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), + ) + # We want to make sure there is no warning + td = policy(TensorDict(batch_size=[4])) + assert isinstance( + policy.get_dist(td).log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False + ), + TensorDict, + ) + assert isinstance( + policy.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False + ), + TensorDict, + ) + value_operator = Seq( + WrapModule( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) + ) + for cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + data = policy(TensorDict(batch_size=[4])) + data.set( + "next", + TensorDict( + reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool) + ), + ) + ppo = cls(policy, value_operator) + ppo.set_keys( + action=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + loss = ppo(data) + loss.sum(reduce=True) + class TestA2C(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index eb9a916dfc1..2412ea62180 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,10 +5,11 @@ from __future__ import annotations import contextlib +import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import List, Tuple import torch from tensordict import ( @@ -27,12 +28,15 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _replace_last from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_add_or_extend_key, + _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, @@ -67,7 +71,10 @@ class PPOLoss(LossModule): Args: actor_network (ProbabilisticTensorDictSequential): policy operator. - critic_network (ValueOperator): value operator. + Typically, a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations + as input and outputting an action (or actions) as well as its log-probability value. + critic_network (ValueOperator): value operator. The critic will usually take the observations as input + and return a scalar value (``state_value`` by default) in the output keys. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -267,16 +274,16 @@ class _AcceptedKeys: Will be used for the underlying value estimator Defaults to ``"value_target"``. value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - sample_log_prob (NestedKey): The input tensordict key where the + sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the sample log probability is expected. Defaults to ``"sample_log_prob"``. - action (NestedKey): The input tensordict key where the action is expected. + action (NestedKey or list of nested keys): The input tensordict key where the action is expected. Defaults to ``"action"``. - reward (NestedKey): The input tensordict key where the reward is expected. + reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. - done (NestedKey): The key in the input TensorDict that indicates + done (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. - terminated (NestedKey): The key in the input TensorDict that indicates + terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ @@ -284,11 +291,11 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" - action: NestedKey = "action" - reward: NestedKey = "reward" - done: NestedKey = "done" - terminated: NestedKey = "terminated" + sample_log_prob: NestedKey | List[NestedKey] = "sample_log_prob" + action: NestedKey | List[NestedKey] = "action" + reward: NestedKey | List[NestedKey] = "reward" + done: NestedKey | List[NestedKey] = "done" + terminated: NestedKey | List[NestedKey] = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -369,7 +376,7 @@ def __init__( try: device = next(self.parameters()).device - except AttributeError: + except (AttributeError, StopIteration): device = torch.device("cpu") self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) @@ -408,16 +415,16 @@ def functional(self): return self._functional def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + self._in_keys = list(set(keys)) @property @@ -456,6 +463,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: reward=self.tensor_keys.reward, done=self.tensor_keys.done, terminated=self.tensor_keys.terminated, + sample_log_prob=self.tensor_keys.sample_log_prob, ) self._set_in_keys() @@ -463,34 +471,58 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: + if isinstance(dist, CompositeDistribution): + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False + kwargs = {"aggregate_probabilities": aggregate, "include_sum": include_sum} + else: + kwargs = {} try: - if isinstance(dist, CompositeDistribution): - kwargs = {"aggregate_probabilities": False, "include_sum": False} - else: - kwargs = {} entropy = dist.entropy(**kwargs) - if is_tensor_collection(entropy): - entropy = _sum_td_features(entropy) except NotImplementedError: - x = dist.rsample((self.samples_mc_entropy,)) - log_prob = dist.log_prob(x) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) + log_prob = dist.log_prob(x, **kwargs) + if is_tensor_collection(log_prob): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + try: + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, log_prob, err) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) + if is_tensor_collection(entropy): + entropy = _sum_td_features(entropy) return entropy.unsqueeze(-1) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: + # current log_prob of actions - action = tensordict.get(self.tensor_keys.action) + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + try: + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, tensordict, err) + if prev_log_prob.requires_grad: raise RuntimeError( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." @@ -500,26 +532,47 @@ def _log_weight( raise RuntimeError( f"tensordict stored {self.tensor_keys.action} requires grad." ) - if isinstance(action, torch.Tensor): + if isinstance(dist, CompositeDistribution): + is_composite = True + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False + kwargs = { + "inplace": False, + "aggregate_probabilities": aggregate, + "include_sum": include_sum, + } + else: + is_composite = False + kwargs = {} + if not is_composite: log_prob = dist.log_prob(action) else: - if isinstance(dist, CompositeDistribution): - is_composite = True - kwargs = { - "inplace": False, - "aggregate_probabilities": False, - "include_sum": False, - } - else: - is_composite = False - kwargs = {} - log_prob = dist.log_prob(tensordict, **kwargs) - if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) + if not is_tensor_collection(prev_log_prob): + # this isn't great, in general multihead actions should have a composite log-prob too + warnings.warn( + "You are using a composite distribution, yet your log-probability is a tensor. " + "This usually happens whenever the CompositeDistribution has aggregate_probabilities=True " + "or include_sum=True. These options should be avoided: leaf log-probs should be written " + "independently and PPO will take care of the aggregation.", + category=UserWarning, + ) + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) + if is_tensor_collection(kl_approx): + kl_approx = _sum_td_features(kl_approx) return log_weight, dist, kl_approx @@ -892,7 +945,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ratio = log_weight_clip.exp() gain2 = ratio * advantage - gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + gain = torch.stack([gain1, gain2], -1).min(dim=-1).values + if is_tensor_collection(gain): + gain = _sum_td_features(gain) td_out = TensorDict({"loss_objective": -gain}, batch_size=[]) td_out.set("clip_fraction", clip_fraction) @@ -1087,16 +1142,16 @@ def __init__( self.samples_mc_kl = samples_mc_kl def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + # Get the parameter keys from the actor dist actor_dist_module = None for module in self.actor_network.modules(): @@ -1156,6 +1211,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: advantage = (advantage - loc) / scale log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage + if is_tensor_collection(neg_loss): + neg_loss = _sum_td_features(neg_loss) with self.actor_network_params.to_module( self.actor_network @@ -1166,17 +1223,24 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) if isinstance(previous_dist, CompositeDistribution): + aggregate = previous_dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = previous_dist.include_sum + if include_sum is None: + include_sum = False kwargs = { - "aggregate_probabilities": False, + "aggregate_probabilities": aggregate, "inplace": False, - "include_sum": False, + "include_sum": include_sum, } else: kwargs = {} previous_log_prob = previous_dist.log_prob(x, **kwargs) current_log_prob = current_dist.log_prob(x, **kwargs) - if is_tensor_collection(current_log_prob): + if is_tensor_collection(previous_log_prob): previous_log_prob = _sum_td_features(previous_log_prob) + # Both dists have presumably the same params current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) @@ -1214,3 +1278,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init + + +def _make_lp_get_error(tensor_keys, log_prob, err): + result = ( + f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does " + f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. " + ) + # now check if we can substitute the actions with action_log_prob and retrieve the log-probs + action_keys = tensor_keys.action + if isinstance(action_keys, list): + has_all_log_probs = True + log_prob_keys = [] + for action_key in action_keys: + log_prob_key = _replace_last(action_key, "action_log_prob") + log_prob_keys.append(log_prob_key) + if log_prob_key not in log_prob: + has_all_log_probs = False + break + if has_all_log_probs: + result += ( + f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the " + f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve " + f"this error." + ) + return KeyError(result) + result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=)." + return KeyError(result) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 9c46fc98262..3e0b97de710 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,10 +8,10 @@ import re import warnings from enum import Enum -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule from torch import nn, Tensor from torch.nn import functional as F @@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize def _sum_td_features(data: TensorDictBase) -> torch.Tensor: # Sum all features and return a tensor return data.sum(dim="feature", reduce=True) + + +def _maybe_get_or_select(td, key_or_keys): + if isinstance(key_or_keys, (str, tuple)): + return td.get(key_or_keys) + return td.select(*key_or_keys) + + +def _maybe_add_or_extend_key( + tensor_keys: List[NestedKey], + key_or_list_of_keys: NestedKey | List[NestedKey], + prefix: NestedKey = None, +): + if prefix is not None: + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(unravel_key((prefix, key_or_list_of_keys))) + else: + tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys]) + return + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(key_or_list_of_keys) + else: + tensor_keys.extend(key_or_list_of_keys) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 3b08780e24c..fa05c8860a6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -13,7 +13,7 @@ from typing import Callable, List, Union import torch -from tensordict import TensorDictBase +from tensordict import is_tensor_collection, TensorDictBase from tensordict.nn import ( CompositeDistribution, dispatch, @@ -23,13 +23,18 @@ TensorDictModuleBase, ) from tensordict.nn.probabilistic import interaction_type -from tensordict.utils import NestedKey +from tensordict.utils import NestedKey, unravel_key from torch import Tensor from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST +from torchrl.objectives.utils import ( + _maybe_get_or_select, + _vmap_func, + hold_out_net, + RANDOM_MODULE_LIST, +) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -293,13 +298,18 @@ def out_keys(self): def set_keys(self, **kwargs) -> None: """Set tensordict key names.""" - for key, value in kwargs.items(): - if not isinstance(value, (str, tuple)): + for key, value in list(kwargs.items()): + if isinstance(value, list): + value = [unravel_key(k) for k in value] + elif not isinstance(value, (str, tuple)): + if value is None: + raise ValueError("tensordict keys cannot be None") raise ValueError( f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}" ) - if value is None: - raise ValueError("tensordict keys cannot be None") + else: + value = unravel_key(value) + if key not in self._AcceptedKeys.__dict__: raise KeyError( f"{key} is not an accepted tensordict key for advantages" @@ -312,6 +322,7 @@ def set_keys(self, **kwargs) -> None: raise KeyError( f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}" ) + kwargs[key] = value if self._tensor_keys is None: conf = asdict(self.default_keys) conf.update(self.dep_keys) @@ -1765,12 +1776,11 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) - # Make sure we have the log prob computed at collection time - if self.tensor_keys.sample_log_prob not in tensordict.keys(): - raise ValueError( - f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" - ) - log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) + lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob) + if is_tensor_collection(lp): + # Sum all values to match the batch size + lp = lp.sum(dim="feature", reduce=True) + log_mu = lp.view_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network):