Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2024
1 parent 77cfc26 commit 839feb0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_inference_quant_bias = False
self.cache_inference_quant_bias_metadata_only = False
self.requires_input_scale = self.quant_injector.requires_input_scale
self.skip_create_quant_tensor = True
self.skip_create_quant_tensor = False

@property
def tracked_parameter_list(self):
Expand Down
12 changes: 8 additions & 4 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
@runtime_checkable
class AccQuantProxyProtocol(QuantProxyProtocol, Protocol):

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def forward(self, x: QuantTensor) -> QuantTensor:
...

Expand Down Expand Up @@ -251,6 +247,10 @@ def zero_point(self, force_eval=True):

class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
Expand All @@ -264,6 +264,10 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:

class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_create_quant_tensor = False

def bit_width(self):
if not self.is_quant_enabled:
return None
Expand Down

0 comments on commit 839feb0

Please sign in to comment.