diff --git a/src/gpgi/_boundaries.py b/src/gpgi/_boundaries.py index 403765e..f20cadc 100644 --- a/src/gpgi/_boundaries.py +++ b/src/gpgi/_boundaries.py @@ -1,26 +1,25 @@ -from __future__ import annotations - from collections.abc import Callable from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import Any, Literal, cast + +from numpy.typing import NDArray -if TYPE_CHECKING: - from gpgi._typing import RealArray +from gpgi._typing import FloatT BoundaryRecipeT = Callable[ [ - "RealArray", - "RealArray", - "RealArray", - "RealArray", - "RealArray", - "RealArray", - "RealArray", - "RealArray", + NDArray[FloatT], + NDArray[FloatT], + NDArray[FloatT], + NDArray[FloatT], + NDArray[FloatT], + NDArray[FloatT], + NDArray[FloatT], + NDArray[FloatT], Literal["left", "right"], dict[str, Any], ], - "RealArray", + NDArray[FloatT], ] @@ -126,64 +125,64 @@ def __contains__(self, key: str) -> bool: # basic recipes def open_boundary( - same_side_active_layer: RealArray, - same_side_ghost_layer: RealArray, - opposite_side_active_layer: RealArray, - opposite_side_ghost_layer: RealArray, - weight_same_side_active_layer: RealArray, - weight_same_side_ghost_layer: RealArray, - weight_opposite_side_active_layer: RealArray, - weight_opposite_side_ghost_layer: RealArray, + same_side_active_layer: NDArray[FloatT], + same_side_ghost_layer: NDArray[FloatT], + opposite_side_active_layer: NDArray[FloatT], + opposite_side_ghost_layer: NDArray[FloatT], + weight_same_side_active_layer: NDArray[FloatT], + weight_same_side_ghost_layer: NDArray[FloatT], + weight_opposite_side_active_layer: NDArray[FloatT], + weight_opposite_side_ghost_layer: NDArray[FloatT], side: Literal["left", "right"], metadata: dict[str, Any], -) -> RealArray: +) -> NDArray[FloatT]: # return the active layer unchanged return same_side_active_layer def wall_boundary( - same_side_active_layer: RealArray, - same_side_ghost_layer: RealArray, - opposite_side_active_layer: RealArray, - opposite_side_ghost_layer: RealArray, - weight_same_side_active_layer: RealArray, - weight_same_side_ghost_layer: RealArray, - weight_opposite_side_active_layer: RealArray, - weight_opposite_side_ghost_layer: RealArray, + same_side_active_layer: NDArray[FloatT], + same_side_ghost_layer: NDArray[FloatT], + opposite_side_active_layer: NDArray[FloatT], + opposite_side_ghost_layer: NDArray[FloatT], + weight_same_side_active_layer: NDArray[FloatT], + weight_same_side_ghost_layer: NDArray[FloatT], + weight_opposite_side_active_layer: NDArray[FloatT], + weight_opposite_side_ghost_layer: NDArray[FloatT], side: Literal["left", "right"], metadata: dict[str, Any], -) -> RealArray: - return cast("RealArray", same_side_active_layer + same_side_ghost_layer) +) -> NDArray[FloatT]: + return cast(NDArray[FloatT], same_side_active_layer + same_side_ghost_layer) def antisymmetric_boundary( - same_side_active_layer: RealArray, - same_side_ghost_layer: RealArray, - opposite_side_active_layer: RealArray, - opposite_side_ghost_layer: RealArray, - weight_same_side_active_layer: RealArray, - weight_same_side_ghost_layer: RealArray, - weight_opposite_side_active_layer: RealArray, - weight_opposite_side_ghost_layer: RealArray, + same_side_active_layer: NDArray[FloatT], + same_side_ghost_layer: NDArray[FloatT], + opposite_side_active_layer: NDArray[FloatT], + opposite_side_ghost_layer: NDArray[FloatT], + weight_same_side_active_layer: NDArray[FloatT], + weight_same_side_ghost_layer: NDArray[FloatT], + weight_opposite_side_active_layer: NDArray[FloatT], + weight_opposite_side_ghost_layer: NDArray[FloatT], side: Literal["left", "right"], metadata: dict[str, Any], -) -> RealArray: - return cast("RealArray", same_side_active_layer - same_side_ghost_layer) +) -> NDArray[FloatT]: + return cast(NDArray[FloatT], same_side_active_layer - same_side_ghost_layer) def periodic_boundary( - same_side_active_layer: RealArray, - same_side_ghost_layer: RealArray, - opposite_side_active_layer: RealArray, - opposite_side_ghost_layer: RealArray, - weight_same_side_active_layer: RealArray, - weight_same_side_ghost_layer: RealArray, - weight_opposite_side_active_layer: RealArray, - weight_opposite_side_ghost_layer: RealArray, + same_side_active_layer: NDArray[FloatT], + same_side_ghost_layer: NDArray[FloatT], + opposite_side_active_layer: NDArray[FloatT], + opposite_side_ghost_layer: NDArray[FloatT], + weight_same_side_active_layer: NDArray[FloatT], + weight_same_side_ghost_layer: NDArray[FloatT], + weight_opposite_side_active_layer: NDArray[FloatT], + weight_opposite_side_ghost_layer: NDArray[FloatT], side: Literal["left", "right"], metadata: dict[str, Any], -) -> RealArray: - return cast("RealArray", same_side_active_layer + opposite_side_ghost_layer) +) -> NDArray[FloatT]: + return cast(NDArray[FloatT], same_side_active_layer + opposite_side_ghost_layer) _base_registry: dict[str, BoundaryRecipeT] = { diff --git a/src/gpgi/_data_types.py b/src/gpgi/_data_types.py index 37bb15f..24a8f22 100644 --- a/src/gpgi/_data_types.py +++ b/src/gpgi/_data_types.py @@ -11,7 +11,7 @@ from textwrap import indent from threading import Lock from time import monotonic_ns -from typing import TYPE_CHECKING, Literal, cast, final +from typing import TYPE_CHECKING, Generic, Literal, cast, final import numpy as np @@ -36,7 +36,7 @@ GeometryValidator, Validator, ) -from gpgi._typing import FieldMap, Name +from gpgi._typing import FieldMap, FloatT, Name from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT if sys.version_info >= (3, 13): @@ -45,9 +45,13 @@ from _thread import LockType if TYPE_CHECKING: - from typing import Any, Self + from typing import Any, Self, TypeVar - from gpgi._typing import FieldMap, HCIArray, Name, RealArray + from numpy.typing import NDArray + + from gpgi._typing import FieldMap, HCIArray, Name + + _FloatingT = TypeVar("_FloatingT", bound=np.floating) BoundarySpec = tuple[tuple[str, str, str], ...] @@ -111,13 +115,13 @@ def check(cls, data: Grid) -> None: @final -class Grid: +class Grid(Generic[FloatT]): def __init__( self, *, geometry: Geometry, - cell_edges: FieldMap, - fields: FieldMap | None = None, + cell_edges: FieldMap[FloatT], + fields: FieldMap[FloatT] | None = None, ) -> None: r""" Define a Grid from cell left-edges and data fields. @@ -133,7 +137,7 @@ def __init__( fields (keyword-only, optional): gpgi.typing.FieldMap """ self.geometry = geometry - self.coordinates = cell_edges + self.coordinates: FieldMap[FloatT] = cell_edges if fields is None: fields = {} @@ -141,9 +145,11 @@ def __init__( self.axes = tuple(self.coordinates.keys()) self._validate() - self.dtype = self.coordinates[self.axes[0]].dtype + self.dtype: np.dtype[FloatT] = self.coordinates[self.axes[0]].dtype - self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype) + self._dx: NDArray[FloatT] = np.full( + (3,), -1, dtype=self.coordinates[self.axes[0]].dtype + ) for i, ax in enumerate(self.axes): if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16: # got a constant step in this direction, store it @@ -172,17 +178,17 @@ def __repr__(self) -> str: ) @property - def cell_edges(self) -> FieldMap: + def cell_edges(self) -> FieldMap[FloatT]: r"""An alias for self.coordinates.""" return self.coordinates @cached_property - def cell_centers(self) -> FieldMap: + def cell_centers(self) -> FieldMap[FloatT]: r"""The positions of cell centers in each direction.""" - return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()} + return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()} # type: ignore [misc] @cached_property - def cell_widths(self) -> FieldMap: + def cell_widths(self) -> FieldMap[FloatT]: r"""The width of cells, expressed as the difference between consecutive left edges.""" return {ax: np.diff(arr) for ax, arr in self.coordinates.items()} @@ -206,7 +212,7 @@ def ndim(self) -> int: return len(self.axes) @property - def cell_volumes(self) -> RealArray: + def cell_volumes(self) -> NDArray[FloatT]: r""" The generalized ND-volume of grid cells. @@ -217,7 +223,7 @@ def cell_volumes(self) -> RealArray: widths = list(self.cell_widths.values()) if self.geometry is Geometry.CARTESIAN: raw = np.prod(np.meshgrid(*widths), axis=0) - return cast("RealArray", np.swapaxes(raw, 0, 1)) + return np.swapaxes(raw, 0, 1) else: raise NotImplementedError( f"cell_volumes property is not implemented for {self.geometry} geometry" @@ -236,13 +242,13 @@ def check(cls, data: ParticleSet) -> None: @final -class ParticleSet: +class ParticleSet(Generic[FloatT]): def __init__( self, *, geometry: Geometry, - coordinates: FieldMap, - fields: FieldMap | None = None, + coordinates: FieldMap[FloatT], + fields: FieldMap[FloatT] | None = None, ) -> None: r""" Define a ParticleSet from point positions and data fields. @@ -257,15 +263,15 @@ def __init__( fields (keyword-only, optional): gpgi.typing.FieldMap """ self.geometry = geometry - self.coordinates = coordinates + self.coordinates: FieldMap[FloatT] = coordinates if fields is None: fields = {} - self.fields: FieldMap = fields + self.fields: FieldMap[FloatT] = fields self.axes = tuple(self.coordinates.keys()) self._validate() - self.dtype = self.coordinates[self.axes[0]].dtype + self.dtype: np.dtype[FloatT] = self.coordinates[self.axes[0]].dtype _validators: list[type[Validator[ParticleSet]]] = [ GeometryValidator, @@ -380,10 +386,12 @@ def _validate(self) -> None: f"- from particles: {self.particles.dtype}\n" ) - def _get_padded_cell_edges(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _get_padded_cell_edges( + self, + ) -> tuple[NDArray[FloatT], NDArray[FloatT], NDArray[FloatT]]: edges = iter(self.grid.cell_edges.values()) - def pad(a: np.ndarray) -> np.ndarray: + def pad(a: NDArray[FloatT]) -> NDArray[FloatT]: dx = a[1] - a[0] return np.concatenate([[a[0] - dx], a, [a[-1] + dx]]) @@ -391,7 +399,12 @@ def pad(a: np.ndarray) -> np.ndarray: cell_edges_x1 = pad(x1) DTYPE = cell_edges_x1.dtype - cell_edges_x2 = cell_edges_x3 = np.empty(0, DTYPE) + cell_edges_x2: np.ndarray[tuple[int, ...], np.dtype[FloatT]] = np.empty( + 0, DTYPE + ) + cell_edges_x3: np.ndarray[tuple[int, ...], np.dtype[FloatT]] = np.empty( + 0, DTYPE + ) if self.grid.ndim >= 2: cell_edges_x2 = pad(next(edges)) if self.grid.ndim == 3: @@ -399,7 +412,9 @@ def pad(a: np.ndarray) -> np.ndarray: return cell_edges_x1, cell_edges_x2, cell_edges_x3 - def _get_3D_particle_coordinates(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _get_3D_particle_coordinates( + self, + ) -> tuple[NDArray[FloatT], NDArray[FloatT], NDArray[FloatT]]: particle_coords = iter(self.particles.coordinates.values()) particles_x1 = next(particle_coords) DTYPE = particles_x1.dtype @@ -477,9 +492,7 @@ def _validate_sort_axes(self, axes: tuple[int, ...]) -> None: if any(axis > self.grid.ndim - 1 for axis in axes): raise ValueError(f"Expected all axes to be <{self.grid.ndim}, got {axes!r}") - def _get_sort_key( - self, axes: tuple[int, ...] - ) -> np.ndarray[Any, np.dtype[np.uint16]]: + def _get_sort_key(self, axes: tuple[int, ...]) -> NDArray[np.uint16]: self._validate_sort_axes(axes) hci = self.host_cell_index @@ -780,9 +793,9 @@ def _sanitize_boundaries(self, boundaries: dict[Name, tuple[Name, Name]]) -> Non def _apply_boundary_conditions( self, - array: RealArray, + array: NDArray[FloatT], boundaries: dict[Name, tuple[Name, Name]], - weight_array: RealArray | None, + weight_array: NDArray[FloatT] | None, ) -> None: axes = list(self.grid.axes) for ax, bv in boundaries.items(): diff --git a/src/gpgi/_spatial_data.py b/src/gpgi/_spatial_data.py index 60be47e..0cc8d53 100644 --- a/src/gpgi/_spatial_data.py +++ b/src/gpgi/_spatial_data.py @@ -7,7 +7,7 @@ import numpy as np from numpy.typing import NDArray -from gpgi._typing import FieldMap, Name, Real +from gpgi._typing import FieldMap, FloatT, Name class Geometry(StrEnum): @@ -154,7 +154,7 @@ def check( @staticmethod def _validate_shape_equality( name: str, - data: NDArray[Real], + data: NDArray[FloatT], ref_arr: NamedArray | None, ) -> NamedArray: if ref_arr is not None and data.shape != ref_arr.data.shape: @@ -165,7 +165,7 @@ def _validate_shape_equality( return ref_arr or NamedArray(name, data) @staticmethod - def _validate_sorted_state(name: str, data: NDArray[Real]) -> None: + def _validate_sorted_state(name: str, data: NDArray[FloatT]) -> None: a = data[0] for i, b in enumerate(data[1:], start=1): if a > b: @@ -178,7 +178,7 @@ def _validate_sorted_state(name: str, data: NDArray[Real]) -> None: @staticmethod def _validate_required_attributes( name: str, - data: NDArray[Real], + data: NDArray[FloatT], required_attrs: dict[str, Any], ) -> None: for attr, expected in required_attrs.items(): diff --git a/src/gpgi/_typing.py b/src/gpgi/_typing.py index 5cfdc94..6951327 100644 --- a/src/gpgi/_typing.py +++ b/src/gpgi/_typing.py @@ -1,61 +1,60 @@ -from typing import NotRequired, TypedDict, TypeVar +from typing import Generic, NotRequired, TypedDict, TypeVar import numpy as np -import numpy.typing as npt +from numpy.typing import NDArray -Real = TypeVar("Real", np.float32, np.float64) -RealArray = npt.NDArray[Real] -HCIArray = npt.NDArray[np.uint16] +FloatT = TypeVar("FloatT", np.float32, np.float64) +HCIArray = NDArray[np.uint16] Name = str -FieldMap = dict[str, np.ndarray] +FieldMap = dict[str, NDArray[FloatT]] -class CartesianCoordinates(TypedDict): - x: np.ndarray - y: NotRequired[np.ndarray] - z: NotRequired[np.ndarray] +class CartesianCoordinates(TypedDict, Generic[FloatT]): + x: NDArray[FloatT] + y: NotRequired[NDArray[FloatT]] + z: NotRequired[NDArray[FloatT]] -class CylindricalCoordinates(TypedDict): - radius: np.ndarray - azimuth: NotRequired[np.ndarray] - z: NotRequired[np.ndarray] +class CylindricalCoordinates(TypedDict, Generic[FloatT]): + radius: NDArray[FloatT] + azimuth: NotRequired[NDArray[FloatT]] + z: NotRequired[NDArray[FloatT]] -class PolarCoordinates(TypedDict): - radius: np.ndarray - z: NotRequired[np.ndarray] - azimuth: NotRequired[np.ndarray] +class PolarCoordinates(TypedDict, Generic[FloatT]): + radius: NDArray[FloatT] + z: NotRequired[NDArray[FloatT]] + azimuth: NotRequired[NDArray[FloatT]] -class SphericalCoordinates(TypedDict): - colatitude: np.ndarray - radius: NotRequired[np.ndarray] - azimuth: NotRequired[np.ndarray] +class SphericalCoordinates(TypedDict, Generic[FloatT]): + colatitude: NDArray[FloatT] + radius: NotRequired[NDArray[FloatT]] + azimuth: NotRequired[NDArray[FloatT]] -class EquatorialCoordinates(TypedDict): - radius: np.ndarray - latitude: NotRequired[np.ndarray] - azimuth: NotRequired[np.ndarray] +class EquatorialCoordinates(TypedDict, Generic[FloatT]): + radius: NDArray[FloatT] + latitude: NotRequired[NDArray[FloatT]] + azimuth: NotRequired[NDArray[FloatT]] CoordMap = ( - CartesianCoordinates - | CylindricalCoordinates - | PolarCoordinates - | SphericalCoordinates - | EquatorialCoordinates + CartesianCoordinates[FloatT] + | CylindricalCoordinates[FloatT] + | PolarCoordinates[FloatT] + | SphericalCoordinates[FloatT] + | EquatorialCoordinates[FloatT] ) -class GridDict(TypedDict): - cell_edges: CoordMap - fields: NotRequired[FieldMap] +class GridDict(TypedDict, Generic[FloatT]): + cell_edges: CoordMap[FloatT] + fields: NotRequired[FieldMap[FloatT]] -class ParticleSetDict(TypedDict): - coordinates: CoordMap - fields: NotRequired[FieldMap] +class ParticleSetDict(TypedDict, Generic[FloatT]): + coordinates: CoordMap[FloatT] + fields: NotRequired[FieldMap[FloatT]] diff --git a/src/gpgi/typing.py b/src/gpgi/typing.py index 1b431db..222c220 100644 --- a/src/gpgi/typing.py +++ b/src/gpgi/typing.py @@ -8,7 +8,9 @@ from gpgi._typing import FieldMap if TYPE_CHECKING: - from gpgi._typing import HCIArray, RealArray + from numpy.typing import NDArray + + from gpgi._typing import FloatT, HCIArray __all__ = [ @@ -21,32 +23,32 @@ class DepositionMethodT(Protocol): def __call__( # noqa D102 self, - cell_edges_x1: RealArray, - cell_edges_x2: RealArray, - cell_edges_x3: RealArray, - particles_x1: RealArray, - particles_x2: RealArray, - particles_x3: RealArray, - field: RealArray, - weight_field: RealArray, + cell_edges_x1: NDArray[FloatT], + cell_edges_x2: NDArray[FloatT], + cell_edges_x3: NDArray[FloatT], + particles_x1: NDArray[FloatT], + particles_x2: NDArray[FloatT], + particles_x3: NDArray[FloatT], + field: NDArray[FloatT], + weight_field: NDArray[FloatT], hci: HCIArray, - out: RealArray, + out: NDArray[FloatT], ) -> None: ... class DepositionMethodWithMetadataT(Protocol): def __call__( # noqa D102 self, - cell_edges_x1: RealArray, - cell_edges_x2: RealArray, - cell_edges_x3: RealArray, - particles_x1: RealArray, - particles_x2: RealArray, - particles_x3: RealArray, - field: RealArray, - weight_field: RealArray, + cell_edges_x1: NDArray[FloatT], + cell_edges_x2: NDArray[FloatT], + cell_edges_x3: NDArray[FloatT], + particles_x1: NDArray[FloatT], + particles_x2: NDArray[FloatT], + particles_x3: NDArray[FloatT], + field: NDArray[FloatT], + weight_field: NDArray[FloatT], hci: HCIArray, - out: RealArray, + out: NDArray[FloatT], *, metadata: dict[str, Any], ) -> None: ...