diff --git a/src/napari_trackastra/__init__.py b/src/napari_trackastra/__init__.py index f510a10..23e6fed 100644 --- a/src/napari_trackastra/__init__.py +++ b/src/napari_trackastra/__init__.py @@ -1,11 +1,4 @@ __version__ = "0.0.1" -from ._sample_data import make_sample_data -from ._widget import ExampleQWidget, ImageThreshold, threshold_autogenerate_widget, threshold_magic_widget +from ._widget import Tracker -__all__ = ( - "make_sample_data", - "ExampleQWidget", - "ImageThreshold", - "threshold_autogenerate_widget", - "threshold_magic_widget", -) + diff --git a/src/napari_trackastra/_sample_data.py b/src/napari_trackastra/_sample_data.py index 453db51..241618a 100644 --- a/src/napari_trackastra/_sample_data.py +++ b/src/napari_trackastra/_sample_data.py @@ -8,14 +8,17 @@ """ from __future__ import annotations +from pathlib import Path import numpy +import tifffile +from trackastra import data -def make_sample_data(): - """Generates an image""" - # Return list of tuples - # [(data1, add_image_kwargs1), (data2, add_image_kwargs2)] - # Check the documentation for more information about the - # add_image_kwargs - # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - return [(numpy.random.rand(512, 512), {})] +def test_data_bacteria() -> list[tuple[numpy.ndarray, dict, str]]: + imgs, masks = data.test_data_bacteria() + return [(imgs, dict(name='img'), 'image'), (masks, dict(name='mask'), 'labels')] + + +def test_data_hela() -> list[tuple[numpy.ndarray, dict, str]]: + imgs, masks = data.test_data_hela() + return [(imgs, dict(name='img'), 'image'), (masks, dict(name='mask'), 'labels')] diff --git a/src/napari_trackastra/_tests/test_sample_data.py b/src/napari_trackastra/_tests/test_sample_data.py deleted file mode 100644 index 71b37bf..0000000 --- a/src/napari_trackastra/_tests/test_sample_data.py +++ /dev/null @@ -1,7 +0,0 @@ -# from napari_trackastra import make_sample_data - -# add your tests here... - - -def test_something(): - pass diff --git a/src/napari_trackastra/_tests/test_widget.py b/src/napari_trackastra/_tests/test_widget.py index 8b45194..95f0309 100644 --- a/src/napari_trackastra/_tests/test_widget.py +++ b/src/napari_trackastra/_tests/test_widget.py @@ -1,66 +1,21 @@ import numpy as np +import napari +from napari_trackastra._widget import Tracker +from trackastra.data import test_data_bacteria -from napari_trackastra._widget import ( - ExampleQWidget, - ImageThreshold, - threshold_autogenerate_widget, - threshold_magic_widget, -) +def test_widget(): + viewer = napari.Viewer() + img, mask = test_data_bacteria() + viewer.add_image(img) + viewer.add_labels(mask) -def test_threshold_autogenerate_widget(): - # because our "widget" is a pure function, we can call it and - # test it independently of napari - im_data = np.random.random((100, 100)) - thresholded = threshold_autogenerate_widget(im_data, 0.5) - assert thresholded.shape == im_data.shape - # etc. + viewer.window.add_dock_widget(Tracker(viewer)) + -# make_napari_viewer is a pytest fixture that returns a napari viewer object -# you don't need to import it, as long as napari is installed -# in your testing environment -def test_threshold_magic_widget(make_napari_viewer): - viewer = make_napari_viewer() - layer = viewer.add_image(np.random.random((100, 100))) +if __name__ == "__main__": + test_widget() - # our widget will be a MagicFactory or FunctionGui instance - my_widget = threshold_magic_widget() - - # if we "call" this object, it'll execute our function - thresholded = my_widget(viewer.layers[0], 0.5) - assert thresholded.shape == layer.data.shape - # etc. - - -def test_image_threshold_widget(make_napari_viewer): - viewer = make_napari_viewer() - layer = viewer.add_image(np.random.random((100, 100))) - my_widget = ImageThreshold(viewer) - - # because we saved our widgets as attributes of the container - # we can set their values without having to "interact" with the viewer - my_widget._image_layer_combo.value = layer - my_widget._threshold_slider.value = 0.5 - - # this allows us to run our functions directly and ensure - # correct results - my_widget._threshold_im() - assert len(viewer.layers) == 2 - - -# capsys is a pytest fixture that captures stdout and stderr output streams -def test_example_q_widget(make_napari_viewer, capsys): - # make viewer and add an image layer using our fixture - viewer = make_napari_viewer() - viewer.add_image(np.random.random((100, 100))) - - # create our widget, passing in the viewer - my_widget = ExampleQWidget(viewer) - - # call our widget method - my_widget._on_click() - - # read captured output and check that it's as we expected - captured = capsys.readouterr() - assert captured.out == "napari has 1 layers\n" + napari.run() + \ No newline at end of file diff --git a/src/napari_trackastra/_widget.py b/src/napari_trackastra/_widget.py index ed8c358..99da322 100644 --- a/src/napari_trackastra/_widget.py +++ b/src/napari_trackastra/_widget.py @@ -1,128 +1,128 @@ -""" -This module contains four napari widgets declared in -different ways: - -- a pure Python function flagged with `autogenerate: true` - in the plugin manifest. Type annotations are used by - magicgui to generate widgets for each parameter. Best - suited for simple processing tasks - usually taking - in and/or returning a layer. -- a `magic_factory` decorated function. The `magic_factory` - decorator allows us to customize aspects of the resulting - GUI, including the widgets associated with each parameter. - Best used when you have a very simple processing task, - but want some control over the autogenerated widgets. If you - find yourself needing to define lots of nested functions to achieve - your functionality, maybe look at the `Container` widget! -- a `magicgui.widgets.Container` subclass. This provides lots - of flexibility and customization options while still supporting - `magicgui` widgets and convenience methods for creating widgets - from type annotations. If you want to customize your widgets and - connect callbacks, this is the best widget option for you. -- a `QWidget` subclass. This provides maximal flexibility but requires - full specification of widget layouts, callbacks, events, etc. - -References: -- Widget specification: https://napari.org/stable/plugins/guides.html?#widgets -- magicgui docs: https://pyapp-kit.github.io/magicgui/ - -Replace code below according to your needs. -""" -from typing import TYPE_CHECKING - -from magicgui import magic_factory -from magicgui.widgets import CheckBox, Container, create_widget -from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget -from skimage.util import img_as_float - -if TYPE_CHECKING: - import napari - - -# Uses the `autogenerate: true` flag in the plugin manifest -# to indicate it should be wrapped as a magicgui to autogenerate -# a widget. -def threshold_autogenerate_widget( - img: "napari.types.ImageData", - threshold: "float", -) -> "napari.types.LabelsData": - return img_as_float(img) > threshold - - -# the magic_factory decorator lets us customize aspects of our widget -# we specify a widget type for the threshold parameter -# and use auto_call=True so the function is called whenever -# the value of a parameter changes -@magic_factory( - threshold={"widget_type": "FloatSlider", "max": 1}, auto_call=True -) -def threshold_magic_widget( - img_layer: "napari.layers.Image", threshold: "float" -) -> "napari.types.LabelsData": - return img_as_float(img_layer.data) > threshold + +import torch +import numpy as np + + +import napari +from magicgui import magic_factory, magicgui +from magicgui.widgets import CheckBox, Container, create_widget, PushButton, FileEdit, ComboBox, RadioButtons +from pathlib import Path +from typing import List +from napari.utils import progress +import trackastra +from trackastra.utils import normalize +from trackastra.model import Trackastra +from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def _track_function(model, imgs, masks, **kwargs): + print("Normalizing...") + imgs = np.stack([normalize(x) for x in imgs]) + print("Tracking...") + track_graph = model.track(imgs, masks, mode="greedy", + max_distance=128, + progbar_class=progress, + **kwargs) # or mode="ilp" + # Visualise in napari + df, masks_tracked = graph_to_ctc(track_graph,masks,outdir=None) + napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph) + return track_graph, masks_tracked, napari_tracks + + +# logo = Path(__file__).parent/"resources"/"trackastra_logo_small.png" + +# @magicgui(call_button="track", +# label_head=dict(widget_type="Label", label=f'