Skip to content

Commit

Permalink
Add safetensors support (#4659)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
4 people authored Dec 1, 2023
1 parent 02288bc commit 7122362
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from .base_engine import CheckpointEngineBase
from typing import Iterable, Tuple
from functools import partial

from ..logging import inference_logger

Expand All @@ -28,6 +29,7 @@ def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
else:
self.model_config.max_seq_length = self.generation_config.max_length

self._local_checkpoint_dir = None
self._all_ckpt_paths = self._fetch_checkpoint_files()

def _fetch_checkpoint_files(self):
Expand All @@ -41,29 +43,53 @@ def _fetch_checkpoint_files(self):
# NOTE(jeff): allow_patterns here are explicitly not using safetensors or other
# checkpoint files that may be present. Example of all files in the llama-2-7b
# repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, list_files_info

def model_has_safetensors(model_name_or_path: str) -> bool:
if os.path.isdir(model_name_or_path):
file_list = os.listdir(model_name_or_path)
else:
file_list = [rf.rfilename for rf in list_files_info(model_name_or_path)]
for f in file_list:
if f.endswith(".safetensors"):
return True
return False

if os.path.isdir(self.model_name_or_path):
self._local_checkpoint_dir = self.model_name_or_path
else:
# We need to download the checkpoint files from HF
if model_has_safetensors(self.model_name_or_path):
# Prioritize downloading safetensors if they are available
allow_patterns = ["*.safetensors", "*.json", "*.pt"]
else:
# Fallback to bin files when safetensors are not present
allow_patterns = ["*.bin", "*.json", "*.pt"]
self._local_checkpoint_dir = snapshot_download(self.model_name_or_path,
allow_patterns=[
"*.bin",
"*.json",
"*.pt",
],
allow_patterns=allow_patterns,
revision=None,
token=self.auth_token)

assert os.path.isdir(
self._local_checkpoint_dir
), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint."

model_param_json = os.path.join(self._local_checkpoint_dir, "pytorch_model.bin.index.json")
# Set the appropriate file names based on whether we have safetensors or not
if model_has_safetensors(self._local_checkpoint_dir):
from safetensors.torch import load_file
model_param_json_fname = "model.safetensors.index.json"
model_file_fname = "model.safetensors"
self._checkpoint_load_fn = load_file
else:
model_param_json_fname = "pytorch_model.bin.index.json"
model_file_fname = "pytorch_model.bin"
self._checkpoint_load_fn = partial(torch.load, map_location="cpu")

model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname)

if not os.path.isfile(model_param_json):
# We don't need any json as all such HF models will have pytorch_model.bin
all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, 'pytorch_model.bin')]
all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, model_file_fname)]
else:
param_map = json.load(open(model_param_json, "r"))

Expand All @@ -84,7 +110,7 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
"""
for checkpoint in self._all_ckpt_paths:
inference_logger().info(f"Loading checkpoint: {checkpoint}")
checkpoint_sd = torch.load(checkpoint, map_location='cpu')
checkpoint_sd = self._checkpoint_load_fn(checkpoint)
param_keys = list(checkpoint_sd.keys())
for param_name in param_keys:
param = checkpoint_sd[param_name]
Expand Down

0 comments on commit 7122362

Please sign in to comment.