From 6ce723ef0f1db3274340cda730bfb268a73c1929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Tue, 17 Dec 2024 10:28:55 +0800 Subject: [PATCH 01/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_flux.py | 7 ++++--- examples/dreambooth/train_dreambooth_lora_sd3.py | 7 ++++--- examples/dreambooth/train_dreambooth_lora_sdxl.py | 9 +++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index f73269a48967..9d6b9f1ae29e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1242,15 +1242,16 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(transformer))): + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() FluxPipeline.save_lora_weights( output_dir, diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 3f721e56addf..7548ee1ee834 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1292,9 +1292,9 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(transformer))): + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two # both text encoders are of the same class, so we check hidden size to distinguish between the two hidden_size = unwrap_model(model).config.hidden_size if hidden_size == 768: @@ -1305,7 +1305,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() StableDiffusion3Pipeline.save_lora_weights( output_dir, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9cd321f6d055..9d4ea1dae624 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1229,13 +1229,13 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(unet))): + if isinstance(unwrap_model(model), type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(unwrap_model(text_encoder_two))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) @@ -1243,7 +1243,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() StableDiffusionXLPipeline.save_lora_weights( output_dir, From aaf66dfcd881ab090422bc8c216a7e42f67f932c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Wed, 18 Dec 2024 15:00:06 +0800 Subject: [PATCH 02/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_flux.py | 8 +++++--- examples/dreambooth/train_dreambooth_lora_sd3.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9d6b9f1ae29e..c906ed3f985c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -29,7 +29,7 @@ import torch import torch.utils.checkpoint import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1244,8 +1244,10 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(unwrap_model(model), type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and args.train_text_encoder: text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and not args.train_text_encoder: + text_encoder_one_lora_layers_to_save = None else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1786,7 +1788,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 7548ee1ee834..be938c259cd5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -29,7 +29,7 @@ import torch import torch.utils.checkpoint import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1293,14 +1293,20 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_model = unwrap_model(model) + if args.upcast_before_saving: + transformer_model.to(torch.float32 ) transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and args.train_text_encoder: # or text_encoder_two # both text encoders are of the same class, so we check hidden size to distinguish between the two hidden_size = unwrap_model(model).config.hidden_size if hidden_size == 768: text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif hidden_size == 1280: text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and not args.train_text_encoder: + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1830,7 +1836,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: From 5f56eb4baaa828d48bb40bc8e1e1c7387b0dc622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 19 Dec 2024 09:58:44 +0800 Subject: [PATCH 03/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9d4ea1dae624..401704502fa8 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1231,14 +1231,18 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(unwrap_model(model), type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and args.train_text_encoder: text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))) and args.train_text_encoder: text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and not args.train_text_encoder: + text_encoder_one_lora_layers_to_save = None + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))) and not args.train_text_encoder: + text_encoder_two_lora_layers_to_save = None else: raise ValueError(f"unexpected save model: {model.__class__}") From 4ede036a053b046ce396602ff7493b71ad65b408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 19 Dec 2024 16:43:14 +0800 Subject: [PATCH 04/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index c906ed3f985c..d0f3112427c5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1791,7 +1791,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: + if args.checkpoints_total_limit is not None and accelerator.is_main_process: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index be938c259cd5..eaf3327c075f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1839,7 +1839,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: + if args.checkpoints_total_limit is not None and accelerator.is_main_process: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) From 9c1c2e7f6a856709acb311c4fc4b18b1df4681c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Thu, 19 Dec 2024 16:48:39 +0800 Subject: [PATCH 05/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d0f3112427c5..d7f759121758 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1788,7 +1788,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None and accelerator.is_main_process: diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index eaf3327c075f..07d712e2f1be 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1836,7 +1836,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None and accelerator.is_main_process: From 0a4734991964a21076a5c865959bda9e40245055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 23 Dec 2024 11:43:25 +0800 Subject: [PATCH 06/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d7f759121758..11e9c42fb7f6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1571,7 +1571,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): first_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: + if args.resume_from_checkpoint and not accelerator.distributed_type == DistributedType.DEEPSPEED: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 07d712e2f1be..e58fa77ca7c6 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1646,7 +1646,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): first_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: + if args.resume_from_checkpoint and not accelerator.distributed_type == DistributedType.DEEPSPEED: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 401704502fa8..1bfd06689581 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1570,7 +1570,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): first_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: + if args.resume_from_checkpoint and not accelerator.distributed_type == DistributedType.DEEPSPEED: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: From 85b116eda630b8fb1b325b37ea594445590d59b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 23 Dec 2024 17:27:23 +0800 Subject: [PATCH 07/10] [bugfix]Fix bug in Lora checkpoint saving step --- .../dreambooth/train_dreambooth_lora_flux.py | 17 ++-- .../dreambooth/train_dreambooth_lora_sd3.py | 91 ++++++++++--------- .../dreambooth/train_dreambooth_lora_sdxl.py | 15 +-- 3 files changed, 58 insertions(+), 65 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 11e9c42fb7f6..f73269a48967 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -29,7 +29,7 @@ import torch import torch.utils.checkpoint import transformers -from accelerate import Accelerator, DistributedType +from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1242,18 +1242,15 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and args.train_text_encoder: + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and not args.train_text_encoder: - text_encoder_one_lora_layers_to_save = None else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() + weights.pop() FluxPipeline.save_lora_weights( output_dir, @@ -1571,7 +1568,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): first_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint and not accelerator.distributed_type == DistributedType.DEEPSPEED: + if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: @@ -1788,10 +1785,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None and accelerator.is_main_process: + if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index e58fa77ca7c6..de2902c1380e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1322,54 +1322,55 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): - transformer_ = None - text_encoder_one_ = None - text_encoder_two_ = None - - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + transformer_ = None + text_encoder_one_ = None + text_encoder_two_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) + lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") - } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ ) - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ - ) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if args.mixed_precision == "fp16": - models = [transformer_] - if args.train_text_encoder: - models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1646,7 +1647,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): first_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint and not accelerator.distributed_type == DistributedType.DEEPSPEED: + if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 1bfd06689581..9cd321f6d055 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1229,26 +1229,21 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and args.train_text_encoder: + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))) and args.train_text_encoder: + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and not args.train_text_encoder: - text_encoder_one_lora_layers_to_save = None - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))) and not args.train_text_encoder: - text_encoder_two_lora_layers_to_save = None else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() + weights.pop() StableDiffusionXLPipeline.save_lora_weights( output_dir, @@ -1570,7 +1565,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): first_epoch = 0 # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint and not accelerator.distributed_type == DistributedType.DEEPSPEED: + if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: From 48073907a880e5d9c044059d69c9df0b9c0177d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 23 Dec 2024 18:21:32 +0800 Subject: [PATCH 08/10] [bugfix]Fix bug in Lora checkpoint saving step --- .../dreambooth/train_dreambooth_lora_sd3.py | 25 +------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index de2902c1380e..03353b7762e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1328,7 +1328,7 @@ def load_model_hook(models, input_dir): text_encoder_two_ = None while len(models) > 0: - model = models.pop() + model = models.pop( ) if isinstance(model, type(unwrap_model(transformer))): transformer_ = model @@ -1339,29 +1339,6 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) - - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") - } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) - - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ - ) - # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 From 1730a824c470462df79ad5ceed360e9a5a8cb431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 23 Dec 2024 18:58:31 +0800 Subject: [PATCH 09/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_sd3.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 03353b7762e1..bbde38f4c6f2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1304,9 +1304,6 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif hidden_size == 1280: text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and not args.train_text_encoder: - text_encoder_one_lora_layers_to_save = None - text_encoder_two_lora_layers_to_save = None else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1367,8 +1364,6 @@ def load_model_hook(models, input_dir): models = [transformer] if args.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) if args.train_text_encoder: From 5103c73eb4aa307dd8a08f691f606bfeb1456aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 23 Dec 2024 19:06:07 +0800 Subject: [PATCH 10/10] [bugfix]Fix bug in Lora checkpoint saving step --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index bbde38f4c6f2..4ffafde72f04 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1343,8 +1343,6 @@ def load_model_hook(models, input_dir): models = [transformer_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)