@@ -112,18 +112,22 @@ def _maybe_calibrate_or_quantize(
112112 }:
113113 return value
114114
115- device = next (module .parameters ()).device
116- scale = getattr (module , f"{ base_name } _scale" )
117- # zero_point = getattr(module, f"{base_name}_zero_point").data
118- zero_point = getattr (module , f"{ base_name } _zero_point" )
119-
120- if module .quantization_status == QuantizationStatus .CALIBRATION :
121- # get observer and get new quant params from observation
122- observer = getattr (module , f"{ base_name } _observer" )
123- updated_scale , updated_zero_point = observer (value )
124-
125- # update scale and zero point
126- scale .data = updated_scale .to (device )
127- zero_point .data = updated_zero_point .to (device )
115+ observer = getattr (module , f"{ base_name } _observer" )
116+ if observer .DYNAMIC :
117+ # dynamic quantization - get scale and zero point directly from observer
118+ scale , zero_point = observer (value )
119+ else :
120+ # static quantization - get previous scale and zero point from layer
121+ scale = getattr (module , f"{ base_name } _scale" )
122+ zero_point = getattr (module , f"{ base_name } _zero_point" )
123+
124+ if module .quantization_status == QuantizationStatus .CALIBRATION :
125+ # calibration mode - get new quant params from observer
126+ updated_scale , updated_zero_point = observer (value )
127+
128+ # update scale and zero point
129+ device = next (module .parameters ()).device
130+ scale .data = updated_scale .to (device )
131+ zero_point .data = updated_zero_point .to (device )
128132
129133 return fake_quantize (value , scale , zero_point , args )
0 commit comments