Skip to content

Commit

Permalink
fix #3
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Nov 7, 2024
1 parent 237f71f commit e5a511e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
6 changes: 6 additions & 0 deletions napari_spotiflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
face_color=[1.,.5,.2],
border_color=[1.,.5,.2])

_point_layer3d_default_kwargs = dict(size=8,
symbol='ring',
opacity=1,
face_color=[1.,.5,.2],
border_color=[1.,.5,.2],
out_of_slice_display=True)

# def sample_data_2d():
# from spotiflow.data import hybiss_data_2d
Expand Down
21 changes: 14 additions & 7 deletions napari_spotiflow/_dock_widget.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import functools
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import functools
import logging
from copy import deepcopy
from typing import List, Union
from warnings import warn
import logging

import napari
import numpy as np
Expand Down Expand Up @@ -53,10 +54,10 @@ def plugin_wrapper():
# delay imports until plugin is requested by user
import torch
from spotiflow.model import Spotiflow
from spotiflow.model.pretrained import list_registered, _REGISTERED
from spotiflow.model.pretrained import _REGISTERED, list_registered
from spotiflow.utils import normalize

from napari_spotiflow import _point_layer2d_default_kwargs
from napari_spotiflow import _point_layer2d_default_kwargs, _point_layer3d_default_kwargs

def get_data(image):
image = image.data[0] if image.multiscale else image.data
Expand Down Expand Up @@ -186,7 +187,7 @@ def plugin (
) -> list[napari.types.LayerDataTuple]:
if image_axes == "":
raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.")
should_use_mps = torch.backends.mps.is_available() and not IS_3D and not os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "0"
should_use_mps = torch.backends.mps.is_available() and (not IS_3D or os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" or os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None)
DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps else "cpu"
print(f'using device {DEVICE_STR}')

Expand Down Expand Up @@ -244,7 +245,10 @@ def plugin (
def progress(size):
def _progress(it, **kwargs):
progress_bar.label = 'Spotiflow Prediction'
progress_bar.range = (0, size)
if kwargs.get("total", None) is None:
progress_bar.range = (0, size+1)
else:
progress_bar.range = (0, kwargs["total"])
progress_bar.value = 0
progress_bar.show()
app.process_events()
Expand All @@ -268,6 +272,7 @@ def _progress(it, **kwargs):
progress_bar_wrapper=progress(np.prod(actual_n_tiles)),
device=DEVICE_STR,
subpix=subpix,
normalizer=None,
)

if cnn_output:
Expand All @@ -288,6 +293,7 @@ def _progress(it, **kwargs):
verbose=True,
device=DEVICE_STR,
subpix=subpix,
normalizer=None,
) for _x in progress(x.shape[0])(x))))

pred_points = tuple(np.concatenate([[i], p])
Expand All @@ -308,8 +314,9 @@ def _progress(it, **kwargs):
if l.name == points_layer_name:
viewer.layers.remove(l)

point_layer_kwargs = _point_layer2d_default_kwargs if not IS_3D else _point_layer3d_default_kwargs
layers.append((pred_points, dict(name=f'Spots ({image.name})',
**_point_layer2d_default_kwargs), 'points'))
**point_layer_kwargs), 'points'))


progress_bar.hide()
Expand Down

0 comments on commit e5a511e

Please sign in to comment.