-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Partial Loading PR3: Integrate 1) partial loading, 2) quantized model…
…s, 3) model patching (#7500) ## Summary This PR is the third in a sequence of PRs working towards support for partial loading of models onto the compute device (for low-VRAM operation). This PR updates the LoRA patching code so that the following features can cooperate fully: - Partial loading of weights onto the GPU - Quantized layers / weights - Model patches (e.g. LoRA) Note that this PR does not yet enable partial loading. It adds support in the model patching code so that partial loading can be enabled in a future PR. ## Technical Design Decisions The layer patching logic has been integrated into the custom layers (via `CustomModuleMixin`) rather than keeping it in a separate set of wrapper layers, as before. This has the following advantages: - It makes it easier to calculate the modified weights on the fly and then reuse the normal forward() logic. - In the future, it makes it possible to pass original parameters that have been cast to the device down to the LoRA calculation without having to re-cast (but the current implementation hasn't fully taken advantage of this yet). ## Know Limitations 1. I haven't fully solved device management for patch types that require the original layer value to calculate the patch. These aren't very common, and are not compatible with some quantized layers, so leaving this for future if there's demand. 2. There is a small speed regression for models that have CPU bottlenecks. This seems to be caused by slightly slower method resolution on the custom layers sub-classes. The regression does not show up on larger models, like FLUX, that are almost entirely GPU-limited. I think this small regression is tolerable, but if we decide that it's not, then the slowdown can easily be reclaimed by optimizing other CPU operations (e.g. if we only sent every 2nd progress image, we'd see a much more significant speedup). ## Related Issues / Discussions - #7492 - #7494 ## QA Instructions Speed tests: - Vanilla SD1 speed regression - Before: 3.156s (8.78 it/s) - After: 3.54s (8.35 it/s) - Vanilla SDXL speed regression - Before: 6.23s (4.46 it/s) - After: 6.45s (4.31 it/s) - Vanilla FLUX speed regression - Before: 12.02s (2.27 it/s) - After: 11.91s (2.29 it/s) LoRA tests with default configuration: - [x] SD1: A handful of LoRA variants - [x] SDXL: A handful of LoRA variants - [x] flux non-quantized: multiple lora variants - [x] flux bnb-quantized: multiple lora variants - [x] flux ggml-quantized: muliple lora variants - [x] flux non-quantized: FLUX control LoRA - [x] flux bnb-quantized: FLUX control LoRA - [x] flux ggml-quantized: FLUX control LoRA LoRA tests with sidecar patching forced: - [x] SD1: A handful of LoRA variants - [x] SDXL: A handful of LoRA variants - [x] flux non-quantized: multiple lora variants - [x] flux bnb-quantized: multiple lora variants - [x] flux ggml-quantized: muliple lora variants - [x] flux non-quantized: FLUX control LoRA - [x] flux bnb-quantized: FLUX control LoRA - [x] flux ggml-quantized: FLUX control LoRA Other: - [x] Smoke testing of IP-Adapter, ControlNet All tests repeated on: - [x] cuda - [x] cpu (only test SD1, because larger models are prohibitively slow) - [x] mps (skipped FLUX tests, because my Mac doesn't have enough memory to run them in a reasonable amount of time) ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
- Loading branch information
Showing
50 changed files
with
1,732 additions
and
1,033 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 0 additions & 50 deletions
50
invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py
This file was deleted.
Oops, something went wrong.
8 changes: 8 additions & 0 deletions
8
...d/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
This directory contains custom implementations of common torch.nn.Module classes that add support for: | ||
- Streaming weights to the execution device | ||
- Applying sidecar patches at execution time (e.g. sidecar LoRA layers) | ||
|
||
Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved: | ||
- `isinstance(m, torch.nn.OrginalModule)` should still work. | ||
- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.) |
File renamed without changes.
43 changes: 43 additions & 0 deletions
43
...kend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
|
||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device | ||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( | ||
CustomModuleMixin, | ||
) | ||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( | ||
add_nullable_tensors, | ||
) | ||
|
||
|
||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): | ||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
bias = cast_to_device(self.bias, input.device) | ||
|
||
# Prepare the original parameters for the patch aggregation. | ||
orig_params = {"weight": weight, "bias": bias} | ||
# Filter out None values. | ||
orig_params = {k: v for k, v in orig_params.items() if v is not None} | ||
|
||
aggregated_param_residuals = self._aggregate_patch_parameters( | ||
patches_and_weights=self._patches_and_weights, | ||
orig_params=orig_params, | ||
device=input.device, | ||
) | ||
|
||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) | ||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) | ||
return self._conv_forward(input, weight, bias) | ||
|
||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
bias = cast_to_device(self.bias, input.device) | ||
return self._conv_forward(input, weight, bias) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
if len(self._patches_and_weights) > 0: | ||
return self._autocast_forward_with_patches(input) | ||
elif self._device_autocasting_enabled: | ||
return self._autocast_forward(input) | ||
else: | ||
return super().forward(input) |
Oops, something went wrong.