-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Doesn't work on Python 3.12 or 3.11 #92
Comments
Hey @timoklein, thank you very much for spotting this. I wonder if we can fix that bug directly and support also python |
Hi Eduardo, thanks for the quick reply! Python points me to this section here when throwing the error: Lines 48 to 69 in 5fc4975
Not sure if it also occurs in other places. As you said, it shouldn't be that difficult to fix. I'm just starting with JAX, but I should have some time next week to try my hands on this. xland-minigrid had the same issue, so I can use their fix as a template. This issue might also be related. Is it reasonable to use the benchmark script to verify that the compilation properly works after the fix? |
Hi @epignatelli and @timoklein , To chime in here. This issue is related to how CPython is using hashability as a proxy for immutability. It's causing issues for JAX arrays because they are not-hashable but are immutable. Related issues are here: jax-ml/jax#14295 and here: python/cpython#99401. It doesn't look like this will be changed anytime soon. I think the part @timoklein linked to is the only relevant part that needs changing in Navix. Something like below will likely solve it; not sure if this is the most elegant solution though? class Event(Positionable, HasColour):
......
......
# Change to support python 3.11 -> default_factory
position: Array = field(default_factory=lambda: jnp.asarray([-1, -1], dtype=jnp.int32))
colour: Array = field(default_factory=lambda: PALETTE.UNSET)
happened: Array = field(default_factory=lambda: jnp.asarray(False, dtype=jnp.bool_))
event_type: Array = field(default_factory=lambda: EventType.NONE) |
Thank you @ponseko for the valuable context. I do agree with Jake that this should be considered a CPython bug and that the solution using the I still think there might be a more elegant workaround than Would you guys @ponseko @timoklein be interested in setting up a PR that we can use to reason about an action plan? |
As I said, I can probably start working on this next week. The fix with
That would be non-trivial for me due to my lack of exposure to JAX so far. But I can try next week. |
That is very appreciated, @timoklein, thank you very much!
|
I've been looking a bit into this and it doesn't seem that trivial to do. Or maybe I'm just not getting it right now. I'm gonna give it a little more time next week and see if I can make progress. |
As per
pyproject.toml
, this package should work for Python versions >=3.8. However, it actually only works up to and including Python version 3.10. For higher Python versions it throwsValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field position is not allowed: use default_factory
.Fix
Set
requires-python = ">=3.8,<3.11"
.The text was updated successfully, but these errors were encountered: