Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere #6490

Merged
merged 14 commits into from
Jun 13, 2024

Conversation

lstein
Copy link
Collaborator

@lstein lstein commented Jun 6, 2024

Summary

Torch profiling of LoRA patching indicates that most of the CPU time is spent in the RAM->VRAM and VRAM->RAM layer copying steps. See the CPU total column in the upper table for to, _to_copy, cudaMemcpyAsync and cudaStreamSynchronize.

I have experimented with adding non_blocking=True to the various direct and indirect calls to torch.nn.Module.to() and found that the amount of time spent by the CPU in these calls can be reduced by about 40% as shown in the lower of the two tables.

When non-blocking mode is enabled for LoRA patching, TI patching, IP adapters and the model manager's RAM cache code, I get a small but significant walltime speedup of about 0.3s for full generations. The timing test used 3 LoRAs, one TI, an IP adapter and an SDXL model. This is not a large effect, but might be worth it.

I also cleaned up some type checking errors that appeared with pyright but not mypy.

I have not observed any anomalous behavior.

NON_BLOCKING=FALSE
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::to         9.51%      19.967ms        74.39%     156.121ms      30.891us       0.000us         0.00%      25.010ms       4.949us          5054  
                                         aten::_to_copy         7.85%      16.468ms        73.68%     154.636ms      30.597us       0.000us         0.00%      29.555ms       5.848us          5054  
                                            aten::copy_         5.59%      11.735ms        66.62%     139.816ms      27.664us      31.916ms        34.01%      31.916ms       6.315us          5054  
                                        cudaMemcpyAsync        45.75%      96.013ms        45.75%      96.013ms      33.245us       0.000us         0.00%       0.000us       0.000us          2888  
                                  cudaStreamSynchronize        13.88%      29.140ms        13.88%      29.140ms      10.090us       0.000us         0.00%       0.000us       0.000us          2888  
                                           aten::matmul         0.60%       1.255ms         6.90%      14.485ms      20.062us       0.000us         0.00%      12.486ms      17.294us           722  
                                               aten::mm         4.91%      10.299ms         6.68%      14.011ms      19.406us      14.791ms        15.76%      14.791ms      20.486us           722  
                                    aten::empty_strided         4.39%       9.207ms         4.39%       9.207ms       1.822us       0.000us         0.00%       0.000us       0.000us          5054  
                                       cudaLaunchKernel         3.26%       6.851ms         3.26%       6.851ms       1.581us       0.000us         0.00%       0.000us       0.000us          4332  
                                              aten::mul         1.63%       3.423ms         2.15%       4.505ms       6.240us      27.716ms        29.53%      27.716ms      38.388us           722  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 209.872ms
Self CUDA time total: 93.853ms

NON_BLOCKING=TRUE
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::to        12.37%      16.342ms        73.04%      96.475ms      19.089us       0.000us         0.00%      26.937ms       5.330us          5054  
                                         aten::_to_copy         5.95%       7.853ms        71.63%      94.603ms      18.718us       0.000us         0.00%      29.446ms       5.826us          5054  
                                    aten::empty_strided         7.80%      10.297ms        48.42%      63.958ms      12.655us       0.000us         0.00%       0.000us       0.000us          5054  
                                          cudaHostAlloc        40.66%      53.703ms        40.66%      53.703ms      37.190us       0.000us         0.00%       0.000us       0.000us          1444  
                                            aten::copy_         7.27%       9.606ms        18.76%      24.777ms       4.902us      32.473ms        34.34%      32.473ms       6.425us          5054  
                                        cudaMemcpyAsync         8.51%      11.245ms         8.51%      11.245ms       3.894us       0.000us         0.00%       0.000us       0.000us          2888  
                                           aten::matmul         1.12%       1.476ms         6.71%       8.857ms      12.267us       0.000us         0.00%      12.405ms      17.181us           722  
                                               aten::mm         4.22%       5.572ms         6.47%       8.550ms      11.842us      14.772ms        15.62%      14.772ms      20.460us           722  
                                       cudaLaunchKernel         5.31%       7.016ms         5.31%       7.016ms       1.620us       0.000us         0.00%       0.000us       0.000us          4332  
                                              aten::mul         3.25%       4.299ms         4.08%       5.383ms       7.456us      27.828ms        29.43%      27.828ms      38.543us           722  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 132.080ms
Self CUDA time total: 94.561ms

Related Issues / Discussions

See discussion thread in merged PR #6439.

QA Instructions

Try various combinations of LoRAs, TIs, IP adapters and models to see if anything breaks.

Merge Plan

Squash merge when approved.

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)

@github-actions github-actions bot added python PRs that change python files backend PRs that change backend files labels Jun 6, 2024
Copy link
Collaborator

@RyanJDick RyanJDick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the commit history got a little messy on this. It's probably worth doing a quick rebase to just keep the relevant commits.

The change looks good to me, and should be low-risk.

It's probably also worth experimenting with pinned memory (often recommended in combination with non_blocking=True). Up to you if you want to explore that in this PR or a separate one.

@lstein
Copy link
Collaborator Author

lstein commented Jun 13, 2024

It looks like the commit history got a little messy on this. It's probably worth doing a quick rebase to just keep the relevant commits.

I did a merge from main in the middle of working on this and rebase is creating a lot of conflicts.

@lstein lstein enabled auto-merge (squash) June 13, 2024 16:59
@lstein lstein merged commit a3cb5da into main Jun 13, 2024
14 checks passed
@lstein lstein deleted the lstein/feat/lora_patch_optimization_2 branch June 13, 2024 17:10
psychedelicious added a commit that referenced this pull request Jul 15, 2024
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use.

This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe.

- Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549.
- Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit.

On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU.

One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots.

Much safer is to fully revert non-locking - which is what this change does.
psychedelicious added a commit that referenced this pull request Jul 15, 2024
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use.

This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe.

- Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549.
- Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit.

On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU.

One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots.

Much safer is to fully revert non-locking - which is what this change does.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files python PRs that change python files
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants