Skip to content

Commit 06bcaae

Browse files
committed
Fix
1 parent 8f6d656 commit 06bcaae

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/brevitas/graph/equalize.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,14 @@ def _no_equalize():
437437

438438
# Instead of clipping very low values, which would cause their reciprocal to be very large
439439
# thus hindering quantization, we set them to one, which is the no-op equivalent for equalization
440-
sinks_range = torch.where(
441-
sinks_range > EPSILON, sinks_range, torch.tensor(1., dtype=dtype, device=device))
442-
srcs_range = torch.where(
443-
srcs_range > EPSILON, srcs_range, torch.tensor(1., dtype=dtype, device=device))
444-
srcs_range = torch.pow(srcs_range, alpha)
440+
sinks_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON),
441+
torch.tensor(1., dtype=dtype, device=device),
442+
sinks_range)
443+
srcs_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON),
444+
torch.tensor(1., dtype=dtype, device=device),
445+
srcs_range)
445446

447+
srcs_range = torch.pow(srcs_range, alpha)
446448
sinks_range = torch.pow(sinks_range, 1 - alpha)
447449
scaling_factors = srcs_range / sinks_range
448450
inverse_scaling_factors = torch.reciprocal(scaling_factors)

src/brevitas_examples/llm/llm_quant/equalize.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def trace_and_standardize(model, ref_kwargs):
3232
graph_model = value_trace(model, value_args=ref_kwargs)
3333
graph_model = TorchFunctionalToModule().apply(graph_model)
3434
graph_model = DuplicateSharedStatelessModule().apply(graph_model)
35+
return graph_model
3536

3637

3738
@torch.no_grad()

0 commit comments

Comments
 (0)