From 58936c495fd6e7f0514f84b8fb15d13c4bc2d7c2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 1 Dec 2023 17:56:22 +0100 Subject: [PATCH] Fix (jit): remove Injector patcher (#752) --- src/brevitas/core/quant/delay.py | 2 +- src/brevitas/core/quant/int_base.py | 4 +-- src/brevitas/export/manager.py | 7 +----- src/brevitas/jit.py | 9 ------- src/brevitas/utils/jit_utils.py | 30 +---------------------- tests/brevitas_examples/test_jit_trace.py | 7 ------ 6 files changed, 5 insertions(+), 54 deletions(-) diff --git a/src/brevitas/core/quant/delay.py b/src/brevitas/core/quant/delay.py index fc5e35002..804aaa81b 100644 --- a/src/brevitas/core/quant/delay.py +++ b/src/brevitas/core/quant/delay.py @@ -21,7 +21,7 @@ def __init__(self, quant_delay_steps): super(_DelayQuant, self).__init__() self.quant_delay_steps: int = brevitas.jit.Attribute(quant_delay_steps, int) - @brevitas.jit.script_method_110_disabled + @brevitas.jit.script_method def forward(self, x: Tensor, y: Tensor) -> Tensor: if self.quant_delay_steps > 0: self.quant_delay_steps = self.quant_delay_steps - 1 diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index b79f13316..7a7a0f828 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -61,7 +61,7 @@ def __init__( self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) - @brevitas.jit.script_method_110_disabled + @brevitas.jit.script_method def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: y = x / scale y = y + zero_point @@ -134,7 +134,7 @@ def __init__( self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) - @brevitas.jit.script_method_110_disabled + @brevitas.jit.script_method def to_int( self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 0b69ce087..f8f1189fd 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -20,7 +20,6 @@ from brevitas.proxy.quant_proxy import QuantProxyProtocol from brevitas.quant_tensor import QuantTensor from brevitas.utils.jit_utils import clear_class_registry -from brevitas.utils.jit_utils import jit_patches_generator from brevitas.utils.python_utils import patch @@ -162,7 +161,6 @@ class BaseManager(ABC): target_name = None handlers = [] - _base_trace_patches_generator = jit_patches_generator _fn_to_cache = [] _fn_cache = [] _cached_io_handler_map = {} @@ -183,10 +181,7 @@ def _gen_patches(cls, fn_dispatcher): @classmethod def _trace_patches(cls): - patches = [] - if cls._base_trace_patches_generator is not None: - patches += cls._base_trace_patches_generator() - patches += cls._gen_patches(cls._trace_fn_dispatcher) + patches = cls._gen_patches(cls._trace_fn_dispatcher) return patches @classmethod diff --git a/src/brevitas/jit.py b/src/brevitas/jit.py index f5993a7b0..0719e1017 100644 --- a/src/brevitas/jit.py +++ b/src/brevitas/jit.py @@ -1,13 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from packaging import version import torch from brevitas.config import JIT_ENABLED -IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0') - def _disabled(fn): return fn @@ -20,15 +17,9 @@ def _disabled(fn): ScriptModule = torch.jit.ScriptModule Attribute = torch.jit.Attribute - if not IS_ABOVE_110: - script_method_110_disabled = _disabled - else: - script_method_110_disabled = script_method - else: script_method = _disabled script = _disabled - script_method_110_disabled = _disabled ScriptModule = torch.nn.Module Attribute = lambda val, type: val diff --git a/src/brevitas/utils/jit_utils.py b/src/brevitas/utils/jit_utils.py index 4f9d1aa26..54a16ec8a 100644 --- a/src/brevitas/utils/jit_utils.py +++ b/src/brevitas/utils/jit_utils.py @@ -1,38 +1,10 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import inspect - -import torch - -try: - from torch._jit_internal import get_torchscript_modifier -except: - get_torchscript_modifier = None - -from dependencies import Injector from packaging import version +import torch from brevitas import torch_version -from brevitas.inject import ExtendedInjector -from brevitas.jit import IS_ABOVE_110 - -from .python_utils import patch - - -def _get_modifier_wrapper(fn): - if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)): - return None - else: - return get_torchscript_modifier(fn) - - -if IS_ABOVE_110: - - def jit_patches_generator(): - return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)] -else: - jit_patches_generator = None def clear_class_registry(): diff --git a/tests/brevitas_examples/test_jit_trace.py b/tests/brevitas_examples/test_jit_trace.py index 52a4c6b1d..4c16bf800 100644 --- a/tests/brevitas_examples/test_jit_trace.py +++ b/tests/brevitas_examples/test_jit_trace.py @@ -6,7 +6,6 @@ import pytest import torch -from brevitas.utils.jit_utils import jit_patches_generator from brevitas_examples.bnn_pynq.models import model_with_cfg FC_INPUT_SIZE = (1, 1, 28, 28) @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits): fc, _ = model_with_cfg(nname.lower(), pretrained=False) fc.train(False) input_tensor = torch.randn(FC_INPUT_SIZE) - with ExitStack() as stack: - for mgr in jit_patches_generator(): - stack.enter_context(mgr) traced_model = torch.jit.trace(fc, input_tensor) out_traced = traced_model(input_tensor) out = fc(input_tensor) @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits): cnv, _ = model_with_cfg(nname.lower(), pretrained=False) cnv.train(False) input_tensor = torch.randn(CNV_INPUT_SIZE) - with ExitStack() as stack: - for mgr in jit_patches_generator(): - stack.enter_context(mgr) traced_model = torch.jit.trace(cnv, input_tensor) out_traced = traced_model(input_tensor) out = cnv(input_tensor)