Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Status of MPS Support #181

Closed
constantinpape opened this issue Sep 7, 2023 · 21 comments
Closed

Status of MPS Support #181

constantinpape opened this issue Sep 7, 2023 · 21 comments

Comments

@constantinpape
Copy link
Contributor

constantinpape commented Sep 7, 2023

Keeping track of the MPS related things @GenevieveBuckley and me discovered while working on MPS support:

@constantinpape
Copy link
Contributor Author

I implemented the benchmarks in #182 now. They are implemented in https://github.com/computational-cell-analytics/micro-sam/blob/dev/development/benchmark.py.
The script accepts the arguments -m to set the model type (either vit_h, vit_l or vit_m) and -d to set the device (either cpu, cuda or mps, latter two if available).

@GenevieveBuckley @Marei33 could you run these benchmarks on your laptops with the following four settings please:
vit_b, cpu, vit_b, mps, vit_h, cpu, vit_h, mps.

And then paste the results here. To run them you will need to check out the latest dev branch. Let me know if anything is unclear.

For reference, here are the results for vit_b on my (linux) laptop:

model device benchmark runtime
vit_b cpu embeddings 5.72605
vit_b cpu prompt-p1n0 0.0527573
vit_b cpu prompt-p2n4 0.0473237
vit_b cpu prompt-box 0.0519929
vit_b cpu prompt-box-and-points 0.0487959
vit_b cpu amg 70.4746

@GenevieveBuckley
Copy link
Collaborator

Link to gist with the results from yesterday (before we had the benchmark script) that caused us to start this investigation: https://gist.github.com/GenevieveBuckley/874b6b282b388524b6d62f25b3b9bb1c

@GenevieveBuckley
Copy link
Collaborator

Note: to run the benchmark script, pandas requires the (usually optional) dependency tabulate. To save yourself some time and not have to run it twice, double check you have pip install tabulate / conda install tabulate before running the benchmarks.

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 8, 2023

model device benchmark runtime
vit_h cpu embeddings 16.9334
vit_h cpu prompt-p1n0 0.040143
vit_h cpu prompt-p2n4 0.0383821
vit_h cpu prompt-box 0.041321
vit_h cpu prompt-box-and-points 0.0410571
vit_h cpu amg 52.4115
-------- --------- ---------------------- -----------
vit_h mps embeddings 11.4617
vit_h mps prompt-p1n0 0.0342882
vit_h mps prompt-p2n4 0.0313947
vit_h mps prompt-box 0.0284584
vit_h mps prompt-box-and-points 0.032028
vit_h mps amg 220.851
-------- --------- ---------------------- -----------
vit_l cpu embeddings 10.1647
vit_l cpu prompt-p1n0 0.0386679
vit_l cpu prompt-p2n4 0.039386
vit_l cpu prompt-box 0.037823
vit_l cpu prompt-box-and-points 0.0404921
vit_l cpu amg 41.5105
-------- --------- ---------------------- -----------
vit_l mps embeddings 4.63046
vit_l mps prompt-p1n0 0.0345128
vit_l mps prompt-p2n4 0.0316033
vit_l mps prompt-box 0.0278041
vit_l mps prompt-box-and-points 0.0320668
vit_l mps amg 145.669
-------- --------- ---------------------- -----------
vit_b cpu embeddings 4.01707
vit_b cpu prompt-p1n0 0.0395901
vit_b cpu prompt-p2n4 0.0408981
vit_b cpu prompt-box 0.0390179
vit_b cpu prompt-box-and-points 0.0438592
vit_b cpu amg 40.5832
-------- --------- ---------------------- -----------
vit_b mps embeddings 1.97907
vit_b mps prompt-p1n0 0.0335419
vit_b mps prompt-p2n4 0.0308049
vit_b mps prompt-box 0.0274649
vit_b mps prompt-box-and-points 0.0313318
vit_b mps amg 148.047
-------- --------- ---------------------- -----------

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 8, 2023

Summary: something about AMG is really killing the performance for MPS backends.

We'll need to do some line profiling on that part of the code to get more information.

I am suspicious that the fallback to CPU might somehow involve a lot of transferring large tensors back and forth between the cpu and mps, and perhaps that is a large part of why it is so much slower than the cpu only computation.

Running benchmark_amg ...
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())

I also see this conversion to int32 happening automatically (it's possible the int64 tensor is being created in segment-anything, not in the micro-sam code)

/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 8, 2023

Line profiling with line_profiler

The slowest part is in amg.initialize, so we need to look more closely at what's happening in there next.

CPU

kernprof -lv benchmark.py -m vit_h -d cpu -e -p

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   127                                           @profile
   128                                           def benchmark_amg(image, predictor, n=3):
   129         1         43.0     43.0      0.0      print("Running benchmark_amg ...")
   130         1   16802565.0    2e+07     10.4      embeddings = util.precompute_image_embeddings(predictor, image)
   131         1        662.0    662.0      0.0      amg = instance_seg.AutomaticMaskGenerator(predictor)
   132         1          0.0      0.0      0.0      times = []
   133         4          1.0      0.2      0.0      for _ in range(n):
   134         3         10.0      3.3      0.0          t0 = time.time()
   135         3  129829213.0    4e+07     80.4          amg.initialize(image, embeddings)
   136         3   14811001.0    5e+06      9.2          amg.generate()
   137         3         38.0     12.7      0.0          times.append(time.time() - t0)
   138         1        147.0    147.0      0.0      runtime = np.mean(times)
   139         1          1.0      1.0      0.0      return ["amg"], [runtime]

MPS

kernprof -lv benchmark.py -m vit_h -d mps -e -p
Running benchmarks for vit_h
with device: mps
Running benchmark_amg ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | mps      | amg         |   267.354 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 811.088 s
File: benchmark.py
Function: benchmark_amg at line 127

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   127                                           @profile
   128                                           def benchmark_amg(image, predictor, n=3):
   129         1        258.0    258.0      0.0      print("Running benchmark_amg ...")
   130         1    9019222.0    9e+06      1.1      embeddings = util.precompute_image_embeddings(predictor, image)
   131         1       4216.0   4216.0      0.0      amg = instance_seg.AutomaticMaskGenerator(predictor)
   132         1          0.0      0.0      0.0      times = []
   133         4          2.0      0.5      0.0      for _ in range(n):
   134         3         11.0      3.7      0.0          t0 = time.time()
   135         3  782484795.0    3e+08     96.5          amg.initialize(image, embeddings)
   136         3   19574430.0    7e+06      2.4          amg.generate()
   137         3        681.0    227.0      0.0          times.append(time.time() - t0)
   138         1       3923.0   3923.0      0.0      runtime = np.mean(times)
   139         1          0.0      0.0      0.0      return ["amg"], [runtime]

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 8, 2023

The speed is not better with pytorch-nightly, sadly.

But the warning messages about CPU fallback are gone with pytorch-nightly, so it seems my suspicion about that being the reason things are slow is not correct.

EDIT: it might still be possible that there is somehow a lot of transfer of information between cpu and mps, regardless. But it is very slow. It would be good to have some CUDA benchmarks to compare CPU performance to as well.

Details

``` (test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % python benchmark.py -m vit_h -d cpu Running benchmarks for vit_h with device: cpu Running benchmark_embeddings ... Running benchmark_prompts ... Running benchmark_amg ... | model | device | benchmark | runtime | |:--------|:---------|:----------------------|-----------:| | vit_h | cpu | embeddings | 14.7882 | | vit_h | cpu | prompt-p1n0 | 0.0370929 | | vit_h | cpu | prompt-p2n4 | 0.0354571 | | vit_h | cpu | prompt-box | 0.0344546 | | vit_h | cpu | prompt-box-and-points | 0.0367029 | | vit_h | cpu | amg | 41.8242 | (test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % python benchmark.py -m vit_h -d mps Running benchmarks for vit_h with device: mps Running benchmark_embeddings ... Running benchmark_prompts ... Running benchmark_amg ... | model | device | benchmark | runtime | |:--------|:---------|:----------------------|------------:| | vit_h | mps | embeddings | 9.93057 | | vit_h | mps | prompt-p1n0 | 0.032023 | | vit_h | mps | prompt-p2n4 | 0.030647 | | vit_h | mps | prompt-box | 0.0267961 | | vit_h | mps | prompt-box-and-points | 0.0303428 | | vit_h | mps | amg | 282.099 | (test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % conda list | grep torch pytorch 2.2.0.dev20230907 py3.10_0 pytorch-nightly torchvision 0.17.0.dev20230907 py310_cpu pytorch-nightly ```

@constantinpape
Copy link
Contributor Author

Hi @GenevieveBuckley,
thanks for checking this carefully! Indeed most things speed up with MPS, but AMG takes a significant performance hit.
My theory: this is because we use a too large batch size here, which slows down MPS due to memory management overhead.
You can control the batch_size by passing points_per_batch to AutomaticMaskGenerator, e.g. AutomaticMaskGenerator(predictor, points_per_batch=16) (the default is 64). Can you check if the performance is better for smaller batch sizes?

@GenevieveBuckley
Copy link
Collaborator

My theory: this is because we use a too large batch size here, which slows down MPS due to memory management overhead. ... Can you check if the performance is better for smaller batch sizes?

No, not really 😢

n=1 iterations, because this benchmark takes a long time to run

model device benchmark points_per_batch runtime
vit_h cpu amg 64 41.2176
vit_h mps amg 64 465.819
vit_h cpu amg 16 38.4009
vit_h mps amg 16 158.788
vit_h cpu amg 8 39.693
vit_h mps amg 8 166.654
vit_h cpu amg 4 42.2534
vit_h mps amg 4 199.705

@GenevieveBuckley
Copy link
Collaborator

However your theory about memory management problems with MPS may still be correct. It might just happening in a different place.

The run length encoding function segment_anything.utils.amg.mask_to_rle_pytorch().

I found this kaggle notebook on faster RLE implementations.

Next steps:

  • line profile the mask_to_rle_pytorch function - Done! See here for details
  • Put a pdb breakpoint in the mask_to_rle_pytorch() function, and then call torch.mps.current_allocated_memory() as you step through each line to see what happens. (Although I'm not sure how I find what the total MPS memory is, so I can know how close to the limit we are)
  • Consider using a proper memory profiler to run the whole benchmark script with
  • Maybe play with this kaggle notebook for faster run length encodings - but it might not be optimising for constrained memory situations. (Even if it doesn't, it could still lead to a good performance improvement PR to segment-anything, assuming they are open to contributions)

@constantinpape
Copy link
Contributor Author

No, not really 😢

I would say "partially" ;). We see the runtime decrease from 465.819 to 158.788 for batch_size 64 -> 16. That's a huge improvement already!

But you're of course right that we are not close to the performance on CPU yet, and indeed the reason seems to be that the RLE encoding is also significantly slower in MPS. I am not sure if this is due to memory or just due to some operation used there not being as efficient in MPS yet. (And in general thanks for running the line profiler, I only skimmed it, but it's quite interesting to see where the runtime is spent!)

My suggestion would be:

  • skip the memory profiling for now
  • compare the different RLE implementations from the kaggle notebook
  • if one of them gives better performance then implement it in vendored
  • set the default batch size to 16 if using MPS

And re contributions to Segment Anything: I agree that it would be nice to contribute this upstream, but given that they are currently not really responding in the repo opening a PR there directly is probably not worth it rn. I will try to get in contact with them at some point in the next weeks and see if there is some way to get them to review "sensible" community contributions. (I know the SAM first author a tiny bit, and there's always the alternative of twitter shaming Yann LeCun ;) ).

@constantinpape
Copy link
Contributor Author

@GenevieveBuckley: I think we have a consistent speed up with MPS now, after merging #190. Can you confirm?

@Marei33 could you also benchmark this on your MAC?
For this:

  • check out the recent dev branch and make sure it's installed (if you installed via pip install -e . it should already be, but you can also just rerun tat)
  • run the benchmark script four times, with the following settings:
    • python benchmark.py -d cpu -m vit_b
    • python benchmark.py -d cpu -m vit_h
    • python benchmark.py -d mps -m vit_b
    • python benchmark.py -d mps -m vit_h
  • and post the results here :)

(Let me know if anything still unclear.)

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 19, 2023

Here are the benchmarks from the updated dev branch

python benchmark.py -d cpu -m vit_b

model device benchmark runtime
vit_b cpu embeddings 3.97884
vit_b cpu prompt-p1n0 0.040869
vit_b cpu prompt-p2n4 0.0423908
vit_b cpu prompt-box 0.0412707
vit_b cpu prompt-box-and-points 0.0434921
vit_b cpu amg 36.5019

python benchmark.py -d cpu -m vit_h

model device benchmark runtime
vit_h cpu embeddings 14.0066
vit_h cpu prompt-p1n0 0.038532
vit_h cpu prompt-p2n4 0.040257
vit_h cpu prompt-box 0.0395749
vit_h cpu prompt-box-and-points 0.041683
vit_h cpu amg 37.9739

python benchmark.py -d mps -m vit_b

model device benchmark runtime
vit_b mps embeddings 1.61226
vit_b mps prompt-p1n0 0.0352719
vit_b mps prompt-p2n4 0.032263
vit_b mps prompt-box 0.0290198
vit_b mps prompt-box-and-points 0.0322769
vit_b mps amg 33.7368

python benchmark.py -d mps -m vit_h

model device benchmark runtime
vit_h mps embeddings 7.96755
vit_h mps prompt-p1n0 0.0349841
vit_h mps prompt-p2n4 0.0316401
vit_h mps prompt-box 0.0281978
vit_h mps prompt-box-and-points 0.0327029
vit_h mps amg 34.6613

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 19, 2023

It's really striking how much faster the embeddings are with a smaller model. I should really consider using a smaller model wherever possible. The advantage is not just the embedding time, but also the time taken to load the model (kinda obvious, but it's a huge chunk of the time taken in the example scripts).

The micro-sam project should probably also consider if the default model can be changed to a smaller model without a substantial decrease in segmentation quality. That's not entirely straightforward (a) someone would need to assess the difference in segmentation quality, (b) there are no vit_l_em or vit_l_lm finetuned models right now (if it turns out that vit_l is a good default choice of model, users will probably want to try finetuned versions).

Should we make an issue for this discussion? It seems like a good long term goal for the project, and aligned with some of the other current work around adding mobileSAM / fastSAM models too.

@constantinpape
Copy link
Contributor Author

constantinpape commented Sep 19, 2023

It's really striking how much faster the embeddings are with a smaller model. I should really consider using a smaller model wherever possible. The advantage is not just the embedding time, but also the time taken to load the model (kinda obvious, but it's a huge chunk of the time taken in the example scripts).

The micro-sam project should probably also consider if the default model can be changed to a smaller model without a substantial decrease in segmentation quality. That's not entirely straightforward (a) someone would need to assess the difference in segmentation quality, (b) there are no vit_l_em or vit_l_lm finetuned models right now (if it turns out that vit_l is a good default choice of model, users will probably want to try finetuned versions).

Should we make an issue for this discussion? It seems like a good long term goal for the project, and aligned with some of the other current work around adding mobileSAM / fastSAM models too.

Yes on both fronts:

  1. this is a good long-term goal and is aligned with our current work around integrating mobileSAM / fastSAM.
  2. please open a separate issue on this, as it would be important to get external feedback for this! I can then elaborate our current plans more.

And good to see that we now have a consistent advantage for MPS here. I will wait for @Marei33 to run the benchmarks as well and then close the issue (unless we see something unexpected).

@Marei33
Copy link

Marei33 commented Sep 19, 2023

Hi, here are the benchmarks on my Mac:

1. cpu, vit_b
Screenshot 2023-09-19 at 18 57 31

2. cpu, vit_h
Screenshot 2023-09-19 at 19 12 55

3. mps, vit_b
Screenshot 2023-09-19 at 19 30 34

4. mps, vit_h
Screenshot 2023-09-19 at 20 22 38

In both mps cases, I got the error message that int64 is not supported and will be cast into int32.
Additionally, the script on the mps with the vit_h model didn't work out at the first try, because I had some memory issues. When I closed all open windows, it was fine. This was the error message I got:
Screenshot 2023-09-19 at 20 06 53

@constantinpape
Copy link
Contributor Author

Hi @Marei33,
thanks for running the benchmarks!
It looks like the AMG is affected a lot by memory in all the benchmarks. Could you rerun all four once again and make sure that nothing else is running? Thanks!

P.s.: you can copy paste the text of the table here and it gets displayed as a nice table. (But you can also keep posting screenshots if you prefer ;)).

@Marei33
Copy link

Marei33 commented Sep 20, 2023

Yes, sure. Here are the results, when nothing else is running. But they are looking quite similar.

model device benchmark runtime
vit_b cpu embeddings 17.1556
vit_b cpu prompt-p1n0 0.108205
vit_b cpu prompt-p2n4 0.107451
vit_b cpu prompt-box 0.105359
vit_b cpu prompt-box-and-points 0.108766
vit_b cpu amg 143.788
model device benchmark runtime
vit_h cpu embeddings 89.6328
vit_h cpu prompt-p1n0 0.111439
vit_h cpu prompt-p2n4 0.107213
vit_h cpu prompt-box 0.106269
vit_h cpu prompt-box-and-points 0.113369
vit_h cpu amg 145.32
model device benchmark runtime
vit_b mps embeddings 2.53891
vit_b mps prompt-p1n0 0.050036
vit_b mps prompt-p2n4 0.0502841
vit_b mps prompt-box 0.0446169
vit_b mps prompt-box-and-points 0.0510437
vit_b mps amg 261.172
model device benchmark runtime
vit_h mps embeddings 30.8518
vit_h mps prompt-p1n0 0.050679
vit_h mps prompt-p2n4 0.05003
vit_h mps prompt-box 0.038578
vit_h mps prompt-box-and-points 0.0498359
vit_h mps amg 77.7267

For the vit_h I had to run it three times, because there was still the memory error.

@constantinpape
Copy link
Contributor Author

Ok, thanks for checking again!
We still see some memory issues with vit_b as well. It looks like it would still be good to decrease the memory requirement...

@GenevieveBuckley @Marei33 how much RAM do your MACs have?

@constantinpape
Copy link
Contributor Author

I will go ahead and close this for now, I think we have done everything that is currently possible for MPS support. We will expose more control about the devices to the users in any case (I will create a separate issue on this).

@GenevieveBuckley
Copy link
Collaborator

I have a 16GB Macbook Pro (from the first M1 generation)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants