Skip to content

Commit

Permalink
Fix (proxy): fix for attributes retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 27, 2024
1 parent 6079b12 commit 49a7683
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
34 changes: 23 additions & 11 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,22 @@ def requires_quant_input(self):
return False

def scale(self):
if not self.is_quant_enabled:
return None
scale = self.__call__(self.tracked_parameter_list[0]).scale
return scale

def zero_point(self):
if not self.is_quant_enabled:
return None
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
return zero_point

def bit_width(self):
bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width_
if not self.is_quant_enabled:
return None
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
Expand All @@ -106,11 +112,15 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector):

def pre_scale(self):
if not self.is_quant_enabled:
return None
output_tuple = self.tensor_quant(self.tracked_parameter_list[0])
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_scale

def pre_zero_point(self):
if not self.is_quant_enabled:
return None
output_tuple = self.tensor_quant(self.tracked_parameter_list[0])
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_zero_point
Expand Down Expand Up @@ -152,7 +162,7 @@ def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor,
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return QuantTensor(x, training=self.training)
return x


class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol):
Expand All @@ -176,23 +186,25 @@ def requires_input_scale(self) -> bool:
return False

def scale(self):
if self.requires_input_scale:
if self.requires_input_scale or not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
scale = self.__call__(self.tracked_parameter_list[0], zhs, zhs).scale
return scale
out = self.__call__(self.tracked_parameter_list[0], zhs, zhs)
return out.scale

def zero_point(self):
if not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
zero_point = self.__call__(self.tracked_parameter_list[0], zhs, zhs).zero_point
return zero_point
out = self.__call__(self.tracked_parameter_list[0], zhs, zhs)
return out.zero_point

def bit_width(self):
if self.requires_input_bit_width:
if self.requires_input_bit_width or not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
bit_width = self.__call__(self.tracked_parameter_list[0], zhs, zhs).bit_width
return bit_width
out = self.__call__(self.tracked_parameter_list[0], zhs, zhs)
return out.bit_width

def forward(
self,
Expand Down
30 changes: 19 additions & 11 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,34 @@ def init_tensor_quant(self):
self.fused_activation_quant_proxy = None

def scale(self, force_eval=True):
if not self.is_quant_enabled:
return None
current_status = self.training
if force_eval:
self.eval()
scale = self.__call__(self._zero_hw_sentinel()).scale
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return scale
return out.scale

def zero_point(self, force_eval=True):
if not self.is_quant_enabled:
return None
current_status = self.training
if force_eval:
self.eval()
zero_point = self.__call__(self._zero_hw_sentinel()).zero_point
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return zero_point
return out.zero_point

def bit_width(self):
scale = self.__call__(self._zero_hw_sentinel()).bit_width
return scale
def bit_width(self, force_eval=True):
if not self.is_quant_enabled:
return None
current_status = self.training
if force_eval:
self.eval()
out = self.__call__(self._zero_hw_sentinel())
self.train(current_status)
return out.bit_width

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
if self.fused_activation_quant_proxy is not None:
Expand Down Expand Up @@ -179,10 +189,6 @@ def scale(self, force_eval=True):
def zero_point(self, force_eval=True):
raise RuntimeError("Zero point for Dynamic Act Quant is input-dependant")

def bit_width(self):
bit_width = self.__call__(self._zero_hw_sentinel()).bit_width
return bit_width


class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

Expand All @@ -198,6 +204,8 @@ def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def bit_width(self):
if not self.is_quant_enabled:
return None
zhs = self._zero_hw_sentinel()
# Signed might or might not be defined. We just care about retrieving the bitwidth
empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)
Expand Down

0 comments on commit 49a7683

Please sign in to comment.