From 9b1bb27dafbf6be70c0f8e6780513e39aa910570 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 10 Jan 2024 15:09:52 +0100 Subject: [PATCH] more typing. --- src/simfmri/analysis/stats.py | 18 +++--- src/simfmri/handlers/noise.py | 90 ++++++++++++++-------------- src/simfmri/simulation/lazy.py | 4 +- src/simfmri/simulation/simulation.py | 82 ++++++++++++++++++++++--- src/simfmri/utils/__init__.py | 9 ++- src/simfmri/utils/typing.py | 10 +--- src/simfmri/utils/utils.py | 26 ++++++-- 7 files changed, 157 insertions(+), 82 deletions(-) diff --git a/src/simfmri/analysis/stats.py b/src/simfmri/analysis/stats.py index 17dab39..2620337 100644 --- a/src/simfmri/analysis/stats.py +++ b/src/simfmri/analysis/stats.py @@ -1,7 +1,7 @@ """Analysis module.""" import logging from typing import Literal -from numpy.typing import ArrayLike +from numpy.typing import NDArray from sklearn.metrics import roc_curve import numpy as np @@ -123,16 +123,16 @@ def get_scores( N = gt_f.size - P fpr, tpr, thresholds = roc_curve(gt_f, contrast.flatten()) - stats["fpr"] = list(fpr) - stats["tpr"] = list(tpr) - stats["tresh"] = list(thresholds) - stats["tp"] = np.int32(tpr * P).tolist() - stats["fp"] = np.int32(fpr * N).tolist() - stats["tn"] = np.int32(N * (1 - fpr)).tolist() - stats["fn"] = np.int32(P * (1 - tpr)).tolist() + stats["fpr"] = fpr.tolist() + stats["tpr"] = tpr.tolist() + stats["tresh"] = thresholds.tolist() + stats["tp"] = np.int_(tpr * P).tolist() + stats["fp"] = np.int_(fpr * N).tolist() + stats["tn"] = np.int_(N * (1 - fpr)).tolist() + stats["fn"] = np.int_(P * (1 - tpr)).tolist() return stats -def bacc(tpr: ArrayLike, fpr: ArrayLike) -> ArrayLike: +def bacc(tpr: NDArray, fpr: NDArray) -> NDArray: """Compute Balanced Accuracy from TPR and FPR.""" return (tpr + 1 - fpr) / 2 diff --git a/src/simfmri/handlers/noise.py b/src/simfmri/handlers/noise.py index d31e736..6460d9b 100644 --- a/src/simfmri/handlers/noise.py +++ b/src/simfmri/handlers/noise.py @@ -4,35 +4,34 @@ """ from __future__ import annotations from .base import AbstractHandler, requires_field -from ..simulation import SimData +from ..simulation import SimData, LazySimArray -from simfmri.utils import validate_rng +from simfmri.utils import validate_rng, RngType, real_type import numpy as np +from numpy.typing import NDArray import scipy.stats as sps def _lazy_add_noise( - data: np.ndarray, + data: NDArray[np.complexfloating] | NDArray[np.floating], noise_std: float, root_seed: int, - frame_idx: int = None, + frame_idx: int = 0, ) -> np.ndarray: """Add noise to data.""" rng = np.random.default_rng([frame_idx, root_seed]) if data.dtype in [np.complex128, np.complex64]: noise_std /= np.sqrt(2) - noise = noise_std * rng.standard_normal(data.shape, dtype=abs(data[:][0]).dtype) - noise = noise.astype(data.dtype) + noise = ( + noise_std * rng.standard_normal(data.shape, dtype=real_type(data.dtype)) + ).astype(data.dtype) if data.dtype in [np.complex128, np.complex64]: - noise += ( - 1j - * noise_std - * rng.standard_normal( - data.shape, - dtype=abs(data[:][0]).dtype, - ) + noise = noise * 1j + noise += noise_std * rng.standard_normal( + data.shape, + dtype=abs(data[:][0]).dtype, ) return data + noise @@ -64,15 +63,15 @@ def _handle(self, sim: SimData) -> SimData: else: # SNR is defined as average(brain signal) / noise_std noise_std = np.mean(abs(sim.static_vol[sim.static_vol > 0])) / self._snr - if sim.lazy: - self._add_noise_lazy(sim, sim.rng, noise_std) + if isinstance(sim.data_acq, LazySimArray): + self._add_noise_lazy(sim.data_acq, sim.rng, noise_std) else: self._add_noise(sim, sim.rng, noise_std) sim.extra_infos["input_snr"] = self._snr return sim - def _add_noise(self, sim: SimData, noise_std: float) -> None: + def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None: """Add noise to the simulation. This should only update the attribute data_acq of a Simulation object @@ -85,7 +84,9 @@ def _add_noise(self, sim: SimData, noise_std: float) -> None: """ raise NotImplementedError - def _add_noise_lazy(self, sim: SimData, noise_std: float) -> None: + def _add_noise_lazy( + self, lazy_arr: LazySimArray, rng_seed: int, noise_std: float + ) -> None: """Lazily add noise to the simulation. This should only update the attribute data_acq of a Simulation object @@ -104,31 +105,30 @@ class GaussianNoiseHandler(BaseNoiseHandler): name = "noise-gaussian" - def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None: + def _add_noise(self, sim: SimData, rng_seed: RngType, noise_std: float) -> None: rng = validate_rng(rng_seed) if np.iscomplexobj(sim.data_ref): noise_std /= np.sqrt(2) - noise = noise_std * rng.standard_normal( - sim.data_acq.shape, dtype=abs(sim.data_acq[:][0]).dtype - ) - noise = noise.astype(sim.data_acq.dtype) + noise = ( + noise_std + * rng.standard_normal( + sim.data_acq.shape, dtype=real_type(sim.data_acq.dtype) + ) + ).astype(sim.data_acq.dtype) if sim.data_acq.dtype in [np.complex128, np.complex64]: - noise += ( - 1j - * noise_std - * rng.standard_normal( - sim.data_ref.shape, - dtype=abs(sim.data_ref[:][0]).dtype, - ) + noise = noise * 1j + noise += noise_std * rng.standard_normal( + sim.data_ref.shape, + dtype=real_type(sim.data_ref.dtype), ) sim.data_acq += noise self.log.debug(f"{sim.data_acq}, {sim.data_ref}") - def _add_noise_lazy(self, sim: SimData, rng_seed: int, noise_std: float) -> None: - sim.data_acq.apply(_lazy_add_noise, noise_std, rng_seed) - - self.log.debug(f"{sim.data_acq}, {sim.data_ref}") + def _add_noise_lazy( + self, lazy_arr: LazySimArray, rng_seed: int, noise_std: float + ) -> None: + lazy_arr.apply(_lazy_add_noise, noise_std, rng_seed) class RicianNoiseHandler(BaseNoiseHandler): @@ -136,7 +136,7 @@ class RicianNoiseHandler(BaseNoiseHandler): name = "noise-rician" - def _add_noise(self, sim: SimData, noise_std: float) -> None: + def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None: if np.any(np.iscomplex(sim)): raise ValueError( "The Rice distribution is only applicable to real-valued data." @@ -171,10 +171,11 @@ def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None: # Complex Value, so the std is spread. noise_std /= np.sqrt(2) for kf in range(len(sim.kspace_data)): - kspace_noise = np.complex64( - rng.standard_normal(sim.kspace_data.shape[1:], dtype="float32") - ) - kspace_noise += 1j * rng.standard_normal( + kspace_noise = rng.standard_normal( + sim.kspace_data.shape[1:], dtype="float32" + ).astype("complex64") + kspace_noise *= 1j + kspace_noise += rng.standard_normal( sim.kspace_data.shape[1:], dtype="float32" ) kspace_noise *= noise_std @@ -215,15 +216,16 @@ def __init__( self, drift_model: str = "polynomial", order: int = 1, - high_pass: float = None, - drift_intensities: float = 0.01, + high_pass: float | None = None, + drift_intensities: float | np.ndarray = 0.01, ): super().__init__() self._drift_model = drift_model if not isinstance(drift_intensities, np.ndarray): if isinstance(drift_intensities, (int, float)): - drift_intensities = [drift_intensities] * order - drift_intensities = np.array([drift_intensities]) + drift_intensities = np.array([drift_intensities] * order) + else: + drift_intensities = np.array([drift_intensities]) self._drift_intensities = drift_intensities self._drift_order = order self._drift_high_pass = high_pass @@ -242,11 +244,11 @@ def _handle(self, sim: SimData) -> SimData: ) drift_matrix = drift_matrix[:, :-1] # remove the intercept column - drift_intensity = np.linspace(1, 1 + self.drift_intensities, sim.n_frames) + drift_intensity = np.linspace(1, 1 + self._drift_intensities, sim.n_frames) timeseries = drift_intensity @ drift_matrix - if sim.lazy: + if isinstance(sim.data_acq, LazySimArray): raise NotImplementedError( "lazy is not compatible with scanner drift (for now)" ) diff --git a/src/simfmri/simulation/lazy.py b/src/simfmri/simulation/lazy.py index 0bc2c2e..cec6529 100644 --- a/src/simfmri/simulation/lazy.py +++ b/src/simfmri/simulation/lazy.py @@ -10,7 +10,7 @@ from copy import deepcopy from typing import Any, Callable, TypeVar, Mapping import numpy as np -from numpy.typing import ArrayLike, NDArray +from numpy.typing import ArrayLike, NDArray, DTypeLike from functools import wraps T = TypeVar("T") @@ -96,7 +96,7 @@ def shape(self) -> tuple[int, ...]: return (len(self), *(self._base_array.shape)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> DTypeLike: """Get dtype.""" return self._base_array.dtype diff --git a/src/simfmri/simulation/simulation.py b/src/simfmri/simulation/simulation.py index fa1dce6..7836bd9 100644 --- a/src/simfmri/simulation/simulation.py +++ b/src/simfmri/simulation/simulation.py @@ -129,16 +129,82 @@ def __init__( rng=rng, extra_infos=extra_infos, ) - self.static_vol: OptionalArray - self.data_ref: OptionalArray | LazySimArray | None = None - self.roi: OptionalArray = None - self.data_acq: OptionalArray | LazySimArray = None - self.data_rec: OptionalArray | LazySimArray = None - self.kspace_data: OptionalArray = None - self.kspace_mask: OptionalArray = None - self.smaps: OptionalArray = None + self._static_vol: NDArray | None + self._roi: NDArray | None = None + self._data_ref: NDArray | None | LazySimArray = None + self._data_acq: NDArray | None | LazySimArray = None + self.data_rec: NDArray | None = None + self._kspace_data: NDArray | None = None + self._kspace_mask: NDArray | None = None + self.smaps: NDArray | None = None self.lazy = lazy + @property + def static_vol(self) -> NDArray: + """Static volume.""" + if self._static_vol is not None: + return self._static_vol + raise ValueError("static_vol is not defined") + + @static_vol.setter + def static_vol(self, value: NDArray) -> None: + self._static_vol = value + + @property + def kspace_data(self) -> NDArray: + """Static volume.""" + if self._kspace_data is not None: + return self._kspace_data + raise ValueError("static_vol is not defined") + + @kspace_data.setter + def kspace_data(self, value: NDArray) -> None: + self._kspace_data = value + + @property + def kspace_mask(self) -> NDArray: + """Static volume.""" + if self._kspace_mask is not None: + return self._kspace_mask + raise ValueError("static_vol is not defined") + + @kspace_mask.setter + def kspace_mask(self, value: NDArray) -> None: + self._kspace_mask = value + + @property + def data_ref(self) -> NDArray | LazySimArray: + """Static volume.""" + if self._data_ref is not None: + return self._data_ref + raise ValueError("data_ref is not defined") + + @data_ref.setter + def data_ref(self, value: NDArray) -> None: + self._data_ref = value + + @property + def data_acq(self) -> NDArray | LazySimArray: + """Acquired Volume.""" + if self._data_acq is not None: + return self._data_acq + raise ValueError("data_acq is not defined") + + @data_acq.setter + def data_acq(self, value: NDArray) -> None: + self._data_acq = value + + @property + def roi(self) -> NDArray: + """Reference data volume.""" + if self._roi is not None: + return self._roi + raise ValueError("static_vol is not defined") + + @roi.setter + def roi(self, value: NDArray) -> None: + self._roi = value + @classmethod def from_params(cls, sim_params: SimParams, in_place: bool = False) -> SimData: """Create a Simulation from its meta parameters. diff --git a/src/simfmri/utils/__init__.py b/src/simfmri/utils/__init__.py index a37797c..e85684e 100644 --- a/src/simfmri/utils/__init__.py +++ b/src/simfmri/utils/__init__.py @@ -1,12 +1,11 @@ """Utilities tools for simfmri.""" -from .typing import RngType, AnyShape, Shape2d, Shape3d -from .utils import validate_rng, cplx_type +from .typing import RngType, AnyShape +from .utils import validate_rng, cplx_type, real_type __all__ = [ - "RngType", "AnyShape", - "Shape2d", - "Shape3d", + "RngType", "validate_rng", "cplx_type", + "real_type", ] diff --git a/src/simfmri/utils/typing.py b/src/simfmri/utils/typing.py index e88f146..dfa7c71 100644 --- a/src/simfmri/utils/typing.py +++ b/src/simfmri/utils/typing.py @@ -10,12 +10,4 @@ or a numpy.random.Generator. """ - -Shape2d = tuple[int, int] -"""Type for a 2D shape.""" -Shape3d = tuple[int, int, int] -"""Type for a 3D shape.""" - -AnyShape = Shape2d | Shape3d - -"""Type for a 2D or 3D shape.""" +AnyShape = tuple[int, ...] diff --git a/src/simfmri/utils/utils.py b/src/simfmri/utils/utils.py index efe70f6..3ae45ea 100644 --- a/src/simfmri/utils/utils.py +++ b/src/simfmri/utils/utils.py @@ -31,12 +31,28 @@ def cplx_type(dtype: DTypeLike) -> DTypeLike: d = np.dtype(dtype) if d.type is np.float64: return np.complex128 - elif d.type is np.float128: - return np.complex256 elif d.type is np.float32: return np.complex64 else: - sim_logger.warning( - "not supported dtype, use matching complex64", stack_info=True - ) + sim_logger.warning("unsupported dtype, use matching complex64", stack_info=True) return np.complex64 + + +def real_type( + dtype: DTypeLike, +) -> np.dtype[np.float32] | np.dtype[np.float64]: + """Return the real type associated with the complex one. + + Examples + -------- + >>> cplx_type(np.float32) + np.complex64 + """ + d = np.dtype(dtype) + if d.type is np.complex64: + return np.dtype("float32") + elif d.type is np.complex128: + return np.dtype("float64") + else: + sim_logger.warning("unsupported dtype, use matching float32", stack_info=True) + return np.dtype("float32")