Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(csr): Set type of node_props to float64 #235

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def csr_to_nbgraph(csr, node_props=None):
csr.indices,
csr.data,
np.array(csr.shape, dtype=np.int32),
node_props,
node_props.astype(np.float64),
)


Expand Down Expand Up @@ -525,7 +525,7 @@ def __init__(
if np.issubdtype(skeleton_image.dtype, np.floating):
self.pixel_values = skeleton_image[coords]
elif np.issubdtype(skeleton_image.dtype, np.integer):
self.pixel_values = skeleton_image.astype(float)[coords]
self.pixel_values = skeleton_image.astype(np.float64)[coords]
else:
self.pixel_values = None
self.graph = graph
Expand Down
49 changes: 48 additions & 1 deletion src/skan/test/test_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from numpy.testing import assert_equal, assert_almost_equal
import pandas as pd
import pytest
import scipy
from scipy import ndimage as ndi
from skimage import data
from skimage.draw import line
from skimage.morphology import skeletonize

from skan import csr
from skan._testdata import (
Expand Down Expand Up @@ -357,6 +361,16 @@ def test_skeleton_integer_dtype(dtype):
assert stats['mean_pixel_value'].max() > 1


@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_skeleton_all_float_dtypes(dtype):
"""Test that skeleton data types can be both float32 and float64."""
horse = ~data.horse()
skeleton_image = skeletonize(horse)
dt = ndi.distance_transform_edt(horse)
float_skel = (dt * skeleton_image).astype(dtype)
_ = csr.Skeleton(float_skel)


def test_default_summarize_separator():
with pytest.warns(np.exceptions.VisibleDeprecationWarning,
match='separator in column name'):
Expand Down Expand Up @@ -523,7 +537,7 @@ def test_nx_to_skeleton(


@pytest.mark.parametrize(
'wrong_skeleton',
('wrong_skeleton'),
[
pytest.param(skeleton0, id='Numpy Array.'),
pytest.param(csr.Skeleton(skeleton0), id='Skeleton.'),
Expand All @@ -538,3 +552,36 @@ def test_nx_to_skeleton_attribute_error(wrong_skeleton: Any) -> None:
"""Test various errors are raised by nx_to_skeleton()."""
with pytest.raises(Exception):
csr.nx_to_skeleton(wrong_skeleton)


@pytest.mark.parametrize(
('skeleton'),
[
pytest.param(skeleton0, id='Numpy Array'),
pytest.param(csr.Skeleton(skeleton0), id='Skeleton'),
pytest.param(nx_graph, id='NetworkX Graph without edges.'),
],
)
def test_csr_to_nbgraph_attribute_error(skeleton: Any) -> None:
"""Raise AttributeError if csr_to_nbgraph() passed incomplete objects."""
with pytest.raises(AttributeError):
csr.csr_to_nbgraph(skeleton)


@pytest.mark.parametrize(
('graph'),
[
pytest.param(
scipy.sparse.csr_matrix(skeleton0),
id='Sparse matrix directly from Numpy Array',
),
pytest.param(
scipy.sparse.csr_matrix(csr.Skeleton(skeleton0)),
id='Sparse matrix from csr.Skeleton',
),
],
)
def test_csr_to_nbgraph_type_error(graph: scipy.sparse.csr_matrix) -> None:
"""Test TypeError is raised by csr_to_nbgraph() if wrong type is passed."""
with pytest.raises(TypeError):
csr.csr_to_nbgraph(graph)
Loading