-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
155 additions
and
222 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,4 @@ | ||
__version__ = "0.0.1" | ||
from ._sample_data import make_sample_data | ||
from ._widget import ExampleQWidget, ImageThreshold, threshold_autogenerate_widget, threshold_magic_widget | ||
from ._widget import Tracker | ||
|
||
__all__ = ( | ||
"make_sample_data", | ||
"ExampleQWidget", | ||
"ImageThreshold", | ||
"threshold_autogenerate_widget", | ||
"threshold_magic_widget", | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,21 @@ | ||
import numpy as np | ||
import napari | ||
from napari_trackastra._widget import Tracker | ||
from trackastra.data import test_data_bacteria | ||
|
||
from napari_trackastra._widget import ( | ||
ExampleQWidget, | ||
ImageThreshold, | ||
threshold_autogenerate_widget, | ||
threshold_magic_widget, | ||
) | ||
|
||
def test_widget(): | ||
viewer = napari.Viewer() | ||
img, mask = test_data_bacteria() | ||
viewer.add_image(img) | ||
viewer.add_labels(mask) | ||
|
||
def test_threshold_autogenerate_widget(): | ||
# because our "widget" is a pure function, we can call it and | ||
# test it independently of napari | ||
im_data = np.random.random((100, 100)) | ||
thresholded = threshold_autogenerate_widget(im_data, 0.5) | ||
assert thresholded.shape == im_data.shape | ||
# etc. | ||
viewer.window.add_dock_widget(Tracker(viewer)) | ||
|
||
|
||
|
||
# make_napari_viewer is a pytest fixture that returns a napari viewer object | ||
# you don't need to import it, as long as napari is installed | ||
# in your testing environment | ||
def test_threshold_magic_widget(make_napari_viewer): | ||
viewer = make_napari_viewer() | ||
layer = viewer.add_image(np.random.random((100, 100))) | ||
if __name__ == "__main__": | ||
test_widget() | ||
|
||
# our widget will be a MagicFactory or FunctionGui instance | ||
my_widget = threshold_magic_widget() | ||
|
||
# if we "call" this object, it'll execute our function | ||
thresholded = my_widget(viewer.layers[0], 0.5) | ||
assert thresholded.shape == layer.data.shape | ||
# etc. | ||
|
||
|
||
def test_image_threshold_widget(make_napari_viewer): | ||
viewer = make_napari_viewer() | ||
layer = viewer.add_image(np.random.random((100, 100))) | ||
my_widget = ImageThreshold(viewer) | ||
|
||
# because we saved our widgets as attributes of the container | ||
# we can set their values without having to "interact" with the viewer | ||
my_widget._image_layer_combo.value = layer | ||
my_widget._threshold_slider.value = 0.5 | ||
|
||
# this allows us to run our functions directly and ensure | ||
# correct results | ||
my_widget._threshold_im() | ||
assert len(viewer.layers) == 2 | ||
|
||
|
||
# capsys is a pytest fixture that captures stdout and stderr output streams | ||
def test_example_q_widget(make_napari_viewer, capsys): | ||
# make viewer and add an image layer using our fixture | ||
viewer = make_napari_viewer() | ||
viewer.add_image(np.random.random((100, 100))) | ||
|
||
# create our widget, passing in the viewer | ||
my_widget = ExampleQWidget(viewer) | ||
|
||
# call our widget method | ||
my_widget._on_click() | ||
|
||
# read captured output and check that it's as we expected | ||
captured = capsys.readouterr() | ||
assert captured.out == "napari has 1 layers\n" | ||
napari.run() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,128 +1,128 @@ | ||
""" | ||
This module contains four napari widgets declared in | ||
different ways: | ||
- a pure Python function flagged with `autogenerate: true` | ||
in the plugin manifest. Type annotations are used by | ||
magicgui to generate widgets for each parameter. Best | ||
suited for simple processing tasks - usually taking | ||
in and/or returning a layer. | ||
- a `magic_factory` decorated function. The `magic_factory` | ||
decorator allows us to customize aspects of the resulting | ||
GUI, including the widgets associated with each parameter. | ||
Best used when you have a very simple processing task, | ||
but want some control over the autogenerated widgets. If you | ||
find yourself needing to define lots of nested functions to achieve | ||
your functionality, maybe look at the `Container` widget! | ||
- a `magicgui.widgets.Container` subclass. This provides lots | ||
of flexibility and customization options while still supporting | ||
`magicgui` widgets and convenience methods for creating widgets | ||
from type annotations. If you want to customize your widgets and | ||
connect callbacks, this is the best widget option for you. | ||
- a `QWidget` subclass. This provides maximal flexibility but requires | ||
full specification of widget layouts, callbacks, events, etc. | ||
References: | ||
- Widget specification: https://napari.org/stable/plugins/guides.html?#widgets | ||
- magicgui docs: https://pyapp-kit.github.io/magicgui/ | ||
Replace code below according to your needs. | ||
""" | ||
from typing import TYPE_CHECKING | ||
|
||
from magicgui import magic_factory | ||
from magicgui.widgets import CheckBox, Container, create_widget | ||
from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget | ||
from skimage.util import img_as_float | ||
|
||
if TYPE_CHECKING: | ||
import napari | ||
|
||
|
||
# Uses the `autogenerate: true` flag in the plugin manifest | ||
# to indicate it should be wrapped as a magicgui to autogenerate | ||
# a widget. | ||
def threshold_autogenerate_widget( | ||
img: "napari.types.ImageData", | ||
threshold: "float", | ||
) -> "napari.types.LabelsData": | ||
return img_as_float(img) > threshold | ||
|
||
|
||
# the magic_factory decorator lets us customize aspects of our widget | ||
# we specify a widget type for the threshold parameter | ||
# and use auto_call=True so the function is called whenever | ||
# the value of a parameter changes | ||
@magic_factory( | ||
threshold={"widget_type": "FloatSlider", "max": 1}, auto_call=True | ||
) | ||
def threshold_magic_widget( | ||
img_layer: "napari.layers.Image", threshold: "float" | ||
) -> "napari.types.LabelsData": | ||
return img_as_float(img_layer.data) > threshold | ||
|
||
import torch | ||
import numpy as np | ||
|
||
|
||
import napari | ||
from magicgui import magic_factory, magicgui | ||
from magicgui.widgets import CheckBox, Container, create_widget, PushButton, FileEdit, ComboBox, RadioButtons | ||
from pathlib import Path | ||
from typing import List | ||
from napari.utils import progress | ||
import trackastra | ||
from trackastra.utils import normalize | ||
from trackastra.model import Trackastra | ||
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
def _track_function(model, imgs, masks, **kwargs): | ||
print("Normalizing...") | ||
imgs = np.stack([normalize(x) for x in imgs]) | ||
print("Tracking...") | ||
track_graph = model.track(imgs, masks, mode="greedy", | ||
max_distance=128, | ||
progbar_class=progress, | ||
**kwargs) # or mode="ilp" | ||
# Visualise in napari | ||
df, masks_tracked = graph_to_ctc(track_graph,masks,outdir=None) | ||
napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph) | ||
return track_graph, masks_tracked, napari_tracks | ||
|
||
|
||
# logo = Path(__file__).parent/"resources"/"trackastra_logo_small.png" | ||
|
||
# @magicgui(call_button="track", | ||
# label_head=dict(widget_type="Label", label=f'<h1>Trackastra</h1>'), | ||
# model_path={"label": "Model Path", "mode": "d"}, | ||
# persist=True) | ||
# def track(label_head, img_layer: napari.layers.Image, mask_layer:napari.layers.Labels, model_path:Path, distance_costs:bool=False) -> List[napari.types.LayerDataTuple]: | ||
# if model_path.exists(): | ||
# model = Trackastra.from_folder(model_path, device=device) | ||
# else: | ||
# model = Trackastra.from_pretrained(model_path.name, device=device) | ||
# imgs = np.asarray(img_layer.data) | ||
# masks = np.asarray(mask_layer.data) | ||
# track_graph, masks_tracked, napari_tracks = _track_function(model, imgs, masks, use_distance=distance_costs) | ||
# mask_layer.visible = False | ||
# return [(napari_tracks, dict(name='tracks',tail_length=5), "tracks"), (masks_tracked, dict(name='masks_tracked', opacity=0.3), "labels")] | ||
|
||
|
||
|
||
# if we want even more control over our widget, we can use | ||
# magicgui `Container` | ||
class ImageThreshold(Container): | ||
class Tracker(Container): | ||
def __init__(self, viewer: "napari.viewer.Viewer"): | ||
super().__init__() | ||
self._viewer = viewer | ||
# use create_widget to generate widgets from type annotations | ||
self._image_layer_combo = create_widget( | ||
label="Image", annotation="napari.layers.Image" | ||
) | ||
self._threshold_slider = create_widget( | ||
label="Threshold", annotation=float, widget_type="FloatSlider" | ||
) | ||
self._threshold_slider.min = 0 | ||
self._threshold_slider.max = 1 | ||
# use magicgui widgets directly | ||
self._invert_checkbox = CheckBox(text="Keep pixels below threshold") | ||
self._label = create_widget(widget_type="Label", label=f'<h1>Trackastra</h1>') | ||
self._image_layer = create_widget(label="Images", annotation="napari.layers.Image") | ||
|
||
self._out_mask, self._out_tracks = None, None | ||
|
||
self._mask_layer = create_widget(label="Masks", annotation="napari.layers.Labels") | ||
self._model_type = RadioButtons(label="Model Type", choices=["Pretrained", "Custom"], orientation="horizontal", value="Pretrained") | ||
self._model_pretrained = ComboBox(label="Pretrained Model", | ||
choices=tuple(trackastra.model.pretrained._MODELS.keys()), value="general_2d") | ||
self._model_path = FileEdit(label="Model Path", mode="d") | ||
self._model_path.hide() | ||
self._run_button = PushButton(label="Track") | ||
|
||
|
||
self._model_type.changed.connect(self._model_type_changed) | ||
self._model_pretrained.changed.connect(self._update_model) | ||
self._model_path.changed.connect(self._update_model) | ||
self._run_button.changed.connect(self._run) | ||
|
||
# connect your own callbacks | ||
self._threshold_slider.changed.connect(self._threshold_im) | ||
self._invert_checkbox.changed.connect(self._threshold_im) | ||
|
||
# append into/extend the container with your widgets | ||
self.extend( | ||
[ | ||
self._image_layer_combo, | ||
self._threshold_slider, | ||
self._invert_checkbox, | ||
self._label, | ||
self._image_layer, | ||
self._mask_layer, | ||
self._model_type, | ||
self._model_pretrained, | ||
self._model_path, | ||
self._run_button, | ||
] | ||
) | ||
|
||
def _threshold_im(self): | ||
image_layer = self._image_layer_combo.value | ||
if image_layer is None: | ||
return | ||
|
||
image = img_as_float(image_layer.data) | ||
name = image_layer.name + "_thresholded" | ||
threshold = self._threshold_slider.value | ||
if self._invert_checkbox.value: | ||
thresholded = image < threshold | ||
def _model_type_changed(self, event): | ||
if event == "Pretrained": | ||
self._model_pretrained.show() | ||
self._model_path.hide() | ||
else: | ||
thresholded = image > threshold | ||
if name in self._viewer.layers: | ||
self._viewer.layers[name].data = thresholded | ||
self._model_pretrained.hide() | ||
self._model_path.show() | ||
|
||
def _run(self, event=None): | ||
self._update_model() | ||
|
||
if self.model is None: | ||
raise ValueError("Model not loaded") | ||
|
||
imgs = np.asarray(self._image_layer.value.data) | ||
masks = np.asarray(self._mask_layer.value.data) | ||
track_graph, masks_tracked, napari_tracks = _track_function(self.model, imgs, masks) | ||
self._mask_layer.value.visible = False | ||
|
||
|
||
lays = tuple(lay for lay in self._viewer.layers if lay.name=="masks_tracked") | ||
if len(lays) > 0: | ||
lays[0].data = masks_tracked | ||
else: | ||
self._viewer.add_labels(thresholded, name=name) | ||
|
||
|
||
class ExampleQWidget(QWidget): | ||
# your QWidget.__init__ can optionally request the napari viewer instance | ||
# use a type annotation of 'napari.viewer.Viewer' for any parameter | ||
def __init__(self, viewer: "napari.viewer.Viewer"): | ||
super().__init__() | ||
self.viewer = viewer | ||
|
||
btn = QPushButton("Click me!") | ||
btn.clicked.connect(self._on_click) | ||
|
||
self.setLayout(QHBoxLayout()) | ||
self.layout().addWidget(btn) | ||
|
||
def _on_click(self): | ||
print("napari has", len(self.viewer.layers), "layers") | ||
self._viewer.add_labels(masks_tracked, name="masks_tracked") | ||
|
||
lays = tuple(lay for lay in self._viewer.layers if lay.name=="tracks") | ||
if len(lays) > 0: | ||
lays[0].data = napari_tracks | ||
else: | ||
self._viewer.add_tracks(napari_tracks, name="tracks") | ||
|
||
def _update_model(self, event=None): | ||
if self._model_type.value == "Pretrained": | ||
self.model = Trackastra.from_pretrained(self._model_pretrained.value, device=device) | ||
else: | ||
self.model = Trackastra.from_folder(self._model_path.value, device=device) |
Oops, something went wrong.