@@ -75,6 +75,12 @@ public class Sam2 extends AbstractSamJ {
7575 + "from skimage import measure" + System .lineSeparator ()
7676 + "measure.label(np.ones((10, 10)), connectivity=1)" + System .lineSeparator ()
7777 + "import torch" + System .lineSeparator ()
78+ + "device = 'cpu'" + System .lineSeparator ()
79+ + ((!IS_APPLE_SILICON ) ? ""
80+ : "from torch.backends import mps" + System .lineSeparator ()
81+ + "if mps.is_built() and mps.is_available():" + System .lineSeparator ()
82+ + " device = 'mps'" + System .lineSeparator ())
83+ + "print(device)" + System .lineSeparator ()
7884 + "from scipy.ndimage import binary_fill_holes" + System .lineSeparator ()
7985 + "from scipy.ndimage import label" + System .lineSeparator ()
8086 + "import sys" + System .lineSeparator ()
@@ -83,7 +89,7 @@ public class Sam2 extends AbstractSamJ {
8389 + "from sam2.build_sam import build_sam2" + System .lineSeparator ()
8490 + "from sam2.sam2_image_predictor import SAM2ImagePredictor" + System .lineSeparator ()
8591 + "from sam2.utils.misc import variant_to_config_mapping" + System .lineSeparator ()
86- + "model = build_sam2(variant_to_config_mapping['%s'],r'%s')" + System .lineSeparator ()
92+ + "model = build_sam2(variant_to_config_mapping['%s'],r'%s').to(device) " + System .lineSeparator ()
8793 + "predictor = SAM2ImagePredictor(model)" + System .lineSeparator ()
8894 + "task.update('created predictor')" + System .lineSeparator ()
8995 + "encodings_map = {}" + System .lineSeparator ()
@@ -94,7 +100,8 @@ public class Sam2 extends AbstractSamJ {
94100 + "globals()['torch'] = torch" + System .lineSeparator ()
95101 + "globals()['label'] = label" + System .lineSeparator ()
96102 + "globals()['binary_fill_holes'] = binary_fill_holes" + System .lineSeparator ()
97- + "globals()['predictor'] = predictor" + System .lineSeparator ();
103+ + "globals()['predictor'] = predictor" + System .lineSeparator ()
104+ + "globals()['device'] = device" + System .lineSeparator ();
98105 /**
99106 * String containing the Python imports code after it has been formated with the correct
100107 * paths and names
0 commit comments