-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere #6490
Conversation
…com/invoke-ai/InvokeAI into lstein/feat/lora_patch_optimization
Co-authored-by: Ryan Dick <[email protected]>
…s, for slight performance increases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the commit history got a little messy on this. It's probably worth doing a quick rebase to just keep the relevant commits.
The change looks good to me, and should be low-risk.
It's probably also worth experimenting with pinned memory (often recommended in combination with non_blocking=True
). Up to you if you want to explore that in this PR or a separate one.
I did a merge from main in the middle of working on this and rebase is creating a lot of conflicts. |
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use. This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe. - Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549. - Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit. On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU. One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots. Much safer is to fully revert non-locking - which is what this change does.
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use. This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe. - Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549. - Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit. On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU. One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots. Much safer is to fully revert non-locking - which is what this change does.
Summary
Torch profiling of LoRA patching indicates that most of the CPU time is spent in the RAM->VRAM and VRAM->RAM layer copying steps. See the
CPU total
column in the upper table forto
,_to_copy
,cudaMemcpyAsync
andcudaStreamSynchronize
.I have experimented with adding
non_blocking=True
to the various direct and indirect calls totorch.nn.Module.to()
and found that the amount of time spent by the CPU in these calls can be reduced by about 40% as shown in the lower of the two tables.When non-blocking mode is enabled for LoRA patching, TI patching, IP adapters and the model manager's RAM cache code, I get a small but significant walltime speedup of about 0.3s for full generations. The timing test used 3 LoRAs, one TI, an IP adapter and an SDXL model. This is not a large effect, but might be worth it.
I also cleaned up some type checking errors that appeared with pyright but not mypy.
I have not observed any anomalous behavior.
Related Issues / Discussions
See discussion thread in merged PR #6439.
QA Instructions
Try various combinations of LoRAs, TIs, IP adapters and models to see if anything breaks.
Merge Plan
Squash merge when approved.
Checklist