diff --git a/README.md b/README.md index feb2f9a..8fd15b0 100644 --- a/README.md +++ b/README.md @@ -370,11 +370,26 @@ mflux-generate \ Just to see the difference, this image displays the four cases: One of having both adapters fully active, partially active and no LoRA at all. The example above also show the usage of `--lora-scales` flag. +#### Supported LoRA formats (updated) + +Since different fine-tuning services can use different implementations of FLUX, the corresponding +LoRA weights trained on these services can be different from one another. The aim of MFLUX is to support the most common ones. +The following table show the current supported formats: + +| Supported | Name | Example | Notes | +|-----------|-----------|----------------------------------------------------------------------------------------------------------|-------------------------------------| +| ✅ | BFL | [civitai - Impressionism](https://civitai.com/models/545264/impressionism-sdxl-pony-flux) | Many things on civitai seem to work | +| ✅ | Diffusers | [Flux_1_Dev_LoRA_Paper-Cutout-Style](https://huggingface.co/Norod78/Flux_1_Dev_LoRA_Paper-Cutout-Style/) | | +| ❌ | XLabs-AI | [flux-RealismLora](https://huggingface.co/XLabs-AI/flux-RealismLora/tree/main) | | + +To report additional formats, examples or other any suggestions related to LoRA format support, please see [issue #47](https://github.com/filipstrand/mflux/issues/47). + ### Current limitations - Images are generated one by one. - Negative prompts not supported. - LoRA weights are only supported for the transformer part of the network. +- Some LoRA adapters does not work. ### TODO diff --git a/src/mflux/weights/lora_converter.py b/src/mflux/weights/lora_converter.py new file mode 100644 index 0000000..8e8a7d2 --- /dev/null +++ b/src/mflux/weights/lora_converter.py @@ -0,0 +1,239 @@ +import logging +import mlx.core as mx +import torch +from mlx.utils import tree_unflatten +from safetensors import safe_open + +logger = logging.getLogger(__name__) + + +# This script is based on `convert_flux_lora.py` from `kohya-ss/sd-scripts`. +# For more info, see: https://github.com/kohya-ss/sd-scripts/blob/sd3/networks/convert_flux_lora.py + +class LoRAConverter: + + @staticmethod + def load_weights(lora_path: str) -> dict: + state_dict = LoRAConverter._load_pytorch_weights(lora_path) + state_dict = LoRAConverter._convert_weights_to_diffusers(state_dict) + state_dict = LoRAConverter._convert_to_mlx(state_dict) + state_dict = list(state_dict.items()) + state_dict = tree_unflatten(state_dict) + return state_dict + + @staticmethod + def _load_pytorch_weights(lora_path: str) -> dict: + state_dict = {} + with safe_open(lora_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + return state_dict + + @staticmethod + def _convert_weights_to_diffusers(source: dict) -> dict: + target = {} + for i in range(19): + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + LoRAConverter._convert_to_diffusers_cat( + source, + target, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2" + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear" + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + LoRAConverter._convert_to_diffusers_cat( + source, + target, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + LoRAConverter._convert_to_diffusers_cat( + source, + target, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + LoRAConverter._convert_to_diffusers( + source, + target, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out" + ) + LoRAConverter._convert_to_diffusers( + source, + target, f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(source) > 0: + logger.warning(f"Unsupported keys for diffusers: {source.keys()}") + return target + + @staticmethod + def _convert_to_diffusers( + source: dict, + target: dict, + source_key: str, + target_key: str + ): + if source_key + ".lora_down.weight" not in source: + return + down_weight = source.pop(source_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = source.pop(source_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + target[target_key + ".lora_A.weight"] = down_weight * scale_down + target[target_key + ".lora_B.weight"] = source.pop(source_key + ".lora_up.weight") * scale_up + + @staticmethod + def _convert_to_diffusers_cat( + source: dict, + target: dict, + source_key: str, + target_keys: list[str], + dims=None + ): + if source_key + ".lora_down.weight" not in source: + return + down_weight = source.pop(source_key + ".lora_down.weight") + up_weight = source.pop(source_key + ".lora_up.weight") + source_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = source.pop(source_key + ".alpha") + scale = alpha / source_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(target_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check up-weight is sparse or not + is_sparse = False + if source_lora_rank % num_splits == 0: + diffusers_rank = source_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i: i + dims[j], k * diffusers_rank: (k + 1) * diffusers_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {source_key}") + + # make diffusers weight + diffusers_down_keys = [k + ".lora_A.weight" for k in target_keys] + diffusers_up_keys = [k + ".lora_B.weight" for k in target_keys] + if not is_sparse: + # down_weight is copied to each split + target.update({k: down_weight for k in diffusers_down_keys}) + + # up_weight is split to each split + target.update({k: v for k, v in zip(diffusers_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + target.update({k: v for k, v in zip(diffusers_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + target[diffusers_up_keys[j]] = up_weight[i: i + dims[j], j * diffusers_rank: (j + 1) * diffusers_rank].contiguous() + i += dims[j] + + @staticmethod + def _convert_to_mlx(torch_dict: dict): + mlx_dict = {} + for key, value in torch_dict.items(): + if isinstance(value, torch.Tensor): + mlx_dict[key] = mx.array(value.detach().cpu()) + else: + mlx_dict[key] = value + return mlx_dict diff --git a/src/mflux/weights/lora_util.py b/src/mflux/weights/lora_util.py index 43d3fd8..11bf3d2 100644 --- a/src/mflux/weights/lora_util.py +++ b/src/mflux/weights/lora_util.py @@ -28,9 +28,6 @@ def _validate_lora_scales(lora_files: list[str], lora_scales: list[float]) -> li @staticmethod def _apply_lora(transformer: dict, lora_file: str, lora_scale: float) -> None: - if lora_scale < 0.0 or lora_scale > 1.0: - raise Exception(f"Invalid scale {lora_scale} provided for {lora_file}. Valid Range [0.0 - 1.0] ") - from mflux.weights.weight_handler import WeightHandler lora_transformer, _ = WeightHandler.load_transformer(lora_path=lora_file) LoraUtil._apply_transformer(transformer, lora_transformer, lora_scale) diff --git a/src/mflux/weights/weight_handler.py b/src/mflux/weights/weight_handler.py index fc485ca..0204728 100644 --- a/src/mflux/weights/weight_handler.py +++ b/src/mflux/weights/weight_handler.py @@ -4,6 +4,7 @@ from huggingface_hub import snapshot_download from mlx.utils import tree_unflatten +from mflux.weights.lora_converter import LoRAConverter from mflux.weights.lora_util import LoraUtil from mflux.weights.weight_util import WeightUtil @@ -65,7 +66,7 @@ def load_transformer(root_path: Path | None = None, lora_path: str | None = None if lora_path: if 'transformer' not in weights: - raise Exception("The key `transformer` is missing in the LoRA safetensors file. Please ensure that the file is correctly formatted and contains the expected keys.") + weights = LoRAConverter.load_weights(lora_path) weights = weights["transformer"] # Quantized weights (i.e. ones exported from this project) don't need any post-processing. @@ -75,10 +76,11 @@ def load_transformer(root_path: Path | None = None, lora_path: str | None = None # Reshape and process the huggingface weights if "transformer_blocks" in weights: for block in weights["transformer_blocks"]: - block["ff"] = { - "linear1": block["ff"]["net"][0]["proj"], - "linear2": block["ff"]["net"][2] - } + if block.get("ff") is not None: + block["ff"] = { + "linear1": block["ff"]["net"][0]["proj"], + "linear2": block["ff"]["net"][2] + } if block.get("ff_context") is not None: block["ff_context"] = { "linear1": block["ff_context"]["net"][0]["proj"],