@@ -1100,7 +1100,7 @@ def calib_func(prepared_model):
1100
1100
assert os .path .exists (pathname ), f"Checkpoint file does not exist: { pathname } "
1101
1101
if os .path .isfile (pathname ):
1102
1102
low_precision_checkpoint = None
1103
- if pathname .endswith (".pt" ) or pathname . endswith ( ".pth" ):
1103
+ if pathname .endswith (( ".pt" , ".pth" , ".bin" ) ):
1104
1104
low_precision_checkpoint = torch .load (pathname , weights_only = True )
1105
1105
elif pathname .endswith (".safetensors" ):
1106
1106
try :
@@ -1113,13 +1113,13 @@ def calib_func(prepared_model):
1113
1113
low_precision_checkpoint = safetensors .torch .load_file (pathname )
1114
1114
assert (
1115
1115
low_precision_checkpoint is not None
1116
- ), f"Invalid checkpoint file: { pathname } . Should be a .pt, .pth or .safetensors file."
1116
+ ), f"Invalid checkpoint file: { pathname } . Should be a .pt, .pth, .bin or .safetensors file."
1117
1117
1118
1118
quant_method = {"quant_method" : "gptq" }
1119
1119
1120
1120
elif os .path .isdir (pathname ):
1121
1121
low_precision_checkpoint = {}
1122
- for pattern in ["*.pt" , "*.pth" ]:
1122
+ for pattern in ["*.pt" , "*.pth" , "*.bin" ]:
1123
1123
files = list (pathlib .Path (pathname ).glob (pattern ))
1124
1124
if files :
1125
1125
for f in files :
@@ -1141,7 +1141,7 @@ def calib_func(prepared_model):
1141
1141
low_precision_checkpoint .update (data_f )
1142
1142
assert (
1143
1143
len (low_precision_checkpoint ) > 0
1144
- ), f"Cannot find checkpoint (.pt/.pth/.safetensors) files in path { pathname } ."
1144
+ ), f"Cannot find checkpoint (.pt/.pth/.bin/. safetensors) files in path { pathname } ."
1145
1145
1146
1146
try :
1147
1147
with open (pathname + "/config.json" ) as f :
0 commit comments