Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doesn't work on Python 3.12 or 3.11 #92

Open
timoklein opened this issue Oct 8, 2024 · 7 comments
Open

Doesn't work on Python 3.12 or 3.11 #92

timoklein opened this issue Oct 8, 2024 · 7 comments

Comments

@timoklein
Copy link

As per pyproject.toml, this package should work for Python versions >=3.8. However, it actually only works up to and including Python version 3.10. For higher Python versions it throws ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field position is not allowed: use default_factory.

Fix

Set requires-python = ">=3.8,<3.11".

@epignatelli
Copy link
Owner

epignatelli commented Oct 11, 2024

Hey @timoklein, thank you very much for spotting this.

I wonder if we can fix that bug directly and support also python>=3.11.
Can you share more info about the error? Where does it happen? I think it should be just about fixing the default value of that property.

@timoklein
Copy link
Author

Hi Eduardo,

thanks for the quick reply! Python points me to this section here when throwing the error:

navix/navix/states.py

Lines 48 to 69 in 5fc4975

class Event(Positionable, HasColour):
"""A struct representing an event that happened in the environment. It contains the
position of the event, the colour of the entity involved in the event, whether the event
happened, and the type of event that happened.
!!! note
Notice that we need the `happened` property, which flags if an event has
happened or not, because JAX does not support variable size arrays.
This means that we cannot add an event to the list in the middle of training.
Instead, we initialise all events, and mask them out as not happened.
Attributes:
position (Array): The (row, column) position of the event in the grid.
colour (Array): The colour of the entity involved in the event.
happened (Array): A boolean flag indicating whether the event happened.
event_type (Array): The type of event that happened."""
position: Array = jnp.asarray([-1, -1], dtype=jnp.int32)
colour: Array = PALETTE.UNSET
happened: Array = jnp.asarray(False, dtype=jnp.bool_)
event_type: Array = EventType.NONE

Not sure if it also occurs in other places.

As you said, it shouldn't be that difficult to fix. I'm just starting with JAX, but I should have some time next week to try my hands on this. xland-minigrid had the same issue, so I can use their fix as a template. This issue might also be related. Is it reasonable to use the benchmark script to verify that the compilation properly works after the fix?

@ponseko
Copy link

ponseko commented Oct 14, 2024

Hi @epignatelli and @timoklein ,

To chime in here. This issue is related to how CPython is using hashability as a proxy for immutability. It's causing issues for JAX arrays because they are not-hashable but are immutable.

Related issues are here: jax-ml/jax#14295 and here: python/cpython#99401. It doesn't look like this will be changed anytime soon.

I think the part @timoklein linked to is the only relevant part that needs changing in Navix. Something like below will likely solve it; not sure if this is the most elegant solution though?

class Event(Positionable, HasColour):
      ......
      ......
      # Change to support python 3.11 -> default_factory
      position: Array = field(default_factory=lambda: jnp.asarray([-1, -1], dtype=jnp.int32))
      colour: Array = field(default_factory=lambda: PALETTE.UNSET)
      happened: Array = field(default_factory=lambda: jnp.asarray(False, dtype=jnp.bool_))
      event_type: Array = field(default_factory=lambda: EventType.NONE)

@epignatelli
Copy link
Owner

epignatelli commented Oct 15, 2024

Thank you @ponseko for the valuable context.

I do agree with Jake that this should be considered a CPython bug and that the solution using the default_factory property is not the most elegant one.
However, I do value Python>=3.11 support more, and I think we should pay the cost of inelegance if necessary to add the Python>=3.11 for the moment.

I still think there might be a more elegant workaround than default_factory, though.
For example, I wonder if we can intervene on the flax.struct.PyTreeNode with a wrapper that accepts JAX arrays, and programmatically converts them to fields with a default_factory.
If this is a viable solution, we can also consider pushing this upstream to FLAX (WDYT @cgarciae?)

Would you guys @ponseko @timoklein be interested in setting up a PR that we can use to reason about an action plan?

@timoklein
Copy link
Author

Hi @ponseko @epignatelli

As I said, I can probably start working on this next week. The fix with default_factory is trivial to implement.

For example, I wonder if we can intervene on the flax.struct.PyTreeNode with a wrapper that accepts JAX arrays, and programmatically converts them to fields with a default_factory.

That would be non-trivial for me due to my lack of exposure to JAX so far. But I can try next week.

@epignatelli
Copy link
Owner

epignatelli commented Oct 16, 2024

That is very appreciated, @timoklein, thank you very much!

flax.struct.PyTreeNode have very few to do with JAX.
They are pretty much a wrapper around dataclasses.dataclass. I think if we want to take that direction, we could subclass flax.struct.PyTreeNode, and, at initialisation, check the type of the inputs. When that is a JAX array, we wrap it around a dataclasses.field, otherwise we go to the next property, until we have no more.

@timoklein
Copy link
Author

I've been looking a bit into this and it doesn't seem that trivial to do. Or maybe I'm just not getting it right now. I'm gonna give it a little more time next week and see if I can make progress.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: TODOs
Development

No branches or pull requests

3 participants