Convert General PyTree to builtins #15116
-
Hi, I'm mixing libraries with jax logic and ran into the problem where some of the custom nested PyTree objects that I use in my code are not directly compatible with other libraries (e.g., dm-reverb, dm-acme, dm-tree). But these libraries do work with the builtins, is there a straightforward way to convert any nested PyTree objects to a builtin type. For example, Right now I've tried implementing this by recursing through the I need something like I currently have something like this: def traverse(tree: Any, stop_branch: PyTreeDef | None = None) -> Iterator[Any]:
leaves, treedef = tree_flatten(tree)
children_stack, data_gen = treedef_children(treedef)[::-1], iter(leaves)
children_stack += [treedef]
try:
while branch := children_stack.pop():
if (stop_branch is not None) and (branch == stop_branch):
return
if treedef_is_leaf(branch):
yield True, branch, next(data_gen)
else:
yield False, branch, None
children_stack += treedef_children(branch)[::-1]
except IndexError as e:
if stop_branch is not None:
raise ValueError(
f"Stop-condition not met! {stop_branch} not in {treedef}"
) from e
my_tree = (1, 2, {5: 2, 'hi': 3})
for is_leaf, branch, value in traverse(my_tree):
...
# desired
if not isinstance(branch.node_type, Sequence) or not isinstance(branch.node_type, Mapping):
# Do something: Convert branch type or raise ValueError. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question. I don't know of any APi to do what you have in mind. PyTrees were not really designed with this level of introspection in mind. I think you'd have better luck unflattening the object and inspecting things at the Python level, where you have access to all the type names. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. I don't know of any APi to do what you have in mind. PyTrees were not really designed with this level of introspection in mind. I think you'd have better luck unflattening the object and inspecting things at the Python level, where you have access to all the type names.