Skip to content

Commit

Permalink
fixup 2
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jan 25, 2025
1 parent 8564d03 commit aa2483d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 58 deletions.
42 changes: 22 additions & 20 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from textwrap import indent
from threading import Lock
from time import monotonic_ns
from typing import TYPE_CHECKING, Literal, cast, final
from typing import TYPE_CHECKING, Generic, Literal, cast, final

import numpy as np

Expand All @@ -36,7 +36,7 @@
GeometryValidator,
Validator,
)
from gpgi._typing import FieldMap, Name
from gpgi._typing import FieldMap, FloatT, Name
from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT

if sys.version_info >= (3, 13):
Expand All @@ -47,9 +47,10 @@
if TYPE_CHECKING:
from typing import Any, Self, TypeVar

from numpy import dtype
from numpy.typing import NDArray

from gpgi._typing import FieldMap, HCIArray, Name, RealArray, RealT
from gpgi._typing import FieldMap, HCIArray, Name, RealArray

_FloatingT = TypeVar("_FloatingT", bound=np.floating)

Expand Down Expand Up @@ -115,13 +116,13 @@ def check(cls, data: Grid) -> None:


@final
class Grid:
class Grid(Generic[FloatT]):
def __init__(
self,
*,
geometry: Geometry,
cell_edges: FieldMap[RealT],
fields: FieldMap[RealT] | None = None,
cell_edges: FieldMap[FloatT],
fields: FieldMap[FloatT] | None = None,
) -> None:
r"""
Define a Grid from cell left-edges and data fields.
Expand All @@ -137,18 +138,19 @@ def __init__(
fields (keyword-only, optional): gpgi.typing.FieldMap
"""
self.geometry = geometry
self.coordinates = cell_edges
self.coordinates: FieldMap[FloatT] = cell_edges

if fields is None:
fields = {}
self.fields: FieldMap = fields

self.axes = tuple(self.coordinates.keys())
self._validate()
self.dtype = self.coordinates[self.axes[0]].dtype

self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype)
self.dtype: dtype[FloatT] = self.coordinates[self.axes[0]].dtype

self._dx: NDArray[FloatT] = np.full(
(3,), -1, dtype=self.coordinates[self.axes[0]].dtype
)
for i, ax in enumerate(self.axes):
if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16:
# got a constant step in this direction, store it
Expand Down Expand Up @@ -177,17 +179,17 @@ def __repr__(self) -> str:
)

@property
def cell_edges(self) -> FieldMap[RealT]:
def cell_edges(self) -> FieldMap[FloatT]:
r"""An alias for self.coordinates."""
return self.coordinates

@cached_property
def cell_centers(self) -> FieldMap[RealT]:
def cell_centers(self) -> FieldMap[FloatT]:
r"""The positions of cell centers in each direction."""
return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()}

@cached_property
def cell_widths(self) -> FieldMap[RealT]:
def cell_widths(self) -> FieldMap[FloatT]:
r"""The width of cells, expressed as the difference between consecutive left edges."""
return {ax: np.diff(arr) for ax, arr in self.coordinates.items()}

Expand All @@ -211,7 +213,7 @@ def ndim(self) -> int:
return len(self.axes)

@property
def cell_volumes(self) -> RealArray[RealT]:
def cell_volumes(self) -> RealArray[FloatT]:
r"""
The generalized ND-volume of grid cells.
Expand Down Expand Up @@ -241,13 +243,13 @@ def check(cls, data: ParticleSet) -> None:


@final
class ParticleSet:
class ParticleSet(Generic[FloatT]):
def __init__(
self,
*,
geometry: Geometry,
coordinates: FieldMap[RealT],
fields: FieldMap[RealT] | None = None,
coordinates: FieldMap[FloatT],
fields: FieldMap[FloatT] | None = None,
) -> None:
r"""
Define a ParticleSet from point positions and data fields.
Expand All @@ -262,15 +264,15 @@ def __init__(
fields (keyword-only, optional): gpgi.typing.FieldMap
"""
self.geometry = geometry
self.coordinates = coordinates
self.coordinates: FieldMap[FloatT] = coordinates

if fields is None:
fields = {}
self.fields = fields
self.fields: FieldMap[FloatT] = fields

self.axes = tuple(self.coordinates.keys())
self._validate()
self.dtype = self.coordinates[self.axes[0]].dtype
self.dtype: dtype[FloatT] = self.coordinates[self.axes[0]].dtype

_validators: list[type[Validator[ParticleSet]]] = [
GeometryValidator,
Expand Down
8 changes: 4 additions & 4 deletions src/gpgi/_spatial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numpy.typing import NDArray

from gpgi._typing import FieldMap, Name, RealT
from gpgi._typing import FieldMap, FloatT, Name


class Geometry(StrEnum):
Expand Down Expand Up @@ -154,7 +154,7 @@ def check(
@staticmethod
def _validate_shape_equality(
name: str,
data: NDArray[RealT],
data: NDArray[FloatT],
ref_arr: NamedArray | None,
) -> NamedArray:
if ref_arr is not None and data.shape != ref_arr.data.shape:
Expand All @@ -165,7 +165,7 @@ def _validate_shape_equality(
return ref_arr or NamedArray(name, data)

@staticmethod
def _validate_sorted_state(name: str, data: NDArray[RealT]) -> None:
def _validate_sorted_state(name: str, data: NDArray[FloatT]) -> None:
a = data[0]
for i, b in enumerate(data[1:], start=1):
if a > b:
Expand All @@ -178,7 +178,7 @@ def _validate_sorted_state(name: str, data: NDArray[RealT]) -> None:
@staticmethod
def _validate_required_attributes(
name: str,
data: NDArray[RealT],
data: NDArray[FloatT],
required_attrs: dict[str, Any],
) -> None:
for attr, expected in required_attrs.items():
Expand Down
68 changes: 34 additions & 34 deletions src/gpgi/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,59 @@
import numpy as np
from numpy.typing import NDArray

RealT = TypeVar("RealT", np.float32, np.float64)
RealArray = NDArray[RealT]
FloatT = TypeVar("FloatT", np.float32, np.float64)
RealArray = NDArray[FloatT]
HCIArray = NDArray[np.uint16]


Name = str
FieldMap = dict[str, NDArray[RealT]]
FieldMap = dict[str, NDArray[FloatT]]


class CartesianCoordinates(TypedDict, Generic[RealT]):
x: NDArray[RealT]
y: NotRequired[NDArray[RealT]]
z: NotRequired[NDArray[RealT]]
class CartesianCoordinates(TypedDict, Generic[FloatT]):
x: NDArray[FloatT]
y: NotRequired[NDArray[FloatT]]
z: NotRequired[NDArray[FloatT]]


class CylindricalCoordinates(TypedDict, Generic[RealT]):
radius: NDArray[RealT]
azimuth: NotRequired[NDArray[RealT]]
z: NotRequired[NDArray[RealT]]
class CylindricalCoordinates(TypedDict, Generic[FloatT]):
radius: NDArray[FloatT]
azimuth: NotRequired[NDArray[FloatT]]
z: NotRequired[NDArray[FloatT]]


class PolarCoordinates(TypedDict, Generic[RealT]):
radius: NDArray[RealT]
z: NotRequired[NDArray[RealT]]
azimuth: NotRequired[NDArray[RealT]]
class PolarCoordinates(TypedDict, Generic[FloatT]):
radius: NDArray[FloatT]
z: NotRequired[NDArray[FloatT]]
azimuth: NotRequired[NDArray[FloatT]]


class SphericalCoordinates(TypedDict, Generic[RealT]):
colatitude: NDArray[RealT]
radius: NotRequired[NDArray[RealT]]
azimuth: NotRequired[NDArray[RealT]]
class SphericalCoordinates(TypedDict, Generic[FloatT]):
colatitude: NDArray[FloatT]
radius: NotRequired[NDArray[FloatT]]
azimuth: NotRequired[NDArray[FloatT]]


class EquatorialCoordinates(TypedDict, Generic[RealT]):
radius: NDArray[RealT]
latitude: NotRequired[NDArray[RealT]]
azimuth: NotRequired[NDArray[RealT]]
class EquatorialCoordinates(TypedDict, Generic[FloatT]):
radius: NDArray[FloatT]
latitude: NotRequired[NDArray[FloatT]]
azimuth: NotRequired[NDArray[FloatT]]


CoordMap = (
CartesianCoordinates[RealT]
| CylindricalCoordinates[RealT]
| PolarCoordinates[RealT]
| SphericalCoordinates[RealT]
| EquatorialCoordinates[RealT]
CartesianCoordinates[FloatT]
| CylindricalCoordinates[FloatT]
| PolarCoordinates[FloatT]
| SphericalCoordinates[FloatT]
| EquatorialCoordinates[FloatT]
)


class GridDict(TypedDict, Generic[RealT]):
cell_edges: CoordMap[RealT]
fields: NotRequired[FieldMap[RealT]]
class GridDict(TypedDict, Generic[FloatT]):
cell_edges: CoordMap[FloatT]
fields: NotRequired[FieldMap[FloatT]]


class ParticleSetDict(TypedDict, Generic[RealT]):
coordinates: CoordMap[RealT]
fields: NotRequired[FieldMap[RealT]]
class ParticleSetDict(TypedDict, Generic[FloatT]):
coordinates: CoordMap[FloatT]
fields: NotRequired[FieldMap[FloatT]]

0 comments on commit aa2483d

Please sign in to comment.