-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: rotation matrices using s2fft #212
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
…/jaxoplanet into mpmath_starry_core
for more information, see https://pre-commit.ci
…/jaxoplanet into mpmath_starry_core
for more information, see https://pre-commit.ci
…/jaxoplanet into mpmath_starry_core
for more information, see https://pre-commit.ci
…/jaxoplanet into mpmath_starry_core
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…/jaxoplanet into mpmath_starry_core
for more information, see https://pre-commit.ci
…oplanet into s2fft_rotation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…oplanet into s2fft_rotation
The issue here must be related to s2fft! @lgrcia, can you try to work out which inputs we're passing in the test that hangs to isolate exactly what s2fft call we're executing. It would be interesting to know if we can get the same hang using just s2fft and no jaxoplanet. In that case we can report upstream and see what they say. |
I think I identified where the problem is from. Here is a way to reproduce it on macosimport numpy as np
from jaxoplanet.experimental.starry.rotation import dot_rotation_matrix
l_max = 5
theta = 0
ident = np.eye(l_max**2 + 2 * l_max + 1)
expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident)
calc = dot_rotation_matrix(l_max, None, None, 1.0, theta)(ident) This runs ok. But then, when l_max = 6
ident = np.eye(l_max**2 + 2 * l_max + 1)
expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident) it freezes. I think it has to do in how I (or s2fft) combine the static arguments in the jitted functions from I don't really understand why it behaves like this but a workaround for me is to avoid decorating the s2fft rotation functions with jit. So I copied all required functions (we only need 100 lines of python from s2fft) and removed the s2fft dependency, for now. I'm down to understand the problem better before reintroducing s2fft as a dependency. |
for more information, see https://pre-commit.ci
…oplanet into s2fft_rotation
for more information, see https://pre-commit.ci
…oplanet into s2fft_rotation
@dfm, here is a way to reproduce only with s2fft: import jax
from functools import partial
from s2fft.utils.rotation import generate_rotate_dls
@partial(jax.jit, static_argnums=(0,))
def f(deg, alpha):
return generate_rotate_dls(deg, alpha)
_ = f(5, 0.0) # this executes fine
_ = f(10, 0.0) # this freezes I might open an issue but I think this is not a proper use of this function given the static args (see https://github.com/astro-informatics/s2fft/blob/main/s2fft/utils/rotation.py#L75). Any idea why this would happen? I understand that each test run in separate python instances should pass. So could the issue be due to how pytest runners are dispatched on macOS? |
It's fascinating to me that that happens and that your solution works! I don't see any reason why the Regardless: I think this is a good "fix"! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! All done
Using the s2fft Python package to compute the Wigner D-matrices used to rotate the spherical harmonics. See also #140