-
Notifications
You must be signed in to change notification settings - Fork 385
[Feature] Adds ordinal distributions #2520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -9,11 +9,9 @@ | |||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||
import torch.distributions as D | ||||||||||||||||||||||||||||||||||||||||||||
import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
__all__ = [ | ||||||||||||||||||||||||||||||||||||||||||||
"OneHotCategorical", | ||||||||||||||||||||||||||||||||||||||||||||
"MaskedCategorical", | ||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||
__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"] | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def _treat_categorical_params( | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -56,7 +54,7 @@ class ReparamGradientStrategy(Enum): | |||||||||||||||||||||||||||||||||||||||||||
class OneHotCategorical(D.Categorical): | ||||||||||||||||||||||||||||||||||||||||||||
"""One-hot categorical distribution. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings | ||||||||||||||||||||||||||||||||||||||||||||
This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings | ||||||||||||||||||||||||||||||||||||||||||||
of the discrete tensors. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -66,7 +64,7 @@ class OneHotCategorical(D.Categorical): | |||||||||||||||||||||||||||||||||||||||||||
reparameterized samples. | ||||||||||||||||||||||||||||||||||||||||||||
``ReparamGradientStrategy.PassThrough`` will compute the sample gradients | ||||||||||||||||||||||||||||||||||||||||||||
by using the softmax valued log-probability as a proxy to the | ||||||||||||||||||||||||||||||||||||||||||||
samples gradients. | ||||||||||||||||||||||||||||||||||||||||||||
sample gradients. | ||||||||||||||||||||||||||||||||||||||||||||
louisfaury marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||
``ReparamGradientStrategy.RelaxedOneHot`` will use | ||||||||||||||||||||||||||||||||||||||||||||
:class:`torch.distributions.RelaxedOneHot` to sample from the distribution. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
@@ -81,8 +79,6 @@ class OneHotCategorical(D.Categorical): | |||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
num_params: int = 1 | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||
logits: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -155,7 +151,7 @@ class MaskedCategorical(D.Categorical): | |||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||
logits (torch.Tensor): event log probabilities (unnormalized) | ||||||||||||||||||||||||||||||||||||||||||||
probs (torch.Tensor): event probabilities. If provided, the probabilities | ||||||||||||||||||||||||||||||||||||||||||||
corresponding to to masked items will be zeroed and the probability | ||||||||||||||||||||||||||||||||||||||||||||
corresponding to masked items will be zeroed and the probability | ||||||||||||||||||||||||||||||||||||||||||||
re-normalized along its last dimension. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
Keyword Args: | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -306,7 +302,7 @@ class MaskedOneHotCategorical(MaskedCategorical): | |||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||
logits (torch.Tensor): event log probabilities (unnormalized) | ||||||||||||||||||||||||||||||||||||||||||||
probs (torch.Tensor): event probabilities. If provided, the probabilities | ||||||||||||||||||||||||||||||||||||||||||||
corresponding to to masked items will be zeroed and the probability | ||||||||||||||||||||||||||||||||||||||||||||
corresponding to masked items will be zeroed and the probability | ||||||||||||||||||||||||||||||||||||||||||||
re-normalized along its last dimension. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
Keyword Args: | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -469,3 +465,82 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten | |||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||
f"Unknown reparametrization strategy {self.reparam_strategy}." | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
class Ordinal(D.Categorical): | ||||||||||||||||||||||||||||||||||||||||||||
louisfaury marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||
"""A discrete distribution for learning to sample from finite ordered sets. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
It is defined in contrast with the `Categorical` distribution, which does | ||||||||||||||||||||||||||||||||||||||||||||
not impose any notion of proximity or ordering over its support's atoms. | ||||||||||||||||||||||||||||||||||||||||||||
The `Ordinal` distribution explicitly encodes those concepts, which is | ||||||||||||||||||||||||||||||||||||||||||||
useful for learning discrete sampling from continuous sets. See §5 of | ||||||||||||||||||||||||||||||||||||||||||||
`Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
.. note:: | ||||||||||||||||||||||||||||||||||||||||||||
This class is mostly useful when you want to learn a distribution over | ||||||||||||||||||||||||||||||||||||||||||||
a finite set which is obtained by discretising a continuous set. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. | ||||||||||||||||||||||||||||||||||||||||||||
Typically, the output of a neural network parametrising the distribution. | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you could add a simple example it'd be awesome, what about (feel free to edit)
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about this but I was unsure about its meaningfulness -- since this is a "learning-time" feature, here we wouldn't be showcasing anything that cannot be done by a Categorical distribution. I thought about it and wrote an example which involves some learning, maybe it is more useful like that. Wdyt ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My personal view on doctests is that they should be (1) almost alwyas there because people want to understand by seeing code, (2) simple enought that you can grasp what this class / function is about just by looking at one example and (3) they should not add an extra layer of complexity. If you think there's an example that can be put together that fullfills these three principles and also show some learning I'm open to it, but I'm just afraid that it'll require other components of the lib and give the wrong impression that this class requires other classes to run. That being said I agree that it's not obvious from the example above that you can do things you couldn't easily do with just a categorical! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed with all of the above; do you think the one I pushed this morning (copied below for visibility) clashes with 2. and 3 ? """
Examples:
>>> num_atoms, num_samples = 5, 20
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
>>> torch.manual_seed(42)
>>> logits = torch.ones((num_atoms), requires_grad=True)
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
>>>
>>> # Perform optimisation loop to minimise deviation from `mean`
>>> for _ in range(20):
>>> sampler = Ordinal(scores=logits)
>>> samples = sampler.sample((num_samples,))
>>> # Define loss to encourage samples around the mean by penalising deviation from mean
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
>>> loss.backward()
>>> optimizer.step()
>>> optimizer.zero_grad()
>>>
>>> sampler.probs
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
torch.return_types.histogram(
hist=tensor([ 24., 158., 478., 228., 112.]),
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
""" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Amazing I love it |
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
Examples: | ||||||||||||||||||||||||||||||||||||||||||||
>>> num_atoms, num_samples = 5, 20 | ||||||||||||||||||||||||||||||||||||||||||||
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom | ||||||||||||||||||||||||||||||||||||||||||||
>>> torch.manual_seed(42) | ||||||||||||||||||||||||||||||||||||||||||||
>>> logits = torch.ones((num_atoms), requires_grad=True) | ||||||||||||||||||||||||||||||||||||||||||||
>>> optimizer = torch.optim.Adam([logits], lr=0.1) | ||||||||||||||||||||||||||||||||||||||||||||
>>> | ||||||||||||||||||||||||||||||||||||||||||||
>>> # Perform optimisation loop to minimise deviation from `mean` | ||||||||||||||||||||||||||||||||||||||||||||
>>> for _ in range(20): | ||||||||||||||||||||||||||||||||||||||||||||
>>> sampler = Ordinal(scores=logits) | ||||||||||||||||||||||||||||||||||||||||||||
>>> samples = sampler.sample((num_samples,)) | ||||||||||||||||||||||||||||||||||||||||||||
>>> # Define loss to encourage samples around the mean by penalising deviation from mean | ||||||||||||||||||||||||||||||||||||||||||||
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples)) | ||||||||||||||||||||||||||||||||||||||||||||
>>> loss.backward() | ||||||||||||||||||||||||||||||||||||||||||||
>>> optimizer.step() | ||||||||||||||||||||||||||||||||||||||||||||
>>> optimizer.zero_grad() | ||||||||||||||||||||||||||||||||||||||||||||
>>> | ||||||||||||||||||||||||||||||||||||||||||||
>>> sampler.probs | ||||||||||||||||||||||||||||||||||||||||||||
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...) | ||||||||||||||||||||||||||||||||||||||||||||
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4) | ||||||||||||||||||||||||||||||||||||||||||||
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms) | ||||||||||||||||||||||||||||||||||||||||||||
torch.return_types.histogram( | ||||||||||||||||||||||||||||||||||||||||||||
hist=tensor([ 24., 158., 478., 228., 112.]), | ||||||||||||||||||||||||||||||||||||||||||||
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000])) | ||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, scores: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||
logits = _generate_ordinal_logits(scores) | ||||||||||||||||||||||||||||||||||||||||||||
super().__init__(logits=logits) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
class OneHotOrdinal(OneHotCategorical): | ||||||||||||||||||||||||||||||||||||||||||||
"""The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. | ||||||||||||||||||||||||||||||||||||||||||||
Typically, the output of a neural network parametrising the distribution. | ||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, scores: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||
logits = _generate_ordinal_logits(scores) | ||||||||||||||||||||||||||||||||||||||||||||
super().__init__(logits=logits) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||
louisfaury marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||
"""Implements Eq. 4 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`__.""" | ||||||||||||||||||||||||||||||||||||||||||||
# Assigns Bernoulli-like probabilities for each class in the set | ||||||||||||||||||||||||||||||||||||||||||||
log_probs = F.logsigmoid(scores) | ||||||||||||||||||||||||||||||||||||||||||||
complementary_log_probs = F.logsigmoid(-scores) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Total log-probability for being "larger than k" | ||||||||||||||||||||||||||||||||||||||||||||
larger_than_log_probs = log_probs.cumsum(dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Total log-probability for being "smaller than k" | ||||||||||||||||||||||||||||||||||||||||||||
smaller_than_log_probs = ( | ||||||||||||||||||||||||||||||||||||||||||||
complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||
- complementary_log_probs | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
return larger_than_log_probs + smaller_than_log_probs |
Uh oh!
There was an error while loading. Please reload this page.