Skip to content

Commit

Permalink
Partially revert 4162ef3bfda8aaaa055eeeb2e9f5044282953e74 -> forward
Browse files Browse the repository at this point in the history
function of float proxy.
  • Loading branch information
nickfraser committed Sep 3, 2024
1 parent 4162ef3 commit 71c93c8
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,21 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
y = self.fused_activation_quant_proxy.activation_impl(y)
y = (self.fused_activation_quant_proxy.activation_impl(y), None)
else:
y = self.fused_activation_quant_proxy(y)

# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

# If the second value (i.e., scale) is None, then quant is disabled
if isinstance(y, tuple) and y[1] is not None:
out = self.create_quant_tensor(y)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve quant_metadata
if isinstance(y, tuple):
y = y[0]
y = y[0]
out = self.create_quant_tensor(y, x=x)
else:
if isinstance(y, tuple):
y = y[0]
out = y
out = y[0]

if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only)
Expand Down

0 comments on commit 71c93c8

Please sign in to comment.