-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
patching lora load for faster loading #39
Conversation
126730e
to
34b4947
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great sleuthing! A couple nitpicks but not blockers.
@@ -101,6 +102,9 @@ def setup(self) -> None: # pyright: ignore | |||
"FLUX.1-dev", | |||
torch_dtype=torch.bfloat16, | |||
).to("cuda") | |||
dev_pipe.__class__.load_lora_into_transformer = classmethod( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, but could probably also done a similar pattern to https://github.com/replicate/flux-fine-tuner/blob/main/submodule_patches.py to avoid duplicating this line.
@@ -430,6 +440,7 @@ def load_single_lora(self, lora_url: str, model: str): | |||
lora_path = self.weights_cache.ensure(lora_url) | |||
pipe.load_lora_weights(lora_path, adapter_name="main") | |||
self.loaded_lora_urls[model] = LoadedLoRAs(main=lora_url, extra=None) | |||
pipe = pipe.to("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not already on gpu at this point? How did it work before? Might be worth adding a comment here.
0861945
to
f53e34c
Compare
Did a test-only cog-safe-push, it passed https://github.com/replicate/flux-fine-tuner/actions/runs/11072824989 |
Profiled lora loading in prod to see what was taking so long - it looks like our old nemesis
kaiming_uniform
, and other such functions that are used to randomly initialize empty tensors.conveniently, someone just pushed functionality to peft to disable this. So we don't have to get too wild to turn it off; we just need to patch
load_lora_into_transformer
to take advantage of that functionality until it's integrated into diffusers.That's what this PR does. With this I saw lora load times (after download) drop from about 10 seconds to about 1.2 seconds. Tested locally w/dev and schnell.
Important
Patch
load_lora_into_transformer
to uselow_cpu_mem_usage=True
, significantly reducing LoRA load times, and update dependencies accordingly.load_lora_into_transformer
inlora_loading_patch.py
to uselow_cpu_mem_usage=True
, reducing LoRA load times from ~10s to ~1.2s.predict.py
to use the patchedload_lora_into_transformer
forFluxPipeline
,FluxImg2ImgPipeline
, andFluxInpaintPipeline
.peft
version from0.12.0
to0.13.0
incog.yaml
./weights-cache
to.dockerignore
.This description was created by for f53e34c. It will automatically update as commits are pushed.