Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
maweigert committed May 27, 2024
1 parent a618ac8 commit 603b22c
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 222 deletions.
11 changes: 2 additions & 9 deletions src/napari_trackastra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
__version__ = "0.0.1"
from ._sample_data import make_sample_data
from ._widget import ExampleQWidget, ImageThreshold, threshold_autogenerate_widget, threshold_magic_widget
from ._widget import Tracker

__all__ = (
"make_sample_data",
"ExampleQWidget",
"ImageThreshold",
"threshold_autogenerate_widget",
"threshold_magic_widget",
)

19 changes: 11 additions & 8 deletions src/napari_trackastra/_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
"""
from __future__ import annotations

from pathlib import Path
import numpy
import tifffile
from trackastra import data


def make_sample_data():
"""Generates an image"""
# Return list of tuples
# [(data1, add_image_kwargs1), (data2, add_image_kwargs2)]
# Check the documentation for more information about the
# add_image_kwargs
# https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image
return [(numpy.random.rand(512, 512), {})]
def test_data_bacteria() -> list[tuple[numpy.ndarray, dict, str]]:
imgs, masks = data.test_data_bacteria()
return [(imgs, dict(name='img'), 'image'), (masks, dict(name='mask'), 'labels')]


def test_data_hela() -> list[tuple[numpy.ndarray, dict, str]]:
imgs, masks = data.test_data_hela()
return [(imgs, dict(name='img'), 'image'), (masks, dict(name='mask'), 'labels')]
7 changes: 0 additions & 7 deletions src/napari_trackastra/_tests/test_sample_data.py

This file was deleted.

73 changes: 14 additions & 59 deletions src/napari_trackastra/_tests/test_widget.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,21 @@
import numpy as np
import napari
from napari_trackastra._widget import Tracker
from trackastra.data import test_data_bacteria

from napari_trackastra._widget import (
ExampleQWidget,
ImageThreshold,
threshold_autogenerate_widget,
threshold_magic_widget,
)

def test_widget():
viewer = napari.Viewer()
img, mask = test_data_bacteria()
viewer.add_image(img)
viewer.add_labels(mask)

def test_threshold_autogenerate_widget():
# because our "widget" is a pure function, we can call it and
# test it independently of napari
im_data = np.random.random((100, 100))
thresholded = threshold_autogenerate_widget(im_data, 0.5)
assert thresholded.shape == im_data.shape
# etc.
viewer.window.add_dock_widget(Tracker(viewer))



# make_napari_viewer is a pytest fixture that returns a napari viewer object
# you don't need to import it, as long as napari is installed
# in your testing environment
def test_threshold_magic_widget(make_napari_viewer):
viewer = make_napari_viewer()
layer = viewer.add_image(np.random.random((100, 100)))
if __name__ == "__main__":
test_widget()

# our widget will be a MagicFactory or FunctionGui instance
my_widget = threshold_magic_widget()

# if we "call" this object, it'll execute our function
thresholded = my_widget(viewer.layers[0], 0.5)
assert thresholded.shape == layer.data.shape
# etc.


def test_image_threshold_widget(make_napari_viewer):
viewer = make_napari_viewer()
layer = viewer.add_image(np.random.random((100, 100)))
my_widget = ImageThreshold(viewer)

# because we saved our widgets as attributes of the container
# we can set their values without having to "interact" with the viewer
my_widget._image_layer_combo.value = layer
my_widget._threshold_slider.value = 0.5

# this allows us to run our functions directly and ensure
# correct results
my_widget._threshold_im()
assert len(viewer.layers) == 2


# capsys is a pytest fixture that captures stdout and stderr output streams
def test_example_q_widget(make_napari_viewer, capsys):
# make viewer and add an image layer using our fixture
viewer = make_napari_viewer()
viewer.add_image(np.random.random((100, 100)))

# create our widget, passing in the viewer
my_widget = ExampleQWidget(viewer)

# call our widget method
my_widget._on_click()

# read captured output and check that it's as we expected
captured = capsys.readouterr()
assert captured.out == "napari has 1 layers\n"
napari.run()

222 changes: 111 additions & 111 deletions src/napari_trackastra/_widget.py
Original file line number Diff line number Diff line change
@@ -1,128 +1,128 @@
"""
This module contains four napari widgets declared in
different ways:
- a pure Python function flagged with `autogenerate: true`
in the plugin manifest. Type annotations are used by
magicgui to generate widgets for each parameter. Best
suited for simple processing tasks - usually taking
in and/or returning a layer.
- a `magic_factory` decorated function. The `magic_factory`
decorator allows us to customize aspects of the resulting
GUI, including the widgets associated with each parameter.
Best used when you have a very simple processing task,
but want some control over the autogenerated widgets. If you
find yourself needing to define lots of nested functions to achieve
your functionality, maybe look at the `Container` widget!
- a `magicgui.widgets.Container` subclass. This provides lots
of flexibility and customization options while still supporting
`magicgui` widgets and convenience methods for creating widgets
from type annotations. If you want to customize your widgets and
connect callbacks, this is the best widget option for you.
- a `QWidget` subclass. This provides maximal flexibility but requires
full specification of widget layouts, callbacks, events, etc.
References:
- Widget specification: https://napari.org/stable/plugins/guides.html?#widgets
- magicgui docs: https://pyapp-kit.github.io/magicgui/
Replace code below according to your needs.
"""
from typing import TYPE_CHECKING

from magicgui import magic_factory
from magicgui.widgets import CheckBox, Container, create_widget
from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget
from skimage.util import img_as_float

if TYPE_CHECKING:
import napari


# Uses the `autogenerate: true` flag in the plugin manifest
# to indicate it should be wrapped as a magicgui to autogenerate
# a widget.
def threshold_autogenerate_widget(
img: "napari.types.ImageData",
threshold: "float",
) -> "napari.types.LabelsData":
return img_as_float(img) > threshold


# the magic_factory decorator lets us customize aspects of our widget
# we specify a widget type for the threshold parameter
# and use auto_call=True so the function is called whenever
# the value of a parameter changes
@magic_factory(
threshold={"widget_type": "FloatSlider", "max": 1}, auto_call=True
)
def threshold_magic_widget(
img_layer: "napari.layers.Image", threshold: "float"
) -> "napari.types.LabelsData":
return img_as_float(img_layer.data) > threshold

import torch
import numpy as np


import napari
from magicgui import magic_factory, magicgui
from magicgui.widgets import CheckBox, Container, create_widget, PushButton, FileEdit, ComboBox, RadioButtons
from pathlib import Path
from typing import List
from napari.utils import progress
import trackastra
from trackastra.utils import normalize
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks

device = "cuda" if torch.cuda.is_available() else "cpu"


def _track_function(model, imgs, masks, **kwargs):
print("Normalizing...")
imgs = np.stack([normalize(x) for x in imgs])
print("Tracking...")
track_graph = model.track(imgs, masks, mode="greedy",
max_distance=128,
progbar_class=progress,
**kwargs) # or mode="ilp"
# Visualise in napari
df, masks_tracked = graph_to_ctc(track_graph,masks,outdir=None)
napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)
return track_graph, masks_tracked, napari_tracks


# logo = Path(__file__).parent/"resources"/"trackastra_logo_small.png"

# @magicgui(call_button="track",
# label_head=dict(widget_type="Label", label=f'<h1>Trackastra</h1>'),
# model_path={"label": "Model Path", "mode": "d"},
# persist=True)
# def track(label_head, img_layer: napari.layers.Image, mask_layer:napari.layers.Labels, model_path:Path, distance_costs:bool=False) -> List[napari.types.LayerDataTuple]:
# if model_path.exists():
# model = Trackastra.from_folder(model_path, device=device)
# else:
# model = Trackastra.from_pretrained(model_path.name, device=device)
# imgs = np.asarray(img_layer.data)
# masks = np.asarray(mask_layer.data)
# track_graph, masks_tracked, napari_tracks = _track_function(model, imgs, masks, use_distance=distance_costs)
# mask_layer.visible = False
# return [(napari_tracks, dict(name='tracks',tail_length=5), "tracks"), (masks_tracked, dict(name='masks_tracked', opacity=0.3), "labels")]



# if we want even more control over our widget, we can use
# magicgui `Container`
class ImageThreshold(Container):
class Tracker(Container):
def __init__(self, viewer: "napari.viewer.Viewer"):
super().__init__()
self._viewer = viewer
# use create_widget to generate widgets from type annotations
self._image_layer_combo = create_widget(
label="Image", annotation="napari.layers.Image"
)
self._threshold_slider = create_widget(
label="Threshold", annotation=float, widget_type="FloatSlider"
)
self._threshold_slider.min = 0
self._threshold_slider.max = 1
# use magicgui widgets directly
self._invert_checkbox = CheckBox(text="Keep pixels below threshold")
self._label = create_widget(widget_type="Label", label=f'<h1>Trackastra</h1>')
self._image_layer = create_widget(label="Images", annotation="napari.layers.Image")

self._out_mask, self._out_tracks = None, None

self._mask_layer = create_widget(label="Masks", annotation="napari.layers.Labels")
self._model_type = RadioButtons(label="Model Type", choices=["Pretrained", "Custom"], orientation="horizontal", value="Pretrained")
self._model_pretrained = ComboBox(label="Pretrained Model",
choices=tuple(trackastra.model.pretrained._MODELS.keys()), value="general_2d")
self._model_path = FileEdit(label="Model Path", mode="d")
self._model_path.hide()
self._run_button = PushButton(label="Track")


self._model_type.changed.connect(self._model_type_changed)
self._model_pretrained.changed.connect(self._update_model)
self._model_path.changed.connect(self._update_model)
self._run_button.changed.connect(self._run)

# connect your own callbacks
self._threshold_slider.changed.connect(self._threshold_im)
self._invert_checkbox.changed.connect(self._threshold_im)

# append into/extend the container with your widgets
self.extend(
[
self._image_layer_combo,
self._threshold_slider,
self._invert_checkbox,
self._label,
self._image_layer,
self._mask_layer,
self._model_type,
self._model_pretrained,
self._model_path,
self._run_button,
]
)

def _threshold_im(self):
image_layer = self._image_layer_combo.value
if image_layer is None:
return

image = img_as_float(image_layer.data)
name = image_layer.name + "_thresholded"
threshold = self._threshold_slider.value
if self._invert_checkbox.value:
thresholded = image < threshold
def _model_type_changed(self, event):
if event == "Pretrained":
self._model_pretrained.show()
self._model_path.hide()
else:
thresholded = image > threshold
if name in self._viewer.layers:
self._viewer.layers[name].data = thresholded
self._model_pretrained.hide()
self._model_path.show()

def _run(self, event=None):
self._update_model()

if self.model is None:
raise ValueError("Model not loaded")

imgs = np.asarray(self._image_layer.value.data)
masks = np.asarray(self._mask_layer.value.data)
track_graph, masks_tracked, napari_tracks = _track_function(self.model, imgs, masks)
self._mask_layer.value.visible = False


lays = tuple(lay for lay in self._viewer.layers if lay.name=="masks_tracked")
if len(lays) > 0:
lays[0].data = masks_tracked
else:
self._viewer.add_labels(thresholded, name=name)


class ExampleQWidget(QWidget):
# your QWidget.__init__ can optionally request the napari viewer instance
# use a type annotation of 'napari.viewer.Viewer' for any parameter
def __init__(self, viewer: "napari.viewer.Viewer"):
super().__init__()
self.viewer = viewer

btn = QPushButton("Click me!")
btn.clicked.connect(self._on_click)

self.setLayout(QHBoxLayout())
self.layout().addWidget(btn)

def _on_click(self):
print("napari has", len(self.viewer.layers), "layers")
self._viewer.add_labels(masks_tracked, name="masks_tracked")

lays = tuple(lay for lay in self._viewer.layers if lay.name=="tracks")
if len(lays) > 0:
lays[0].data = napari_tracks
else:
self._viewer.add_tracks(napari_tracks, name="tracks")

def _update_model(self, event=None):
if self._model_type.value == "Pretrained":
self.model = Trackastra.from_pretrained(self._model_pretrained.value, device=device)
else:
self.model = Trackastra.from_folder(self._model_path.value, device=device)
Loading

0 comments on commit 603b22c

Please sign in to comment.