diff --git a/src/brevitas/core/quant/delay.py b/src/brevitas/core/quant/delay.py index fc5e35002..804aaa81b 100644 --- a/src/brevitas/core/quant/delay.py +++ b/src/brevitas/core/quant/delay.py @@ -21,7 +21,7 @@ def __init__(self, quant_delay_steps): super(_DelayQuant, self).__init__() self.quant_delay_steps: int = brevitas.jit.Attribute(quant_delay_steps, int) - @brevitas.jit.script_method_110_disabled + @brevitas.jit.script_method def forward(self, x: Tensor, y: Tensor) -> Tensor: if self.quant_delay_steps > 0: self.quant_delay_steps = self.quant_delay_steps - 1 diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index b79f13316..7a7a0f828 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -61,7 +61,7 @@ def __init__( self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) - @brevitas.jit.script_method_110_disabled + @brevitas.jit.script_method def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: y = x / scale y = y + zero_point @@ -134,7 +134,7 @@ def __init__( self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) - @brevitas.jit.script_method_110_disabled + @brevitas.jit.script_method def to_int( self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: diff --git a/src/brevitas/jit.py b/src/brevitas/jit.py index c3dce6d13..0719e1017 100644 --- a/src/brevitas/jit.py +++ b/src/brevitas/jit.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from packaging import version import torch from brevitas.config import JIT_ENABLED @@ -17,12 +16,10 @@ def _disabled(fn): script = torch.jit.script ScriptModule = torch.jit.ScriptModule Attribute = torch.jit.Attribute - script_method_110_disabled = script_method else: script_method = _disabled script = _disabled - script_method_110_disabled = _disabled ScriptModule = torch.nn.Module Attribute = lambda val, type: val