Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: fix invalid np.ndarray type annotations #290

Closed
wants to merge 9 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixup ?
neutrinoceros committed Jan 25, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 8564d0361506e16b8327e2ed1a921b167816721b
27 changes: 16 additions & 11 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
@@ -45,9 +45,13 @@
from _thread import LockType

if TYPE_CHECKING:
from typing import Any, Self
from typing import Any, Self, TypeVar

from gpgi._typing import FieldMap, HCIArray, Name, RealArray
from numpy.typing import NDArray

from gpgi._typing import FieldMap, HCIArray, Name, RealArray, RealT

_FloatingT = TypeVar("_FloatingT", bound=np.floating)


BoundarySpec = tuple[tuple[str, str, str], ...]
@@ -116,8 +120,8 @@ def __init__(
self,
*,
geometry: Geometry,
cell_edges: FieldMap,
fields: FieldMap | None = None,
cell_edges: FieldMap[RealT],
fields: FieldMap[RealT] | None = None,
) -> None:
r"""
Define a Grid from cell left-edges and data fields.
@@ -144,6 +148,7 @@ def __init__(
self.dtype = self.coordinates[self.axes[0]].dtype

self._dx = np.full((3,), -1, dtype=self.coordinates[self.axes[0]].dtype)

for i, ax in enumerate(self.axes):
if self.size == 1 or np.diff(self.coordinates[ax]).std() < 1e-16:
# got a constant step in this direction, store it
@@ -172,17 +177,17 @@ def __repr__(self) -> str:
)

@property
def cell_edges(self) -> FieldMap:
def cell_edges(self) -> FieldMap[RealT]:
r"""An alias for self.coordinates."""
return self.coordinates

@cached_property
def cell_centers(self) -> FieldMap:
def cell_centers(self) -> FieldMap[RealT]:
r"""The positions of cell centers in each direction."""
return {ax: 0.5 * (arr[1:] + arr[:-1]) for ax, arr in self.coordinates.items()}

@cached_property
def cell_widths(self) -> FieldMap:
def cell_widths(self) -> FieldMap[RealT]:
r"""The width of cells, expressed as the difference between consecutive left edges."""
return {ax: np.diff(arr) for ax, arr in self.coordinates.items()}

@@ -206,7 +211,7 @@ def ndim(self) -> int:
return len(self.axes)

@property
def cell_volumes(self) -> RealArray:
def cell_volumes(self) -> RealArray[RealT]:
r"""
The generalized ND-volume of grid cells.

@@ -241,8 +246,8 @@ def __init__(
self,
*,
geometry: Geometry,
coordinates: FieldMap,
fields: FieldMap | None = None,
coordinates: FieldMap[RealT],
fields: FieldMap[RealT] | None = None,
) -> None:
r"""
Define a ParticleSet from point positions and data fields.
@@ -261,7 +266,7 @@ def __init__(

if fields is None:
fields = {}
self.fields: FieldMap = fields
self.fields = fields

self.axes = tuple(self.coordinates.keys())
self._validate()
8 changes: 4 additions & 4 deletions src/gpgi/_spatial_data.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import numpy as np
from numpy.typing import NDArray

from gpgi._typing import FieldMap, Name, Real
from gpgi._typing import FieldMap, Name, RealT


class Geometry(StrEnum):
@@ -154,7 +154,7 @@ def check(
@staticmethod
def _validate_shape_equality(
name: str,
data: NDArray[Real],
data: NDArray[RealT],
ref_arr: NamedArray | None,
) -> NamedArray:
if ref_arr is not None and data.shape != ref_arr.data.shape:
@@ -165,7 +165,7 @@ def _validate_shape_equality(
return ref_arr or NamedArray(name, data)

@staticmethod
def _validate_sorted_state(name: str, data: NDArray[Real]) -> None:
def _validate_sorted_state(name: str, data: NDArray[RealT]) -> None:
a = data[0]
for i, b in enumerate(data[1:], start=1):
if a > b:
@@ -178,7 +178,7 @@ def _validate_sorted_state(name: str, data: NDArray[Real]) -> None:
@staticmethod
def _validate_required_attributes(
name: str,
data: NDArray[Real],
data: NDArray[RealT],
required_attrs: dict[str, Any],
) -> None:
for attr, expected in required_attrs.items():