diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index af9c3639e047..080442efb0d1 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -268,6 +268,43 @@ images = pipe( images[0].save("flux-redux.png") ``` +## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux + +We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD). + +```py +from diffusers import FluxControlPipeline +from image_gen_aux import DepthPreprocessor +from diffusers.utils import load_image +from huggingface_hub import hf_hub_download +import torch + +control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) +control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth") +control_pipe.load_lora_weights( + hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" +) +control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125]) +control_pipe.enable_model_cpu_offload() + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = control_pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=8, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + ## Running FP16 inference Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 46d744233014..e69681611a4a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1863,6 +1863,9 @@ def load_lora_weights( "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "To get a comprehensive list of parameter names that were modified, enable debug logging." ) + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict + ) if len(transformer_lora_state_dict) > 0: self.load_lora_into_transformer( @@ -2309,16 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False - + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data module_bias = module.bias.data if module.bias is not None else None bias = module_bias is not None - lora_A_weight_name = f"{name}.lora_A.weight" - lora_B_weight_name = f"{name}.lora_B.weight" - if lora_A_weight_name not in state_dict.keys(): + lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name + lora_A_weight_name = f"{lora_base_name}.lora_A.weight" + lora_B_weight_name = f"{lora_base_name}.lora_B.weight" + if lora_A_weight_name not in state_dict: continue in_features = state_dict[lora_A_weight_name].shape[1] @@ -2329,56 +2333,105 @@ def _maybe_expand_transformer_param_shape_or_error_( continue module_out_features, module_in_features = module_weight.shape - if out_features < module_out_features or in_features < module_in_features: - raise NotImplementedError( - f"Only LoRAs with input/output features higher than the current module's input/output features " - f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which " - f"are lower than {module_in_features=} and {module_out_features=}. If you require support for " - f"this please open an issue at https://github.com/huggingface/diffusers/issues." + debug_message = "" + if in_features > module_in_features: + debug_message += ( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}" ) - - debug_message = ( - f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' - f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}" - ) - if module_out_features != out_features: + if out_features > module_out_features: debug_message += ( ", and the number of output features will be " f"expanded from {module_out_features} to {out_features}." ) else: debug_message += "." - logger.debug(debug_message) + if debug_message: + logger.debug(debug_message) + + if out_features > module_out_features or in_features > module_in_features: + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + with torch.device("meta"): + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. This is because only the input dimensions + # are changed while the output dimensions remain the same. The shape of the weight tensor + # is (out_features, in_features), while the shape of bias tensor is (out_features,), which + # explains the reason why only weights are expanded. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + tmp_state_dict = {"weight": new_weight} + if module_bias is not None: + tmp_state_dict["bias"] = module_bias + expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) + + setattr(parent_module, current_module_name, expanded_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) - has_param_with_shape_update = True - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) + return has_param_with_shape_update - # TODO: consider initializing this under meta device for optims. - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype - ) - # Only weights are expanded and biases are not. - new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + @classmethod + def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): + expanded_module_names = set() + transformer_state_dict = transformer.state_dict() + prefix = f"{cls.transformer_name}." + + lora_module_names = [ + key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") + ] + lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] + lora_module_names = sorted(set(lora_module_names)) + transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) + unexpected_modules = set(lora_module_names) - set(transformer_module_names) + if unexpected_modules: + logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + for k in lora_module_names: + if k in unexpected_modules: + continue + + base_param_name = ( + f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight" + ) + base_weight_param = transformer_state_dict[base_param_name] + lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] + + if base_weight_param.shape[1] > lora_A_param.shape[1]: + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) + lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight + expanded_module_names.add(k) + elif base_weight_param.shape[1] < lora_A_param.shape[1]: + raise NotImplementedError( + f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) - slices = tuple(slice(0, dim) for dim in module_weight.shape) - new_weight[slices] = module_weight - expanded_module.weight.data.copy_(new_weight) - if module_bias is not None: - expanded_module.bias.data.copy_(module_bias) - - setattr(parent_module, current_module_name, expanded_module) - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(expanded_module.weight.data.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + if expanded_module_names: + logger.info( + f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + ) - return has_param_with_shape_update + return lora_state_dict # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b28fdde91574..1378c048b868 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -340,21 +340,6 @@ def test_lora_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - dummy_lora_A = torch.nn.Linear(1, rank, bias=False) - dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) - lora_state_dict = { - "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, - "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, - } - # We should error out because lora input features is less than original. We only - # support expanding the module, not shrinking it - with self.assertRaises(NotImplementedError): - pipe.load_lora_weights(lora_state_dict, "adapter-1") - @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -430,10 +415,10 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - def test_lora_expanding_shape_with_normal_lora_raises_error(self): - # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but - # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. - # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 + def test_lora_expanding_shape_with_normal_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. @@ -478,27 +463,18 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct - # input features before expansion. This should raise an error about the weight shapes being incompatible. - self.assertRaisesRegex( - RuntimeError, - "size mismatch for x_embedder.lora_A.adapter-2.weight", - pipe.load_lora_weights, - lora_state_dict, - "adapter-2", - ) - # We should have `adapter-1` as the only adapter. - self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-2") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) - # Check if the output is the same after lora loading error - lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3)) + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. - # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the - # original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora - # weight is compatible with the current model inadequate. This should be addressed when attempting support for - # https://github.com/huggingface/diffusers/issues/10180 (TODO) + # This should raise a runtime error on input shapes being incompatible. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case. num_channels_without_control = 4 @@ -521,14 +497,11 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } + pipe.load_lora_weights(lora_state_dict, "adapter-1") - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) self.assertTrue(pipe.transformer.config.in_channels == in_features) - self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, @@ -546,6 +519,98 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self): "adapter-2", ) + def test_fuse_expanded_lora_with_regular_lora(self): + # This test checks if it works when a lora with expanded shapes (like control loras) but + # another lora with correct shapes is loaded. The opposite direction isn't supported and is + # tested with it. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + components["transformer"] = transformer + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + + shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, + "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, + } + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + pipe.load_lora_weights(lora_state_dict, "adapter-2") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) + lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) + self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) + + pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) + lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) + + def test_load_regular_lora(self): + # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded + # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those + # transformers include Flux Fill, Flux Control, etc. + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + out_features, in_features = pipe.transformer.x_embedder.weight.shape + rank = 4 + in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. + normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) + normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, + "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, + } + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.INFO) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) + self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass