diff --git a/src/gpgi/_data_types.py b/src/gpgi/_data_types.py index 4b4fc9b..5964417 100644 --- a/src/gpgi/_data_types.py +++ b/src/gpgi/_data_types.py @@ -14,7 +14,7 @@ from textwrap import indent from threading import Lock from time import monotonic_ns -from typing import TYPE_CHECKING, Literal, assert_never, cast +from typing import TYPE_CHECKING, Generic, Literal, assert_never, cast, overload import numpy as np @@ -31,7 +31,7 @@ _deposit_tsc_3D, _index_particles, ) -from gpgi._typing import FieldMap, Name +from gpgi._typing import FieldMap, Name, RealT from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT if sys.version_info >= (3, 13): @@ -40,12 +40,14 @@ from _thread import LockType if TYPE_CHECKING: - from typing import Any, Self + from typing import Any, Self, TypeVar from numpy.typing import NDArray from gpgi._typing import FieldMap, HCIArray, Name, RealArray + _FloatingT = TypeVar("_FloatingT", bound=np.floating) + BoundarySpec = tuple[tuple[str, str, str], ...] @@ -94,11 +96,11 @@ class GeometricData(ABC): axes: tuple[Name, ...] -class CoordinateData(ABC): +class CoordinateData(ABC, Generic[RealT]): geometry: Geometry axes: tuple[Name, ...] - coordinates: FieldMap - fields: FieldMap + coordinates: FieldMap[RealT] + fields: FieldMap[RealT] class ValidatorMixin(GeometricData, ABC): @@ -206,7 +208,7 @@ def _validate_coordinates(self) -> None: coord = self.coordinates[axis] if len(coord) == 0: continue - coord_dtype = self._get_safe_datatype(coord) + coord_dtype = self._get_safe_datatype(reference=coord) dt = coord_dtype.type xmin, xmax = (dt(_) for _ in _AXES_LIMITS[axis]) if (cmin := dt(np.min(coord))) < xmin or not math.isfinite(cmin): @@ -230,9 +232,15 @@ def _validate_coordinates(self) -> None: self.coordinates[axis] = coord.astype(coord_dtype, copy=False) + @overload + def _get_safe_datatype(self, *, reference: None) -> np.dtype[np.floating]: ... + + @overload def _get_safe_datatype( - self, reference: NDArray[np.floating] | None = None - ) -> np.dtype[np.floating]: + self, *, reference: NDArray[_FloatingT] + ) -> np.dtype[_FloatingT]: ... + + def _get_safe_datatype(self, reference): # type: ignore[no-untyped-def] if reference is None: reference = self.coordinates[self.axes[0]] dt = reference.dtype @@ -241,13 +249,13 @@ def _get_safe_datatype( return dt -class Grid(_CoordinateValidatorMixin): +class Grid(_CoordinateValidatorMixin, Generic[RealT]): def __init__( self, *, geometry: Geometry, - cell_edges: FieldMap, - fields: FieldMap | None = None, + cell_edges: FieldMap[RealT], + fields: FieldMap[RealT] | None = None, ) -> None: r""" Define a Grid from cell left-edges and data fields. @@ -272,7 +280,9 @@ def __init__( self.axes = tuple(self.coordinates.keys()) super().__init__() - self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype) + dt = self._get_safe_datatype(reference=None) + self._dx = np.full((3,), -1, dtype=dt) + for i, ax in enumerate(self.axes): if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16: # got a constant step in this direction, store it @@ -300,17 +310,17 @@ def _validate(self) -> None: ) @property - def cell_edges(self) -> FieldMap: + def cell_edges(self) -> FieldMap[RealT]: r"""An alias for self.coordinates.""" return self.coordinates @cached_property - def cell_centers(self) -> FieldMap: + def cell_centers(self) -> FieldMap[RealT]: r"""The positions of cell centers in each direction.""" return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()} @cached_property - def cell_widths(self) -> FieldMap: + def cell_widths(self) -> FieldMap[RealT]: r"""The width of cells, expressed as the difference between consecutive left edges.""" return {ax: np.diff(arr) for ax, arr in self.coordinates.items()} @@ -334,7 +344,7 @@ def ndim(self) -> int: return len(self.axes) @property - def cell_volumes(self) -> RealArray: + def cell_volumes(self) -> RealArray[RealT]: r""" The generalized ND-volume of grid cells. @@ -357,8 +367,8 @@ def __init__( self, *, geometry: Geometry, - coordinates: FieldMap, - fields: FieldMap | None = None, + coordinates: FieldMap[RealT], + fields: FieldMap[RealT] | None = None, ) -> None: r""" Define a ParticleSet from point positions and data fields. @@ -377,7 +387,7 @@ def __init__( if fields is None: fields = {} - self.fields: FieldMap = fields + self.fields = fields self.axes = tuple(self.coordinates.keys()) super().__init__() @@ -442,7 +452,7 @@ def __init__( self.geometry = geometry if particles is None: - dt = grid._get_safe_datatype() + dt = grid._get_safe_datatype(reference=None) particles = ParticleSet( geometry=grid.geometry, coordinates={ax: np.array([], dtype=dt) for ax in grid.axes},