diff --git a/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py b/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py index 7a2c2878..b4047fa3 100644 --- a/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py +++ b/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py @@ -52,35 +52,56 @@ def load_pipeline( if os.path.isfile(model_name_or_path): return pipeline_class.from_single_file(model_name_or_path, torch_dtype=torch_dtype, load_safety_checker=False) + return from_pretrained_with_variant_fallback( + logger=logger, + model_class=pipeline_class, + model_name_or_path=model_name_or_path, + torch_dtype=torch_dtype, + variant=variant, + # kwargs + safety_checker=None, + requires_safety_checker=False, + ) + + +ModelT = typing.TypeVar("ModelT") + + +def from_pretrained_with_variant_fallback( + logger: logging.Logger, + model_class: typing.Type[ModelT], + model_name_or_path: str, + torch_dtype: torch.dtype | None = None, + variant: str | None = None, + **kwargs, +) -> ModelT: + """A wrapper for .from_pretrained() that tries multiple variants if the initial one fails.""" variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant] - pipeline = None + model: ModelT | None = None for variant_to_try in variants_to_try: if variant_to_try != variant: logger.warning(f"Trying fallback variant '{variant_to_try}'.") try: - pipeline = pipeline_class.from_pretrained( + model = model_class.from_pretrained( model_name_or_path, - safety_checker=None, torch_dtype=torch_dtype, variant=variant_to_try, - requires_safety_checker=False, + **kwargs, ) except OSError as e: if "no file named" in str(e): # Ok; we'll try the variant fallbacks. - logger.warning( - f"Failed to load pipeline '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}." - ) + logger.warning(f"Failed to load '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}.") else: raise - if pipeline is not None: + if model is not None: break - if pipeline is None: - raise RuntimeError(f"Failed to load pipeline '{model_name_or_path}'.") - return pipeline + if model is None: + raise RuntimeError(f"Failed to load model '{model_name_or_path}'.") + return model def load_models_sd( diff --git a/src/invoke_training/model_merge/scripts/extract_lora_from_checkpoint.py b/src/invoke_training/model_merge/scripts/extract_lora_from_checkpoint.py index 995e4013..28c97a1a 100644 --- a/src/invoke_training/model_merge/scripts/extract_lora_from_checkpoint.py +++ b/src/invoke_training/model_merge/scripts/extract_lora_from_checkpoint.py @@ -20,7 +20,11 @@ UNET_TARGET_MODULES, save_sdxl_kohya_checkpoint, ) -from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline +from invoke_training._shared.stable_diffusion.model_loading_utils import ( + PipelineVersionEnum, + from_pretrained_with_variant_fallback, + load_pipeline, +) from invoke_training.model_merge.extract_lora import ( PEFT_BASE_LAYER_PREFIX, extract_lora_from_diffs, @@ -64,9 +68,13 @@ def load_model( submodel_path: Path = model_path / submodel_name if submodel_path.exists(): logger.info(f"Loading '{submodel_name}' from '{submodel_path}'.") - # TODO(ryand): Add variant fallbacks? - submodel = submodel_class.from_pretrained( - submodel_path, variant=variant, torch_dtype=dtype, local_files_only=True + submodel = from_pretrained_with_variant_fallback( + logger=logger, + model_class=submodel_class, + model_name_or_path=submodel_path, + torch_dtype=dtype, + variant=variant, + local_files_only=True, ) setattr(sd_model, submodel_name, submodel) else: @@ -110,24 +118,6 @@ def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device: raise ValueError(f"Unexpected device: {device_str}") -# TODO(ryand): Delete this after integrating the variant fallback logic. -# def load_sdxl_unet(model_path: str) -> UNet2DConditionModel: -# variants_to_try = [None, "fp16"] -# unet = None -# for variant in variants_to_try: -# try: -# unet = UNet2DConditionModel.from_pretrained(model_path, variant=variant, local_files_only=True) -# except OSError as e: -# if "no file named" in str(e): -# # Ok. We'll try a different variant. -# pass -# else: -# raise -# if unet is None: -# raise RuntimeError(f"Failed to load UNet from '{model_path}'.") -# return unet - - def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: return {k: v.to(device=device) for k, v in state_dict.items()} @@ -172,6 +162,8 @@ def extract_lora_from_submodel( # We just use the device for this calculation, since it's slow, then we move the results back to the CPU. logger.info("Calculating LoRA weights with SVD.") diffs = state_dict_to_device(diffs, device) + # TODO(ryand): Should we skip if the diffs are all zeros? This would happen if two models are identical. This could + # happen if some submodels differ while others don't. lora_weights = extract_lora_from_diffs( diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=out_dtype ) @@ -322,9 +314,8 @@ def main(): args = parser.parse_args() - logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger() - logger.setLevel(logging.DEBUG) orig_model_name_or_path, orig_model_variant = parse_model_arg(args.model_orig) tuned_model_name_or_path, tuned_model_variant = parse_model_arg(args.model_tuned)