Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase test coverage #280

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions micro_sam/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .evaluation import run_evaluation
from .inference import (
get_predictor,
run_inference_with_iterative_prompting,
run_inference_with_prompts,
precompute_all_embeddings,
precompute_all_prompts,
Expand Down
10 changes: 5 additions & 5 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
}
_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam')
_CHECKPOINT_FOLDER = os.path.join(_CACHE_DIR, 'models')

_CHECKSUMS = {
# the default segment anything models
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
Expand Down Expand Up @@ -87,7 +86,7 @@ def get_cache_directory() -> None:

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
"""
default_cache_directory = os.path.expanduser(pooch.os_cache('micro-sam'))
default_cache_directory = os.path.expanduser(pooch.os_cache('micro_sam'))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the cachedir was different from the previous one in _CHECKPOINT_FOLDER (see above).
I think micro_sam is better than micro-sam (because it's the name of the actual python package) but if there is a reason to switch to micro-sam instead I would also be fine with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind either, so that sounds fine. I don't have any reason for it. I think it's also a good idea to double check there's no remaining micro-sam left anywhere else in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly checked and this was the only location in the code.
(It's still a bit inconsistent in overall naming, but we can address that later on.)

cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory))
return cache_directory

Expand Down Expand Up @@ -127,11 +126,12 @@ def _get_checkpoint(model_type, checkpoint_path=None):
if checkpoint_path is None:
checkpoint_url = _MODEL_URLS[model_type]
checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1])
checkpoint_path = os.path.join(_CHECKPOINT_FOLDER, checkpoint_name)
checkpoint_folder = os.path.join(get_cache_directory(), "models")
checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name)

# download the checkpoint if necessary
if not os.path.exists(checkpoint_path):
os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True)
os.makedirs(checkpoint_folder, exist_ok=True)
_download(checkpoint_url, checkpoint_path, model_type)
elif not os.path.exists(checkpoint_path):
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")
Expand Down
30 changes: 23 additions & 7 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_em

from micro_sam.sample_data import synthetic_data
from micro_sam.util import VIT_T_SUPPORT
from micro_sam.util import VIT_T_SUPPORT, get_custom_sam_model, SamPredictor


@unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.")
Expand Down Expand Up @@ -133,6 +133,9 @@ def _run_inference_and_check_results(
inference_function(predictor, image_paths, label_paths, embedding_dir, prediction_dir)

pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
if len(pred_paths) == 0: # we need to go to subfolder for iterative inference
pred_paths = sorted(glob(os.path.join(prediction_dir, "iteration02", "*.tif")))

self.assertEqual(len(pred_paths), len(label_paths))
eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False)
result = eval_res["sa50"].values.item()
Expand All @@ -150,39 +153,52 @@ def test_training(self):
checkpoint_path = os.path.join(self.tmp_folder, "checkpoints", "test", "best.pt")
self.assertTrue(os.path.exists(checkpoint_path))

# Check that the model can be loaded from a custom checkpoint.
predictor = get_custom_sam_model(checkpoint_path, model_type=model_type, device=device)
self.assertTrue(isinstance(predictor, SamPredictor))

# Export the model.
export_path = os.path.join(self.tmp_folder, "exported_model.pth")
self._export_model(checkpoint_path, export_path, model_type)
self.assertTrue(os.path.exists(export_path))

# Check the model with inference with a single point prompt.
prediction_dir = os.path.join(self.tmp_folder, "predictions-points")
normal_inference = partial(
point_inference = partial(
evaluation.run_inference_with_prompts,
use_points=True, use_boxes=False,
n_positives=1, n_negatives=0,
batch_size=64,
)
self._run_inference_and_check_results(
export_path, model_type, prediction_dir=prediction_dir,
inference_function=normal_inference, expected_sa=0.9
inference_function=point_inference, expected_sa=0.9
)

# Check the model with inference with a box point prompt.
prediction_dir = os.path.join(self.tmp_folder, "predictions-boxes")
normal_inference = partial(
box_inference = partial(
evaluation.run_inference_with_prompts,
use_points=False, use_boxes=True,
n_positives=1, n_negatives=0,
batch_size=64,
)
self._run_inference_and_check_results(
export_path, model_type, prediction_dir=prediction_dir,
inference_function=normal_inference, expected_sa=0.95,
inference_function=box_inference, expected_sa=0.95,
)

# Check the model with interactive inference
# TODO
# Check the model with interactive inference.
prediction_dir = os.path.join(self.tmp_folder, "predictions-iterative")
iterative_inference = partial(
evaluation.run_inference_with_iterative_prompting,
start_with_box_prompt=False,
n_iterations=3,
)
self._run_inference_and_check_results(
export_path, model_type, prediction_dir=prediction_dir,
inference_function=iterative_inference, expected_sa=0.95,
)


if __name__ == "__main__":
Expand Down
20 changes: 20 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from skimage.data import binary_blobs
from skimage.measure import label
from micro_sam.util import VIT_T_SUPPORT, SamPredictor, get_cache_directory


class TestUtil(unittest.TestCase):
model_type = "vit_t" if VIT_T_SUPPORT else "vit_b"
tmp_folder = "tmp-files"

def setUp(self):
Expand All @@ -19,6 +21,24 @@ def setUp(self):
def tearDown(self):
rmtree(self.tmp_folder)

def test_get_sam_model(self):
from micro_sam.util import get_sam_model

def check_predictor(predictor):
self.assertTrue(isinstance(predictor, SamPredictor))
self.assertEqual(predictor.model_type, self.model_type)

# check predictor with download
predictor = get_sam_model(model_type=self.model_type)
check_predictor(predictor)

# check predictor with checkpoint path (using the cached model)
checkpoint_path = os.path.join(
get_cache_directory(), "models", "vit_t_mobile_sam.pth" if VIT_T_SUPPORT else "sam_vit_b_01ec64.pth"
)
predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path)
check_predictor(predictor)

def test_compute_iou(self):
from micro_sam.util import compute_iou

Expand Down