Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: fix invalid np.ndarray type annotations #290

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 52 additions & 53 deletions src/gpgi/_boundaries.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from __future__ import annotations

from collections.abc import Callable
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import Any, Literal, cast

from numpy.typing import NDArray

if TYPE_CHECKING:
from gpgi._typing import RealArray
from gpgi._typing import FloatT

BoundaryRecipeT = Callable[
[
"RealArray",
"RealArray",
"RealArray",
"RealArray",
"RealArray",
"RealArray",
"RealArray",
"RealArray",
NDArray[FloatT],
NDArray[FloatT],
NDArray[FloatT],
NDArray[FloatT],
NDArray[FloatT],
NDArray[FloatT],
NDArray[FloatT],
NDArray[FloatT],
Literal["left", "right"],
dict[str, Any],
],
"RealArray",
NDArray[FloatT],
]


Expand Down Expand Up @@ -126,64 +125,64 @@ def __contains__(self, key: str) -> bool:

# basic recipes
def open_boundary(
same_side_active_layer: RealArray,
same_side_ghost_layer: RealArray,
opposite_side_active_layer: RealArray,
opposite_side_ghost_layer: RealArray,
weight_same_side_active_layer: RealArray,
weight_same_side_ghost_layer: RealArray,
weight_opposite_side_active_layer: RealArray,
weight_opposite_side_ghost_layer: RealArray,
same_side_active_layer: NDArray[FloatT],
same_side_ghost_layer: NDArray[FloatT],
opposite_side_active_layer: NDArray[FloatT],
opposite_side_ghost_layer: NDArray[FloatT],
weight_same_side_active_layer: NDArray[FloatT],
weight_same_side_ghost_layer: NDArray[FloatT],
weight_opposite_side_active_layer: NDArray[FloatT],
weight_opposite_side_ghost_layer: NDArray[FloatT],
side: Literal["left", "right"],
metadata: dict[str, Any],
) -> RealArray:
) -> NDArray[FloatT]:
# return the active layer unchanged
return same_side_active_layer


def wall_boundary(
same_side_active_layer: RealArray,
same_side_ghost_layer: RealArray,
opposite_side_active_layer: RealArray,
opposite_side_ghost_layer: RealArray,
weight_same_side_active_layer: RealArray,
weight_same_side_ghost_layer: RealArray,
weight_opposite_side_active_layer: RealArray,
weight_opposite_side_ghost_layer: RealArray,
same_side_active_layer: NDArray[FloatT],
same_side_ghost_layer: NDArray[FloatT],
opposite_side_active_layer: NDArray[FloatT],
opposite_side_ghost_layer: NDArray[FloatT],
weight_same_side_active_layer: NDArray[FloatT],
weight_same_side_ghost_layer: NDArray[FloatT],
weight_opposite_side_active_layer: NDArray[FloatT],
weight_opposite_side_ghost_layer: NDArray[FloatT],
side: Literal["left", "right"],
metadata: dict[str, Any],
) -> RealArray:
return cast("RealArray", same_side_active_layer + same_side_ghost_layer)
) -> NDArray[FloatT]:
return cast(NDArray[FloatT], same_side_active_layer + same_side_ghost_layer)


def antisymmetric_boundary(
same_side_active_layer: RealArray,
same_side_ghost_layer: RealArray,
opposite_side_active_layer: RealArray,
opposite_side_ghost_layer: RealArray,
weight_same_side_active_layer: RealArray,
weight_same_side_ghost_layer: RealArray,
weight_opposite_side_active_layer: RealArray,
weight_opposite_side_ghost_layer: RealArray,
same_side_active_layer: NDArray[FloatT],
same_side_ghost_layer: NDArray[FloatT],
opposite_side_active_layer: NDArray[FloatT],
opposite_side_ghost_layer: NDArray[FloatT],
weight_same_side_active_layer: NDArray[FloatT],
weight_same_side_ghost_layer: NDArray[FloatT],
weight_opposite_side_active_layer: NDArray[FloatT],
weight_opposite_side_ghost_layer: NDArray[FloatT],
side: Literal["left", "right"],
metadata: dict[str, Any],
) -> RealArray:
return cast("RealArray", same_side_active_layer - same_side_ghost_layer)
) -> NDArray[FloatT]:
return cast(NDArray[FloatT], same_side_active_layer - same_side_ghost_layer)


def periodic_boundary(
same_side_active_layer: RealArray,
same_side_ghost_layer: RealArray,
opposite_side_active_layer: RealArray,
opposite_side_ghost_layer: RealArray,
weight_same_side_active_layer: RealArray,
weight_same_side_ghost_layer: RealArray,
weight_opposite_side_active_layer: RealArray,
weight_opposite_side_ghost_layer: RealArray,
same_side_active_layer: NDArray[FloatT],
same_side_ghost_layer: NDArray[FloatT],
opposite_side_active_layer: NDArray[FloatT],
opposite_side_ghost_layer: NDArray[FloatT],
weight_same_side_active_layer: NDArray[FloatT],
weight_same_side_ghost_layer: NDArray[FloatT],
weight_opposite_side_active_layer: NDArray[FloatT],
weight_opposite_side_ghost_layer: NDArray[FloatT],
side: Literal["left", "right"],
metadata: dict[str, Any],
) -> RealArray:
return cast("RealArray", same_side_active_layer + opposite_side_ghost_layer)
) -> NDArray[FloatT]:
return cast(NDArray[FloatT], same_side_active_layer + opposite_side_ghost_layer)


_base_registry: dict[str, BoundaryRecipeT] = {
Expand Down
75 changes: 44 additions & 31 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from textwrap import indent
from threading import Lock
from time import monotonic_ns
from typing import TYPE_CHECKING, Literal, cast, final
from typing import TYPE_CHECKING, Generic, Literal, cast, final

import numpy as np

Expand All @@ -36,7 +36,7 @@
GeometryValidator,
Validator,
)
from gpgi._typing import FieldMap, Name
from gpgi._typing import FieldMap, FloatT, Name
from gpgi.typing import DepositionMethodT, DepositionMethodWithMetadataT

if sys.version_info >= (3, 13):
Expand All @@ -45,9 +45,13 @@
from _thread import LockType

if TYPE_CHECKING:
from typing import Any, Self
from typing import Any, Self, TypeVar

from gpgi._typing import FieldMap, HCIArray, Name, RealArray
from numpy.typing import NDArray

from gpgi._typing import FieldMap, HCIArray, Name

_FloatingT = TypeVar("_FloatingT", bound=np.floating)


BoundarySpec = tuple[tuple[str, str, str], ...]
Expand Down Expand Up @@ -111,13 +115,13 @@ def check(cls, data: Grid) -> None:


@final
class Grid:
class Grid(Generic[FloatT]):
def __init__(
self,
*,
geometry: Geometry,
cell_edges: FieldMap,
fields: FieldMap | None = None,
cell_edges: FieldMap[FloatT],
fields: FieldMap[FloatT] | None = None,
) -> None:
r"""
Define a Grid from cell left-edges and data fields.
Expand All @@ -133,17 +137,19 @@ def __init__(
fields (keyword-only, optional): gpgi.typing.FieldMap
"""
self.geometry = geometry
self.coordinates = cell_edges
self.coordinates: FieldMap[FloatT] = cell_edges

if fields is None:
fields = {}
self.fields: FieldMap = fields

self.axes = tuple(self.coordinates.keys())
self._validate()
self.dtype = self.coordinates[self.axes[0]].dtype
self.dtype: np.dtype[FloatT] = self.coordinates[self.axes[0]].dtype

self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype)
self._dx: NDArray[FloatT] = np.full(
(3,), -1, dtype=self.coordinates[self.axes[0]].dtype
)
for i, ax in enumerate(self.axes):
if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16:
# got a constant step in this direction, store it
Expand Down Expand Up @@ -172,17 +178,17 @@ def __repr__(self) -> str:
)

@property
def cell_edges(self) -> FieldMap:
def cell_edges(self) -> FieldMap[FloatT]:
r"""An alias for self.coordinates."""
return self.coordinates

@cached_property
def cell_centers(self) -> FieldMap:
def cell_centers(self) -> FieldMap[FloatT]:
r"""The positions of cell centers in each direction."""
return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()}
return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()} # type: ignore [misc]

@cached_property
def cell_widths(self) -> FieldMap:
def cell_widths(self) -> FieldMap[FloatT]:
r"""The width of cells, expressed as the difference between consecutive left edges."""
return {ax: np.diff(arr) for ax, arr in self.coordinates.items()}

Expand All @@ -206,7 +212,7 @@ def ndim(self) -> int:
return len(self.axes)

@property
def cell_volumes(self) -> RealArray:
def cell_volumes(self) -> NDArray[FloatT]:
r"""
The generalized ND-volume of grid cells.

Expand All @@ -217,7 +223,7 @@ def cell_volumes(self) -> RealArray:
widths = list(self.cell_widths.values())
if self.geometry is Geometry.CARTESIAN:
raw = np.prod(np.meshgrid(*widths), axis=0)
return cast("RealArray", np.swapaxes(raw, 0, 1))
return np.swapaxes(raw, 0, 1)
else:
raise NotImplementedError(
f"cell_volumes property is not implemented for {self.geometry} geometry"
Expand All @@ -236,13 +242,13 @@ def check(cls, data: ParticleSet) -> None:


@final
class ParticleSet:
class ParticleSet(Generic[FloatT]):
def __init__(
self,
*,
geometry: Geometry,
coordinates: FieldMap,
fields: FieldMap | None = None,
coordinates: FieldMap[FloatT],
fields: FieldMap[FloatT] | None = None,
) -> None:
r"""
Define a ParticleSet from point positions and data fields.
Expand All @@ -257,15 +263,15 @@ def __init__(
fields (keyword-only, optional): gpgi.typing.FieldMap
"""
self.geometry = geometry
self.coordinates = coordinates
self.coordinates: FieldMap[FloatT] = coordinates

if fields is None:
fields = {}
self.fields: FieldMap = fields
self.fields: FieldMap[FloatT] = fields

self.axes = tuple(self.coordinates.keys())
self._validate()
self.dtype = self.coordinates[self.axes[0]].dtype
self.dtype: np.dtype[FloatT] = self.coordinates[self.axes[0]].dtype

_validators: list[type[Validator[ParticleSet]]] = [
GeometryValidator,
Expand Down Expand Up @@ -380,26 +386,35 @@ def _validate(self) -> None:
f"- from particles: {self.particles.dtype}\n"
)

def _get_padded_cell_edges(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _get_padded_cell_edges(
self,
) -> tuple[NDArray[FloatT], NDArray[FloatT], NDArray[FloatT]]:
edges = iter(self.grid.cell_edges.values())

def pad(a: np.ndarray) -> np.ndarray:
def pad(a: NDArray[FloatT]) -> NDArray[FloatT]:
dx = a[1] - a[0]
return np.concatenate([[a[0] - dx], a, [a[-1] + dx]])

x1 = next(edges)
cell_edges_x1 = pad(x1)
DTYPE = cell_edges_x1.dtype

cell_edges_x2 = cell_edges_x3 = np.empty(0, DTYPE)
cell_edges_x2: np.ndarray[tuple[int, ...], np.dtype[FloatT]] = np.empty(
0, DTYPE
)
cell_edges_x3: np.ndarray[tuple[int, ...], np.dtype[FloatT]] = np.empty(
0, DTYPE
)
if self.grid.ndim >= 2:
cell_edges_x2 = pad(next(edges))
if self.grid.ndim == 3:
cell_edges_x3 = pad(next(edges))

return cell_edges_x1, cell_edges_x2, cell_edges_x3

def _get_3D_particle_coordinates(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _get_3D_particle_coordinates(
self,
) -> tuple[NDArray[FloatT], NDArray[FloatT], NDArray[FloatT]]:
particle_coords = iter(self.particles.coordinates.values())
particles_x1 = next(particle_coords)
DTYPE = particles_x1.dtype
Expand Down Expand Up @@ -477,9 +492,7 @@ def _validate_sort_axes(self, axes: tuple[int, ...]) -> None:
if any(axis > self.grid.ndim - 1 for axis in axes):
raise ValueError(f"Expected all axes to be <{self.grid.ndim}, got {axes!r}")

def _get_sort_key(
self, axes: tuple[int, ...]
) -> np.ndarray[Any, np.dtype[np.uint16]]:
def _get_sort_key(self, axes: tuple[int, ...]) -> NDArray[np.uint16]:
self._validate_sort_axes(axes)

hci = self.host_cell_index
Expand Down Expand Up @@ -780,9 +793,9 @@ def _sanitize_boundaries(self, boundaries: dict[Name, tuple[Name, Name]]) -> Non

def _apply_boundary_conditions(
self,
array: RealArray,
array: NDArray[FloatT],
boundaries: dict[Name, tuple[Name, Name]],
weight_array: RealArray | None,
weight_array: NDArray[FloatT] | None,
) -> None:
axes = list(self.grid.axes)
for ax, bv in boundaries.items():
Expand Down
8 changes: 4 additions & 4 deletions src/gpgi/_spatial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numpy.typing import NDArray

from gpgi._typing import FieldMap, Name, Real
from gpgi._typing import FieldMap, FloatT, Name


class Geometry(StrEnum):
Expand Down Expand Up @@ -154,7 +154,7 @@ def check(
@staticmethod
def _validate_shape_equality(
name: str,
data: NDArray[Real],
data: NDArray[FloatT],
ref_arr: NamedArray | None,
) -> NamedArray:
if ref_arr is not None and data.shape != ref_arr.data.shape:
Expand All @@ -165,7 +165,7 @@ def _validate_shape_equality(
return ref_arr or NamedArray(name, data)

@staticmethod
def _validate_sorted_state(name: str, data: NDArray[Real]) -> None:
def _validate_sorted_state(name: str, data: NDArray[FloatT]) -> None:
a = data[0]
for i, b in enumerate(data[1:], start=1):
if a > b:
Expand All @@ -178,7 +178,7 @@ def _validate_sorted_state(name: str, data: NDArray[Real]) -> None:
@staticmethod
def _validate_required_attributes(
name: str,
data: NDArray[Real],
data: NDArray[FloatT],
required_attrs: dict[str, Any],
) -> None:
for attr, expected in required_attrs.items():
Expand Down
Loading
Loading