Skip to content

[bugfix]Fix bug in Lora checkpoint saving step #10252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 31 additions & 53 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
Expand Down Expand Up @@ -1292,9 +1292,12 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_model = unwrap_model(model)
if args.upcast_before_saving:
transformer_model.to(torch.float32 )
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))) and args.train_text_encoder: # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two
hidden_size = unwrap_model(model).config.hidden_size
if hidden_size == 768:
Expand All @@ -1305,7 +1308,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusion3Pipeline.save_lora_weights(
output_dir,
Expand All @@ -1315,54 +1319,30 @@ def save_model_hook(models, weights, output_dir):
)

def load_model_hook(models, input_dir):
transformer_ = None
text_encoder_one_ = None
text_encoder_two_ = None

while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
transformer_ = None
text_encoder_one_ = None
text_encoder_two_ = None

transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
while len(models) > 0:
model = models.pop( )

_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
)
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand All @@ -1382,8 +1362,6 @@ def load_model_hook(models, input_dir):
models = [transformer]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)

transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
Expand Down Expand Up @@ -1829,10 +1807,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
if args.checkpoints_total_limit is not None and accelerator.is_main_process:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
Expand Down