diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index e096171fb..0012e9065 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -28,7 +28,7 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] EPSILON = 1e-9 -FLOAT16_EPSILON = 1e-4 +FLOAT16_EPSILON = 2e-5 _supported_layers = ( nn.ConvTranspose1d, @@ -280,7 +280,8 @@ def _combine_weights_bias( weight = weight.data.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) - weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight) + epsilon = FLOAT16_EPSILON if weight.dtype == torch.float16 else EPSILON + weight = torch.where(torch.abs(weight) < epsilon, torch.tensor(epsilon).type_as(weight), weight) factor = torch.abs(bias) / torch.abs(weight) # From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450