From acc71d93fa032c505e6427dd3629656236be64fc Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 18 Nov 2023 22:32:57 +0100 Subject: [PATCH 1/2] Add tests for get model and iterative prediction --- micro_sam/evaluation/__init__.py | 1 + micro_sam/util.py | 10 +++++----- test/test_training.py | 30 +++++++++++++++++++++++------- test/test_util.py | 20 ++++++++++++++++++++ 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/micro_sam/evaluation/__init__.py b/micro_sam/evaluation/__init__.py index a63d8597..cf68ee35 100644 --- a/micro_sam/evaluation/__init__.py +++ b/micro_sam/evaluation/__init__.py @@ -9,6 +9,7 @@ from .evaluation import run_evaluation from .inference import ( get_predictor, + run_inference_with_iterative_prompting, run_inference_with_prompts, precompute_all_embeddings, precompute_all_prompts, diff --git a/micro_sam/util.py b/micro_sam/util.py index c769eecc..50421b0a 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -49,8 +49,7 @@ "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", } -_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') -_CHECKPOINT_FOLDER = os.path.join(_CACHE_DIR, 'models') + _CHECKSUMS = { # the default segment anything models "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", @@ -87,7 +86,7 @@ def get_cache_directory() -> None: Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. """ - default_cache_directory = os.path.expanduser(pooch.os_cache('micro-sam')) + default_cache_directory = os.path.expanduser(pooch.os_cache('micro_sam')) cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory)) return cache_directory @@ -127,11 +126,12 @@ def _get_checkpoint(model_type, checkpoint_path=None): if checkpoint_path is None: checkpoint_url = _MODEL_URLS[model_type] checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1]) - checkpoint_path = os.path.join(_CHECKPOINT_FOLDER, checkpoint_name) + checkpoint_folder = os.path.join(get_cache_directory(), "models") + checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name) # download the checkpoint if necessary if not os.path.exists(checkpoint_path): - os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True) + os.makedirs(checkpoint_folder, exist_ok=True) _download(checkpoint_url, checkpoint_path, model_type) elif not os.path.exists(checkpoint_path): raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.") diff --git a/test/test_training.py b/test/test_training.py index 60a59903..df5124f6 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -9,7 +9,7 @@ import torch_em from micro_sam.sample_data import synthetic_data -from micro_sam.util import VIT_T_SUPPORT +from micro_sam.util import VIT_T_SUPPORT, get_custom_sam_model, SamPredictor @unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.") @@ -133,6 +133,9 @@ def _run_inference_and_check_results( inference_function(predictor, image_paths, label_paths, embedding_dir, prediction_dir) pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif"))) + if len(pred_paths) == 0: # we need to go to subfolder for iterative inference + pred_paths = sorted(glob(os.path.join(prediction_dir, "iteration02", "*.tif"))) + self.assertEqual(len(pred_paths), len(label_paths)) eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False) result = eval_res["sa50"].values.item() @@ -150,6 +153,10 @@ def test_training(self): checkpoint_path = os.path.join(self.tmp_folder, "checkpoints", "test", "best.pt") self.assertTrue(os.path.exists(checkpoint_path)) + # Check that the model can be loaded from a custom checkpoint. + predictor = get_custom_sam_model(checkpoint_path, model_type=model_type, device=device) + self.assertTrue(isinstance(predictor, SamPredictor)) + # Export the model. export_path = os.path.join(self.tmp_folder, "exported_model.pth") self._export_model(checkpoint_path, export_path, model_type) @@ -157,7 +164,7 @@ def test_training(self): # Check the model with inference with a single point prompt. prediction_dir = os.path.join(self.tmp_folder, "predictions-points") - normal_inference = partial( + point_inference = partial( evaluation.run_inference_with_prompts, use_points=True, use_boxes=False, n_positives=1, n_negatives=0, @@ -165,12 +172,12 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=normal_inference, expected_sa=0.9 + inference_function=point_inference, expected_sa=0.9 ) # Check the model with inference with a box point prompt. prediction_dir = os.path.join(self.tmp_folder, "predictions-boxes") - normal_inference = partial( + box_inference = partial( evaluation.run_inference_with_prompts, use_points=False, use_boxes=True, n_positives=1, n_negatives=0, @@ -178,11 +185,20 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=normal_inference, expected_sa=0.95, + inference_function=box_inference, expected_sa=0.95, ) - # Check the model with interactive inference - # TODO + # Check the model with interactive inference. + prediction_dir = os.path.join(self.tmp_folder, "predictions-iterative") + iterative_inference = partial( + evaluation.run_inference_with_iterative_prompting, + start_with_box_prompt=False, + n_iterations=3, + ) + self._run_inference_and_check_results( + export_path, model_type, prediction_dir=prediction_dir, + inference_function=iterative_inference, expected_sa=0.95, + ) if __name__ == "__main__": diff --git a/test/test_util.py b/test/test_util.py index 192d14fb..505d0208 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -8,9 +8,11 @@ from skimage.data import binary_blobs from skimage.measure import label +from micro_sam.util import VIT_T_SUPPORT, SamPredictor, get_cache_directory class TestUtil(unittest.TestCase): + model_type = "vit_t" if VIT_T_SUPPORT else "vit_b" tmp_folder = "tmp-files" def setUp(self): @@ -19,6 +21,24 @@ def setUp(self): def tearDown(self): rmtree(self.tmp_folder) + def test_get_sam_model(self): + from micro_sam.util import get_sam_model + + def check_predictor(predictor): + self.assertTrue(isinstance(predictor, SamPredictor)) + self.assertEqual(predictor.model_type, self.model_type) + + # check predictor with download + predictor = get_sam_model(model_type=self.model_type) + check_predictor(predictor) + + # check predictor with checkpoint path (using the cached model) + checkpoint_path = os.path.join( + get_cache_directory(), "models", "vit_t_mobile_sam.pth" if VIT_T_SUPPORT else "sam_vit_b_01ec64.pth" + ) + predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path) + check_predictor(predictor) + def test_compute_iou(self): from micro_sam.util import compute_iou From c38a3f3b2c7ea2f502cbd44e96d5d637eebb11b1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 20 Nov 2023 09:23:22 +0100 Subject: [PATCH 2/2] Update codecov command --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1ce35bd9..38c4ba01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=42.0.0", "wheel"] build-backend = "setuptools.build_meta" [tool.pytest.ini_options] -addopts = "-v --durations=10 --cov=micro_sam --cov-report xml:coverage.xml" +addopts = "-v --durations=10 --cov=micro_sam --cov-report xml:coverage.xml --cov-report term-missing" markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "gui: marks GUI tests (deselect with '-m \"not gui\"')",