You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In certain cases, it seems like using _is_jax_data as a criterion for flattening trees can lead to structural incompatibilities, which can then result in errors when mapping over trees derived from distrax distributions.
To elaborate: let's consider that we have a model represented as a PyTree of parameters and metadata, and that this model contains a distrax distribution (or more generally Jittable) as a child node. We now wish to perform some selective update or partition operation on our model tree — for instance, to separate the tree into DeviceArray leaves and non-DeviceArray leaves. To do this, we will first perform a tree_map on our existing tree, mapping leaves that match the selection criterion to True and leaves that don't match to False. We will then use this mapped “mask” tree to specify the leaves to set to None on either side of the partition.
Unfortunately, this is where we hit a snag. Since our mask tree now contains boolean values in place of DeviceArrays, _is_jax_data will return False for our mask tree where it returned True for the original tree, and the children field could be left empty for the mask tree. Because the flattened distribution and mask trees do not thereafter share the same structure, we cannot use the mask tree as needed to create our partition. (Side note: Even if we didn't create a mask tree for our partition, we'd still end up with None on the side of the partition without DeviceArrays, ultimately resulting in the same structural incompatibility if we later wish to undo the partition.) I'm not actually sure whether the data-based flattening switch is the only cause here, but wanted to share my observations.
Here is a minimal reproducible example demonstrating the issue:
As a result of this design choice, distrax distributions are not currently compatible with equinox’s filter transforms, like eqx.filter_jit. This doesn't actually matter much for my use case — I can mark any model fields that are distrax.Distributions as static without recompiling since the instance doesn't change — but it is possible there are other use cases where this could make a difference.
Details
JAX v0.3.16
distrax v0.1.2 (nightly from c013670)
Running on CPU
The text was updated successfully, but these errors were encountered:
rciric
added a commit
to hypercoil/hypercoil
that referenced
this issue
Aug 27, 2022
Whilst we're here, _is_jax_data also uses a try-except around abstractifyhere which is probably going to really hurt performance.
Equinox actually used to do something similar to filter arrays from non-arrays, but switched to doing instance checks (e.g. isinstance(leaf, (np.ndarray, jnp.ndarray))) because this approch can be very slow.
Modulo these issues, the Jittable base class used here is basically doing the same thing as eqx.Module. I think realistically it's unlikely to happen (this repo hasn't seen any activity in a while) but one possible fix might be to replace it with eqx.Module.
In certain cases, it seems like using
_is_jax_data
as a criterion for flattening trees can lead to structural incompatibilities, which can then result in errors when mapping over trees derived fromdistrax
distributions.To elaborate: let's consider that we have a model represented as a PyTree of parameters and metadata, and that this model contains a
distrax
distribution (or more generallyJittable
) as a child node. We now wish to perform some selective update or partition operation on our model tree — for instance, to separate the tree intoDeviceArray
leaves and non-DeviceArray
leaves. To do this, we will first perform atree_map
on our existing tree, mapping leaves that match the selection criterion toTrue
and leaves that don't match toFalse
. We will then use this mapped “mask” tree to specify the leaves to set toNone
on either side of the partition.Unfortunately, this is where we hit a snag. Since our mask tree now contains boolean values in place of
DeviceArray
s,_is_jax_data
will return False for our mask tree where it returned True for the original tree, and thechildren
field could be left empty for the mask tree. Because the flattened distribution and mask trees do not thereafter share the same structure, we cannot use the mask tree as needed to create our partition. (Side note: Even if we didn't create a mask tree for our partition, we'd still end up withNone
on the side of the partition withoutDeviceArray
s, ultimately resulting in the same structural incompatibility if we later wish to undo the partition.) I'm not actually sure whether the data-based flattening switch is the only cause here, but wanted to share my observations.Here is a minimal reproducible example demonstrating the issue:
Results in:
As a result of this design choice,
distrax
distributions are not currently compatible withequinox
’s filter transforms, likeeqx.filter_jit
. This doesn't actually matter much for my use case — I can mark any model fields that aredistrax.Distribution
s as static without recompiling since the instance doesn't change — but it is possible there are other use cases where this could make a difference.Details
JAX v0.3.16
distrax v0.1.2 (nightly from c013670)
Running on CPU
The text was updated successfully, but these errors were encountered: