From 58202f64af77bfba2feecc0870f55d041dac722a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 10 Nov 2023 13:05:12 +0000 Subject: [PATCH] Fix (jit): remove patcher --- src/brevitas/export/manager.py | 4 ++-- src/brevitas/jit.py | 10 ++++---- src/brevitas/utils/jit_utils.py | 42 ++++++++++++++++----------------- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 0b69ce087..1b945cf8e 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -20,7 +20,7 @@ from brevitas.proxy.quant_proxy import QuantProxyProtocol from brevitas.quant_tensor import QuantTensor from brevitas.utils.jit_utils import clear_class_registry -from brevitas.utils.jit_utils import jit_patches_generator +# from brevitas.utils.jit_utils import jit_patches_generator from brevitas.utils.python_utils import patch @@ -162,7 +162,7 @@ class BaseManager(ABC): target_name = None handlers = [] - _base_trace_patches_generator = jit_patches_generator + _base_trace_patches_generator = None # jit_patches_generator _fn_to_cache = [] _fn_cache = [] _cached_io_handler_map = {} diff --git a/src/brevitas/jit.py b/src/brevitas/jit.py index f5993a7b0..a33ec4219 100644 --- a/src/brevitas/jit.py +++ b/src/brevitas/jit.py @@ -6,7 +6,7 @@ from brevitas.config import JIT_ENABLED -IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0') +# IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0') def _disabled(fn): @@ -20,10 +20,10 @@ def _disabled(fn): ScriptModule = torch.jit.ScriptModule Attribute = torch.jit.Attribute - if not IS_ABOVE_110: - script_method_110_disabled = _disabled - else: - script_method_110_disabled = script_method + script_method_110_disabled = script_method + # script_method_110_disabled = _disabled + # if not IS_ABOVE_110: + # else: else: diff --git a/src/brevitas/utils/jit_utils.py b/src/brevitas/utils/jit_utils.py index 4f9d1aa26..1f461484d 100644 --- a/src/brevitas/utils/jit_utils.py +++ b/src/brevitas/utils/jit_utils.py @@ -1,38 +1,36 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import inspect +# import inspect -import torch - -try: - from torch._jit_internal import get_torchscript_modifier -except: - get_torchscript_modifier = None - -from dependencies import Injector +# from dependencies import Injector from packaging import version +import torch from brevitas import torch_version -from brevitas.inject import ExtendedInjector -from brevitas.jit import IS_ABOVE_110 -from .python_utils import patch +# try: +# from torch._jit_internal import get_torchscript_modifier +# except: +# get_torchscript_modifier = None +# from brevitas.inject import ExtendedInjector +# from brevitas.jit import IS_ABOVE_110 -def _get_modifier_wrapper(fn): - if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)): - return None - else: - return get_torchscript_modifier(fn) +# from .python_utils import patch +# def _get_modifier_wrapper(fn): +# if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)): +# return None +# else: +# return get_torchscript_modifier(fn) -if IS_ABOVE_110: +# if IS_ABOVE_110: - def jit_patches_generator(): - return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)] -else: - jit_patches_generator = None +# def jit_patches_generator(): +# return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)] +# else: +# jit_patches_generator = None def clear_class_registry():