@@ -28,6 +28,12 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
28
28
return outs
29
29
30
30
31
+ def trace_and_standardize (model , ref_kwargs ):
32
+ graph_model = value_trace (model , value_args = ref_kwargs )
33
+ graph_model = TorchFunctionalToModule ().apply (graph_model )
34
+ graph_model = DuplicateSharedStatelessModule ().apply (graph_model )
35
+
36
+
31
37
@torch .no_grad ()
32
38
def apply_act_equalization (
33
39
model ,
@@ -51,9 +57,7 @@ def apply_act_equalization(
51
57
# We can't do fp16 tracing on CPU as many kernels are not implemented
52
58
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
53
59
with cast_to_float32 (model , dtype ):
54
- graph_model = value_trace (model , value_args = ref_kwargs )
55
- graph_model = TorchFunctionalToModule ().apply (graph_model )
56
- graph_model = DuplicateSharedStatelessModule ().apply (graph_model )
60
+ graph_model = trace_and_standardize (model , ref_kwargs = ref_kwargs )
57
61
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
58
62
# or an FX interpreter to run it on GPU
59
63
warnings .warn (
@@ -74,5 +78,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='
74
78
# We can't do fp16 tracing on CPU as many kernels are not implemented
75
79
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
76
80
with cast_to_float32 (model , dtype ):
77
- graph_model = value_trace (model , value_args = ref_kwargs )
81
+ graph_model = trace_and_standardize (model , ref_kwargs = ref_kwargs )
78
82
EqualizeGraph (scale_computation_type = scale_computation_type ).apply (graph_model )
0 commit comments