Skip to content

Commit

Permalink
more typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Jan 10, 2024
1 parent 603c11c commit 9b1bb27
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 82 deletions.
18 changes: 9 additions & 9 deletions src/simfmri/analysis/stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Analysis module."""
import logging
from typing import Literal
from numpy.typing import ArrayLike
from numpy.typing import NDArray
from sklearn.metrics import roc_curve
import numpy as np

Expand Down Expand Up @@ -123,16 +123,16 @@ def get_scores(
N = gt_f.size - P

fpr, tpr, thresholds = roc_curve(gt_f, contrast.flatten())
stats["fpr"] = list(fpr)
stats["tpr"] = list(tpr)
stats["tresh"] = list(thresholds)
stats["tp"] = np.int32(tpr * P).tolist()
stats["fp"] = np.int32(fpr * N).tolist()
stats["tn"] = np.int32(N * (1 - fpr)).tolist()
stats["fn"] = np.int32(P * (1 - tpr)).tolist()
stats["fpr"] = fpr.tolist()
stats["tpr"] = tpr.tolist()
stats["tresh"] = thresholds.tolist()
stats["tp"] = np.int_(tpr * P).tolist()
stats["fp"] = np.int_(fpr * N).tolist()
stats["tn"] = np.int_(N * (1 - fpr)).tolist()
stats["fn"] = np.int_(P * (1 - tpr)).tolist()
return stats


def bacc(tpr: ArrayLike, fpr: ArrayLike) -> ArrayLike:
def bacc(tpr: NDArray, fpr: NDArray) -> NDArray:
"""Compute Balanced Accuracy from TPR and FPR."""
return (tpr + 1 - fpr) / 2
90 changes: 46 additions & 44 deletions src/simfmri/handlers/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,34 @@
"""
from __future__ import annotations
from .base import AbstractHandler, requires_field
from ..simulation import SimData
from ..simulation import SimData, LazySimArray

from simfmri.utils import validate_rng
from simfmri.utils import validate_rng, RngType, real_type

import numpy as np
from numpy.typing import NDArray
import scipy.stats as sps


def _lazy_add_noise(
data: np.ndarray,
data: NDArray[np.complexfloating] | NDArray[np.floating],
noise_std: float,
root_seed: int,
frame_idx: int = None,
frame_idx: int = 0,
) -> np.ndarray:
"""Add noise to data."""
rng = np.random.default_rng([frame_idx, root_seed])
if data.dtype in [np.complex128, np.complex64]:
noise_std /= np.sqrt(2)
noise = noise_std * rng.standard_normal(data.shape, dtype=abs(data[:][0]).dtype)
noise = noise.astype(data.dtype)
noise = (
noise_std * rng.standard_normal(data.shape, dtype=real_type(data.dtype))
).astype(data.dtype)

if data.dtype in [np.complex128, np.complex64]:
noise += (
1j
* noise_std
* rng.standard_normal(
data.shape,
dtype=abs(data[:][0]).dtype,
)
noise = noise * 1j
noise += noise_std * rng.standard_normal(
data.shape,
dtype=abs(data[:][0]).dtype,
)
return data + noise

Expand Down Expand Up @@ -64,15 +63,15 @@ def _handle(self, sim: SimData) -> SimData:
else:
# SNR is defined as average(brain signal) / noise_std
noise_std = np.mean(abs(sim.static_vol[sim.static_vol > 0])) / self._snr
if sim.lazy:
self._add_noise_lazy(sim, sim.rng, noise_std)
if isinstance(sim.data_acq, LazySimArray):
self._add_noise_lazy(sim.data_acq, sim.rng, noise_std)
else:
self._add_noise(sim, sim.rng, noise_std)

sim.extra_infos["input_snr"] = self._snr
return sim

def _add_noise(self, sim: SimData, noise_std: float) -> None:
def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None:
"""Add noise to the simulation.
This should only update the attribute data_acq of a Simulation object
Expand All @@ -85,7 +84,9 @@ def _add_noise(self, sim: SimData, noise_std: float) -> None:
"""
raise NotImplementedError

def _add_noise_lazy(self, sim: SimData, noise_std: float) -> None:
def _add_noise_lazy(
self, lazy_arr: LazySimArray, rng_seed: int, noise_std: float
) -> None:
"""Lazily add noise to the simulation.
This should only update the attribute data_acq of a Simulation object
Expand All @@ -104,39 +105,38 @@ class GaussianNoiseHandler(BaseNoiseHandler):

name = "noise-gaussian"

def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None:
def _add_noise(self, sim: SimData, rng_seed: RngType, noise_std: float) -> None:
rng = validate_rng(rng_seed)
if np.iscomplexobj(sim.data_ref):
noise_std /= np.sqrt(2)
noise = noise_std * rng.standard_normal(
sim.data_acq.shape, dtype=abs(sim.data_acq[:][0]).dtype
)
noise = noise.astype(sim.data_acq.dtype)
noise = (
noise_std
* rng.standard_normal(
sim.data_acq.shape, dtype=real_type(sim.data_acq.dtype)
)
).astype(sim.data_acq.dtype)

if sim.data_acq.dtype in [np.complex128, np.complex64]:
noise += (
1j
* noise_std
* rng.standard_normal(
sim.data_ref.shape,
dtype=abs(sim.data_ref[:][0]).dtype,
)
noise = noise * 1j
noise += noise_std * rng.standard_normal(
sim.data_ref.shape,
dtype=real_type(sim.data_ref.dtype),
)
sim.data_acq += noise
self.log.debug(f"{sim.data_acq}, {sim.data_ref}")

def _add_noise_lazy(self, sim: SimData, rng_seed: int, noise_std: float) -> None:
sim.data_acq.apply(_lazy_add_noise, noise_std, rng_seed)

self.log.debug(f"{sim.data_acq}, {sim.data_ref}")
def _add_noise_lazy(
self, lazy_arr: LazySimArray, rng_seed: int, noise_std: float
) -> None:
lazy_arr.apply(_lazy_add_noise, noise_std, rng_seed)


class RicianNoiseHandler(BaseNoiseHandler):
"""Add rician noise to the data."""

name = "noise-rician"

def _add_noise(self, sim: SimData, noise_std: float) -> None:
def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None:
if np.any(np.iscomplex(sim)):
raise ValueError(
"The Rice distribution is only applicable to real-valued data."
Expand Down Expand Up @@ -171,10 +171,11 @@ def _add_noise(self, sim: SimData, rng_seed: int, noise_std: float) -> None:
# Complex Value, so the std is spread.
noise_std /= np.sqrt(2)
for kf in range(len(sim.kspace_data)):
kspace_noise = np.complex64(
rng.standard_normal(sim.kspace_data.shape[1:], dtype="float32")
)
kspace_noise += 1j * rng.standard_normal(
kspace_noise = rng.standard_normal(
sim.kspace_data.shape[1:], dtype="float32"
).astype("complex64")
kspace_noise *= 1j
kspace_noise += rng.standard_normal(
sim.kspace_data.shape[1:], dtype="float32"
)
kspace_noise *= noise_std
Expand Down Expand Up @@ -215,15 +216,16 @@ def __init__(
self,
drift_model: str = "polynomial",
order: int = 1,
high_pass: float = None,
drift_intensities: float = 0.01,
high_pass: float | None = None,
drift_intensities: float | np.ndarray = 0.01,
):
super().__init__()
self._drift_model = drift_model
if not isinstance(drift_intensities, np.ndarray):
if isinstance(drift_intensities, (int, float)):
drift_intensities = [drift_intensities] * order
drift_intensities = np.array([drift_intensities])
drift_intensities = np.array([drift_intensities] * order)
else:
drift_intensities = np.array([drift_intensities])
self._drift_intensities = drift_intensities
self._drift_order = order
self._drift_high_pass = high_pass
Expand All @@ -242,11 +244,11 @@ def _handle(self, sim: SimData) -> SimData:
)
drift_matrix = drift_matrix[:, :-1] # remove the intercept column

drift_intensity = np.linspace(1, 1 + self.drift_intensities, sim.n_frames)
drift_intensity = np.linspace(1, 1 + self._drift_intensities, sim.n_frames)

timeseries = drift_intensity @ drift_matrix

if sim.lazy:
if isinstance(sim.data_acq, LazySimArray):
raise NotImplementedError(
"lazy is not compatible with scanner drift (for now)"
)
Expand Down
4 changes: 2 additions & 2 deletions src/simfmri/simulation/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from copy import deepcopy
from typing import Any, Callable, TypeVar, Mapping
import numpy as np
from numpy.typing import ArrayLike, NDArray
from numpy.typing import ArrayLike, NDArray, DTypeLike
from functools import wraps

T = TypeVar("T")
Expand Down Expand Up @@ -96,7 +96,7 @@ def shape(self) -> tuple[int, ...]:
return (len(self), *(self._base_array.shape))

@property
def dtype(self) -> np.dtype:
def dtype(self) -> DTypeLike:
"""Get dtype."""
return self._base_array.dtype

Expand Down
82 changes: 74 additions & 8 deletions src/simfmri/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,82 @@ def __init__(
rng=rng,
extra_infos=extra_infos,
)
self.static_vol: OptionalArray
self.data_ref: OptionalArray | LazySimArray | None = None
self.roi: OptionalArray = None
self.data_acq: OptionalArray | LazySimArray = None
self.data_rec: OptionalArray | LazySimArray = None
self.kspace_data: OptionalArray = None
self.kspace_mask: OptionalArray = None
self.smaps: OptionalArray = None
self._static_vol: NDArray | None
self._roi: NDArray | None = None
self._data_ref: NDArray | None | LazySimArray = None
self._data_acq: NDArray | None | LazySimArray = None
self.data_rec: NDArray | None = None
self._kspace_data: NDArray | None = None
self._kspace_mask: NDArray | None = None
self.smaps: NDArray | None = None
self.lazy = lazy

@property
def static_vol(self) -> NDArray:
"""Static volume."""
if self._static_vol is not None:
return self._static_vol
raise ValueError("static_vol is not defined")

@static_vol.setter
def static_vol(self, value: NDArray) -> None:
self._static_vol = value

@property
def kspace_data(self) -> NDArray:
"""Static volume."""
if self._kspace_data is not None:
return self._kspace_data
raise ValueError("static_vol is not defined")

@kspace_data.setter
def kspace_data(self, value: NDArray) -> None:
self._kspace_data = value

@property
def kspace_mask(self) -> NDArray:
"""Static volume."""
if self._kspace_mask is not None:
return self._kspace_mask
raise ValueError("static_vol is not defined")

@kspace_mask.setter
def kspace_mask(self, value: NDArray) -> None:
self._kspace_mask = value

@property
def data_ref(self) -> NDArray | LazySimArray:
"""Static volume."""
if self._data_ref is not None:
return self._data_ref
raise ValueError("data_ref is not defined")

@data_ref.setter
def data_ref(self, value: NDArray) -> None:
self._data_ref = value

@property
def data_acq(self) -> NDArray | LazySimArray:
"""Acquired Volume."""
if self._data_acq is not None:
return self._data_acq
raise ValueError("data_acq is not defined")

@data_acq.setter
def data_acq(self, value: NDArray) -> None:
self._data_acq = value

@property
def roi(self) -> NDArray:
"""Reference data volume."""
if self._roi is not None:
return self._roi
raise ValueError("static_vol is not defined")

@roi.setter
def roi(self, value: NDArray) -> None:
self._roi = value

@classmethod
def from_params(cls, sim_params: SimParams, in_place: bool = False) -> SimData:
"""Create a Simulation from its meta parameters.
Expand Down
9 changes: 4 additions & 5 deletions src/simfmri/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Utilities tools for simfmri."""
from .typing import RngType, AnyShape, Shape2d, Shape3d
from .utils import validate_rng, cplx_type
from .typing import RngType, AnyShape
from .utils import validate_rng, cplx_type, real_type

__all__ = [
"RngType",
"AnyShape",
"Shape2d",
"Shape3d",
"RngType",
"validate_rng",
"cplx_type",
"real_type",
]
10 changes: 1 addition & 9 deletions src/simfmri/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,4 @@
or a numpy.random.Generator.
"""


Shape2d = tuple[int, int]
"""Type for a 2D shape."""
Shape3d = tuple[int, int, int]
"""Type for a 3D shape."""

AnyShape = Shape2d | Shape3d

"""Type for a 2D or 3D shape."""
AnyShape = tuple[int, ...]
26 changes: 21 additions & 5 deletions src/simfmri/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,28 @@ def cplx_type(dtype: DTypeLike) -> DTypeLike:
d = np.dtype(dtype)
if d.type is np.float64:
return np.complex128
elif d.type is np.float128:
return np.complex256
elif d.type is np.float32:
return np.complex64
else:
sim_logger.warning(
"not supported dtype, use matching complex64", stack_info=True
)
sim_logger.warning("unsupported dtype, use matching complex64", stack_info=True)
return np.complex64


def real_type(
dtype: DTypeLike,
) -> np.dtype[np.float32] | np.dtype[np.float64]:
"""Return the real type associated with the complex one.
Examples
--------
>>> cplx_type(np.float32)
np.complex64
"""
d = np.dtype(dtype)
if d.type is np.complex64:
return np.dtype("float32")
elif d.type is np.complex128:
return np.dtype("float64")
else:
sim_logger.warning("unsupported dtype, use matching float32", stack_info=True)
return np.dtype("float32")

0 comments on commit 9b1bb27

Please sign in to comment.