Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MPS device #176

Merged
merged 11 commits into from
Sep 7, 2023
Merged

Add support for MPS device #176

merged 11 commits into from
Sep 7, 2023

Conversation

constantinpape
Copy link
Contributor

Should provide a significant speed up on MAC notebooks with M1 or M2.

@Marei33: could you check out this PR at some point and see if micro_sam still works, and if it runs faster?

@GenevieveBuckley
Copy link
Collaborator

I hit an error involving the float type of the MPS tensor.

I don't think we need float32 dtype anywhere. Perhaps it will be easiest to explicitly specify all tensors as float32?

python GitHub/micro-sam/examples/annotator_2d.py
Example data directory is: /Users/genevieb/Documents/data
Using apple MPS device.
Precomputing the state for instance segmentation.
Predict masks for point grid prompts:   0%|              | 0/16 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/Users/genevieb/Documents/GitHub/micro-sam/examples/annotator_2d.py", line 75, in <module>
    main()
  File "/Users/genevieb/Documents/GitHub/micro-sam/examples/annotator_2d.py", line 68, in main
    hela_2d_annotator(use_finetuned_model)
  File "/Users/genevieb/Documents/GitHub/micro-sam/examples/annotator_2d.py", line 37, in hela_2d_annotator
    annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True)
  File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/sam_annotator/annotator_2d.py", line 237, in annotator_2d
    AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)
  File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/precompute_state.py", line 55, in cache_amg_state
    amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
  File "/Users/genevieb/mambaforge/envs/napari/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py", line 451, in initialize
    crop_data = self._process_crop(
  File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py", line 401, in _process_crop
    batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
  File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py", line 367, in _process_batch
    in_points = torch.as_tensor(transformed_points, device=self._predictor.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

@GenevieveBuckley
Copy link
Collaborator

Update: I've made a PR to this branch, addressing some of this problem #178

@constantinpape
Copy link
Contributor Author

Hi @GenevieveBuckley ,
thanks for looking into this! Just a quick comment: indeed there is no reason to use float64 anywhere, and it is harmful for computational performance without giving any quality benefits. It's just an oversight that this we don't cast to float32 explicitly already. If you want you can add this to #178, otherwise I can look into this later as well.

Supporting Apple Silicon devices
@constantinpape
Copy link
Contributor Author

Ok, a quick update:

  • I have merged Supporting Apple Silicon devices #178. This leads to reproducible test errors in the instance segmentation. This points to the fact that these tests are brittle and I will investigate this later.
  • I think the issue with tensors being represented as torch.float64 if the input data is given as float64 is still present and should be fixed. @GenevieveBuckley I will also look into this later and let you later what I find.

Copy link
Contributor Author

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GenevieveBuckley I looked a bit more into the dtype issues and:

  • This should only be a problem in the automatic mask generation, because for other computations the dtypes are always cast to float32
  • I have narrowed down the change you made that causes the test to fail. I will think about what to do when you paste the error you get without it. See comment in the code for details.

@@ -269,6 +269,7 @@ def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):

# threshold masks and calculate boxes
data["masks"] = data["masks"] > self._predictor.model.mask_threshold
data["masks"] = data["masks"].type(torch.int)
Copy link
Contributor Author

@constantinpape constantinpape Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have traced down the failing test to this line. It seems like the bounding boxes computed by batched_mask_to_box are different when a bool tensor is passed than when an int tensor is passed. This is unfortunate, but would be a bit of effort to fix since this is part of upstream code.
@GenevieveBuckley I assume that you have made this change because something fails with mps if a bool tensor is used. Can you please send me the exact error message you get without this line? Then I will think about the best strategy here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Wow, I did not know boolean masks returned correct results, but integer masks return something totally different. Yikes! I have changed the type of this line to torch.bool, thank you for pointing it out.

  2. I have also gone ahead with my first suggestion here about how to handle the problem with batched_mask_to_box(). I have made a new file _vendored.py containing a copy of batched_mask_to_box() from segment_anything/util/amg.py, which I have edited to (a) make sure the input mask is boolean, and (b) make compatible with the MPS Pytorch backend for apple silicon.

I think this should fix the problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Wow, I did not know boolean masks returned correct results, but integer masks return something totally different. Yikes! I have changed the type of this line to torch.bool, thank you for pointing it out.

Yes, it is really unfortunate... I think this is a pretty big bug on the torch / SegmentAnything side. It could be worth it to figure out what exactly causes it and report that at some point.

2. I have also gone ahead with my first suggestion here about how to handle the problem with batched_mask_to_box(). I have made a new file _vendored.py containing a copy of batched_mask_to_box() from segment_anything/util/amg.py, which I have edited to (a) make sure the input mask is boolean, and (b) make compatible with the MPS Pytorch backend for apple silicon.

I think this should fix the problems.

👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you made an issue already :D
facebookresearch/segment-anything#552

@GenevieveBuckley
Copy link
Collaborator

@constantinpape I have made another PR to your mps branch #180

If you can review that, I think it will solve a lot of the discussion here 😄

@GenevieveBuckley
Copy link
Collaborator

GenevieveBuckley commented Sep 7, 2023

I have bad news though - somehow the mps brranch runs three and a half time SLOWER than the dev branch (with cpu pytorch backend), when predicting masks for point grid prompts: https://gist.github.com/GenevieveBuckley/874b6b282b388524b6d62f25b3b9bb1c

I don't quite understand how it is possible for me to get that result. I am worried about this - we definitely don't want to automatically select the mps backend for the user if it is going to be much slower. I hope I've done something wrong, but I don't think so.

It might be possible that the fallback to cpu we have because torchvision.ops.nms does not support MPS backends could be the reason? Could it be time-consuming to move large tensors from the MPS backend -> cpu -> MPS again?

I think we need to do some line profiling to figure out where the slow parts are.

Perhaps you and @Marei33 can try some benchmarks yourself to compare?

@GenevieveBuckley
Copy link
Collaborator

It might be possible that the fallback to cpu we have because torchvision.ops.nms does not support MPS backends could be the reason? Could it be time-consuming to move large tensors from the MPS backend -> cpu -> MPS again?

This is probably not correct. I've just tried removing the PYTORCH_ENABLE_MPS_FALLBACK environment variable, and everything still runs? Maybe the fallback is never actually triggered.
... It's late, I'm going to stop for now
...

Add MPS Pytorch support for batched_mask_to_box function
@constantinpape
Copy link
Contributor Author

Thanks for all the work on this @GenevieveBuckley! Regarding the next steps:

  • If the tests pass I will go ahead and merge this now.
  • I will create some benchmark scripts that we can use to compare runtimes of MPS vs CPU in different settings that you and @Marei33 can then run to see what's going on with performance. (@Marei33 I will ping you once this is ready and try to explain clearly ;).)
  • Based on this we will decide what we do as default if MPS is available. (For now we merge onto dev anyways, so this will not affect any users).
  • I will also create a follow up issue for all MPS related stuff.

@constantinpape constantinpape merged commit 80d9e42 into dev Sep 7, 2023
2 checks passed
@constantinpape constantinpape deleted the mps branch September 7, 2023 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants