Skip to content

Commit 8e39b02

Browse files
committed
Fix (graph/equalize): increase epsilon for float16
1 parent 52daf86 commit 8e39b02

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/brevitas/graph/equalize.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']
2929

3030
EPSILON = 1e-9
31+
FLOAT16_EPSILON = 1e-4
3132

3233
_supported_layers = (
3334
nn.ConvTranspose1d,
@@ -334,6 +335,7 @@ def _cross_layer_equalization(
334335
# Determine device and type of tensors
335336
device = next(sinks[0].parameters()).device
336337
dtype = next(sinks[0].parameters()).dtype
338+
epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON
337339

338340
# If equalization criteria are not met, we return a scalar one to indicate that no equalization
339341
# has been performed
@@ -398,7 +400,7 @@ def _no_equalize():
398400
scale_fn = _select_scale_computation_fn(scale_computation_type)
399401
sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()]
400402
sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1))
401-
sinks_range = torch.clamp(sinks_range, EPSILON)
403+
sinks_range = torch.clamp(sinks_range, epsilon)
402404

403405
# Determine the srcs_range based on where we are performing activation equalization or
404406
# weight equalization
@@ -434,7 +436,7 @@ def _no_equalize():
434436
srcs_range = torch.pow(srcs_range, alpha)
435437
sinks_range = torch.pow(sinks_range, 1 - alpha)
436438
scaling_factors = srcs_range / sinks_range
437-
scaling_factors = torch.clamp(scaling_factors, EPSILON)
439+
scaling_factors = torch.clamp(scaling_factors, epsilon)
438440
inverse_scaling_factors = torch.reciprocal(scaling_factors)
439441

440442
if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:

0 commit comments

Comments
 (0)