From aa2483da785d1e850b0a1efdeab214c6416f971b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sat, 25 Jan 2025 17:58:31 +0100 Subject: [PATCH] fixup 2 --- src/gpgi/_data_types.py | 42 ++++++++++++------------ src/gpgi/_spatial_data.py | 8 ++--- src/gpgi/_typing.py | 68 +++++++++++++++++++-------------------- 3 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/gpgi/_data_types.py b/src/gpgi/_data_types.py index 7a78283..8a080d3 100644 --- a/src/gpgi/_data_types.py +++ b/src/gpgi/_data_types.py @@ -11,7 +11,7 @@ from textwrap import indent from threading import Lock from time import monotonic_ns -from typing import TYPE_CHECKING, Literal, cast, final +from typing import TYPE_CHECKING, Generic, Literal, cast, final import numpy as np @@ -36,7 +36,7 @@ GeometryValidator, Validator, ) -from gpgi._typing import FieldMap, Name +from gpgi._typing import FieldMap, FloatT, Name from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT if sys.version_info >= (3, 13): @@ -47,9 +47,10 @@ if TYPE_CHECKING: from typing import Any, Self, TypeVar + from numpy import dtype from numpy.typing import NDArray - from gpgi._typing import FieldMap, HCIArray, Name, RealArray, RealT + from gpgi._typing import FieldMap, HCIArray, Name, RealArray _FloatingT = TypeVar("_FloatingT", bound=np.floating) @@ -115,13 +116,13 @@ def check(cls, data: Grid) -> None: @final -class Grid: +class Grid(Generic[FloatT]): def __init__( self, *, geometry: Geometry, - cell_edges: FieldMap[RealT], - fields: FieldMap[RealT] | None = None, + cell_edges: FieldMap[FloatT], + fields: FieldMap[FloatT] | None = None, ) -> None: r""" Define a Grid from cell left-edges and data fields. @@ -137,7 +138,7 @@ def __init__( fields (keyword-only, optional): gpgi.typing.FieldMap """ self.geometry = geometry - self.coordinates = cell_edges + self.coordinates: FieldMap[FloatT] = cell_edges if fields is None: fields = {} @@ -145,10 +146,11 @@ def __init__( self.axes = tuple(self.coordinates.keys()) self._validate() - self.dtype = self.coordinates[self.axes[0]].dtype - - self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype) + self.dtype: dtype[FloatT] = self.coordinates[self.axes[0]].dtype + self._dx: NDArray[FloatT] = np.full( + (3,), -1, dtype=self.coordinates[self.axes[0]].dtype + ) for i, ax in enumerate(self.axes): if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16: # got a constant step in this direction, store it @@ -177,17 +179,17 @@ def __repr__(self) -> str: ) @property - def cell_edges(self) -> FieldMap[RealT]: + def cell_edges(self) -> FieldMap[FloatT]: r"""An alias for self.coordinates.""" return self.coordinates @cached_property - def cell_centers(self) -> FieldMap[RealT]: + def cell_centers(self) -> FieldMap[FloatT]: r"""The positions of cell centers in each direction.""" return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()} @cached_property - def cell_widths(self) -> FieldMap[RealT]: + def cell_widths(self) -> FieldMap[FloatT]: r"""The width of cells, expressed as the difference between consecutive left edges.""" return {ax: np.diff(arr) for ax, arr in self.coordinates.items()} @@ -211,7 +213,7 @@ def ndim(self) -> int: return len(self.axes) @property - def cell_volumes(self) -> RealArray[RealT]: + def cell_volumes(self) -> RealArray[FloatT]: r""" The generalized ND-volume of grid cells. @@ -241,13 +243,13 @@ def check(cls, data: ParticleSet) -> None: @final -class ParticleSet: +class ParticleSet(Generic[FloatT]): def __init__( self, *, geometry: Geometry, - coordinates: FieldMap[RealT], - fields: FieldMap[RealT] | None = None, + coordinates: FieldMap[FloatT], + fields: FieldMap[FloatT] | None = None, ) -> None: r""" Define a ParticleSet from point positions and data fields. @@ -262,15 +264,15 @@ def __init__( fields (keyword-only, optional): gpgi.typing.FieldMap """ self.geometry = geometry - self.coordinates = coordinates + self.coordinates: FieldMap[FloatT] = coordinates if fields is None: fields = {} - self.fields = fields + self.fields: FieldMap[FloatT] = fields self.axes = tuple(self.coordinates.keys()) self._validate() - self.dtype = self.coordinates[self.axes[0]].dtype + self.dtype: dtype[FloatT] = self.coordinates[self.axes[0]].dtype _validators: list[type[Validator[ParticleSet]]] = [ GeometryValidator, diff --git a/src/gpgi/_spatial_data.py b/src/gpgi/_spatial_data.py index 44160c7..0cc8d53 100644 --- a/src/gpgi/_spatial_data.py +++ b/src/gpgi/_spatial_data.py @@ -7,7 +7,7 @@ import numpy as np from numpy.typing import NDArray -from gpgi._typing import FieldMap, Name, RealT +from gpgi._typing import FieldMap, FloatT, Name class Geometry(StrEnum): @@ -154,7 +154,7 @@ def check( @staticmethod def _validate_shape_equality( name: str, - data: NDArray[RealT], + data: NDArray[FloatT], ref_arr: NamedArray | None, ) -> NamedArray: if ref_arr is not None and data.shape != ref_arr.data.shape: @@ -165,7 +165,7 @@ def _validate_shape_equality( return ref_arr or NamedArray(name, data) @staticmethod - def _validate_sorted_state(name: str, data: NDArray[RealT]) -> None: + def _validate_sorted_state(name: str, data: NDArray[FloatT]) -> None: a = data[0] for i, b in enumerate(data[1:], start=1): if a > b: @@ -178,7 +178,7 @@ def _validate_sorted_state(name: str, data: NDArray[RealT]) -> None: @staticmethod def _validate_required_attributes( name: str, - data: NDArray[RealT], + data: NDArray[FloatT], required_attrs: dict[str, Any], ) -> None: for attr, expected in required_attrs.items(): diff --git a/src/gpgi/_typing.py b/src/gpgi/_typing.py index 2d0145a..1a16589 100644 --- a/src/gpgi/_typing.py +++ b/src/gpgi/_typing.py @@ -3,59 +3,59 @@ import numpy as np from numpy.typing import NDArray -RealT = TypeVar("RealT", np.float32, np.float64) -RealArray = NDArray[RealT] +FloatT = TypeVar("FloatT", np.float32, np.float64) +RealArray = NDArray[FloatT] HCIArray = NDArray[np.uint16] Name = str -FieldMap = dict[str, NDArray[RealT]] +FieldMap = dict[str, NDArray[FloatT]] -class CartesianCoordinates(TypedDict, Generic[RealT]): - x: NDArray[RealT] - y: NotRequired[NDArray[RealT]] - z: NotRequired[NDArray[RealT]] +class CartesianCoordinates(TypedDict, Generic[FloatT]): + x: NDArray[FloatT] + y: NotRequired[NDArray[FloatT]] + z: NotRequired[NDArray[FloatT]] -class CylindricalCoordinates(TypedDict, Generic[RealT]): - radius: NDArray[RealT] - azimuth: NotRequired[NDArray[RealT]] - z: NotRequired[NDArray[RealT]] +class CylindricalCoordinates(TypedDict, Generic[FloatT]): + radius: NDArray[FloatT] + azimuth: NotRequired[NDArray[FloatT]] + z: NotRequired[NDArray[FloatT]] -class PolarCoordinates(TypedDict, Generic[RealT]): - radius: NDArray[RealT] - z: NotRequired[NDArray[RealT]] - azimuth: NotRequired[NDArray[RealT]] +class PolarCoordinates(TypedDict, Generic[FloatT]): + radius: NDArray[FloatT] + z: NotRequired[NDArray[FloatT]] + azimuth: NotRequired[NDArray[FloatT]] -class SphericalCoordinates(TypedDict, Generic[RealT]): - colatitude: NDArray[RealT] - radius: NotRequired[NDArray[RealT]] - azimuth: NotRequired[NDArray[RealT]] +class SphericalCoordinates(TypedDict, Generic[FloatT]): + colatitude: NDArray[FloatT] + radius: NotRequired[NDArray[FloatT]] + azimuth: NotRequired[NDArray[FloatT]] -class EquatorialCoordinates(TypedDict, Generic[RealT]): - radius: NDArray[RealT] - latitude: NotRequired[NDArray[RealT]] - azimuth: NotRequired[NDArray[RealT]] +class EquatorialCoordinates(TypedDict, Generic[FloatT]): + radius: NDArray[FloatT] + latitude: NotRequired[NDArray[FloatT]] + azimuth: NotRequired[NDArray[FloatT]] CoordMap = ( - CartesianCoordinates[RealT] - | CylindricalCoordinates[RealT] - | PolarCoordinates[RealT] - | SphericalCoordinates[RealT] - | EquatorialCoordinates[RealT] + CartesianCoordinates[FloatT] + | CylindricalCoordinates[FloatT] + | PolarCoordinates[FloatT] + | SphericalCoordinates[FloatT] + | EquatorialCoordinates[FloatT] ) -class GridDict(TypedDict, Generic[RealT]): - cell_edges: CoordMap[RealT] - fields: NotRequired[FieldMap[RealT]] +class GridDict(TypedDict, Generic[FloatT]): + cell_edges: CoordMap[FloatT] + fields: NotRequired[FieldMap[FloatT]] -class ParticleSetDict(TypedDict, Generic[RealT]): - coordinates: CoordMap[RealT] - fields: NotRequired[FieldMap[RealT]] +class ParticleSetDict(TypedDict, Generic[FloatT]): + coordinates: CoordMap[FloatT] + fields: NotRequired[FieldMap[FloatT]]