diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a755b0a60bf0a..c2aba909706d0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -148,9 +148,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_names_from_parts.update(model_part.keys()) for name in model_part.keys(): - data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] - if self.lazy: - data = LazyTorchTensor.from_eager(data) + if self.is_safetensors: + if self.lazy: + data = model_part.get_slice(name) + data = LazyTorchTensor.from_safetensors_slice(data) + else: + data = model_part.get_tensor(name) + else: + data = model_part[name] + if self.lazy: + data = LazyTorchTensor.from_eager(data) yield name, data # only verify tensor name presence; it doesn't matter if they are not in the right files @@ -3424,19 +3431,46 @@ class LazyTorchTensor(gguf.LazyBase): torch.float32: np.float32, } + # used for safetensors slices + # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 + # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 + _dtype_str_map: dict[str, torch.dtype] = { + "F64": torch.float64, + "F32": torch.float32, + "BF16": torch.bfloat16, + "F16": torch.float16, + # "U64": torch.uint64, + "I64": torch.int64, + # "U32": torch.uint32, + "I32": torch.int32, + # "U16": torch.uint16, + "I16": torch.int16, + "U8": torch.uint8, + "I8": torch.int8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + } + def numpy(self) -> gguf.LazyNumpyTensor: dtype = self._dtype_map[self.dtype] return gguf.LazyNumpyTensor( meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), - lazy=self._lazy, args=(self,), - func=(lambda s: s[0].numpy()) + func=(lambda s: s.numpy()) ) @classmethod - def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor: + def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: return torch.empty(size=shape, dtype=dtype, device="meta") + @classmethod + def from_safetensors_slice(cls, st_slice: Any) -> Tensor: + dtype = cls._dtype_str_map[st_slice.get_dtype()] + shape: tuple[int, ...] = tuple(st_slice.get_shape()) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + return cast(torch.Tensor, lazy) + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): del types # unused @@ -3447,7 +3481,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if func is torch.Tensor.numpy: return args[0].numpy() - return LazyTorchTensor._wrap_fn(func)(*args, **kwargs) + return cls._wrap_fn(func)(*args, **kwargs) def parse_args() -> argparse.Namespace: diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index 6e266f34f76ac..ac98d9a92a3e9 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -3,7 +3,6 @@ import logging from typing import Any, Callable -from collections import deque import numpy as np from numpy.typing import DTypeLike @@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta): _tensor_type: type _meta: Any _data: Any | None - _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager _args: tuple - _func: Callable[[tuple], Any] | None + _kwargs: dict[str, Any] + _func: Callable[[Any], Any] | None - def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): + def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None): super().__init__() self._meta = meta self._data = data - self._lazy = lazy if lazy is not None else deque() self._args = args + self._kwargs = kwargs if kwargs is not None else {} self._func = func assert self._func is not None or self._data is not None - if self._data is None: - self._lazy.append(self) def __init_subclass__(cls) -> None: if "_tensor_type" not in cls.__dict__: @@ -117,6 +114,7 @@ def wrapped_fn(*args, **kwargs): args = ((use_self,) if use_self is not None else ()) + args meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) + # TODO: maybe handle tensors in kwargs too if isinstance(meta_noop, bool) and not meta_noop: try: @@ -140,23 +138,7 @@ def wrapped_fn(*args, **kwargs): res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): - class CollectSharedLazy: - # emulating a static variable - shared_lazy: None | deque[LazyBase] = None - - @staticmethod - def collect_replace(t: LazyBase): - if CollectSharedLazy.shared_lazy is None: - CollectSharedLazy.shared_lazy = t._lazy - else: - CollectSharedLazy.shared_lazy.extend(t._lazy) - t._lazy = CollectSharedLazy.shared_lazy - - LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace) - - shared_lazy = CollectSharedLazy.shared_lazy - - return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs)) + return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn) else: del res # not needed # non-tensor return likely relies on the contents of the args @@ -168,26 +150,18 @@ def collect_replace(t: LazyBase): @classmethod def to_eager(cls, t: Any) -> Any: def simple_to_eager(_t: LazyBase) -> Any: - def already_eager_to_eager(_t: LazyBase) -> Any: - assert _t._data is not None + if _t._data is not None: return _t._data - while _t._data is None: - lt = _t._lazy.popleft() - if lt._data is not None: - # Lazy tensor did not belong in the lazy queue. - # Weirdly only happens with Bloom models... - # likely because tensors aren't unique in the queue. - # The final output is still the same as in eager mode, - # so it's safe to ignore this. - continue - assert lt._func is not None - lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) - lt._data = lt._func(lt._args) - # sanity check - assert lt._data is not None - assert lt._data.dtype == lt._meta.dtype - assert lt._data.shape == lt._meta.shape + # NOTE: there's a recursion limit in Python (usually 1000) + + assert _t._func is not None + _t._args = cls._recurse_apply(_t._args, simple_to_eager) + _t._data = _t._func(*_t._args, **_t._kwargs) + # sanity check + assert _t._data is not None + assert _t._data.dtype == _t._meta.dtype + assert _t._data.shape == _t._meta.shape return _t._data @@ -206,7 +180,7 @@ def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass @classmethod def from_eager(cls, t: Any) -> Any: if type(t) is cls: - # already eager + # already lazy return t elif isinstance(t, cls._tensor_type): return cls(meta=cls.eager_to_meta(t), data=t) @@ -228,8 +202,7 @@ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> def astype(self, dtype, *args, **kwargs): meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) full_args = (self, dtype,) + args - # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. - return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs))) + return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs))) def tofile(self, *args, **kwargs): eager = LazyNumpyTensor.to_eager(self) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7264240f5e17a..9aa2209e2f69f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -602,14 +602,12 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): for tensor, keys in self.block_mappings_cfg.items(): if tensor not in MODEL_TENSORS[arch]: continue - # TODO: make this configurable - n_experts = 160 - for xid in range(n_experts): - tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) - self.mapping[tensor_name] = (tensor, tensor_name) - for key in keys: - key = key.format(bid = bid, xid = xid) - self.mapping[key] = (tensor, tensor_name) + + tensor_name = TENSOR_NAMES[tensor].format(bid = bid) + self.mapping[tensor_name] = (tensor, tensor_name) + for key in keys: + key = key.format(bid = bid) + self.mapping[key] = (tensor, tensor_name) def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: result = self.mapping.get(key)