Skip to content

Commit

Permalink
Bring back variant fallback support in the LoRA extraction script.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Jun 3, 2024
1 parent 3024f73 commit 9ee9d18
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 35 deletions.
43 changes: 32 additions & 11 deletions src/invoke_training/_shared/stable_diffusion/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,56 @@ def load_pipeline(
if os.path.isfile(model_name_or_path):
return pipeline_class.from_single_file(model_name_or_path, torch_dtype=torch_dtype, load_safety_checker=False)

return from_pretrained_with_variant_fallback(
logger=logger,
model_class=pipeline_class,
model_name_or_path=model_name_or_path,
torch_dtype=torch_dtype,
variant=variant,
# kwargs
safety_checker=None,
requires_safety_checker=False,
)


ModelT = typing.TypeVar("ModelT")


def from_pretrained_with_variant_fallback(
logger: logging.Logger,
model_class: typing.Type[ModelT],
model_name_or_path: str,
torch_dtype: torch.dtype | None = None,
variant: str | None = None,
**kwargs,
) -> ModelT:
"""A wrapper for .from_pretrained() that tries multiple variants if the initial one fails."""
variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant]

pipeline = None
model: ModelT | None = None
for variant_to_try in variants_to_try:
if variant_to_try != variant:
logger.warning(f"Trying fallback variant '{variant_to_try}'.")
try:
pipeline = pipeline_class.from_pretrained(
model = model_class.from_pretrained(
model_name_or_path,
safety_checker=None,
torch_dtype=torch_dtype,
variant=variant_to_try,
requires_safety_checker=False,
**kwargs,
)
except OSError as e:
if "no file named" in str(e):
# Ok; we'll try the variant fallbacks.
logger.warning(
f"Failed to load pipeline '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}."
)
logger.warning(f"Failed to load '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}.")
else:
raise

if pipeline is not None:
if model is not None:
break

if pipeline is None:
raise RuntimeError(f"Failed to load pipeline '{model_name_or_path}'.")
return pipeline
if model is None:
raise RuntimeError(f"Failed to load model '{model_name_or_path}'.")
return model


def load_models_sd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
)
from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline
from invoke_training._shared.stable_diffusion.model_loading_utils import (
PipelineVersionEnum,
from_pretrained_with_variant_fallback,
load_pipeline,
)
from invoke_training.model_merge.extract_lora import (
PEFT_BASE_LAYER_PREFIX,
extract_lora_from_diffs,
Expand Down Expand Up @@ -64,9 +68,13 @@ def load_model(
submodel_path: Path = model_path / submodel_name
if submodel_path.exists():
logger.info(f"Loading '{submodel_name}' from '{submodel_path}'.")
# TODO(ryand): Add variant fallbacks?
submodel = submodel_class.from_pretrained(
submodel_path, variant=variant, torch_dtype=dtype, local_files_only=True
submodel = from_pretrained_with_variant_fallback(
logger=logger,
model_class=submodel_class,
model_name_or_path=submodel_path,
torch_dtype=dtype,
variant=variant,
local_files_only=True,
)
setattr(sd_model, submodel_name, submodel)
else:
Expand Down Expand Up @@ -110,24 +118,6 @@ def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
raise ValueError(f"Unexpected device: {device_str}")


# TODO(ryand): Delete this after integrating the variant fallback logic.
# def load_sdxl_unet(model_path: str) -> UNet2DConditionModel:
# variants_to_try = [None, "fp16"]
# unet = None
# for variant in variants_to_try:
# try:
# unet = UNet2DConditionModel.from_pretrained(model_path, variant=variant, local_files_only=True)
# except OSError as e:
# if "no file named" in str(e):
# # Ok. We'll try a different variant.
# pass
# else:
# raise
# if unet is None:
# raise RuntimeError(f"Failed to load UNet from '{model_path}'.")
# return unet


def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
return {k: v.to(device=device) for k, v in state_dict.items()}

Expand Down Expand Up @@ -172,6 +162,8 @@ def extract_lora_from_submodel(
# We just use the device for this calculation, since it's slow, then we move the results back to the CPU.
logger.info("Calculating LoRA weights with SVD.")
diffs = state_dict_to_device(diffs, device)
# TODO(ryand): Should we skip if the diffs are all zeros? This would happen if two models are identical. This could
# happen if some submodels differ while others don't.
lora_weights = extract_lora_from_diffs(
diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=out_dtype
)
Expand Down Expand Up @@ -322,9 +314,8 @@ def main():

args = parser.parse_args()

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

orig_model_name_or_path, orig_model_variant = parse_model_arg(args.model_orig)
tuned_model_name_or_path, tuned_model_variant = parse_model_arg(args.model_tuned)
Expand Down

0 comments on commit 9ee9d18

Please sign in to comment.