Skip to content

Commit

Permalink
convert : identify missing model files (ggerganov#9397)
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade authored Sep 16, 2024
1 parent 19514d6 commit d54c21d
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,22 @@ def set_vocab(self):
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts: set[str] = set()

if len(self.part_names) > 1:
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
index_name += ".index.json"
index_file = self.dir_model / index_name

if index_file.is_file():
self.tensor_names = set()
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
index_name += ".index.json"
logger.info(f"gguf: loading model weight map from '{index_name}'")
with open(self.dir_model / index_name, "r", encoding="utf-8") as f:
with open(index_file, "r", encoding="utf-8") as f:
index: dict[str, Any] = json.load(f)
weight_map = index.get("weight_map")
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
self.tensor_names.update(weight_map.keys())
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}

for part_name in self.part_names:
logger.info(f"gguf: loading model part '{part_name}'")
Expand All @@ -171,9 +174,17 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
data = LazyTorchTensor.from_eager(data)
yield name, data

# only verify tensor name presence; it doesn't matter if they are not in the right files
if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
# verify tensor name presence and identify potentially missing files
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}")
else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n"
f"Extra tensors: {extra}")

def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
Expand Down

0 comments on commit d54c21d

Please sign in to comment.