From b6405445961f51033bcf8d30a1667a18c2c22a1d Mon Sep 17 00:00:00 2001 From: Xin Chen Date: Tue, 26 Sep 2023 20:53:06 +0800 Subject: [PATCH] load safetensors first --- lmdeploy/serve/turbomind/deploy.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py index 49b31def42..81129623ef 100644 --- a/lmdeploy/serve/turbomind/deploy.py +++ b/lmdeploy/serve/turbomind/deploy.py @@ -117,10 +117,15 @@ def load_checkpoint(model_path): Returns: Dict[str, torch.Tensor]: weight in torch format """ - files = [ - file for file in os.listdir(model_path) - if file.endswith('.bin') or file.endswith('.safetensors') - ] + suffixes = ['.safetensors', '.bin'] + for suffix in suffixes: + files = [ + file for file in os.listdir(model_path) if file.endswith(suffix) + ] + if len(files) > 0: + break + + assert len(files) > 0, f'could not find checkpoints in {model_path}' files = sorted(files) print(files) params = {}