Skip to content

Commit

Permalink
Fix (jit): remove Injector patcher (#752)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 1, 2023
1 parent 652ca8e commit 58936c4
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/core/quant/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, quant_delay_steps):
super(_DelayQuant, self).__init__()
self.quant_delay_steps: int = brevitas.jit.Attribute(quant_delay_steps, int)

@brevitas.jit.script_method_110_disabled
@brevitas.jit.script_method
def forward(self, x: Tensor, y: Tensor) -> Tensor:
if self.quant_delay_steps > 0:
self.quant_delay_steps = self.quant_delay_steps - 1
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method_110_disabled
@brevitas.jit.script_method
def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
y = x / scale
y = y + zero_point
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method_110_disabled
@brevitas.jit.script_method
def to_int(
self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor,
x: Tensor) -> Tensor:
Expand Down
7 changes: 1 addition & 6 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from brevitas.proxy.quant_proxy import QuantProxyProtocol
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.jit_utils import jit_patches_generator
from brevitas.utils.python_utils import patch


Expand Down Expand Up @@ -162,7 +161,6 @@ class BaseManager(ABC):

target_name = None
handlers = []
_base_trace_patches_generator = jit_patches_generator
_fn_to_cache = []
_fn_cache = []
_cached_io_handler_map = {}
Expand All @@ -183,10 +181,7 @@ def _gen_patches(cls, fn_dispatcher):

@classmethod
def _trace_patches(cls):
patches = []
if cls._base_trace_patches_generator is not None:
patches += cls._base_trace_patches_generator()
patches += cls._gen_patches(cls._trace_fn_dispatcher)
patches = cls._gen_patches(cls._trace_fn_dispatcher)
return patches

@classmethod
Expand Down
9 changes: 0 additions & 9 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from packaging import version
import torch

from brevitas.config import JIT_ENABLED

IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0')


def _disabled(fn):
return fn
Expand All @@ -20,15 +17,9 @@ def _disabled(fn):
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

if not IS_ABOVE_110:
script_method_110_disabled = _disabled
else:
script_method_110_disabled = script_method

else:

script_method = _disabled
script = _disabled
script_method_110_disabled = _disabled
ScriptModule = torch.nn.Module
Attribute = lambda val, type: val
30 changes: 1 addition & 29 deletions src/brevitas/utils/jit_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import inspect

import torch

try:
from torch._jit_internal import get_torchscript_modifier
except:
get_torchscript_modifier = None

from dependencies import Injector
from packaging import version
import torch

from brevitas import torch_version
from brevitas.inject import ExtendedInjector
from brevitas.jit import IS_ABOVE_110

from .python_utils import patch


def _get_modifier_wrapper(fn):
if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)):
return None
else:
return get_torchscript_modifier(fn)


if IS_ABOVE_110:

def jit_patches_generator():
return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)]
else:
jit_patches_generator = None


def clear_class_registry():
Expand Down
7 changes: 0 additions & 7 deletions tests/brevitas_examples/test_jit_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch

from brevitas.utils.jit_utils import jit_patches_generator
from brevitas_examples.bnn_pynq.models import model_with_cfg

FC_INPUT_SIZE = (1, 1, 28, 28)
Expand All @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits):
fc, _ = model_with_cfg(nname.lower(), pretrained=False)
fc.train(False)
input_tensor = torch.randn(FC_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(fc, input_tensor)
out_traced = traced_model(input_tensor)
out = fc(input_tensor)
Expand All @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits):
cnv, _ = model_with_cfg(nname.lower(), pretrained=False)
cnv.train(False)
input_tensor = torch.randn(CNV_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(cnv, input_tensor)
out_traced = traced_model(input_tensor)
out = cnv(input_tensor)
Expand Down

0 comments on commit 58936c4

Please sign in to comment.