Skip to content

Commit

Permalink
TYP: fix invalid np.ndarray type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Dec 10, 2024
1 parent b39374c commit efee673
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 43 deletions.
10 changes: 4 additions & 6 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,10 @@ def _validate(self) -> None:
f"- from particles: {self.particles.dtype}\n"
)

def _get_padded_cell_edges(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _get_padded_cell_edges(self) -> tuple[RealArray, RealArray, RealArray]:
edges = iter(self.grid.cell_edges.values())

def pad(a: np.ndarray) -> np.ndarray:
def pad(a: RealArray) -> RealArray:
dx = a[1] - a[0]
return np.concatenate([[a[0] - dx], a, [a[-1] + dx]])

Expand All @@ -510,7 +510,7 @@ def pad(a: np.ndarray) -> np.ndarray:

return cell_edges_x1, cell_edges_x2, cell_edges_x3

def _get_3D_particle_coordinates(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _get_3D_particle_coordinates(self) -> tuple[RealArray, RealArray, RealArray]:
particle_coords = iter(self.particles.coordinates.values())
particles_x1 = next(particle_coords)
DTYPE = particles_x1.dtype
Expand Down Expand Up @@ -588,9 +588,7 @@ def _validate_sort_axes(self, axes: tuple[int, ...]) -> None:
if any(axis > self.grid.ndim - 1 for axis in axes):
raise ValueError(f"Expected all axes to be <{self.grid.ndim}, got {axes!r}")

def _get_sort_key(
self, axes: tuple[int, ...]
) -> np.ndarray[Any, np.dtype[np.uint16]]:
def _get_sort_key(self, axes: tuple[int, ...]) -> NDArray[np.uint16]:
self._validate_sort_axes(axes)

hci = self.host_cell_index
Expand Down
74 changes: 37 additions & 37 deletions src/gpgi/_typing.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,61 @@
from typing import NotRequired, TypedDict, TypeVar
from typing import Generic, NotRequired, TypedDict, TypeVar

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

Real = TypeVar("Real", np.float32, np.float64)
RealArray = npt.NDArray[Real]
HCIArray = npt.NDArray[np.uint16]
RealT = TypeVar("RealT", np.float32, np.float64)
RealArray = NDArray[RealT]
HCIArray = NDArray[np.uint16]


Name = str
FieldMap = dict[str, np.ndarray]
FieldMap = dict[str, NDArray[RealT]]


class CartesianCoordinates(TypedDict):
x: np.ndarray
y: NotRequired[np.ndarray]
z: NotRequired[np.ndarray]
class CartesianCoordinates(TypedDict, Generic[RealT]):
x: NDArray[RealT]
y: NotRequired[NDArray[RealT]]
z: NotRequired[NDArray[RealT]]


class CylindricalCoordinates(TypedDict):
radius: np.ndarray
azimuth: NotRequired[np.ndarray]
z: NotRequired[np.ndarray]
class CylindricalCoordinates(TypedDict, Generic[RealT]):
radius: NDArray[RealT]
azimuth: NotRequired[NDArray[RealT]]
z: NotRequired[NDArray[RealT]]


class PolarCoordinates(TypedDict):
radius: np.ndarray
z: NotRequired[np.ndarray]
azimuth: NotRequired[np.ndarray]
class PolarCoordinates(TypedDict, Generic[RealT]):
radius: NDArray[RealT]
z: NotRequired[NDArray[RealT]]
azimuth: NotRequired[NDArray[RealT]]


class SphericalCoordinates(TypedDict):
colatitude: np.ndarray
radius: NotRequired[np.ndarray]
azimuth: NotRequired[np.ndarray]
class SphericalCoordinates(TypedDict, Generic[RealT]):
colatitude: NDArray[RealT]
radius: NotRequired[NDArray[RealT]]
azimuth: NotRequired[NDArray[RealT]]


class EquatorialCoordinates(TypedDict):
radius: np.ndarray
latitude: NotRequired[np.ndarray]
azimuth: NotRequired[np.ndarray]
class EquatorialCoordinates(TypedDict, Generic[RealT]):
radius: NDArray[RealT]
latitude: NotRequired[NDArray[RealT]]
azimuth: NotRequired[NDArray[RealT]]


CoordMap = (
CartesianCoordinates
| CylindricalCoordinates
| PolarCoordinates
| SphericalCoordinates
| EquatorialCoordinates
CartesianCoordinates[RealT]
| CylindricalCoordinates[RealT]
| PolarCoordinates[RealT]
| SphericalCoordinates[RealT]
| EquatorialCoordinates[RealT]
)


class GridDict(TypedDict):
cell_edges: CoordMap
fields: NotRequired[FieldMap]
class GridDict(TypedDict, Generic[RealT]):
cell_edges: CoordMap[RealT]
fields: NotRequired[FieldMap[RealT]]


class ParticleSetDict(TypedDict):
coordinates: CoordMap
fields: NotRequired[FieldMap]
class ParticleSetDict(TypedDict, Generic[RealT]):
coordinates: CoordMap[RealT]
fields: NotRequired[FieldMap[RealT]]

0 comments on commit efee673

Please sign in to comment.