From 3d735b4ab4562a0454f2847babe57995e9d04a9d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 11 Dec 2024 16:31:09 +0530 Subject: [PATCH 01/13] lora expansion with dummy zeros. --- src/diffusers/loaders/lora_pipeline.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1445394b8784..d34319dea207 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1863,6 +1863,7 @@ def load_lora_weights( "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "To get a comprehensive list of parameter names that were modified, enable debug logging." ) + transformer_lora_state_dict = self._maybe_expand_lora_state_dict(transformer=transformer, lora_state_dict=transformer_lora_state_dict) if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( @@ -2373,6 +2374,32 @@ def _maybe_expand_transformer_param_shape_or_error_( return has_param_with_shape_update + @classmethod + def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): + expanded_module_names = set() + transformer_state_dict = transformer.state_dict() + lora_module_names = set([k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k]) + lora_module_names = sorted(lora_module_names) + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + + for k in lora_module_names: + base_param_name = f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(f'{cls.transformer_name}.', '')}.weight" + base_weight_param = transformer_state_dict[base_param_name] + lora_A_param = lora_state_dict[f"{k}.lora_A.weight"] + # lora_B_param = lora_state_dict[f"{k}.lora_B.weight"] + + if base_weight_param.shape[1] > lora_A_param.shape[1]: + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, :lora_A_param.shape[1]].copy_(lora_A_param) + lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight + expanded_module_names.add(k) + + if expanded_module_names: + logger.info(f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new.") + return lora_state_dict + + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. From ed91c533f037bcf523cf6b5940e73d71224c86e3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 11:37:00 +0530 Subject: [PATCH 02/13] updates --- src/diffusers/loaders/lora_pipeline.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 46640daca96f..969c2189666c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1863,7 +1863,9 @@ def load_lora_weights( "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "To get a comprehensive list of parameter names that were modified, enable debug logging." ) - transformer_lora_state_dict = self._maybe_expand_lora_state_dict(transformer=transformer, lora_state_dict=transformer_lora_state_dict) + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict + ) if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( @@ -2385,29 +2387,32 @@ def _maybe_expand_transformer_param_shape_or_error_( def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): expanded_module_names = set() transformer_state_dict = transformer.state_dict() - lora_module_names = set([k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k]) - lora_module_names = sorted(lora_module_names) + lora_module_names = sorted({k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k}) is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: - base_param_name = f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(f'{cls.transformer_name}.', '')}.weight" + base_param_name = ( + f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" + if is_peft_loaded + else f"{k.replace(f'{cls.transformer_name}.', '')}.weight" + ) base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{k}.lora_A.weight"] - # lora_B_param = lora_state_dict[f"{k}.lora_B.weight"] if base_weight_param.shape[1] > lora_A_param.shape[1]: - shape = (lora_A_param.shape[0], base_weight_param.shape[1]) - expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) - expanded_state_dict_weight[:, :lora_A_param.shape[1]].copy_(lora_A_param) + # could be made more advanced with `repeats`. + # have tried zero-padding but that doesn't work, either. + expanded_state_dict_weight = torch.cat([lora_A_param, lora_A_param], dim=1) lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) if expanded_module_names: - logger.info(f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new.") + logger.info( + f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + ) return lora_state_dict - # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): From d3e177c6ffd2153af03df2f382ad3a43f3b0b26d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 13:04:52 +0530 Subject: [PATCH 03/13] =?UTF-8?q?fix=20working=20=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/loaders/lora_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 969c2189666c..dfa6d3ccb5cf 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2400,9 +2400,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): lora_A_param = lora_state_dict[f"{k}.lora_A.weight"] if base_weight_param.shape[1] > lora_A_param.shape[1]: - # could be made more advanced with `repeats`. - # have tried zero-padding but that doesn't work, either. - expanded_state_dict_weight = torch.cat([lora_A_param, lora_A_param], dim=1) + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) From 258a3980c05970cf62d26ba6b2d064207d9772ef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 13:59:06 +0530 Subject: [PATCH 04/13] working. --- src/diffusers/loaders/lora_pipeline.py | 83 +++++++++++++------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index dfa6d3ccb5cf..d2cba70055b9 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2312,7 +2312,6 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False - for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data @@ -2332,54 +2331,52 @@ def _maybe_expand_transformer_param_shape_or_error_( continue module_out_features, module_in_features = module_weight.shape - if out_features < module_out_features or in_features < module_in_features: - raise NotImplementedError( - f"Only LoRAs with input/output features higher than the current module's input/output features " - f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " - f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " - f"this please open an issue at https://github.com/huggingface/diffusers/issues." + debug_message = "" + if in_features > module_in_features: + debug_message += ( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}" ) - - debug_message = ( - f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' - f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}" - ) - if module_out_features != out_features: + if out_features > module_out_features: debug_message += ( ", and the number of output features will be " f"expanded from {module_out_features} to {out_features}." ) else: debug_message += "." - logger.debug(debug_message) + if debug_message: + logger.debug(debug_message) - has_param_with_shape_update = True - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) + if out_features > module_out_features or in_features > module_in_features: + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) - # TODO: consider initializing this under meta device for optims. - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype - ) - # Only weights are expanded and biases are not. - new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype - ) - slices = tuple(slice(0, dim) for dim in module_weight.shape) - new_weight[slices] = module_weight - expanded_module.weight.data.copy_(new_weight) - if module_bias is not None: - expanded_module.bias.data.copy_(module_bias) - - setattr(parent_module, current_module_name, expanded_module) - - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(expanded_module.weight.data.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + # TODO: consider initializing this under meta device for optims. + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + expanded_module.weight.data.copy_(new_weight) + if module_bias is not None: + expanded_module.bias.data.copy_(module_bias) + + setattr(parent_module, current_module_name, expanded_module) + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) return has_param_with_shape_update @@ -2405,10 +2402,14 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) + elif base_weight_param.shape[1] < lora_A_param.shape[1]: + raise NotImplementedError( + "We currently don't support loading LoRAs for this use case. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." + ) if expanded_module_names: logger.info( - f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + f"Found some LoRA modules for which the weights were zero-padded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." ) return lora_state_dict From 5ef79f30060702f2e3dd467bc798b26e3c08b1b1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 14:14:58 +0530 Subject: [PATCH 05/13] use torch.device meta for state dict expansion. --- src/diffusers/loaders/lora_pipeline.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index d2cba70055b9..7f72ac4f346f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2353,22 +2353,25 @@ def _maybe_expand_transformer_param_shape_or_error_( parent_module_name, _, current_module_name = name.rpartition(".") parent_module = transformer.get_submodule(parent_module_name) - # TODO: consider initializing this under meta device for optims. - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype - ) + with torch.device("meta"): + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, dtype=module_weight.dtype + ) # Only weights are expanded and biases are not. new_weight = torch.zeros_like( expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) slices = tuple(slice(0, dim) for dim in module_weight.shape) new_weight[slices] = module_weight - expanded_module.weight.data.copy_(new_weight) + tmp_state_dict = {"weight": new_weight} if module_bias is not None: - expanded_module.bias.data.copy_(module_bias) + tmp_state_dict["bias"] = module_bias + expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) setattr(parent_module, current_module_name, expanded_module) + del tmp_state_dict + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] new_value = int(expanded_module.weight.data.shape[1]) From 4eef79e1c34e82f296452e64b61624cac2bdb14d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 14:59:22 +0530 Subject: [PATCH 06/13] tests Co-authored-by: a-r-r-o-w --- src/diffusers/loaders/lora_pipeline.py | 2 +- tests/lora/test_lora_layers_flux.py | 131 ++++++++++++++++++++----- 2 files changed, 109 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7f72ac4f346f..c831fb0ed4a0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2407,7 +2407,7 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): expanded_module_names.add(k) elif base_weight_param.shape[1] < lora_A_param.shape[1]: raise NotImplementedError( - "We currently don't support loading LoRAs for this use case. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." + f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) if expanded_module_names: diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b28fdde91574..dc08a3c07e79 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -430,10 +430,10 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - def test_lora_expanding_shape_with_normal_lora_raises_error(self): - # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but - # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. - # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 + def test_lora_expanding_shape_with_normal_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. @@ -478,21 +478,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct - # input features before expansion. This should raise an error about the weight shapes being incompatible. - self.assertRaisesRegex( - RuntimeError, - "size mismatch for x_embedder.lora_A.adapter-2.weight", - pipe.load_lora_weights, - lora_state_dict, - "adapter-2", - ) - # We should have `adapter-1` as the only adapter. - self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-2") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) + + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - # Check if the output is the same after lora loading error - lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3)) + self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out) + self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the @@ -524,8 +519,8 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features) self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) @@ -535,17 +530,107 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } - # We should check for input shapes being incompatible here. But because above mentioned issue is - # not a supported use case, and because of the PEFT renaming, we will currently have a shape - # mismatch error. + # We should check for input shapes being incompatible here. self.assertRaisesRegex( RuntimeError, - "size mismatch for x_embedder.lora_A.adapter-2.weight", + "x_embedder.lora_A.weight", pipe.load_lora_weights, lora_state_dict, "adapter-2", ) + def test_fuse_expanded_lora_with_regular_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + pipe.load_lora_weights(lora_state_dict, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) + lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) + + pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) + lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) + + def test_load_regular_lora(self): + # This test checks if a regular lora (think of one trained Flux.1 Dev for example) can be loaded + # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those + # transformers include Flux Fill, Flux Control, etc. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) + self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From 3785dfea6e5447ecfb8fe157f609a626ff19c8ab Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 16:57:08 +0530 Subject: [PATCH 07/13] fixes --- src/diffusers/loaders/lora_pipeline.py | 19 ++++++++++++++++--- tests/lora/test_lora_layers_flux.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index dac050a69c97..97cea117e9dd 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2387,17 +2387,30 @@ def _maybe_expand_transformer_param_shape_or_error_( def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): expanded_module_names = set() transformer_state_dict = transformer.state_dict() - lora_module_names = sorted({k.replace(".lora_A.weight", "") for k in lora_state_dict if "lora_A" in k}) - is_peft_loaded = getattr(transformer, "peft_config", None) is not None + prefix = f"{cls.transformer_name}." + + lora_module_names = [ + key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") + ] + lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] + lora_module_names = sorted(set(lora_module_names)) + transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) + unexpected_modules = set(lora_module_names) - set(transformer_module_names) + if unexpected_modules: + logger.info(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: + if k in unexpected_modules: + continue + base_param_name = ( f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(f'{cls.transformer_name}.', '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] - lora_A_param = lora_state_dict[f"{k}.lora_A.weight"] + lora_A_param = lora_state_dict[f"{cls.transformer_name}.{k}.lora_A.weight"] if base_weight_param.shape[1] > lora_A_param.shape[1]: shape = (lora_A_param.shape[0], base_weight_param.shape[1]) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index dc08a3c07e79..2cb7394cfad1 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -352,7 +352,7 @@ def test_lora_parameter_expanded_shapes(self): } # We should error out because lora input features is less than original. We only # support expanding the module, not shrinking it - with self.assertRaises(NotImplementedError): + with self.assertRaises(RuntimeError): pipe.load_lora_weights(lora_state_dict, "adapter-1") @require_peft_version_greater("0.13.2") From b7269f415f93caeaacf13636bd7b91c0a5f73af3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 17:23:19 +0530 Subject: [PATCH 08/13] fixes --- src/diffusers/loaders/lora_pipeline.py | 11 +++++------ tests/lora/test_lora_layers_flux.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 97cea117e9dd..4175c29c21c2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2405,18 +2405,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): continue base_param_name = ( - f"{k.replace(f'{cls.transformer_name}.', '')}.base_layer.weight" - if is_peft_loaded - else f"{k.replace(f'{cls.transformer_name}.', '')}.weight" + f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] - lora_A_param = lora_state_dict[f"{cls.transformer_name}.{k}.lora_A.weight"] + lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] if base_weight_param.shape[1] > lora_A_param.shape[1]: shape = (lora_A_param.shape[0], base_weight_param.shape[1]) expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) - lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight + lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) elif base_weight_param.shape[1] < lora_A_param.shape[1]: raise NotImplementedError( @@ -2425,8 +2423,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if expanded_module_names: logger.info( - f"Found some LoRA modules for which the weights were zero-padded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." ) + return lora_state_dict diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 2cb7394cfad1..30fd320a2e1e 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -627,7 +627,7 @@ def test_load_regular_lora(self): lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out) + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) From 6da9697f503ef868e6f89df4dbbddb5e25f8f394 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 17:26:36 +0530 Subject: [PATCH 09/13] switch to debug --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4175c29c21c2..d9d6590059ed 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2397,7 +2397,7 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) unexpected_modules = set(lora_module_names) - set(transformer_module_names) if unexpected_modules: - logger.info(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: From b9a2670d1fb15807bffd3c6f65879d74694f485d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 18:31:33 +0530 Subject: [PATCH 10/13] fix --- tests/lora/test_lora_layers_flux.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 30fd320a2e1e..815c688dfc13 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -340,21 +340,6 @@ def test_lora_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - dummy_lora_A = torch.nn.Linear(1, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - # We should error out because lora input features is less than original. We only - # support expanding the module, not shrinking it - with self.assertRaises(RuntimeError): - pipe.load_lora_weights(lora_state_dict, "adapter-1") - @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -486,7 +471,7 @@ def test_lora_expanding_shape_with_normal_lora(self): lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue("Found some LoRA modules for which the weights were zero-padded" in cap_logger.out) + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. From eb2ad022253aa7cad6dfeff109fd9f0be32b129a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Dec 2024 11:15:38 +0530 Subject: [PATCH 11/13] Apply suggestions from code review Co-authored-by: Aryan --- src/diffusers/loaders/lora_pipeline.py | 5 ++++- tests/lora/test_lora_layers_flux.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e71152eb1c29..74f8e11f11ea 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2357,7 +2357,10 @@ def _maybe_expand_transformer_param_shape_or_error_( expanded_module = torch.nn.Linear( in_features, out_features, bias=bias, dtype=module_weight.dtype ) - # Only weights are expanded and biases are not. + # Only weights are expanded and biases are not. This is because only the input dimensions + # are changed while the output dimensions remain the same. The shape of the weight tensor + # is (out_features, in_features), while the shape of bias tensor is (out_features,), which + # explains the reason why only weights are expanded. new_weight = torch.zeros_like( expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 815c688dfc13..86c9b99d9cfc 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -583,7 +583,7 @@ def test_fuse_expanded_lora_with_regular_lora(self): self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) def test_load_regular_lora(self): - # This test checks if a regular lora (think of one trained Flux.1 Dev for example) can be loaded + # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those # transformers include Flux Fill, Flux Control, etc. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) From 143df0c42db5d7951246240bf3cf7a93814847a3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Dec 2024 12:51:15 +0530 Subject: [PATCH 12/13] fix stuff --- src/diffusers/loaders/lora_pipeline.py | 8 +++++--- tests/lora/test_lora_layers_flux.py | 19 +++++++------------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 74f8e11f11ea..3ed29a905532 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2312,15 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data module_bias = module.bias.data if module.bias is not None else None bias = module_bias is not None - lora_A_weight_name = f"{name}.lora_A.weight" - lora_B_weight_name = f"{name}.lora_B.weight" - if lora_A_weight_name not in state_dict.keys(): + lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name + lora_A_weight_name = f"{lora_base_name}.lora_A.weight" + lora_B_weight_name = f"{lora_base_name}.lora_B.weight" + if lora_A_weight_name not in state_dict: continue in_features = state_dict[lora_A_weight_name].shape[1] diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 86c9b99d9cfc..1378c048b868 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -467,18 +467,14 @@ def test_lora_expanding_shape_with_normal_lora(self): pipe.load_lora_weights(lora_state_dict, "adapter-2") self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. - # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the - # original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora - # weight is compatible with the current model inadequate. This should be addressed when attempting support for - # https://github.com/huggingface/diffusers/issues/10180 (TODO) + # This should raise a runtime error on input shapes being incompatible. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. num_channels_without_control = 4 @@ -501,24 +497,23 @@ def test_lora_expanding_shape_with_normal_lora(self): "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") + pipe.load_lora_weights(lora_state_dict, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features) - self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } - # We should check for input shapes being incompatible here. + # We should check for input shapes being incompatible here. But because above mentioned issue is + # not a supported use case, and because of the PEFT renaming, we will currently have a shape + # mismatch error. self.assertRaisesRegex( RuntimeError, - "x_embedder.lora_A.weight", + "size mismatch for x_embedder.lora_A.adapter-2.weight", pipe.load_lora_weights, lora_state_dict, "adapter-2", From d62125ad7fadb419516f0cb4e2bca10ef5b4a576 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Dec 2024 12:55:06 +0530 Subject: [PATCH 13/13] docs --- docs/source/en/api/pipelines/flux.md | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index af9c3639e047..080442efb0d1 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -268,6 +268,43 @@ images = pipe( images[0].save("flux-redux.png") ``` +## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux + +We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD). + +```py +from diffusers import FluxControlPipeline +from image_gen_aux import DepthPreprocessor +from diffusers.utils import load_image +from huggingface_hub import hf_hub_download +import torch + +control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) +control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth") +control_pipe.load_lora_weights( + hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" +) +control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125]) +control_pipe.enable_model_cpu_offload() + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = control_pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=8, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + ## Running FP16 inference Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.