diff --git a/src/gpgi/_data_types.py b/src/gpgi/_data_types.py index 62af0ef..4b4fc9b 100644 --- a/src/gpgi/_data_types.py +++ b/src/gpgi/_data_types.py @@ -491,10 +491,10 @@ def _validate(self) -> None: f"- from particles: {self.particles.dtype}\n" ) - def _get_padded_cell_edges(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _get_padded_cell_edges(self) -> tuple[RealArray, RealArray, RealArray]: edges = iter(self.grid.cell_edges.values()) - def pad(a: np.ndarray) -> np.ndarray: + def pad(a: RealArray) -> RealArray: dx = a[1] - a[0] return np.concatenate([[a[0] - dx], a, [a[-1] + dx]]) @@ -510,7 +510,7 @@ def pad(a: np.ndarray) -> np.ndarray: return cell_edges_x1, cell_edges_x2, cell_edges_x3 - def _get_3D_particle_coordinates(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _get_3D_particle_coordinates(self) -> tuple[RealArray, RealArray, RealArray]: particle_coords = iter(self.particles.coordinates.values()) particles_x1 = next(particle_coords) DTYPE = particles_x1.dtype @@ -588,9 +588,7 @@ def _validate_sort_axes(self, axes: tuple[int, ...]) -> None: if any(axis > self.grid.ndim - 1 for axis in axes): raise ValueError(f"Expected all axes to be <{self.grid.ndim}, got {axes!r}") - def _get_sort_key( - self, axes: tuple[int, ...] - ) -> np.ndarray[Any, np.dtype[np.uint16]]: + def _get_sort_key(self, axes: tuple[int, ...]) -> NDArray[np.uint16]: self._validate_sort_axes(axes) hci = self.host_cell_index diff --git a/src/gpgi/_typing.py b/src/gpgi/_typing.py index 5cfdc94..2d0145a 100644 --- a/src/gpgi/_typing.py +++ b/src/gpgi/_typing.py @@ -1,61 +1,61 @@ -from typing import NotRequired, TypedDict, TypeVar +from typing import Generic, NotRequired, TypedDict, TypeVar import numpy as np -import numpy.typing as npt +from numpy.typing import NDArray -Real = TypeVar("Real", np.float32, np.float64) -RealArray = npt.NDArray[Real] -HCIArray = npt.NDArray[np.uint16] +RealT = TypeVar("RealT", np.float32, np.float64) +RealArray = NDArray[RealT] +HCIArray = NDArray[np.uint16] Name = str -FieldMap = dict[str, np.ndarray] +FieldMap = dict[str, NDArray[RealT]] -class CartesianCoordinates(TypedDict): - x: np.ndarray - y: NotRequired[np.ndarray] - z: NotRequired[np.ndarray] +class CartesianCoordinates(TypedDict, Generic[RealT]): + x: NDArray[RealT] + y: NotRequired[NDArray[RealT]] + z: NotRequired[NDArray[RealT]] -class CylindricalCoordinates(TypedDict): - radius: np.ndarray - azimuth: NotRequired[np.ndarray] - z: NotRequired[np.ndarray] +class CylindricalCoordinates(TypedDict, Generic[RealT]): + radius: NDArray[RealT] + azimuth: NotRequired[NDArray[RealT]] + z: NotRequired[NDArray[RealT]] -class PolarCoordinates(TypedDict): - radius: np.ndarray - z: NotRequired[np.ndarray] - azimuth: NotRequired[np.ndarray] +class PolarCoordinates(TypedDict, Generic[RealT]): + radius: NDArray[RealT] + z: NotRequired[NDArray[RealT]] + azimuth: NotRequired[NDArray[RealT]] -class SphericalCoordinates(TypedDict): - colatitude: np.ndarray - radius: NotRequired[np.ndarray] - azimuth: NotRequired[np.ndarray] +class SphericalCoordinates(TypedDict, Generic[RealT]): + colatitude: NDArray[RealT] + radius: NotRequired[NDArray[RealT]] + azimuth: NotRequired[NDArray[RealT]] -class EquatorialCoordinates(TypedDict): - radius: np.ndarray - latitude: NotRequired[np.ndarray] - azimuth: NotRequired[np.ndarray] +class EquatorialCoordinates(TypedDict, Generic[RealT]): + radius: NDArray[RealT] + latitude: NotRequired[NDArray[RealT]] + azimuth: NotRequired[NDArray[RealT]] CoordMap = ( - CartesianCoordinates - | CylindricalCoordinates - | PolarCoordinates - | SphericalCoordinates - | EquatorialCoordinates + CartesianCoordinates[RealT] + | CylindricalCoordinates[RealT] + | PolarCoordinates[RealT] + | SphericalCoordinates[RealT] + | EquatorialCoordinates[RealT] ) -class GridDict(TypedDict): - cell_edges: CoordMap - fields: NotRequired[FieldMap] +class GridDict(TypedDict, Generic[RealT]): + cell_edges: CoordMap[RealT] + fields: NotRequired[FieldMap[RealT]] -class ParticleSetDict(TypedDict): - coordinates: CoordMap - fields: NotRequired[FieldMap] +class ParticleSetDict(TypedDict, Generic[RealT]): + coordinates: CoordMap[RealT] + fields: NotRequired[FieldMap[RealT]]