Skip to content

Commit

Permalink
Upstream dev with master (#799)
Browse files Browse the repository at this point in the history
Important changes for the automatic segmentation CLI and misc evaluation script changes
---------

Co-authored-by: Anwai Archit <[email protected]>
Co-authored-by: Constantin Pape <[email protected]>
Co-authored-by: Parsa744 <[email protected]>
  • Loading branch information
4 people authored Nov 25, 2024
1 parent 8560a65 commit 5feed91
Show file tree
Hide file tree
Showing 24 changed files with 1,002 additions and 587 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ Compared to these we support more applications (2d, 3d and tracking), and provid

## Release Overview

**New in version 1.1.1**

Fixing minor issues with 1.1.0 and enabling pytorch 2.5 support.

**New in version 1.1.0**

This version introduces several improvements:

- Bugfixes and several minor improvements
- Compatibility with napari >=0.5
- Automatic instance segmentation CLI
- Initial support for parameter efficient fine-tuning and automatic semantic segmentation in 2d and 3d (not available in napari plugin, part of the python library)

**New in version 1.0.1**

Use stable URL for model downloads and fix issues in state precomputation for automatic segmentation.
Expand Down
24 changes: 24 additions & 0 deletions doc/cli_tools.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Using the Command Line Interface (CLI)

`micro-sam` extends access to a bunch of functionalities using the command line interface (CLI) scripts via terminal.

The supported CLIs can be used by
- Running `$ micro_sam.precompute_embeddings` for precomputing and caching the image embeddings.
- Running `$ micro_sam.annotator_2d` for starting the 2d annotator.
- Running `$ micro_sam.annotator_3d` for starting the 3d annotator.
- Running `$ micro_sam.annotator_tracking` for starting the tracking annotator.
- Running `$ micro_sam.image_series_annotator` for starting the image series annotator.
- Running `$ micro_sam.automatic_segmentation` for automatic instance segmentation.
- We support all post-processing parameters for automatic instance segmentation (for both AMG and AIS).
- The automatic segmentation mode can be controlled by: `--mode <MODE_NAME>`, where the available choice for `MODE_NAME` is `amg` / `ais`.
- AMG is supported by both default Segment Anything models and `micro-sam` models / finetuned models.
- AIS is supported by `micro-sam` models (or finetuned models; subjected to they are trained with the additional instance segmentation decoder)
- If these parameters are not provided by the user, `micro-sam` makes use of the best post-processing parameters (depending on the choice of model).
- The post-processing parameters can be changed by parsing the parameters via the CLI using `--<PARAMETER_NAME> <VALUE>.` For example, one can update the parameter values (eg. `pred_iou_thresh`, `stability_iou_thresh`, etc. - supported by AMG) using
```bash
$ micro_sam.automatic_segmentation ... --pred_iou_thresh 0.6 --stability_iou_thresh 0.6 ...
```
- Remember to specify the automatic segmentation mode using `--mode <MODE_NAME>` when using additional post-processing parameters.
- You can check details for supported parameters and their respective default values at `micro_sam/instance_segmentation.py` under the `generate` method for `AutomaticMaskGenerator` and `InstanceSegmentationWithDecoder` class.

NOTE: For all CLIs above, you can find more details by adding the argument `-h` to the CLI script (eg. `$ micro_sam.annotator_2d -h`).
21 changes: 21 additions & 0 deletions doc/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ The PyPI page for `micro-sam` exists only so that the [napari-hub](https://www.n
### 7. I get the following error: `importError: cannot import name 'UNETR' from 'torch_em.model'`.
It's possible that you have an older version of `torch-em` installed. Similar errors could often be raised from other libraries, the reasons being: a) Outdated packages installed, or b) Some non-existent module being called. If the source of such error is from `micro_sam`, then `a)` is most likely the reason . We recommend installing the latest version following the [installation instructions](https://github.com/constantinpape/torch-em?tab=readme-ov-file#installation).

### 8. My system does not support internet connection. Where should I put the model checkpoints for the `micro-sam` models?
We recommend transferring the model checkpoints to the system-level cache directory (you can find yours by running the following in terminal: `python -c "from micro_sam import util; print(util.microsam_cachedir())`). Once you have identified the cache directory, you need to create an additional `models` directory inside the `micro-sam` cache directory (if not present already) and move the model checkpoints there. At last, you **must** rename the transferred checkpoints as per the respective [key values](https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/util.py#L87) in the url dictionaries located in the `micro_sam.util.models` function (below mentioned is an example for Linux users).

```bash
# Download and transfer the model checkpoints for 'vit_b_lm' and `vit_b_lm_decoder`.
# Next, verify the cache directory.
> python -c "from micro_sam import util; print(util.microsam_cachedir())"
/home/anwai/.cache/micro_sam

# Create 'models' folder in the cache directory
> mkdir /home/anwai/.cache/micro_sam/models

# Move the checkpoints to the models directory and rename them
# The following steps transfer and rename the checkpoints to the desired filenames.
> mv vit_b.pt /home/anwai/.cache/micro_sam/models/vit_b_lm
> mv vit_b_decoder.pt /home/anwai/.cache/micro_sam/models/vit_b_lm_decoder
```

## Usage questions

Expand Down Expand Up @@ -141,6 +158,10 @@ The `micro-sam` CLIs for precomputation of image embeddings and annotators (Anno
NOTE: It is important to choose the correct model type when you opt for the above recommendation, using the `-m / --model_type` argument or selecting it from the `Model` dropdown in `Embedding Settings` respectively. Otherwise you will face parameter mismatch issues.


### 16. Some parameters in the annotator / finetuning widget are unclear to me.
`micro-sam` has tooltips for menu options across all widgets (i.e. an information window will appear if you hover over name of the menu), which briefly describe the utility of the specific menu option.


## Fine-tuning questions


Expand Down
11 changes: 6 additions & 5 deletions examples/annotator_2d.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os

import imageio.v3 as imageio
from micro_sam.util import get_cache_directory
from micro_sam.sam_annotator import annotator_2d
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data
from micro_sam.util import get_cache_directory


DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
os.makedirs(EMBEDDING_CACHE, exist_ok=True)


def livecell_annotator(use_finetuned_model):
"""Run the 2d annotator for an example image from the LiveCELL dataset.
"""Run the 2d annotator for an example image from the LIVEcell dataset.
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
"""
Expand All @@ -29,7 +30,7 @@ def livecell_annotator(use_finetuned_model):


def hela_2d_annotator(use_finetuned_model):
"""Run the 2d annotator for an example image form the cell tracking challenge HeLa 2d dataset.
"""Run the 2d annotator for an example image from the Cell Tracking Challenge (HeLa 2d) dataset.
"""
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)
Expand All @@ -46,7 +47,7 @@ def hela_2d_annotator(use_finetuned_model):

def wholeslide_annotator(use_finetuned_model):
"""Run the 2d annotator with tiling for an example whole-slide image from the
NeuRIPS cell segmentation challenge.
NeurIPS Cell Segmentation challenge.
See https://neurips22-cellseg.grand-challenge.org/ for details on the data.
"""
Expand Down Expand Up @@ -79,6 +80,6 @@ def main():

# The corresponding CLI call for hela_2d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_2d -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr
# $ micro_sam.annotator_2d -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr # noqa
if __name__ == "__main__":
main()
150 changes: 150 additions & 0 deletions examples/automatic_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os

import imageio.v3 as imageio

from micro_sam.util import get_cache_directory
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data


DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")


def livecell_automatic_segmentation(model_type, use_amg, generate_kwargs):
"""Run the automatic segmentation for an example image from the LIVECell dataset.
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
"""
example_data = fetch_livecell_example_data(DATA_CACHE)
image = imageio.imread(example_data)

predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type,
checkpoint=None, # Replace this with your custom checkpoint.
amg=use_amg,
is_tiled=False, # Switch to 'True' in case you would like to perform tiling-window based prediction.
)

segmentation = automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=image,
ndim=2,
tile_shape=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a tile shape
halo=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a halo shape.
**generate_kwargs
)

import napari
v = napari.Viewer()
v.add_image(image)
v.add_labels(segmentation)
napari.run()


def hela_automatic_segmentation(model_type, use_amg, generate_kwargs):
"""Run the automatic segmentation for an example image from the Cell Tracking Challenge (HeLa 2d) dataset.
"""
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)

predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type,
checkpoint=None, # Replace this with your custom checkpoint.
amg=use_amg,
is_tiled=False, # Switch to 'True' in case you would like to perform tiling-window based prediction.
)

segmentation = automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=image,
ndim=2,
tile_shape=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a tile shape
halo=None, # If you set 'is_tiled' in 'get_predictor_and_segmeter' to True, set a halo shape.
**generate_kwargs
)

import napari
v = napari.Viewer()
v.add_image(image)
v.add_labels(segmentation)
napari.run()


def wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs):
"""Run the automatic segmentation with tiling for an example whole-slide image from the
NeurIPS Cell Segmentation challenge.
"""
example_data = fetch_wholeslide_example_data(DATA_CACHE)
image = imageio.imread(example_data)

predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type,
checkpoint=None, # Replace this with your custom checkpoint.
amg=use_amg,
is_tiled=True,
)

segmentation = automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=image,
ndim=2,
tile_shape=(1024, 1024),
halo=(256, 256),
**generate_kwargs
)

import napari
v = napari.Viewer()
v.add_image(image)
v.add_labels(segmentation)
napari.run()


def main():
# The choice of Segment Anything model.
model_type = "vit_b_lm"

# Whether to use:
# the automatic mask generation (AMG): supported by all our models.
# the automatic instance segmentation (AIS): supported by 'micro-sam' models.
use_amg = False # 'False' chooses AIS as the automatic segmentation mode.

# Post-processing parameters for automatic segmentation.
if use_amg: # AMG parameters
generate_kwargs = {
"pred_iou_thresh": 0.88,
"stability_score_thresh": 0.95,
"box_nms_thresh": 0.7,
"crop_nms_thresh": 0.7,
"min_mask_region_area": 0,
"output_mode": "binary_mask",
}
else: # AIS parameters
generate_kwargs = {
"center_distance_threshold": 0.5,
"boundary_distance_threshold": 0.5,
"foreground_threshold": 0.5,
"foreground_smoothing": 1.0,
"distance_smoothing": 1.6,
"min_size": 0,
"output_mode": "binary_mask",
}

# Automatic segmentation for livecell data.
livecell_automatic_segmentation(model_type, use_amg, generate_kwargs)

# Automatic segmentation for cell tracking challenge hela data.
# hela_automatic_segmentation(model_type, use_amg, generate_kwargs)

# Automatic segmentation for a whole slide image.
# wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs)


# The corresponding CLI call for hela_automatic_segmentation:
# (replace with cache directory on your machine)
# $ micro_sam.automatic_segmentation -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -o hela-2d-image_segmentation.tif # noqa
if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions micro_sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.. include:: ../doc/start_page.md
.. include:: ../doc/installation.md
.. include:: ../doc/annotation_tools.md
.. include:: ../doc/cli_tools.md
.. include:: ../doc/python_library.md
.. include:: ../doc/finetuned_models.md
.. include:: ../doc/faq.md
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def njit(func):


def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask.
"""Calculates boxes in XYXY format around masks. Return [0, 0, 0, 0] for an empty mask.
For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
Expand All @@ -38,7 +38,7 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
It further ensures that inputs are boolean tensors, otherwise the function yields wrong results.
See https://github.com/facebookresearch/segment-anything/issues/552 for details.
"""
assert masks.dtype == torch.bool
assert masks.dtype == torch.bool, masks.dtype

# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
Expand Down
Loading

0 comments on commit 5feed91

Please sign in to comment.