Skip to content

Commit ead9765

Browse files
authored
Add .bin load in WOQ scripts (#3346)
1 parent 11b2b1d commit ead9765

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/cpu/llm/inference/single_instance/run_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def calib_func(prepared_model):
11001100
assert os.path.exists(pathname), f"Checkpoint file does not exist: {pathname}"
11011101
if os.path.isfile(pathname):
11021102
low_precision_checkpoint = None
1103-
if pathname.endswith(".pt") or pathname.endswith(".pth"):
1103+
if pathname.endswith((".pt", ".pth", ".bin")):
11041104
low_precision_checkpoint = torch.load(pathname, weights_only=True)
11051105
elif pathname.endswith(".safetensors"):
11061106
try:
@@ -1113,13 +1113,13 @@ def calib_func(prepared_model):
11131113
low_precision_checkpoint = safetensors.torch.load_file(pathname)
11141114
assert (
11151115
low_precision_checkpoint is not None
1116-
), f"Invalid checkpoint file: {pathname}. Should be a .pt, .pth or .safetensors file."
1116+
), f"Invalid checkpoint file: {pathname}. Should be a .pt, .pth, .bin or .safetensors file."
11171117

11181118
quant_method = {"quant_method": "gptq"}
11191119

11201120
elif os.path.isdir(pathname):
11211121
low_precision_checkpoint = {}
1122-
for pattern in ["*.pt", "*.pth"]:
1122+
for pattern in ["*.pt", "*.pth", "*.bin"]:
11231123
files = list(pathlib.Path(pathname).glob(pattern))
11241124
if files:
11251125
for f in files:
@@ -1141,7 +1141,7 @@ def calib_func(prepared_model):
11411141
low_precision_checkpoint.update(data_f)
11421142
assert (
11431143
len(low_precision_checkpoint) > 0
1144-
), f"Cannot find checkpoint (.pt/.pth/.safetensors) files in path {pathname}."
1144+
), f"Cannot find checkpoint (.pt/.pth/.bin/.safetensors) files in path {pathname}."
11451145

11461146
try:
11471147
with open(pathname + "/config.json") as f:

0 commit comments

Comments
 (0)