Skip to content

Commit

Permalink
Fix backport imports
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jun 27, 2023
1 parent e6a7d70 commit 32bfbe9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/brevitas/backport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,4 @@ def sym_min(a, b):

# Populate magic methods on SymInt and SymFloat
import brevitas.backport.fx.experimental.symbolic_shapes
import brevitas.backport.fx
5 changes: 3 additions & 2 deletions src/brevitas/backport/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@
from torch.utils._python_dispatch import TorchDispatchMode
import torch.utils._pytree as pytree

from brevitas import backport
from brevitas.backport import fx
from brevitas.backport import SymBool
from brevitas.backport import SymFloat
from brevitas.backport import SymInt
from brevitas.backport.fx import GraphModule
from brevitas.backport.fx import Proxy
from brevitas.backport.fx import Tracer
import brevitas.backport.fx as fx
from brevitas.backport.fx.passes.shape_prop import _extract_tensor_metadata
from brevitas.backport.utils._stats import count
from brevitas.backport.utils.weak import WeakTensorKeyDictionary
Expand Down Expand Up @@ -316,7 +317,7 @@ def proxy_call(proxy_mode, func, args, kwargs):
# `__torch_dispatch__` is only called on torch ops, which must subclass `OpOverload`
# We treat all other functions as an `external_call`, for instance, a function decorated
# with `@torch.fx.wrap`
external_call = not isinstance(func, fx.backport._ops.OpOverload)
external_call = not isinstance(func, backport._ops.OpOverload)

def can_handle_tensor(x):
return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
Expand Down

0 comments on commit 32bfbe9

Please sign in to comment.