diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index e1642868228..349d1277c98 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -553,6 +553,8 @@ Some distributions are typically used in RL scripts. OneHotCategorical MaskedCategorical MaskedOneHotCategorical + Ordinal + OneHotOrdinal Utils ----- diff --git a/test/test_distributions.py b/test/test_distributions.py index 79929135bc8..e283fb9a9b8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -17,6 +17,8 @@ from torchrl.modules import ( NormalParamWrapper, OneHotCategorical, + OneHotOrdinal, + Ordinal, ReparamGradientStrategy, TanhNormal, TruncatedNormal, @@ -28,6 +30,7 @@ TanhDelta, ) from torchrl.modules.distributions.continuous import SafeTanhTransform +from torchrl.modules.distributions.discrete import _generate_ordinal_logits if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices @@ -677,6 +680,125 @@ def test_reparam(self, grad_method, sparse): assert logits.grad is not None and logits.grad.norm() > 0 +class TestOrdinal: + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)]) + def test_correct_sampling_shape( + self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str + ) -> None: + logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device) + + sampler = Ordinal(scores=logits) + actions = sampler.sample() # type: ignore[no-untyped-call] + log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call] + + expected_log_prob_shape = logit_shape[:-1] + expected_action_shape = logit_shape[:-1] + + assert actions.size() == torch.Size(expected_action_shape) + assert log_probs.size() == torch.Size(expected_log_prob_shape) + + @pytest.mark.parametrize("num_categories", [1, 10, 20]) + def test_correct_range(self, num_categories: int) -> None: + seq_size = 10 + batch_size = 100 + logits = torch.ones((batch_size, seq_size, num_categories)) + + sampler = Ordinal(scores=logits) + + actions = sampler.sample() # type: ignore[no-untyped-call] + + assert actions.min() >= 0 + assert actions.max() < num_categories + + def test_bounded_gradients(self) -> None: + logits = torch.tensor( + [[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]], + requires_grad=True, + dtype=torch.float32, + ) + + sampler = Ordinal(scores=logits) + + actions = sampler.sample() + log_probs = sampler.log_prob(actions) + + dummy_objective = log_probs.sum() + dummy_objective.backward() + + assert logits.grad is not None + assert not torch.isnan(logits.grad).any() + + def test_generate_ordinal_logits_numerical(self) -> None: + logits = torch.ones((3, 4)) + + ordinal_logits = _generate_ordinal_logits(scores=logits) + + expected_ordinal_logits = torch.tensor( + [ + [-4.2530, -3.2530, -2.2530, -1.2530], + [-4.2530, -3.2530, -2.2530, -1.2530], + [-4.2530, -3.2530, -2.2530, -1.2530], + ] + ) + + torch.testing.assert_close( + ordinal_logits, expected_ordinal_logits, atol=1e-4, rtol=1e-6 + ) + + +class TestOneHotOrdinal: + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)]) + def test_correct_sampling_shape( + self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str + ) -> None: + logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device) + + sampler = OneHotOrdinal(scores=logits) + actions = sampler.sample() # type: ignore[no-untyped-call] + log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call] + expected_log_prob_shape = logit_shape[:-1] + + expected_action_shape = logit_shape + + assert actions.size() == torch.Size(expected_action_shape) + assert log_probs.size() == torch.Size(expected_log_prob_shape) + + @pytest.mark.parametrize("num_categories", [2, 10, 20]) + def test_correct_range(self, num_categories: int) -> None: + seq_size = 10 + batch_size = 100 + logits = torch.ones((batch_size, seq_size, num_categories)) + + sampler = OneHotOrdinal(scores=logits) + + actions = sampler.sample() # type: ignore[no-untyped-call] + + assert torch.all(actions.sum(-1)) + assert actions.shape[-1] == num_categories + + def test_bounded_gradients(self) -> None: + logits = torch.tensor( + [[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]], + requires_grad=True, + dtype=torch.float32, + ) + + sampler = OneHotOrdinal(scores=logits) + + actions = sampler.sample() + log_probs = sampler.log_prob(actions) + + dummy_objective = log_probs.sum() + dummy_objective.backward() + + assert logits.grad is not None + assert not torch.isnan(logits.grad).any() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index f65461842bb..8523a783676 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -14,6 +14,8 @@ NormalParamExtractor, NormalParamWrapper, OneHotCategorical, + OneHotOrdinal, + Ordinal, ReparamGradientStrategy, TanhDelta, TanhNormal, diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 367765812bb..52f8f302a35 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -17,6 +17,8 @@ MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, + OneHotOrdinal, + Ordinal, ReparamGradientStrategy, ) @@ -31,5 +33,7 @@ MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, + Ordinal, + OneHotOrdinal, ) } diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index d2ffba30686..eb802294a12 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -9,11 +9,9 @@ import torch import torch.distributions as D +import torch.nn.functional as F -__all__ = [ - "OneHotCategorical", - "MaskedCategorical", -] +__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"] def _treat_categorical_params( @@ -56,7 +54,7 @@ class ReparamGradientStrategy(Enum): class OneHotCategorical(D.Categorical): """One-hot categorical distribution. - This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings + This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings of the discrete tensors. Args: @@ -66,7 +64,7 @@ class OneHotCategorical(D.Categorical): reparameterized samples. ``ReparamGradientStrategy.PassThrough`` will compute the sample gradients by using the softmax valued log-probability as a proxy to the - samples gradients. + sample gradients. ``ReparamGradientStrategy.RelaxedOneHot`` will use :class:`torch.distributions.RelaxedOneHot` to sample from the distribution. @@ -81,8 +79,6 @@ class OneHotCategorical(D.Categorical): """ - num_params: int = 1 - def __init__( self, logits: Optional[torch.Tensor] = None, @@ -155,7 +151,7 @@ class MaskedCategorical(D.Categorical): Args: logits (torch.Tensor): event log probabilities (unnormalized) probs (torch.Tensor): event probabilities. If provided, the probabilities - corresponding to to masked items will be zeroed and the probability + corresponding to masked items will be zeroed and the probability re-normalized along its last dimension. Keyword Args: @@ -306,7 +302,7 @@ class MaskedOneHotCategorical(MaskedCategorical): Args: logits (torch.Tensor): event log probabilities (unnormalized) probs (torch.Tensor): event probabilities. If provided, the probabilities - corresponding to to masked items will be zeroed and the probability + corresponding to masked items will be zeroed and the probability re-normalized along its last dimension. Keyword Args: @@ -469,3 +465,82 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten raise ValueError( f"Unknown reparametrization strategy {self.reparam_strategy}." ) + + +class Ordinal(D.Categorical): + """A discrete distribution for learning to sample from finite ordered sets. + + It is defined in contrast with the `Categorical` distribution, which does + not impose any notion of proximity or ordering over its support's atoms. + The `Ordinal` distribution explicitly encodes those concepts, which is + useful for learning discrete sampling from continuous sets. See ยง5 of + `Tang & Agrawal, 2020`_ for details. + + .. note:: + This class is mostly useful when you want to learn a distribution over + a finite set which is obtained by discretising a continuous set. + + Args: + scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. + Typically, the output of a neural network parametrising the distribution. + + Examples: + >>> num_atoms, num_samples = 5, 20 + >>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom + >>> torch.manual_seed(42) + >>> logits = torch.ones((num_atoms), requires_grad=True) + >>> optimizer = torch.optim.Adam([logits], lr=0.1) + >>> + >>> # Perform optimisation loop to minimise deviation from `mean` + >>> for _ in range(20): + >>> sampler = Ordinal(scores=logits) + >>> samples = sampler.sample((num_samples,)) + >>> # Define loss to encourage samples around the mean by penalising deviation from mean + >>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples)) + >>> loss.backward() + >>> optimizer.step() + >>> optimizer.zero_grad() + >>> + >>> sampler.probs + tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...) + >>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4) + >>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms) + torch.return_types.histogram( + hist=tensor([ 24., 158., 478., 228., 112.]), + bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000])) + """ + + def __init__(self, scores: torch.Tensor): + logits = _generate_ordinal_logits(scores) + super().__init__(logits=logits) + + +class OneHotOrdinal(OneHotCategorical): + """The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution. + + Args: + scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. + Typically, the output of a neural network parametrising the distribution. + """ + + def __init__(self, scores: torch.Tensor): + logits = _generate_ordinal_logits(scores) + super().__init__(logits=logits) + + +def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor: + """Implements Eq. 4 of `Tang & Agrawal, 2020`__.""" + # Assigns Bernoulli-like probabilities for each class in the set + log_probs = F.logsigmoid(scores) + complementary_log_probs = F.logsigmoid(-scores) + + # Total log-probability for being "larger than k" + larger_than_log_probs = log_probs.cumsum(dim=-1) + + # Total log-probability for being "smaller than k" + smaller_than_log_probs = ( + complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + - complementary_log_probs + ) + + return larger_than_log_probs + smaller_than_log_probs