diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 0b69ce087..1b945cf8e 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -20,7 +20,7 @@ from brevitas.proxy.quant_proxy import QuantProxyProtocol from brevitas.quant_tensor import QuantTensor from brevitas.utils.jit_utils import clear_class_registry -from brevitas.utils.jit_utils import jit_patches_generator +# from brevitas.utils.jit_utils import jit_patches_generator from brevitas.utils.python_utils import patch @@ -162,7 +162,7 @@ class BaseManager(ABC): target_name = None handlers = [] - _base_trace_patches_generator = jit_patches_generator + _base_trace_patches_generator = None # jit_patches_generator _fn_to_cache = [] _fn_cache = [] _cached_io_handler_map = {} diff --git a/src/brevitas/jit.py b/src/brevitas/jit.py index f5993a7b0..a33ec4219 100644 --- a/src/brevitas/jit.py +++ b/src/brevitas/jit.py @@ -6,7 +6,7 @@ from brevitas.config import JIT_ENABLED -IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0') +# IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0') def _disabled(fn): @@ -20,10 +20,10 @@ def _disabled(fn): ScriptModule = torch.jit.ScriptModule Attribute = torch.jit.Attribute - if not IS_ABOVE_110: - script_method_110_disabled = _disabled - else: - script_method_110_disabled = script_method + script_method_110_disabled = script_method + # script_method_110_disabled = _disabled + # if not IS_ABOVE_110: + # else: else: diff --git a/src/brevitas/utils/jit_utils.py b/src/brevitas/utils/jit_utils.py index 4f9d1aa26..1f461484d 100644 --- a/src/brevitas/utils/jit_utils.py +++ b/src/brevitas/utils/jit_utils.py @@ -1,38 +1,36 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import inspect +# import inspect -import torch - -try: - from torch._jit_internal import get_torchscript_modifier -except: - get_torchscript_modifier = None - -from dependencies import Injector +# from dependencies import Injector from packaging import version +import torch from brevitas import torch_version -from brevitas.inject import ExtendedInjector -from brevitas.jit import IS_ABOVE_110 -from .python_utils import patch +# try: +# from torch._jit_internal import get_torchscript_modifier +# except: +# get_torchscript_modifier = None +# from brevitas.inject import ExtendedInjector +# from brevitas.jit import IS_ABOVE_110 -def _get_modifier_wrapper(fn): - if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)): - return None - else: - return get_torchscript_modifier(fn) +# from .python_utils import patch +# def _get_modifier_wrapper(fn): +# if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)): +# return None +# else: +# return get_torchscript_modifier(fn) -if IS_ABOVE_110: +# if IS_ABOVE_110: - def jit_patches_generator(): - return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)] -else: - jit_patches_generator = None +# def jit_patches_generator(): +# return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)] +# else: +# jit_patches_generator = None def clear_class_registry(): diff --git a/tests/brevitas_examples/test_jit_trace.py b/tests/brevitas_examples/test_jit_trace.py index 52a4c6b1d..4c16bf800 100644 --- a/tests/brevitas_examples/test_jit_trace.py +++ b/tests/brevitas_examples/test_jit_trace.py @@ -6,7 +6,6 @@ import pytest import torch -from brevitas.utils.jit_utils import jit_patches_generator from brevitas_examples.bnn_pynq.models import model_with_cfg FC_INPUT_SIZE = (1, 1, 28, 28) @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits): fc, _ = model_with_cfg(nname.lower(), pretrained=False) fc.train(False) input_tensor = torch.randn(FC_INPUT_SIZE) - with ExitStack() as stack: - for mgr in jit_patches_generator(): - stack.enter_context(mgr) traced_model = torch.jit.trace(fc, input_tensor) out_traced = traced_model(input_tensor) out = fc(input_tensor) @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits): cnv, _ = model_with_cfg(nname.lower(), pretrained=False) cnv.train(False) input_tensor = torch.randn(CNV_INPUT_SIZE) - with ExitStack() as stack: - for mgr in jit_patches_generator(): - stack.enter_context(mgr) traced_model = torch.jit.trace(cnv, input_tensor) out_traced = traced_model(input_tensor) out = cnv(input_tensor)