vmap crash with tree_flatten trigger #14553
Unanswered
jecampagne
asked this question in
Q&A
Replies: 2 comments 15 replies
-
Thanks for the question – this issue is covered in Custom PyTrees and Initialization. You need to allow for the fact that your pytrees may be instantiated with placeholder values. You can do that by explicitly checking for such values, or by avoiding calling |
Beta Was this translation helpful? Give feedback.
4 replies
-
Since class T(BaseObj):
def __init__(self, a, gsparams=None):
super().__init__(a=a, gsparams=gsparams)
@property
def _m(self):
return self.a * 2
...
def f(self, x):
return jnp.sin(self.a * x) + self._m # <<< use it here |
Beta Was this translation helpful? Give feedback.
11 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
here is a snippet that reproduce may concern using a third party that designed the equivalent of
BaseObj
:If I instantiate a collection of T objs and apply a
vmap
Then if I need the
self._m
parameter to compute the result off
-function then I get a crashDo you know how I can compute
_m
at inititialization and that I can use it in the T-functions later?Thanks
Beta Was this translation helpful? Give feedback.
All reactions