@@ -437,12 +437,14 @@ def _no_equalize():
437
437
438
438
# Instead of clipping very low values, which would cause their reciprocal to be very large
439
439
# thus hindering quantization, we set them to one, which is the no-op equivalent for equalization
440
- sinks_range = torch .where (
441
- sinks_range > EPSILON , sinks_range , torch .tensor (1. , dtype = dtype , device = device ))
442
- srcs_range = torch .where (
443
- srcs_range > EPSILON , srcs_range , torch .tensor (1. , dtype = dtype , device = device ))
444
- srcs_range = torch .pow (srcs_range , alpha )
440
+ sinks_range = torch .where ((sinks_range < EPSILON ) | (srcs_range < EPSILON ),
441
+ torch .tensor (1. , dtype = dtype , device = device ),
442
+ sinks_range )
443
+ srcs_range = torch .where ((sinks_range < EPSILON ) | (srcs_range < EPSILON ),
444
+ torch .tensor (1. , dtype = dtype , device = device ),
445
+ srcs_range )
445
446
447
+ srcs_range = torch .pow (srcs_range , alpha )
446
448
sinks_range = torch .pow (sinks_range , 1 - alpha )
447
449
scaling_factors = srcs_range / sinks_range
448
450
inverse_scaling_factors = torch .reciprocal (scaling_factors )
0 commit comments