diff --git a/src/brevitas/backport/__init__.py b/src/brevitas/backport/__init__.py index 00102698a..57912f0e2 100644 --- a/src/brevitas/backport/__init__.py +++ b/src/brevitas/backport/__init__.py @@ -251,3 +251,4 @@ def sym_min(a, b): # Populate magic methods on SymInt and SymFloat import brevitas.backport.fx.experimental.symbolic_shapes +import brevitas.backport.fx diff --git a/src/brevitas/backport/fx/experimental/proxy_tensor.py b/src/brevitas/backport/fx/experimental/proxy_tensor.py index f920e9f8e..430623e9d 100644 --- a/src/brevitas/backport/fx/experimental/proxy_tensor.py +++ b/src/brevitas/backport/fx/experimental/proxy_tensor.py @@ -62,13 +62,14 @@ from torch.utils._python_dispatch import TorchDispatchMode import torch.utils._pytree as pytree +from brevitas import backport +from brevitas.backport import fx from brevitas.backport import SymBool from brevitas.backport import SymFloat from brevitas.backport import SymInt from brevitas.backport.fx import GraphModule from brevitas.backport.fx import Proxy from brevitas.backport.fx import Tracer -import brevitas.backport.fx as fx from brevitas.backport.fx.passes.shape_prop import _extract_tensor_metadata from brevitas.backport.utils._stats import count from brevitas.backport.utils.weak import WeakTensorKeyDictionary @@ -316,7 +317,7 @@ def proxy_call(proxy_mode, func, args, kwargs): # `__torch_dispatch__` is only called on torch ops, which must subclass `OpOverload` # We treat all other functions as an `external_call`, for instance, a function decorated # with `@torch.fx.wrap` - external_call = not isinstance(func, fx.backport._ops.OpOverload) + external_call = not isinstance(func, backport._ops.OpOverload) def can_handle_tensor(x): return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)