Trace detection when args are pytrees #17704
-
Hey all! Currently, seems to assume that |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 6 replies
-
Hi - you're correct that this function assumes the input is a flattened list of array-like objects. At the level of the stack where this is called, pytrees have already been flattened away, so there shouldn't be any other structure here. This is an internal utility that you shouldn't need to call directly – can you say more about what code you were running when you ran into this problem? |
Beta Was this translation helpful? Give feedback.
-
I assume this holds because But my use-case is different: essentially, I wanted to write a small compatibility bridge between "pytorch functions" (e.g. functions manipulating torch tensors) and jax (I was using such functions with other libraries that expected jax functions and not pytorch functions). Thus, I created a small decorator utility that transforms a function expecting torch tensors into a one expecting jax arrays. Under the hood, the function outputted by the decorator just converts the jax arrays into torch tensors using This all works actually quite well, and even supports nested |
Beta Was this translation helpful? Give feedback.
Hi - you're correct that this function assumes the input is a flattened list of array-like objects. At the level of the stack where this is called, pytrees have already been flattened away, so there shouldn't be any other structure here. This is an internal utility that you shouldn't need to call directly – can you say more about what code you were running when you ran into this problem?