Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test error raised when loading normal and expanding loras together in Flux #10188

Merged
merged 11 commits into from
Dec 15, 2024

Conversation

a-r-r-o-w
Copy link
Member

Context: #10182 (comment)

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul December 11, 2024 08:51
lora_state_dict,
"adapter-2",
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would run another inference round and make sure the outputs match with the LoRA that was correctly loaded. This will help us check if this loading error didn't leave the pipeline in a broken state, which is important.

Comment on lines 479 to 483
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
# weight is compatible with the current model incorrect. This should be addressed when attempting support for
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some concrete LoCs as references for what you mean by:

This makes our logic to check if a lora weight is compatible with the current model incorrect.

Would also love to understand how this relates to how peft names things.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lines in question are:

lora_A_weight_name = f"{name}.lora_A.weight"
lora_B_weight_name = f"{name}.lora_B.weight"
if lora_A_weight_name not in state_dict.keys():
continue

When the first lora layer is loaded, assuming it is named adapter-1 and the layer in question is x_embedder, the nn.Linear layer names are something like [x_embedder]. This check passes because x_embedder.lora_A.weight is indeed a key in the lora state dict.

After the first lora is loaded, peft updates the layer names to something like: [x_embedder.base_layer, x_embedder.adapter-1.lora_A, x_embedder.adapter-1.lora_B].

So, when the second lora is loaded, it tries to find x_embedder.base_layer.lora_A.weight in the lora state dict, which does not exist. It needs to instead search for x_embedder.lora_A.weight because that it the correct key in lora state dict. But this won't happen because the model state dict contains x_embedder original linear layer in a renamed key.

Note that I don't recall the exact layer names, so it may differ when you test and I'm just giving an example. The rough idea is that the current logic only works for loading:

  • one or more "normal" loras
  • a single "shape expansion" lora

For cases where we load shape expansion lora followed by normal lora, or vice versa, it will always fail currently.

But as discussed in DM, this was not an anticipated use case - we only wanted to make control lora work as expected so the shape mismatch error when loading weights, instead of during inference where input shapes don't match, is OK for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm got it. Thanks Aryan. Just noting this is enough for now.

@a-r-r-o-w
Copy link
Member Author

Hmm, not quite sure why the lora outputs after the error is NaNs...

def test_lora_expanding_shape_with_normal_lora_raises_error(self):
        # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
        # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
        # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
        components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        logger = logging.get_logger("diffusers.loaders.lora_pipeline")
        logger.setLevel(logging.DEBUG)

        num_channels_without_control = 4
        transformer = FluxTransformer2DModel.from_config(
            components["transformer"].config, in_channels=num_channels_without_control
        ).to(torch_device)
        original_transformer_state_dict = pipe.transformer.state_dict()
        x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
        transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
        pipe.transformer = transformer

        out_features, in_features = pipe.transformer.x_embedder.weight.shape
        rank = 4

        shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
        shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
        lora_state_dict = {
            "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
            "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
        }
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(lora_state_dict, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

        self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
        self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
        self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

        _, _, inputs = self.get_dummy_inputs(with_generator=False)
        lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

        normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
        normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
        lora_state_dict = {
            "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
            "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
        }

        # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
        # input features before expansion. This should raise an error about the weight shapes being incompatible.
        self.assertRaisesRegex(
            RuntimeError,
            "size mismatch for x_embedder.lora_A.adapter-2.weight",
            pipe.load_lora_weights,
            lora_state_dict,
            "adapter-2",
        )

        # Check if the output is the same after lora loading error
        lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3)
        )

        # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
        # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
        # original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
        # weight is compatible with the current model inadequate. This should be addressed when attempting support for
        # https://github.com/huggingface/diffusers/issues/10180 (TODO)
        components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        logger = logging.get_logger("diffusers.loaders.lora_pipeline")
        logger.setLevel(logging.DEBUG)

        out_features, in_features = pipe.transformer.x_embedder.weight.shape
        rank = 4

        lora_state_dict = {
            "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
            "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
        }

        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(lora_state_dict, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

        self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
        self.assertTrue(pipe.transformer.config.in_channels == in_features)
        self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

        lora_state_dict = {
            "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
            "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
        }

        # We should check for input shapes being incompatible here. But because above mentioned issue is
        # not a supported use case, and because of the PEFT renaming, we will currently have a shape
        # mismatch error.
        self.assertRaisesRegex(
            RuntimeError,
            "size mismatch for x_embedder.lora_A.adapter-2.weight",
            pipe.load_lora_weights,
            lora_state_dict,
            "adapter-2",
        )

@sayakpaul
Copy link
Member

@a-r-r-o-w could you expand on what is the issue?

@a-r-r-o-w
Copy link
Member Author

could you expand on what is the issue?

Yes. So you had asked me to add a test before and after the second lora was loaded (which would error out because of weight shape mismatch).

Here I would run another inference round and make sure the outputs match with the LoRA that was correctly loaded. This will help us check if this loading error didn't leave the pipeline in a broken state, which is important.

It seems like the pipeline is in a permanent broken state. After loading the first lora (shape expander), the inference runs correctly. After trying to load the second lora, which will error out, we run inference again to verify outputs don't change. But here they do, and are all NaNs. I did not find time to investigate why yet

@sayakpaul
Copy link
Member

Thanks! Let me look into this :)

@sayakpaul
Copy link
Member

@a-r-r-o-w WDYT about the changes in 7b5037f?

@BenjaminBossan WDYT about the changes introduced in the PEFT front to handle the following scenario. We first inject the adapter config in the requested model and then we set PEFT model state dict:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/peft.py#L319-L320

Now, if the second step is unsuccessful, I think we should remove the corresponding layers we had injected and make a good error message. LMK what you think.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +2340 to +2352
debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
logger.debug(debug_message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better crafting of the debug message I guess?

@BenjaminBossan
Copy link
Member

WDYT about the changes introduced in the PEFT front to handle the following scenario. We first inject the adapter config in the requested model and then we set PEFT model state dict: https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/peft.py#L319-L320

Now, if the second step is unsuccessful, I think we should remove the corresponding layers we had injected and make a good error message. LMK what you think.

To be sure I understand your suggestion correctly, do you want inject_adapter_in_model to catch potential errors during injection, and on error, try to roll back the changes and re-raise?

@sayakpaul
Copy link
Member

Yeah exactly.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Dec 13, 2024

Yes, I think it's reasonable, I'll put it on the back log. Ad hoc, I'm not sure how easy it is to implement. How high priority is it?

@sayakpaul
Copy link
Member

Oh sorry I think I conveyed it wrong. I am asking what do you think of the changes I have already introduced here? I think okay for me if the changes remain in diffusers.

@BenjaminBossan
Copy link
Member

Oh sorry I think I conveyed it wrong. I am asking what do you think of the changes I have already introduced here? I think okay for me if the changes remain in diffusers.

Ah okay, I got you now. So this is more about this part of the code, right?

https://github.com/huggingface/diffusers/pull/10188/files#diff-fdc2e14a8c1091d917cf31ca155be9baa8095d704bd2a2f58970ec31641c0caeR325-R335

I think it serves the purpose well enough. It is not a 100% rollback, as PEFT could do other things (mutate the PEFT config and toggle requires_grad on the base model) but I think those won't matter on this context.

Still, I'll leave the task in the backlog, as it would be useful to have in PEFT too (code would be similar to the code here).

@sayakpaul
Copy link
Member

Alright then. Could you provide an approval then?

@sayakpaul
Copy link
Member

@a-r-r-o-w WDYT about the changes? Will let you merge

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM.

Alright then. Could you provide an approval then?

I was not asked for review ;)

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@a-r-r-o-w a-r-r-o-w merged commit 22c4f07 into main Dec 15, 2024
15 checks passed
@a-r-r-o-w a-r-r-o-w deleted the more-flux-lora-tests branch December 15, 2024 16:16
sayakpaul added a commit that referenced this pull request Dec 23, 2024
… Flux (#10188)

* add test for expanding lora and normal lora error

* Update tests/lora/test_lora_layers_flux.py

* fix things.

* Update src/diffusers/loaders/peft.py

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants