@@ -1065,7 +1065,7 @@ def __init__(
1065
1065
minmax_lr : float = None ,
1066
1066
disable_quanted_input : bool = False ,
1067
1067
nsamples : int = 512 ,
1068
- iters : int = 200 ,
1068
+ iters : int = None ,
1069
1069
use_ggml : bool = False ,
1070
1070
use_neural_speed : bool = False ,
1071
1071
llm_int8_skip_modules = None ,
@@ -1091,7 +1091,6 @@ def __init__(
1091
1091
self .lr = lr
1092
1092
self .minmax_lr = minmax_lr
1093
1093
self .disable_quanted_input = disable_quanted_input
1094
- self .iters = iters
1095
1094
self .llm_int8_skip_modules = (
1096
1095
llm_int8_skip_modules if llm_int8_skip_modules else []
1097
1096
)
@@ -1101,7 +1100,14 @@ def __init__(
1101
1100
self .calib_dataloader = kwargs .get ("calib_dataloader" , None )
1102
1101
self .calib_len = kwargs .get ("calib_len" , 2048 )
1103
1102
self .calib_func = kwargs .get ("calib_func" , None )
1104
- self .calib_iters = kwargs .get ("calib_iters" , 100 )
1103
+ calib_iters = kwargs .get ("calib_iters" , None )
1104
+ if iters is not None :
1105
+ self .calib_iters = iters
1106
+ if calib_iters is not None :
1107
+ logger .info ("cannot be set simultaneously for 'iters' and 'calib_iters', "
1108
+ "we will use 'iters' as calibration iterations!" )
1109
+ else :
1110
+ self .calib_iters = 200 if calib_iters is None else calib_iters
1105
1111
self .scheme = "sym" if self .sym else "asym"
1106
1112
if isinstance (compute_dtype , torch .dtype ):
1107
1113
self .compute_dtype = convert_dtype_torch2str (compute_dtype )
0 commit comments