Skip to content

Commit 8f6d656

Browse files
committed
Fix
1 parent 7d724f6 commit 8f6d656

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/brevitas_examples/llm/llm_quant/equalize.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
2828
return outs
2929

3030

31+
def trace_and_standardize(model, ref_kwargs):
32+
graph_model = value_trace(model, value_args=ref_kwargs)
33+
graph_model = TorchFunctionalToModule().apply(graph_model)
34+
graph_model = DuplicateSharedStatelessModule().apply(graph_model)
35+
36+
3137
@torch.no_grad()
3238
def apply_act_equalization(
3339
model,
@@ -51,9 +57,7 @@ def apply_act_equalization(
5157
# We can't do fp16 tracing on CPU as many kernels are not implemented
5258
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
5359
with cast_to_float32(model, dtype):
54-
graph_model = value_trace(model, value_args=ref_kwargs)
55-
graph_model = TorchFunctionalToModule().apply(graph_model)
56-
graph_model = DuplicateSharedStatelessModule().apply(graph_model)
60+
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
5761
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
5862
# or an FX interpreter to run it on GPU
5963
warnings.warn(
@@ -74,5 +78,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='
7478
# We can't do fp16 tracing on CPU as many kernels are not implemented
7579
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
7680
with cast_to_float32(model, dtype):
77-
graph_model = value_trace(model, value_args=ref_kwargs)
81+
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
7882
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)

0 commit comments

Comments
 (0)