Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Truncated normal dispatches #7506

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

HarshvirSandhu
Copy link
Contributor

@HarshvirSandhu HarshvirSandhu commented Sep 17, 2024

Description

Add jax dispatch for truncated normal distribution

Related Issue

Checklist

Type of change

  • New feature / enhancement

📚 Documentation preview 📚: https://pymc--7506.org.readthedocs.build/en/7506/

[pm.TruncatedNormal("b", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))],
)

assert jax.numpy.array_equal(a1=f_py(), a2=f_jax())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails with a NotImplementedError: No JAX implementation for the given distribution: truncated_normal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dispatch file needs to be imported when pymc is imported in order to be registered

Copy link

codecov bot commented Sep 17, 2024

Codecov Report

Attention: Patch coverage is 90.47619% with 2 lines in your changes missing coverage. Please review.

Project coverage is 92.85%. Comparing base (97df9c3) to head (b31dc5d).
Report is 51 commits behind head on main.

Files with missing lines Patch % Lines
pymc/dispatch/dispatch_jax.py 90.47% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7506      +/-   ##
==========================================
+ Coverage   88.58%   92.85%   +4.26%     
==========================================
  Files         103      106       +3     
  Lines       17104    17612     +508     
==========================================
+ Hits        15152    16354    +1202     
+ Misses       1952     1258     -694     
Files with missing lines Coverage Δ
pymc/dispatch/dispatch_jax.py 90.47% <90.47%> (ø)

... and 87 files with indirect coverage changes

@jax_funcify.register(TruncatedNormalRV)
def jax_funcify_TruncatedNormalRV(op, **kwargs):
def trunc_normal_fn(key, size, mu, sigma, lower, upper):
return None, jax.random.truncated_normal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mu and sigma missing and the split rng should be returned, not None.

Check some of the dispatches in PyTensor for a template

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using jax.nn.initializers.truncated_normal now, but the tests still fail. Not sure if I have used the rng parameter correctly in tests

@ricardoV94 ricardoV94 added the jax label Sep 17, 2024
rng_key, sampling_key = jax.random.split(rng_key, 2)
key["jax_state"] = rng_key

truncnorm = jax.nn.initializers.truncated_normal(sigma, lower=lower, upper=upper)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't pass sigma or mu as parameters in jax.random.truncated_normal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to use jax.random.truncated_normal, had to transform lower and upper

[pm.TruncatedNormal("b", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))],
)

assert jax.numpy.array_equal(a1=f_py(), a2=f_jax())
Copy link
Member

@ricardoV94 ricardoV94 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two are not expected to match in values, because JAX uses a different implementation than numpy. You can make a TruncatedNormal with a large sigma, and confirm it does not go beyond the bounds as a check


truncnorm = jax.nn.initializers.truncated_normal(sigma, lower=lower, upper=upper)

return key, truncnorm(key["jax_state"], size) + mu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding mu like this is potentially wrong, because when size is None, mu could be larger and we end up with repeated values

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you should dispatch on the more specific jax_sample_fn. For the issue with broadcasting, check how we do it here for Normal for example: https://github.com/pymc-devs/pytensor/blob/5d4b0c4b9a1e478dda48e912ee708a9e557e9343/pytensor/link/jax/dispatch/random.py#L147-L173

@HarshvirSandhu HarshvirSandhu marked this pull request as ready for review October 12, 2024 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Add dispatches of TruncatedNormal distribution for forward sampling
2 participants