Skip to content

Commit

Permalink
Followup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 13, 2023
1 parent 8e39b02 commit 3c26d54
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']

EPSILON = 1e-9
FLOAT16_EPSILON = 1e-4
FLOAT16_EPSILON = 2e-5

_supported_layers = (
nn.ConvTranspose1d,
Expand Down Expand Up @@ -280,7 +280,8 @@ def _combine_weights_bias(
weight = weight.data.reshape(weight.shape[0], -1)
bias = bias.reshape(-1, 1)

weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
epsilon = FLOAT16_EPSILON if weight.dtype == torch.float16 else EPSILON
weight = torch.where(torch.abs(weight) < epsilon, torch.tensor(epsilon).type_as(weight), weight)
factor = torch.abs(bias) / torch.abs(weight)

# From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450
Expand Down

0 comments on commit 3c26d54

Please sign in to comment.