Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect Calculation of Mode for TanhNormal Distribution #2186

Closed
Emile-Aquila opened this issue May 31, 2024 · 8 comments
Closed

[BUG] Incorrect Calculation of Mode for TanhNormal Distribution #2186

Emile-Aquila opened this issue May 31, 2024 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@Emile-Aquila
Copy link

Describe the bug

I'd like to express my gratitude for the swift and diligent maintenance of the torchrl library.

I have identified a potential issue with the implementation of the mode method in the TanhNormal distribution within TorchRL.
The calculation of the mode appears to be incorrect due to the nature of the tanh function applied to a normal distribution.

I think the mode of the TanhNormal distribution should accurately reflect the peak of the probability density function after applying the tanh transformation to the underlying normal distribution. Given the non-linearity of the tanh function, the mode calculation should account for this complexity.


To Reproduce

The current implementation of the mode method does not correctly compute the mode, resulting in inaccurate values.
For example, in the following scenario, the mode is expected to be around 1, but the method returns approximately 0.197.

import torch
from torchrl.modules import TanhNormal
import matplotlib.pyplot as plt

torch.random.manual_seed(0)

loc = torch.tensor([0.2], dtype=torch.float32)
scale = torch.tensor([1.0], dtype=torch.float32)

dist = TanhNormal(loc, scale, min=-1, max=1)
print("mode: ", dist.mode.item())  # mode:  0.1973753273487091

sample = dist.sample_n(10000)
plt.hist(sample.numpy(), bins=500, range=(-1, 1))
plt.show()

image



Thank you again for your continuous support and hard work on maintaining the torchrl library.

@Emile-Aquila Emile-Aquila added the bug Something isn't working label May 31, 2024
@vmoens
Copy link
Contributor

vmoens commented Jun 3, 2024

Hello
Thanks for reporting this
#2198 should fix it and improve the API.
Note that properly finding the mode of the distribution requires to find its maximum (well, the mode is the maximum haha) but there is not analytical expression. For the mean I implemented it using a regular stochastic expectation, but for the maximum i had to rely on Newton-Raphson which is considerably slower than what we had before, and non-differentiale.

Basically, before it was fast, now it's accurate

image

@vmoens
Copy link
Contributor

vmoens commented Jun 3, 2024

image

import torch
from torchrl.modules import TanhNormal
import matplotlib.pyplot as plt

torch.random.manual_seed(0)

loc = torch.tensor([0.2], dtype=torch.float32)
scale = torch.tensor([1.0], dtype=torch.float32)

dist = TanhNormal(loc, scale, min=-1, max=1)
print("mode: ", dist.mode.item())  # mode:  0.1973753273487091

sample = dist.sample_n(100000)
plt.hist(sample.numpy(), bins=64, range=(-1, 1))
plt.show()

Now results in
1.0

@Emile-Aquila
Copy link
Author

Emile-Aquila commented Jun 5, 2024

Hello

Thank you for your prompt response and for fixing the implementation error in the mode of TanhNormal distribution calculations.
While the calculations have become slower, I understand that this is a necessary trade-off.

I really appreciate it.

@vmoens
Copy link
Contributor

vmoens commented Jun 5, 2024

Given the time it takes to compute the mode now, wouldn't it make sense to make it a method (not a property) and raise a deprecation warning in the current property?
Intuitively i think most users expect a property to be "fast". I'm mostly worried about runtime of common algos if we adopt this new property - whereas we could redirect users towards an alternative solution if they want to use this.

@vmoens
Copy link
Contributor

vmoens commented Jun 5, 2024

I updated the PR to rely on Adam which is faster and more accurate than LBFGS, SGD and Newton Raphson (crazy right?)

@Emile-Aquila
Copy link
Author

Does this mean we should implement the calculation of the correct mode as a separate method and revert the mode property to its original implementation? Now that you mentioned it, that seem more practical for me.Sorry for my confusing issue.

@vmoens
Copy link
Contributor

vmoens commented Jun 5, 2024

yes if you look at the current status of the PR this is how I went about it

@vmoens
Copy link
Contributor

vmoens commented Jun 28, 2024

Fixed by #2198

@vmoens vmoens closed this as completed Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants