-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax backend): removed manual dtype casting. #23655
feat(jax backend): removed manual dtype casting. #23655
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Compliance Checks
Thank you for your Pull Request! We have run several checks on this pull request in order to make sure it's suitable for merging into this project. The results are listed in the following section.
Issue Reference
In order to be considered for merging, the pull request description must refer to a specific issue number. This is described in our contributing guide and our PR template.
This check is looking for a phrase similar to: "Fixes #XYZ" or "Resolves #XYZ" where XYZ is the issue number that this PR is meant to address.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! Feel free to merge, thanks @Madjid-CH 😄
(I'm assuming you've already run the tests for the functions changed and the tests pass or fail, whichever was the state of those tests before the changes. And even if it fails, the logs of the failure are the same as those before making the changes.)
actually test_eigvals is failing because of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just requested a minor change, thanks @Madjid-CH 😄
elif x.dtype in (jnp.int16, jnp.int8, jnp.uint8): | ||
x = x.astype(jnp.int64) | ||
elif x.dtype in (jnp.complex128, jnp.complex64): | ||
if x.dtype in (jnp.complex128, jnp.complex64): | ||
x = jnp.real(x).astype(jnp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we should just allow this casting for complex128
and mark complex64
as unsupported because if we cast the real element of complex64
(which is float32
) to float64
, we're having the same kind of additional memory consumption as with casting all int
values to int64
and so on.
I think the same has happened in the PRs we merged yesterday. I couldn't notice that then but seems like doing this would be preferable, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually even with complex128
we will have to copy the data if we want to get the real part. Since most likely an array of complex numbers is not 2 arrays of real part and imag part so we can extract only the real part, the values are interleaved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ved you were right I think we can get the real part without creating a new array
a = np.random.uniform(size=10)
a.dtype
Out[49]: dtype('float64')
a.astype("complex128")
Out[50]:
array([0.53129972+0.j, 0.16200054+0.j, 0.08087117+0.j, 0.76545754+0.j,
0.51599048+0.j, 0.37456005+0.j, 0.16486012+0.j, 0.77161369+0.j,
0.92052105+0.j, 0.90555509+0.j])
ac = a.astype("complex128")
ac.real
Out[53]:
array([0.53129972, 0.16200054, 0.08087117, 0.76545754, 0.51599048,
0.37456005, 0.16486012, 0.77161369, 0.92052105, 0.90555509])
ac.real[0] = 123
ac
Out[55]:
array([1.23000000e+02+0.j, 1.62000537e-01+0.j, 8.08711702e-02+0.j,
7.65457537e-01+0.j, 5.15990475e-01+0.j, 3.74560048e-01+0.j,
1.64860120e-01+0.j, 7.71613695e-01+0.j, 9.20521046e-01+0.j,
9.05555095e-01+0.j])
notice how the first element has changed.
I used numpy directly since jax array are immutable.
the same applies for complex64
a = np.random.uniform(size=10)
a = a.astype("float32")
a = a + 1j
a.dtype
Out[65]: dtype('complex64')
a.real
Out[66]:
array([0.15835479, 0.28751314, 0.5719754 , 0.50390446, 0.51206774,
0.6538808 , 0.6945784 , 0.19854428, 0.63909876, 0.5192142 ],
dtype=float32)
a.real[1: 6] = 0
a
Out[68]:
array([0.15835479+1.j, 0. +1.j, 0. +1.j, 0. +1.j,
0. +1.j, 0. +1.j, 0.6945784 +1.j, 0.19854428+1.j,
0.63909876+1.j, 0.5192142 +1.j], dtype=complex64)
a.real.dtype
Out[69]: dtype('float32')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! Feel free to merge, thanks @Madjid-CH 😄
elif x.dtype in (jnp.complex128, jnp.complex64): | ||
x = jnp.real(x).astype(jnp.float64) | ||
if x.dtype in (jnp.complex128, jnp.complex64): | ||
x = x.real |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, let's make this change in the other backends too in a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
understood
PR Description
the used dtype checker: