Skip to content

Commit

Permalink
Merge pull request #46 from filipstrand/better-lora-support
Browse files Browse the repository at this point in the history
Better LoRA support
  • Loading branch information
filipstrand authored Sep 13, 2024
2 parents 3739a6b + d2553f2 commit 93e80ea
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 8 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,26 @@ mflux-generate \
Just to see the difference, this image displays the four cases: One of having both adapters fully active, partially active and no LoRA at all.
The example above also show the usage of `--lora-scales` flag.

#### Supported LoRA formats (updated)

Since different fine-tuning services can use different implementations of FLUX, the corresponding
LoRA weights trained on these services can be different from one another. The aim of MFLUX is to support the most common ones.
The following table show the current supported formats:

| Supported | Name | Example | Notes |
|-----------|-----------|----------------------------------------------------------------------------------------------------------|-------------------------------------|
|| BFL | [civitai - Impressionism](https://civitai.com/models/545264/impressionism-sdxl-pony-flux) | Many things on civitai seem to work |
|| Diffusers | [Flux_1_Dev_LoRA_Paper-Cutout-Style](https://huggingface.co/Norod78/Flux_1_Dev_LoRA_Paper-Cutout-Style/) | |
|| XLabs-AI | [flux-RealismLora](https://huggingface.co/XLabs-AI/flux-RealismLora/tree/main) | |

To report additional formats, examples or other any suggestions related to LoRA format support, please see [issue #47](https://github.com/filipstrand/mflux/issues/47).

### Current limitations

- Images are generated one by one.
- Negative prompts not supported.
- LoRA weights are only supported for the transformer part of the network.
- Some LoRA adapters does not work.

### TODO

Expand Down
239 changes: 239 additions & 0 deletions src/mflux/weights/lora_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import logging
import mlx.core as mx
import torch
from mlx.utils import tree_unflatten
from safetensors import safe_open

logger = logging.getLogger(__name__)


# This script is based on `convert_flux_lora.py` from `kohya-ss/sd-scripts`.
# For more info, see: https://github.com/kohya-ss/sd-scripts/blob/sd3/networks/convert_flux_lora.py

class LoRAConverter:

@staticmethod
def load_weights(lora_path: str) -> dict:
state_dict = LoRAConverter._load_pytorch_weights(lora_path)
state_dict = LoRAConverter._convert_weights_to_diffusers(state_dict)
state_dict = LoRAConverter._convert_to_mlx(state_dict)
state_dict = list(state_dict.items())
state_dict = tree_unflatten(state_dict)
return state_dict

@staticmethod
def _load_pytorch_weights(lora_path: str) -> dict:
state_dict = {}
with safe_open(lora_path, framework="pt") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
return state_dict

@staticmethod
def _convert_weights_to_diffusers(source: dict) -> dict:
target = {}
for i in range(19):
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0"
)
LoRAConverter._convert_to_diffusers_cat(
source,
target,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj"
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2"
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear"
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out"
)
LoRAConverter._convert_to_diffusers_cat(
source,
target,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2"
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear"
)

for i in range(38):
LoRAConverter._convert_to_diffusers_cat(
source,
target,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
LoRAConverter._convert_to_diffusers(
source,
target,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out"
)
LoRAConverter._convert_to_diffusers(
source,
target, f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear"
)

if len(source) > 0:
logger.warning(f"Unsupported keys for diffusers: {source.keys()}")
return target

@staticmethod
def _convert_to_diffusers(
source: dict,
target: dict,
source_key: str,
target_key: str
):
if source_key + ".lora_down.weight" not in source:
return
down_weight = source.pop(source_key + ".lora_down.weight")

# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = source.pop(source_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here

# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

target[target_key + ".lora_A.weight"] = down_weight * scale_down
target[target_key + ".lora_B.weight"] = source.pop(source_key + ".lora_up.weight") * scale_up

@staticmethod
def _convert_to_diffusers_cat(
source: dict,
target: dict,
source_key: str,
target_keys: list[str],
dims=None
):
if source_key + ".lora_down.weight" not in source:
return
down_weight = source.pop(source_key + ".lora_down.weight")
up_weight = source.pop(source_key + ".lora_up.weight")
source_lora_rank = down_weight.shape[0]

# scale weight by alpha and dim
alpha = source.pop(source_key + ".alpha")
scale = alpha / source_lora_rank

# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

# calculate dims if not provided
num_splits = len(target_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]

# check up-weight is sparse or not
is_sparse = False
if source_lora_rank % num_splits == 0:
diffusers_rank = source_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i: i + dims[j], k * diffusers_rank: (k + 1) * diffusers_rank] == 0)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {source_key}")

# make diffusers weight
diffusers_down_keys = [k + ".lora_A.weight" for k in target_keys]
diffusers_up_keys = [k + ".lora_B.weight" for k in target_keys]
if not is_sparse:
# down_weight is copied to each split
target.update({k: down_weight for k in diffusers_down_keys})

# up_weight is split to each split
target.update({k: v for k, v in zip(diffusers_up_keys, torch.split(up_weight, dims, dim=0))})
else:
# down_weight is chunked to each split
target.update({k: v for k, v in zip(diffusers_down_keys, torch.chunk(down_weight, num_splits, dim=0))})

# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
target[diffusers_up_keys[j]] = up_weight[i: i + dims[j], j * diffusers_rank: (j + 1) * diffusers_rank].contiguous()
i += dims[j]

@staticmethod
def _convert_to_mlx(torch_dict: dict):
mlx_dict = {}
for key, value in torch_dict.items():
if isinstance(value, torch.Tensor):
mlx_dict[key] = mx.array(value.detach().cpu())
else:
mlx_dict[key] = value
return mlx_dict
3 changes: 0 additions & 3 deletions src/mflux/weights/lora_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ def _validate_lora_scales(lora_files: list[str], lora_scales: list[float]) -> li

@staticmethod
def _apply_lora(transformer: dict, lora_file: str, lora_scale: float) -> None:
if lora_scale < 0.0 or lora_scale > 1.0:
raise Exception(f"Invalid scale {lora_scale} provided for {lora_file}. Valid Range [0.0 - 1.0] ")

from mflux.weights.weight_handler import WeightHandler
lora_transformer, _ = WeightHandler.load_transformer(lora_path=lora_file)
LoraUtil._apply_transformer(transformer, lora_transformer, lora_scale)
Expand Down
12 changes: 7 additions & 5 deletions src/mflux/weights/weight_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten

from mflux.weights.lora_converter import LoRAConverter
from mflux.weights.lora_util import LoraUtil
from mflux.weights.weight_util import WeightUtil

Expand Down Expand Up @@ -65,7 +66,7 @@ def load_transformer(root_path: Path | None = None, lora_path: str | None = None

if lora_path:
if 'transformer' not in weights:
raise Exception("The key `transformer` is missing in the LoRA safetensors file. Please ensure that the file is correctly formatted and contains the expected keys.")
weights = LoRAConverter.load_weights(lora_path)
weights = weights["transformer"]

# Quantized weights (i.e. ones exported from this project) don't need any post-processing.
Expand All @@ -75,10 +76,11 @@ def load_transformer(root_path: Path | None = None, lora_path: str | None = None
# Reshape and process the huggingface weights
if "transformer_blocks" in weights:
for block in weights["transformer_blocks"]:
block["ff"] = {
"linear1": block["ff"]["net"][0]["proj"],
"linear2": block["ff"]["net"][2]
}
if block.get("ff") is not None:
block["ff"] = {
"linear1": block["ff"]["net"][0]["proj"],
"linear2": block["ff"]["net"][2]
}
if block.get("ff_context") is not None:
block["ff_context"] = {
"linear1": block["ff_context"]["net"][0]["proj"],
Expand Down

0 comments on commit 93e80ea

Please sign in to comment.