Skip to content

[Feature] Adds ordinal distributions #2520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ Some distributions are typically used in RL scripts.
OneHotCategorical
MaskedCategorical
MaskedOneHotCategorical
Ordinal
OneHotOrdinal

Utils
-----
Expand Down
122 changes: 122 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from torchrl.modules import (
NormalParamWrapper,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
ReparamGradientStrategy,
TanhNormal,
TruncatedNormal,
Expand All @@ -28,6 +30,7 @@
TanhDelta,
)
from torchrl.modules.distributions.continuous import SafeTanhTransform
from torchrl.modules.distributions.discrete import _generate_ordinal_logits

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
Expand Down Expand Up @@ -677,6 +680,125 @@ def test_reparam(self, grad_method, sparse):
assert logits.grad is not None and logits.grad.norm() > 0


class TestOrdinal:
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)])
def test_correct_sampling_shape(
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
) -> None:
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)

sampler = Ordinal(scores=logits)
actions = sampler.sample() # type: ignore[no-untyped-call]
log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call]

expected_log_prob_shape = logit_shape[:-1]
expected_action_shape = logit_shape[:-1]

assert actions.size() == torch.Size(expected_action_shape)
assert log_probs.size() == torch.Size(expected_log_prob_shape)

@pytest.mark.parametrize("num_categories", [1, 10, 20])
def test_correct_range(self, num_categories: int) -> None:
seq_size = 10
batch_size = 100
logits = torch.ones((batch_size, seq_size, num_categories))

sampler = Ordinal(scores=logits)

actions = sampler.sample() # type: ignore[no-untyped-call]

assert actions.min() >= 0
assert actions.max() < num_categories

def test_bounded_gradients(self) -> None:
logits = torch.tensor(
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
requires_grad=True,
dtype=torch.float32,
)

sampler = Ordinal(scores=logits)

actions = sampler.sample()
log_probs = sampler.log_prob(actions)

dummy_objective = log_probs.sum()
dummy_objective.backward()

assert logits.grad is not None
assert not torch.isnan(logits.grad).any()

def test_generate_ordinal_logits_numerical(self) -> None:
logits = torch.ones((3, 4))

ordinal_logits = _generate_ordinal_logits(scores=logits)

expected_ordinal_logits = torch.tensor(
[
[-4.2530, -3.2530, -2.2530, -1.2530],
[-4.2530, -3.2530, -2.2530, -1.2530],
[-4.2530, -3.2530, -2.2530, -1.2530],
]
)

torch.testing.assert_close(
ordinal_logits, expected_ordinal_logits, atol=1e-4, rtol=1e-6
)


class TestOneHotOrdinal:
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)])
def test_correct_sampling_shape(
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
) -> None:
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)

sampler = OneHotOrdinal(scores=logits)
actions = sampler.sample() # type: ignore[no-untyped-call]
log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call]
expected_log_prob_shape = logit_shape[:-1]

expected_action_shape = logit_shape

assert actions.size() == torch.Size(expected_action_shape)
assert log_probs.size() == torch.Size(expected_log_prob_shape)

@pytest.mark.parametrize("num_categories", [2, 10, 20])
def test_correct_range(self, num_categories: int) -> None:
seq_size = 10
batch_size = 100
logits = torch.ones((batch_size, seq_size, num_categories))

sampler = OneHotOrdinal(scores=logits)

actions = sampler.sample() # type: ignore[no-untyped-call]

assert torch.all(actions.sum(-1))
assert actions.shape[-1] == num_categories

def test_bounded_gradients(self) -> None:
logits = torch.tensor(
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
requires_grad=True,
dtype=torch.float32,
)

sampler = OneHotOrdinal(scores=logits)

actions = sampler.sample()
log_probs = sampler.log_prob(actions)

dummy_objective = log_probs.sum()
dummy_objective.backward()

assert logits.grad is not None
assert not torch.isnan(logits.grad).any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
NormalParamExtractor,
NormalParamWrapper,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
ReparamGradientStrategy,
TanhDelta,
TanhNormal,
Expand Down
4 changes: 4 additions & 0 deletions torchrl/modules/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
MaskedCategorical,
MaskedOneHotCategorical,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
ReparamGradientStrategy,
)

Expand All @@ -31,5 +33,7 @@
MaskedCategorical,
MaskedOneHotCategorical,
OneHotCategorical,
Ordinal,
OneHotOrdinal,
)
}
95 changes: 85 additions & 10 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@

import torch
import torch.distributions as D
import torch.nn.functional as F

__all__ = [
"OneHotCategorical",
"MaskedCategorical",
]
__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"]


def _treat_categorical_params(
Expand Down Expand Up @@ -56,7 +54,7 @@ class ReparamGradientStrategy(Enum):
class OneHotCategorical(D.Categorical):
"""One-hot categorical distribution.

This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings
This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings
of the discrete tensors.

Args:
Expand All @@ -66,7 +64,7 @@ class OneHotCategorical(D.Categorical):
reparameterized samples.
``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
by using the softmax valued log-probability as a proxy to the
samples gradients.
sample gradients.
``ReparamGradientStrategy.RelaxedOneHot`` will use
:class:`torch.distributions.RelaxedOneHot` to sample from the distribution.

Expand All @@ -81,8 +79,6 @@ class OneHotCategorical(D.Categorical):

"""

num_params: int = 1

def __init__(
self,
logits: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -155,7 +151,7 @@ class MaskedCategorical(D.Categorical):
Args:
logits (torch.Tensor): event log probabilities (unnormalized)
probs (torch.Tensor): event probabilities. If provided, the probabilities
corresponding to to masked items will be zeroed and the probability
corresponding to masked items will be zeroed and the probability
re-normalized along its last dimension.

Keyword Args:
Expand Down Expand Up @@ -306,7 +302,7 @@ class MaskedOneHotCategorical(MaskedCategorical):
Args:
logits (torch.Tensor): event log probabilities (unnormalized)
probs (torch.Tensor): event probabilities. If provided, the probabilities
corresponding to to masked items will be zeroed and the probability
corresponding to masked items will be zeroed and the probability
re-normalized along its last dimension.

Keyword Args:
Expand Down Expand Up @@ -469,3 +465,82 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten
raise ValueError(
f"Unknown reparametrization strategy {self.reparam_strategy}."
)


class Ordinal(D.Categorical):
"""A discrete distribution for learning to sample from finite ordered sets.

It is defined in contrast with the `Categorical` distribution, which does
not impose any notion of proximity or ordering over its support's atoms.
The `Ordinal` distribution explicitly encodes those concepts, which is
useful for learning discrete sampling from continuous sets. See §5 of
`Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details.

.. note::
This class is mostly useful when you want to learn a distribution over
a finite set which is obtained by discretising a continuous set.

Args:
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
Typically, the output of a neural network parametrising the distribution.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could add a simple example it'd be awesome, what about (feel free to edit)

Suggested change
Typically, the output of a neural network parametrising the distribution.
Typically, the output of a neural network parametrising the distribution.
Examples:
>>> batch_size = 1
>>> seq_size = 4
>>> num_categories = 3
>>> logits = torch.ones((batch_size, seq_size, num_categories))
>>> sampler = Ordinal(scores=logits)
>>> # The probabilities and logits are not equal anymore, but reflect the ordering of the samples
>>> sampler.probs
tensor([[[0.0900, 0.2447, 0.6652],
[0.0900, 0.2447, 0.6652],
[0.0900, 0.2447, 0.6652],
[0.0900, 0.2447, 0.6652]]])
>>> torch.manual_seed(10)
>>> # We print an histogram to reflect the sampling frequency of each item (0, 1 and 2)
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=3)
torch.return_types.histogram(
hist=tensor([ 386., 989., 2625.]),
bin_edges=tensor([0.0000, 0.6667, 1.3333, 2.0000]))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this but I was unsure about its meaningfulness -- since this is a "learning-time" feature, here we wouldn't be showcasing anything that cannot be done by a Categorical distribution.

I thought about it and wrote an example which involves some learning, maybe it is more useful like that. Wdyt ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My personal view on doctests is that they should be (1) almost alwyas there because people want to understand by seeing code, (2) simple enought that you can grasp what this class / function is about just by looking at one example and (3) they should not add an extra layer of complexity.

If you think there's an example that can be put together that fullfills these three principles and also show some learning I'm open to it, but I'm just afraid that it'll require other components of the lib and give the wrong impression that this class requires other classes to run.

That being said I agree that it's not obvious from the example above that you can do things you couldn't easily do with just a categorical!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with all of the above; do you think the one I pushed this morning (copied below for visibility) clashes with 2. and 3 ?

"""
    Examples:
        >>> num_atoms, num_samples = 5, 20
        >>> mean = (num_atoms - 1) / 2  # Target mean for samples, centered around the middle atom
        >>> torch.manual_seed(42)
        >>> logits = torch.ones((num_atoms), requires_grad=True)
        >>> optimizer = torch.optim.Adam([logits], lr=0.1)
        >>>
        >>> # Perform optimisation loop to minimise deviation from `mean`
        >>> for _ in range(20):
        >>>     sampler = Ordinal(scores=logits)
        >>>     samples = sampler.sample((num_samples,))
        >>>     # Define loss to encourage samples around the mean by penalising deviation from mean
        >>>     loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
        >>>     loss.backward()
        >>>     optimizer.step()
        >>>     optimizer.zero_grad()
        >>>
        >>> sampler.probs
        tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
        >>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
        >>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
        torch.return_types.histogram(
            hist=tensor([ 24., 158., 478., 228., 112.]),
            bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
    """

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing I love it


Examples:
>>> num_atoms, num_samples = 5, 20
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
>>> torch.manual_seed(42)
>>> logits = torch.ones((num_atoms), requires_grad=True)
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
>>>
>>> # Perform optimisation loop to minimise deviation from `mean`
>>> for _ in range(20):
>>> sampler = Ordinal(scores=logits)
>>> samples = sampler.sample((num_samples,))
>>> # Define loss to encourage samples around the mean by penalising deviation from mean
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
>>> loss.backward()
>>> optimizer.step()
>>> optimizer.zero_grad()
>>>
>>> sampler.probs
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
torch.return_types.histogram(
hist=tensor([ 24., 158., 478., 228., 112.]),
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
"""

def __init__(self, scores: torch.Tensor):
logits = _generate_ordinal_logits(scores)
super().__init__(logits=logits)


class OneHotOrdinal(OneHotCategorical):
"""The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution.

Args:
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
Typically, the output of a neural network parametrising the distribution.
"""

def __init__(self, scores: torch.Tensor):
logits = _generate_ordinal_logits(scores)
super().__init__(logits=logits)


def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
"""Implements Eq. 4 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`__."""
# Assigns Bernoulli-like probabilities for each class in the set
log_probs = F.logsigmoid(scores)
complementary_log_probs = F.logsigmoid(-scores)

# Total log-probability for being "larger than k"
larger_than_log_probs = log_probs.cumsum(dim=-1)

# Total log-probability for being "smaller than k"
smaller_than_log_probs = (
complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
- complementary_log_probs
)

return larger_than_log_probs + smaller_than_log_probs
Loading