Skip to content

Commit

Permalink
[Feature] Adds ordinal distributions (#2520)
Browse files Browse the repository at this point in the history
Co-authored-by: Louis Faury <[email protected]>
  • Loading branch information
louisfaury and Louis Faury authored Oct 29, 2024
1 parent d524d0d commit c851e16
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 10 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ Some distributions are typically used in RL scripts.
OneHotCategorical
MaskedCategorical
MaskedOneHotCategorical
Ordinal
OneHotOrdinal

Utils
-----
Expand Down
122 changes: 122 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from torchrl.modules import (
NormalParamWrapper,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
ReparamGradientStrategy,
TanhNormal,
TruncatedNormal,
Expand All @@ -28,6 +30,7 @@
TanhDelta,
)
from torchrl.modules.distributions.continuous import SafeTanhTransform
from torchrl.modules.distributions.discrete import _generate_ordinal_logits

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
Expand Down Expand Up @@ -677,6 +680,125 @@ def test_reparam(self, grad_method, sparse):
assert logits.grad is not None and logits.grad.norm() > 0


class TestOrdinal:
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)])
def test_correct_sampling_shape(
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
) -> None:
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)

sampler = Ordinal(scores=logits)
actions = sampler.sample() # type: ignore[no-untyped-call]
log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call]

expected_log_prob_shape = logit_shape[:-1]
expected_action_shape = logit_shape[:-1]

assert actions.size() == torch.Size(expected_action_shape)
assert log_probs.size() == torch.Size(expected_log_prob_shape)

@pytest.mark.parametrize("num_categories", [1, 10, 20])
def test_correct_range(self, num_categories: int) -> None:
seq_size = 10
batch_size = 100
logits = torch.ones((batch_size, seq_size, num_categories))

sampler = Ordinal(scores=logits)

actions = sampler.sample() # type: ignore[no-untyped-call]

assert actions.min() >= 0
assert actions.max() < num_categories

def test_bounded_gradients(self) -> None:
logits = torch.tensor(
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
requires_grad=True,
dtype=torch.float32,
)

sampler = Ordinal(scores=logits)

actions = sampler.sample()
log_probs = sampler.log_prob(actions)

dummy_objective = log_probs.sum()
dummy_objective.backward()

assert logits.grad is not None
assert not torch.isnan(logits.grad).any()

def test_generate_ordinal_logits_numerical(self) -> None:
logits = torch.ones((3, 4))

ordinal_logits = _generate_ordinal_logits(scores=logits)

expected_ordinal_logits = torch.tensor(
[
[-4.2530, -3.2530, -2.2530, -1.2530],
[-4.2530, -3.2530, -2.2530, -1.2530],
[-4.2530, -3.2530, -2.2530, -1.2530],
]
)

torch.testing.assert_close(
ordinal_logits, expected_ordinal_logits, atol=1e-4, rtol=1e-6
)


class TestOneHotOrdinal:
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)])
def test_correct_sampling_shape(
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
) -> None:
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)

sampler = OneHotOrdinal(scores=logits)
actions = sampler.sample() # type: ignore[no-untyped-call]
log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call]
expected_log_prob_shape = logit_shape[:-1]

expected_action_shape = logit_shape

assert actions.size() == torch.Size(expected_action_shape)
assert log_probs.size() == torch.Size(expected_log_prob_shape)

@pytest.mark.parametrize("num_categories", [2, 10, 20])
def test_correct_range(self, num_categories: int) -> None:
seq_size = 10
batch_size = 100
logits = torch.ones((batch_size, seq_size, num_categories))

sampler = OneHotOrdinal(scores=logits)

actions = sampler.sample() # type: ignore[no-untyped-call]

assert torch.all(actions.sum(-1))
assert actions.shape[-1] == num_categories

def test_bounded_gradients(self) -> None:
logits = torch.tensor(
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
requires_grad=True,
dtype=torch.float32,
)

sampler = OneHotOrdinal(scores=logits)

actions = sampler.sample()
log_probs = sampler.log_prob(actions)

dummy_objective = log_probs.sum()
dummy_objective.backward()

assert logits.grad is not None
assert not torch.isnan(logits.grad).any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
NormalParamExtractor,
NormalParamWrapper,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
ReparamGradientStrategy,
TanhDelta,
TanhNormal,
Expand Down
4 changes: 4 additions & 0 deletions torchrl/modules/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
MaskedCategorical,
MaskedOneHotCategorical,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
ReparamGradientStrategy,
)

Expand All @@ -31,5 +33,7 @@
MaskedCategorical,
MaskedOneHotCategorical,
OneHotCategorical,
Ordinal,
OneHotOrdinal,
)
}
95 changes: 85 additions & 10 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@

import torch
import torch.distributions as D
import torch.nn.functional as F

__all__ = [
"OneHotCategorical",
"MaskedCategorical",
]
__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"]


def _treat_categorical_params(
Expand Down Expand Up @@ -56,7 +54,7 @@ class ReparamGradientStrategy(Enum):
class OneHotCategorical(D.Categorical):
"""One-hot categorical distribution.
This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings
This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings
of the discrete tensors.
Args:
Expand All @@ -66,7 +64,7 @@ class OneHotCategorical(D.Categorical):
reparameterized samples.
``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
by using the softmax valued log-probability as a proxy to the
samples gradients.
sample gradients.
``ReparamGradientStrategy.RelaxedOneHot`` will use
:class:`torch.distributions.RelaxedOneHot` to sample from the distribution.
Expand All @@ -81,8 +79,6 @@ class OneHotCategorical(D.Categorical):
"""

num_params: int = 1

def __init__(
self,
logits: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -155,7 +151,7 @@ class MaskedCategorical(D.Categorical):
Args:
logits (torch.Tensor): event log probabilities (unnormalized)
probs (torch.Tensor): event probabilities. If provided, the probabilities
corresponding to to masked items will be zeroed and the probability
corresponding to masked items will be zeroed and the probability
re-normalized along its last dimension.
Keyword Args:
Expand Down Expand Up @@ -306,7 +302,7 @@ class MaskedOneHotCategorical(MaskedCategorical):
Args:
logits (torch.Tensor): event log probabilities (unnormalized)
probs (torch.Tensor): event probabilities. If provided, the probabilities
corresponding to to masked items will be zeroed and the probability
corresponding to masked items will be zeroed and the probability
re-normalized along its last dimension.
Keyword Args:
Expand Down Expand Up @@ -469,3 +465,82 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten
raise ValueError(
f"Unknown reparametrization strategy {self.reparam_strategy}."
)


class Ordinal(D.Categorical):
"""A discrete distribution for learning to sample from finite ordered sets.
It is defined in contrast with the `Categorical` distribution, which does
not impose any notion of proximity or ordering over its support's atoms.
The `Ordinal` distribution explicitly encodes those concepts, which is
useful for learning discrete sampling from continuous sets. See §5 of
`Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
.. note::
This class is mostly useful when you want to learn a distribution over
a finite set which is obtained by discretising a continuous set.
Args:
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
Typically, the output of a neural network parametrising the distribution.
Examples:
>>> num_atoms, num_samples = 5, 20
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
>>> torch.manual_seed(42)
>>> logits = torch.ones((num_atoms), requires_grad=True)
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
>>>
>>> # Perform optimisation loop to minimise deviation from `mean`
>>> for _ in range(20):
>>> sampler = Ordinal(scores=logits)
>>> samples = sampler.sample((num_samples,))
>>> # Define loss to encourage samples around the mean by penalising deviation from mean
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
>>> loss.backward()
>>> optimizer.step()
>>> optimizer.zero_grad()
>>>
>>> sampler.probs
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
torch.return_types.histogram(
hist=tensor([ 24., 158., 478., 228., 112.]),
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
"""

def __init__(self, scores: torch.Tensor):
logits = _generate_ordinal_logits(scores)
super().__init__(logits=logits)


class OneHotOrdinal(OneHotCategorical):
"""The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution.
Args:
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
Typically, the output of a neural network parametrising the distribution.
"""

def __init__(self, scores: torch.Tensor):
logits = _generate_ordinal_logits(scores)
super().__init__(logits=logits)


def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
"""Implements Eq. 4 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`__."""
# Assigns Bernoulli-like probabilities for each class in the set
log_probs = F.logsigmoid(scores)
complementary_log_probs = F.logsigmoid(-scores)

# Total log-probability for being "larger than k"
larger_than_log_probs = log_probs.cumsum(dim=-1)

# Total log-probability for being "smaller than k"
smaller_than_log_probs = (
complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
- complementary_log_probs
)

return larger_than_log_probs + smaller_than_log_probs

1 comment on commit c851e16

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: c851e16 Previous: d524d0d Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 315.61984495814033 iter/sec (stddev: 0.056513598058797354) 1465.2834300211289 iter/sec (stddev: 0.00003511093602279113) 4.64

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.