Skip to content

Commit

Permalink
Clean documentation and add type hints (#228)
Browse files Browse the repository at this point in the history
* Clean documentation and add type hints

* Apply ruff ANN rules to trajectories

* Replace np.ndarray type hints with np.typing.NDArray

* Fix missing import

* Fix Literal use, change types to NDArray and reduce mypy errors

* Add type alias for clustering coordinates
  • Loading branch information
Daval-G authored Jan 10, 2025
1 parent 1790c06 commit 3abd8d9
Show file tree
Hide file tree
Showing 13 changed files with 732 additions and 547 deletions.
130 changes: 67 additions & 63 deletions src/mrinufft/trajectories/display.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Display functions for trajectories."""

from __future__ import annotations

import itertools
from typing import Any

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from numpy.typing import NDArray

from .utils import (
DEFAULT_GMAX,
Expand Down Expand Up @@ -46,7 +50,7 @@ class displayConfig:
"""Font size for most labels and texts, by default ``18``."""
small_fontsize: int = 14
"""Font size for smaller texts, by default ``14``."""
nb_colors = 10
nb_colors: int = 10
"""Number of colors to use in the color cycle, by default ``10``."""
palette: str = "tab10"
"""Name of the color palette to use, by default ``"tab10"``.
Expand All @@ -58,33 +62,33 @@ class displayConfig:
slewrate_point_color: str = "b"
"""Matplotlib color for slew rate constraint points, by default ``"b"`` (blue)."""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None: # noqa ANN401
"""Update the display configuration."""
self.update(**kwargs)

def update(self, **kwargs):
def update(self, **kwargs: Any) -> None: # noqa ANN401
"""Update the display configuration."""
self._old_values = {}
for key, value in kwargs.items():
self._old_values[key] = getattr(displayConfig, key)
setattr(displayConfig, key, value)

def reset(self):
def reset(self) -> None:
"""Restore the display configuration."""
for key, value in self._old_values.items():
setattr(displayConfig, key, value)
delattr(self, "_old_values")

def __enter__(self):
def __enter__(self) -> displayConfig:
"""Enter the context manager."""
return self

def __exit__(self, *args):
def __exit__(self, *args: Any) -> None: # noqa ANN401
"""Exit the context manager."""
self.reset()

@classmethod
def get_colorlist(cls):
def get_colorlist(cls) -> list[str | NDArray]:
"""Extract a list of colors from a matplotlib palette.
If the palette is continuous, the colors will be sampled from it.
Expand Down Expand Up @@ -124,7 +128,7 @@ def get_colorlist(cls):
##############


def _setup_2D_ticks(figsize, fig=None):
def _setup_2D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes:
"""Add ticks to 2D plot."""
if fig is None:
fig = plt.figure(figsize=(figsize, figsize))
Expand All @@ -139,7 +143,7 @@ def _setup_2D_ticks(figsize, fig=None):
return ax


def _setup_3D_ticks(figsize, fig=None):
def _setup_3D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes:
"""Add ticks to 3D plot."""
if fig is None:
fig = plt.figure(figsize=(figsize, figsize))
Expand All @@ -163,21 +167,21 @@ def _setup_3D_ticks(figsize, fig=None):


def display_2D_trajectory(
trajectory,
figsize=5,
one_shot=False,
subfigure=None,
show_constraints=False,
gmax=DEFAULT_GMAX,
smax=DEFAULT_SMAX,
constraints_order=None,
**constraints_kwargs,
):
trajectory: NDArray,
figsize: float = 5,
one_shot: bool | int = False,
subfigure: plt.Figure | plt.Axes | None = None,
show_constraints: bool = False,
gmax: float = DEFAULT_GMAX,
smax: float = DEFAULT_SMAX,
constraints_order: int | str | None = None,
**constraints_kwargs: Any, # noqa ANN401
) -> plt.Axes:
"""Display 2D trajectories.
Parameters
----------
trajectory : array_like
trajectory : NDArray
Trajectory to display.
figsize : float, optional
Size of the figure.
Expand All @@ -204,7 +208,7 @@ def display_2D_trajectory(
typically 2 or `np.inf`, following the `numpy.linalg.norm`
conventions on parameter `ord`.
The default is None.
**kwargs
**constraints_kwargs
Acquisition parameters used to check on hardware constraints,
following the parameter convention from
`mrinufft.trajectories.utils.compute_gradients_and_slew_rates`.
Expand Down Expand Up @@ -278,23 +282,23 @@ def display_2D_trajectory(


def display_3D_trajectory(
trajectory,
nb_repetitions=None,
figsize=5,
per_plane=True,
one_shot=False,
subfigure=None,
show_constraints=False,
gmax=DEFAULT_GMAX,
smax=DEFAULT_SMAX,
constraints_order=None,
**constraints_kwargs,
):
trajectory: NDArray,
nb_repetitions: int | None = None,
figsize: float = 5,
per_plane: bool = True,
one_shot: bool | int = False,
subfigure: plt.Figure | plt.Axes | None = None,
show_constraints: bool = False,
gmax: float = DEFAULT_GMAX,
smax: float = DEFAULT_SMAX,
constraints_order: int | str | None = None,
**constraints_kwargs: Any, # noqa ANN401
) -> plt.Axes:
"""Display 3D trajectories.
Parameters
----------
trajectory : array_like
trajectory : NDArray
Trajectory to display.
nb_repetitions : int
Number of repetitions (planes, cones, shells, etc).
Expand Down Expand Up @@ -417,22 +421,22 @@ def display_3D_trajectory(


def display_gradients_simply(
trajectory,
shot_ids=(0,),
figsize=5,
fill_area=True,
show_signal=True,
uni_signal="gray",
uni_gradient=None,
subfigure=None,
):
trajectory: NDArray,
shot_ids: tuple[int, ...] = (0,),
figsize: float = 5,
fill_area: bool = True,
show_signal: bool = True,
uni_signal: str | None = "gray",
uni_gradient: str | None = None,
subfigure: plt.Figure | None = None,
) -> tuple[plt.Axes]:
"""Display gradients based on trajectory of any dimension.
Parameters
----------
trajectory : array_like
trajectory : NDArray
Trajectory to display.
shot_ids : list of int
shot_ids : tuple[int, ...], optional
Indices of the shots to display.
The default is `[0]`.
figsize : float, optional
Expand All @@ -455,7 +459,7 @@ def display_gradients_simply(
unique color given as argument or just by the default
color cycle when `None`.
The default is `None`.
subfigure: plt.Figure or plt.SubFigure, optional
subfigure: plt.Figure, optional
The figure where the trajectory should be displayed.
The default is `None`.
Expand Down Expand Up @@ -531,26 +535,26 @@ def display_gradients_simply(


def display_gradients(
trajectory,
shot_ids=(0,),
figsize=5,
fill_area=True,
show_signal=True,
uni_signal="gray",
uni_gradient=None,
subfigure=None,
show_constraints=False,
gmax=DEFAULT_GMAX,
smax=DEFAULT_SMAX,
constraints_order=None,
raster_time=DEFAULT_RASTER_TIME,
**constraints_kwargs,
):
trajectory: NDArray,
shot_ids: tuple[int, ...] = (0,),
figsize: float = 5,
fill_area: bool = True,
show_signal: bool = True,
uni_signal: str | None = "gray",
uni_gradient: str | None = None,
subfigure: plt.Figure | plt.Axes | None = None,
show_constraints: bool = False,
gmax: float = DEFAULT_GMAX,
smax: float = DEFAULT_SMAX,
constraints_order: int | str | None = None,
raster_time: float = DEFAULT_RASTER_TIME,
**constraints_kwargs: Any, # noqa ANN401
) -> tuple[plt.Axes]:
"""Display gradients based on trajectory of any dimension.
Parameters
----------
trajectory : array_like
trajectory : NDArray
Trajectory to display.
shot_ids : list of int
Indices of the shots to display.
Expand Down Expand Up @@ -597,7 +601,7 @@ def display_gradients(
Amount of time between the acquisition of two
consecutive samples in ms.
The default is `DEFAULT_RASTER_TIME`.
**kwargs
**constraints_kwargs
Acquisition parameters used to check on hardware constraints,
following the parameter convention from
`mrinufft.trajectories.utils.compute_gradients_and_slew_rates`.
Expand Down
31 changes: 17 additions & 14 deletions src/mrinufft/trajectories/gradients.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
"""Functions to improve/modify gradients."""

from typing import Callable

import numpy as np
import numpy.linalg as nl
from numpy.typing import NDArray
from scipy.interpolate import CubicSpline


def patch_center_anomaly(
shot_or_params,
update_shot=None,
update_parameters=None,
in_out=False,
learning_rate=1e-1,
):
shot_or_params: NDArray | tuple,
update_shot: Callable[..., NDArray] | None = None,
update_parameters: Callable[..., tuple] | None = None,
in_out: bool = False,
learning_rate: float = 1e-1,
) -> tuple[NDArray, tuple]:
"""Re-position samples to avoid center anomalies.
Some trajectories behave slightly differently from expected when
approaching definition bounds, most often the k-space center as
for spirals in some cases.
This function enforces non-strictly increasing monoticity of
This function enforces non-strictly increasing monotonicity of
sample distances from the center, effectively reducing slew
rates and smoothing gradient transitions locally.
Expand All @@ -31,27 +34,27 @@ def patch_center_anomaly(
shot_or_params : np.array, list
Either a single shot of shape (Ns, Nd), or a list of arbitrary
arguments used by ``update_shot`` to initialize a single shot.
update_shot : function, optional
update_shot : Callable[..., NDArray], optional
Function used to initialize a single shot based on parameters
provided by ``update_parameters``. If None, cubic splines are
used as an approximation instead, by default None
update_parameters : function, optional
update_parameters : Callable[..., tuple], optional
Function used to update shot parameters when provided in
``shot_or_params`` from an updated shot and parameters.
If None, cubic spline parameterization is used instead,
by default None
in_out : bool, optional
Whether the shot is going in-and-out or start from the center,
Whether the shot is going in-and-out or starts from the center,
by default False
learning_rate : float, optional
Learning rate used in the iterative optimization process,
by default 1e-1
Returns
-------
array_like
NDArray
N-D trajectory based on ``shot_or_params`` if a shot or
update_shot otherwise.
``update_shot`` otherwise.
list
Updated parameters either in the ``shot_or_params`` format
if params, or cubic spline parameterization as an array of
Expand All @@ -70,7 +73,7 @@ def patch_center_anomaly(

if update_shot is None or update_parameters is None:

def _default_update_parameters(shot, *parameters):
def _default_update_parameters(shot: NDArray, *parameters: list) -> list:
return parameters

update_parameters = _default_update_parameters
Expand Down Expand Up @@ -114,5 +117,5 @@ def _default_update_parameters(shot, *parameters):
single_shot = cbs(x_axis).T
parameters = update_parameters(single_shot, *parameters)

single_shot = single_shot = update_shot(*parameters)
single_shot = update_shot(*parameters)
return single_shot, parameters
Loading

0 comments on commit 3abd8d9

Please sign in to comment.