diff --git a/test/test_instance_segmentation.py b/test/test_instance_segmentation.py index 8ee81614..c942589e 100644 --- a/test/test_instance_segmentation.py +++ b/test/test_instance_segmentation.py @@ -119,6 +119,15 @@ def test_embedding_mask_generator(self): def test_tiled_embedding_mask_generator(self): from micro_sam.instance_segmentation import TiledEmbeddingMaskGenerator + # Release all unoccupied cached memory, tiling requires a lot of memory + device = util._get_device(None) + if device == "cuda": + import torch.cuda + torch.cuda.empty_cache() + elif device == "mps": + import torch.mps + torch.mps.empty_cache() + mask, image = self.large_mask, self.large_image predictor, image_embeddings = self.predictor, self.tiled_embeddings pred_iou_thresh, stability_score_thresh = 0.90, 0.60 @@ -144,6 +153,15 @@ def test_tiled_embedding_mask_generator(self): def test_tiled_automatic_mask_generator(self): from micro_sam.instance_segmentation import TiledAutomaticMaskGenerator, mask_data_to_segmentation + # Release all unoccupied cached memory, tiling requires a lot of memory + device = util._get_device(None) + if device == "cuda": + import torch.cuda + torch.cuda.empty_cache() + elif device == "mps": + import torch.mps + torch.mps.empty_cache() + mask, image = self.large_mask, self.large_image predictor, image_embeddings = self.predictor, self.tiled_embeddings