-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Adds ordinal distributions #2520
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2520
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 5 Unrelated FailuresAs of commit f414355 with merge base a70b258 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for this!
See the couple of suggestions but otherwise good to go
|
||
Args: | ||
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. | ||
Typically, the output of a neural network parametrising the distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could add a simple example it'd be awesome, what about (feel free to edit)
Typically, the output of a neural network parametrising the distribution. | |
Typically, the output of a neural network parametrising the distribution. | |
Examples: | |
>>> batch_size = 1 | |
>>> seq_size = 4 | |
>>> num_categories = 3 | |
>>> logits = torch.ones((batch_size, seq_size, num_categories)) | |
>>> sampler = Ordinal(scores=logits) | |
>>> # The probabilities and logits are not equal anymore, but reflect the ordering of the samples | |
>>> sampler.probs | |
tensor([[[0.0900, 0.2447, 0.6652], | |
[0.0900, 0.2447, 0.6652], | |
[0.0900, 0.2447, 0.6652], | |
[0.0900, 0.2447, 0.6652]]]) | |
>>> torch.manual_seed(10) | |
>>> # We print an histogram to reflect the sampling frequency of each item (0, 1 and 2) | |
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=3) | |
torch.return_types.histogram( | |
hist=tensor([ 386., 989., 2625.]), | |
bin_edges=tensor([0.0000, 0.6667, 1.3333, 2.0000])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this but I was unsure about its meaningfulness -- since this is a "learning-time" feature, here we wouldn't be showcasing anything that cannot be done by a Categorical distribution.
I thought about it and wrote an example which involves some learning, maybe it is more useful like that. Wdyt ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My personal view on doctests is that they should be (1) almost alwyas there because people want to understand by seeing code, (2) simple enought that you can grasp what this class / function is about just by looking at one example and (3) they should not add an extra layer of complexity.
If you think there's an example that can be put together that fullfills these three principles and also show some learning I'm open to it, but I'm just afraid that it'll require other components of the lib and give the wrong impression that this class requires other classes to run.
That being said I agree that it's not obvious from the example above that you can do things you couldn't easily do with just a categorical!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed with all of the above; do you think the one I pushed this morning (copied below for visibility) clashes with 2. and 3 ?
"""
Examples:
>>> num_atoms, num_samples = 5, 20
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
>>> torch.manual_seed(42)
>>> logits = torch.ones((num_atoms), requires_grad=True)
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
>>>
>>> # Perform optimisation loop to minimise deviation from `mean`
>>> for _ in range(20):
>>> sampler = Ordinal(scores=logits)
>>> samples = sampler.sample((num_samples,))
>>> # Define loss to encourage samples around the mean by penalising deviation from mean
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
>>> loss.backward()
>>> optimizer.step()
>>> optimizer.zero_grad()
>>>
>>> sampler.probs
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
torch.return_types.histogram(
hist=tensor([ 24., 158., 478., 228., 112.]),
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
"""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing I love it
|
||
Args: | ||
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. | ||
Typically, the output of a neural network parametrising the distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing I love it
Description
Adds an ordinally parametrised distribution from [Tang & Agrawal, 2020].
Motivation and Context
This parametrisation showed useful when learning distribution on finite sets that were obtained by discretising continuous sets.
Note: We can provide an example, although it should probably live in the
torchrl
repository.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!