|
28 | 28 | __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']
|
29 | 29 |
|
30 | 30 | EPSILON = 1e-9
|
| 31 | +FLOAT16_EPSILON = 1e-4 |
31 | 32 |
|
32 | 33 | _supported_layers = (
|
33 | 34 | nn.ConvTranspose1d,
|
@@ -334,6 +335,7 @@ def _cross_layer_equalization(
|
334 | 335 | # Determine device and type of tensors
|
335 | 336 | device = next(sinks[0].parameters()).device
|
336 | 337 | dtype = next(sinks[0].parameters()).dtype
|
| 338 | + epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON |
337 | 339 |
|
338 | 340 | # If equalization criteria are not met, we return a scalar one to indicate that no equalization
|
339 | 341 | # has been performed
|
@@ -398,7 +400,7 @@ def _no_equalize():
|
398 | 400 | scale_fn = _select_scale_computation_fn(scale_computation_type)
|
399 | 401 | sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()]
|
400 | 402 | sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1))
|
401 |
| - sinks_range = torch.clamp(sinks_range, EPSILON) |
| 403 | + sinks_range = torch.clamp(sinks_range, epsilon) |
402 | 404 |
|
403 | 405 | # Determine the srcs_range based on where we are performing activation equalization or
|
404 | 406 | # weight equalization
|
@@ -434,7 +436,7 @@ def _no_equalize():
|
434 | 436 | srcs_range = torch.pow(srcs_range, alpha)
|
435 | 437 | sinks_range = torch.pow(sinks_range, 1 - alpha)
|
436 | 438 | scaling_factors = srcs_range / sinks_range
|
437 |
| - scaling_factors = torch.clamp(scaling_factors, EPSILON) |
| 439 | + scaling_factors = torch.clamp(scaling_factors, epsilon) |
438 | 440 | inverse_scaling_factors = torch.reciprocal(scaling_factors)
|
439 | 441 |
|
440 | 442 | if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
|
|
0 commit comments