Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patching lora load for faster loading #39

Merged
merged 4 commits into from
Sep 27, 2024
Merged

Conversation

daanelson
Copy link
Collaborator

@daanelson daanelson commented Sep 27, 2024

Profiled lora loading in prod to see what was taking so long - it looks like our old nemesis kaiming_uniform, and other such functions that are used to randomly initialize empty tensors.

conveniently, someone just pushed functionality to peft to disable this. So we don't have to get too wild to turn it off; we just need to patch load_lora_into_transformer to take advantage of that functionality until it's integrated into diffusers.

That's what this PR does. With this I saw lora load times (after download) drop from about 10 seconds to about 1.2 seconds. Tested locally w/dev and schnell.


Important

Patch load_lora_into_transformer to use low_cpu_mem_usage=True, significantly reducing LoRA load times, and update dependencies accordingly.

  • Behavior:
    • Patch load_lora_into_transformer in lora_loading_patch.py to use low_cpu_mem_usage=True, reducing LoRA load times from ~10s to ~1.2s.
    • Update predict.py to use the patched load_lora_into_transformer for FluxPipeline, FluxImg2ImgPipeline, and FluxInpaintPipeline.
  • Dependencies:
    • Upgrade peft version from 0.12.0 to 0.13.0 in cog.yaml.
  • Misc:
    • Add /weights-cache to .dockerignore.

This description was created by Ellipsis for f53e34c. It will automatically update as commits are pushed.

Copy link
Member

@andreasjansson andreasjansson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great sleuthing! A couple nitpicks but not blockers.

@@ -101,6 +102,9 @@ def setup(self) -> None: # pyright: ignore
"FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
dev_pipe.__class__.load_lora_into_transformer = classmethod(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but could probably also done a similar pattern to https://github.com/replicate/flux-fine-tuner/blob/main/submodule_patches.py to avoid duplicating this line.

@@ -430,6 +440,7 @@ def load_single_lora(self, lora_url: str, model: str):
lora_path = self.weights_cache.ensure(lora_url)
pipe.load_lora_weights(lora_path, adapter_name="main")
self.loaded_lora_urls[model] = LoadedLoRAs(main=lora_url, extra=None)
pipe = pipe.to("cuda")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not already on gpu at this point? How did it work before? Might be worth adding a comment here.

@andreasjansson
Copy link
Member

Did a test-only cog-safe-push, it passed https://github.com/replicate/flux-fine-tuner/actions/runs/11072824989

@andreasjansson andreasjansson merged commit 9359a41 into main Sep 27, 2024
3 checks passed
@andreasjansson andreasjansson deleted the fast-lora-load branch September 27, 2024 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants