diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index c0cbfc713857..71161cc58529 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -60,6 +60,7 @@ def load_single_file_sub_model( local_files_only=False, torch_dtype=None, is_legacy_loading=False, + disable_mmap=False, **kwargs, ): if is_pipeline_module: @@ -106,6 +107,7 @@ def load_single_file_sub_model( subfolder=name, torch_dtype=torch_dtype, local_files_only=local_files_only, + disable_mmap=disable_mmap, **kwargs, ) @@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component configs in Diffusers format. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + disable_mmap = kwargs.pop("disable_mmap", False) is_legacy_loading = False @@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, + disable_mmap=disable_mmap, ) if config is None: @@ -504,6 +511,7 @@ def load_module(name, value): original_config=original_config, local_files_only=local_files_only, is_legacy_loading=is_legacy_loading, + disable_mmap=disable_mmap, **kwargs, ) except SingleFileComponentError as e: diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 79dc2691b9e4..0c998bab5e0f 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -182,6 +182,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (for example the pipeline components of the specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` @@ -229,6 +232,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) + disable_mmap = kwargs.pop("disable_mmap", False) if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict @@ -241,6 +245,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, + disable_mmap=disable_mmap, ) if quantization_config is not None: hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5933c634f4cc..cf694f4fc746 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -377,6 +377,7 @@ def load_single_file_checkpoint( cache_dir=None, local_files_only=None, revision=None, + disable_mmap=False, ): if os.path.isfile(pretrained_model_link_or_path): pretrained_model_link_or_path = pretrained_model_link_or_path @@ -394,7 +395,7 @@ def load_single_file_checkpoint( revision=revision, ) - checkpoint = load_state_dict(pretrained_model_link_or_path) + checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap) # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5f5ea2351709..a3d006f18994 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False +): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ @@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: - return safetensors.torch.load_file(checkpoint_file, device="cpu") + if disable_mmap: + return safetensors.torch.load(open(checkpoint_file, "rb").read()) + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") elif file_extension == GGUF_FILE_EXTENSION: return load_gguf_checkpoint(checkpoint_file) else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d236ebb83983..9b85968ee7bf 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) quantization_config = kwargs.pop("quantization_config", None) + disable_mmap = kwargs.pop("disable_mmap", False) allow_pickle = False if use_safetensors is None: @@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # TODO (sayakpaul, SunMarc): remove this after model loading refactor else: param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict(model_file, variant=variant) + state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu @@ -983,7 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant) + state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(