Skip to content

Commit

Permalink
fixup ?
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Dec 10, 2024
1 parent efee673 commit 55602db
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from textwrap import indent
from threading import Lock
from time import monotonic_ns
from typing import TYPE_CHECKING, Literal, assert_never, cast
from typing import TYPE_CHECKING, Generic, Literal, assert_never, cast, overload

import numpy as np

Expand All @@ -31,7 +31,7 @@
_deposit_tsc_3D,
_index_particles,
)
from gpgi._typing import FieldMap, Name
from gpgi._typing import FieldMap, Name, RealT
from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT

if sys.version_info >= (3, 13):
Expand All @@ -40,12 +40,14 @@
from _thread import LockType

if TYPE_CHECKING:
from typing import Any, Self
from typing import Any, Self, TypeVar

from numpy.typing import NDArray

from gpgi._typing import FieldMap, HCIArray, Name, RealArray

_FloatingT = TypeVar("_FloatingT", bound=np.floating)


BoundarySpec = tuple[tuple[str, str, str], ...]

Expand Down Expand Up @@ -94,11 +96,11 @@ class GeometricData(ABC):
axes: tuple[Name, ...]


class CoordinateData(ABC):
class CoordinateData(ABC, Generic[RealT]):
geometry: Geometry
axes: tuple[Name, ...]
coordinates: FieldMap
fields: FieldMap
coordinates: FieldMap[RealT]
fields: FieldMap[RealT]


class ValidatorMixin(GeometricData, ABC):
Expand Down Expand Up @@ -206,7 +208,7 @@ def _validate_coordinates(self) -> None:
coord = self.coordinates[axis]
if len(coord) == 0:
continue
coord_dtype = self._get_safe_datatype(coord)
coord_dtype = self._get_safe_datatype(reference=coord)
dt = coord_dtype.type
xmin, xmax = (dt(_) for _ in _AXES_LIMITS[axis])
if (cmin := dt(np.min(coord))) < xmin or not math.isfinite(cmin):
Expand All @@ -230,9 +232,15 @@ def _validate_coordinates(self) -> None:

self.coordinates[axis] = coord.astype(coord_dtype, copy=False)

@overload
def _get_safe_datatype(self, *, reference: None) -> np.dtype[np.floating]: ...

@overload
def _get_safe_datatype(
self, reference: NDArray[np.floating] | None = None
) -> np.dtype[np.floating]:
self, *, reference: NDArray[_FloatingT]
) -> np.dtype[_FloatingT]: ...

def _get_safe_datatype(self, reference): # type: ignore[no-untyped-def]
if reference is None:
reference = self.coordinates[self.axes[0]]
dt = reference.dtype
Expand All @@ -241,13 +249,13 @@ def _get_safe_datatype(
return dt


class Grid(_CoordinateValidatorMixin):
class Grid(_CoordinateValidatorMixin, Generic[RealT]):
def __init__(
self,
*,
geometry: Geometry,
cell_edges: FieldMap,
fields: FieldMap | None = None,
cell_edges: FieldMap[RealT],
fields: FieldMap[RealT] | None = None,
) -> None:
r"""
Define a Grid from cell left-edges and data fields.
Expand All @@ -272,7 +280,9 @@ def __init__(
self.axes = tuple(self.coordinates.keys())
super().__init__()

self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype)
dt = self._get_safe_datatype(reference=None)
self._dx = np.full((3,), -1, dtype=dt)

for i, ax in enumerate(self.axes):
if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16:
# got a constant step in this direction, store it
Expand Down Expand Up @@ -300,17 +310,17 @@ def _validate(self) -> None:
)

@property
def cell_edges(self) -> FieldMap:
def cell_edges(self) -> FieldMap[RealT]:
r"""An alias for self.coordinates."""
return self.coordinates

@cached_property
def cell_centers(self) -> FieldMap:
def cell_centers(self) -> FieldMap[RealT]:
r"""The positions of cell centers in each direction."""
return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()}

@cached_property
def cell_widths(self) -> FieldMap:
def cell_widths(self) -> FieldMap[RealT]:
r"""The width of cells, expressed as the difference between consecutive left edges."""
return {ax: np.diff(arr) for ax, arr in self.coordinates.items()}

Expand All @@ -334,7 +344,7 @@ def ndim(self) -> int:
return len(self.axes)

@property
def cell_volumes(self) -> RealArray:
def cell_volumes(self) -> RealArray[RealT]:
r"""
The generalized ND-volume of grid cells.
Expand All @@ -357,8 +367,8 @@ def __init__(
self,
*,
geometry: Geometry,
coordinates: FieldMap,
fields: FieldMap | None = None,
coordinates: FieldMap[RealT],
fields: FieldMap[RealT] | None = None,
) -> None:
r"""
Define a ParticleSet from point positions and data fields.
Expand All @@ -377,7 +387,7 @@ def __init__(

if fields is None:
fields = {}
self.fields: FieldMap = fields
self.fields = fields

self.axes = tuple(self.coordinates.keys())
super().__init__()
Expand Down Expand Up @@ -442,7 +452,7 @@ def __init__(
self.geometry = geometry

if particles is None:
dt = grid._get_safe_datatype()
dt = grid._get_safe_datatype(reference=None)
particles = ParticleSet(
geometry=grid.geometry,
coordinates={ax: np.array([], dtype=dt) for ax in grid.axes},
Expand Down

0 comments on commit 55602db

Please sign in to comment.