-
Notifications
You must be signed in to change notification settings - Fork 11
EagerJAXArrayContext: return mutable numpy array in to_numpy #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR modifies the to_numpy
method in the JAX array context to return a mutable NumPy array by explicitly copying the result of jax.device_get
.
- Wraps
jax.device_get
result innp.copy
to produce a writable array.
Comments suppressed due to low confidence (1)
arraycontext/impl/jax/init.py:104
- The code uses
np.copy
butnumpy
is not imported in this scope, which will raise aNameError
. Addimport numpy as np
at the module level or within the function.
return np.copy(jax.device_get(ary))
I'm a bit out of the loop, so I have questions 😁
|
Thanks for the review @alexfikl !
There are in-place modifications in grudge, e.g. Without this PR, some tests in grudge with Jax fail (see e.g. inducer/grudge#380).
I think there are copies done, both for CPU and GPU, the problem is that |
I remember writing some version of that code.. it should really be rewritten in a nicer way :( But fair enough, I can see how that breaks. Are there any other places? (the tests in inducer/grudge#380 seem to be passing nicely at the moment) EDIT: It wasn't running tests with jax, I see them now! 😁 |
5d68e6e
to
4386b19
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would personally prefer if this gets fixed in grudge, since it's just that one place that seems to fail (?). Doing two copies on to_numpy
here seems like a hack.
If we do go for it, we should also document it somewhere that the arrays returned by to_numpy
are writable, since it's apparently not a given!
I agree, and I've just added a (sketchy) reimplementation in inducer/grudge#380, feel free to take a look. Setting this PR back to draft.
Right, either way of what happens with this PR, we should document this in |
AFAICS, there is no way to make an immutable numpy array mutable again except for copying it.