-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for MPS device #176
Conversation
I hit an error involving the float type of the MPS tensor. I don't think we need float32 dtype anywhere. Perhaps it will be easiest to explicitly specify all tensors as float32? python GitHub/micro-sam/examples/annotator_2d.py
Example data directory is: /Users/genevieb/Documents/data
Using apple MPS device.
Precomputing the state for instance segmentation.
Predict masks for point grid prompts: 0%| | 0/16 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Users/genevieb/Documents/GitHub/micro-sam/examples/annotator_2d.py", line 75, in <module>
main()
File "/Users/genevieb/Documents/GitHub/micro-sam/examples/annotator_2d.py", line 68, in main
hela_2d_annotator(use_finetuned_model)
File "/Users/genevieb/Documents/GitHub/micro-sam/examples/annotator_2d.py", line 37, in hela_2d_annotator
annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True)
File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/sam_annotator/annotator_2d.py", line 237, in annotator_2d
AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)
File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/precompute_state.py", line 55, in cache_amg_state
amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
File "/Users/genevieb/mambaforge/envs/napari/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py", line 451, in initialize
crop_data = self._process_crop(
File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py", line 401, in _process_crop
batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
File "/Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py", line 367, in _process_batch
in_points = torch.as_tensor(transformed_points, device=self._predictor.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead. |
Update: I've made a PR to this branch, addressing some of this problem #178 |
…emented for Apple Silicon
Hi @GenevieveBuckley , |
Supporting Apple Silicon devices
Ok, a quick update:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@GenevieveBuckley I looked a bit more into the dtype issues and:
- This should only be a problem in the automatic mask generation, because for other computations the dtypes are always cast to float32
- I have narrowed down the change you made that causes the test to fail. I will think about what to do when you paste the error you get without it. See comment in the code for details.
micro_sam/instance_segmentation.py
Outdated
@@ -269,6 +269,7 @@ def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None): | |||
|
|||
# threshold masks and calculate boxes | |||
data["masks"] = data["masks"] > self._predictor.model.mask_threshold | |||
data["masks"] = data["masks"].type(torch.int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have traced down the failing test to this line. It seems like the bounding boxes computed by batched_mask_to_box
are different when a bool
tensor is passed than when an int
tensor is passed. This is unfortunate, but would be a bit of effort to fix since this is part of upstream code.
@GenevieveBuckley I assume that you have made this change because something fails with mps
if a bool tensor is used. Can you please send me the exact error message you get without this line? Then I will think about the best strategy here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Wow, I did not know boolean masks returned correct results, but integer masks return something totally different. Yikes! I have changed the type of this line to
torch.bool
, thank you for pointing it out. -
I have also gone ahead with my first suggestion here about how to handle the problem with
batched_mask_to_box()
. I have made a new file_vendored.py
containing a copy ofbatched_mask_to_box()
fromsegment_anything/util/amg.py
, which I have edited to (a) make sure the input mask is boolean, and (b) make compatible with the MPS Pytorch backend for apple silicon.
I think this should fix the problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Wow, I did not know boolean masks returned correct results, but integer masks return something totally different. Yikes! I have changed the type of this line to
torch.bool
, thank you for pointing it out.
Yes, it is really unfortunate... I think this is a pretty big bug on the torch / SegmentAnything side. It could be worth it to figure out what exactly causes it and report that at some point.
2. I have also gone ahead with my first suggestion here about how to handle the problem with
batched_mask_to_box()
. I have made a new file_vendored.py
containing a copy ofbatched_mask_to_box()
fromsegment_anything/util/amg.py
, which I have edited to (a) make sure the input mask is boolean, and (b) make compatible with the MPS Pytorch backend for apple silicon.I think this should fix the problems.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see you made an issue already :D
facebookresearch/segment-anything#552
… so we can fix a bug affecting our code
…t output is incorrect
…t masks produce incorrect results)
@constantinpape I have made another PR to your If you can review that, I think it will solve a lot of the discussion here 😄 |
I have bad news though - somehow the mps brranch runs three and a half time SLOWER than the dev branch (with cpu pytorch backend), when predicting masks for point grid prompts: https://gist.github.com/GenevieveBuckley/874b6b282b388524b6d62f25b3b9bb1c I don't quite understand how it is possible for me to get that result. I am worried about this - we definitely don't want to automatically select the mps backend for the user if it is going to be much slower. I hope I've done something wrong, but I don't think so. It might be possible that the fallback to cpu we have because torchvision.ops.nms does not support MPS backends could be the reason? Could it be time-consuming to move large tensors from the MPS backend -> cpu -> MPS again? I think we need to do some line profiling to figure out where the slow parts are. Perhaps you and @Marei33 can try some benchmarks yourself to compare? |
This is probably not correct. I've just tried removing the |
Add MPS Pytorch support for batched_mask_to_box function
Thanks for all the work on this @GenevieveBuckley! Regarding the next steps:
|
Should provide a significant speed up on MAC notebooks with M1 or M2.
@Marei33: could you check out this PR at some point and see if
micro_sam
still works, and if it runs faster?