Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill #10259

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
126 changes: 87 additions & 39 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,9 @@ def load_lora_weights(
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)

if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer(
Expand Down Expand Up @@ -2309,7 +2312,6 @@ def _maybe_expand_transformer_param_shape_or_error_(

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
Expand All @@ -2329,56 +2331,102 @@ def _maybe_expand_transformer_param_shape_or_error_(
continue

module_out_features, module_in_features = module_weight.shape
if out_features < module_out_features or in_features < module_in_features:
raise NotImplementedError(
f"Only LoRAs with input/output features higher than the current module's input/output features "
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
debug_message = ""
if in_features > module_in_features:
debug_message += (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)

debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
if out_features > module_out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
logger.debug(debug_message)
if debug_message:
logger.debug(debug_message)

if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
tmp_state_dict["bias"] = module_bias
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)

setattr(parent_module, current_module_name, expanded_module)

del tmp_state_dict

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
return has_param_with_shape_update

# TODO: consider initializing this under meta device for optims.
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
@classmethod
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
expanded_module_names = set()
transformer_state_dict = transformer.state_dict()
prefix = f"{cls.transformer_name}."

lora_module_names = [
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
]
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
lora_module_names = sorted(set(lora_module_names))
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for k in lora_module_names:
if k in unexpected_modules:
continue

base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

if base_weight_param.shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)

setattr(parent_module, current_module_name, expanded_module)

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
if expanded_module_names:
logger.info(
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
)

return has_param_with_shape_update
return lora_state_dict


# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
Expand Down
146 changes: 108 additions & 38 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,6 @@ def test_lora_parameter_expanded_shapes(self):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
pipe.load_lora_weights(lora_state_dict, "adapter-1")
Comment on lines -343 to -356
Copy link
Member Author

@sayakpaul sayakpaul Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this part of the test because in case LoRA input feature dimensions are less than the original, we expand it.

This is tested below with test_lora_expanding_shape_with_normal_lora() and test_load_regular_lora().


@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
Expand Down Expand Up @@ -430,10 +415,10 @@ def test_correct_lora_configs_with_different_ranks(self):
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

def test_lora_expanding_shape_with_normal_lora_raises_error(self):
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
def test_lora_expanding_shape_with_normal_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

# Change the transformer config to mimic a real use case.
Expand Down Expand Up @@ -478,21 +463,16 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
# input features before expansion. This should raise an error about the weight shapes being incompatible.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
# We should have `adapter-1` as the only adapter.
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-2")

self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])

# Check if the output is the same after lora loading error
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))

# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
Expand Down Expand Up @@ -524,8 +504,8 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):

with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
Expand All @@ -535,17 +515,107 @@ def test_lora_expanding_shape_with_normal_lora_raises_error(self):
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}

# We should check for input shapes being incompatible here. But because above mentioned issue is
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
# mismatch error.
# We should check for input shapes being incompatible here.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
"x_embedder.lora_A.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)

def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
components["transformer"] = transformer

pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4

shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]

pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))

pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))

def test_load_regular_lora(self):
# This test checks if a regular lora (think of one trained Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down
Loading