diff --git a/torch_xla/experimental/unbounded_dynamism_export.py b/torch_xla/experimental/unbounded_dynamism_export.py index b5f0637fd12..08958c86904 100644 --- a/torch_xla/experimental/unbounded_dynamism_export.py +++ b/torch_xla/experimental/unbounded_dynamism_export.py @@ -8,6 +8,7 @@ from torch.fx import Graph, GraphModule, subgraph_rewriter from torch.utils import _pytree as pytree from torch.utils._pytree import tree_map +from torch._dispatch.python import enable_python_dispatcher aten = torch.ops.aten @@ -30,7 +31,8 @@ def call_function( ) -> torch.fx.Node: node = graph.call_function(target, args, kwargs) _, args, kwargs = get_fake_args_kwargs(node) - node.meta["val"] = target(*args, **kwargs) + with enable_python_dispatcher(): + node.meta["val"] = target(*args, **kwargs) return node