diff --git a/src/gpgi/_data_types.py b/src/gpgi/_data_types.py index 67843d7..62af0ef 100644 --- a/src/gpgi/_data_types.py +++ b/src/gpgi/_data_types.py @@ -42,8 +42,11 @@ if TYPE_CHECKING: from typing import Any, Self + from numpy.typing import NDArray + from gpgi._typing import FieldMap, HCIArray, Name, RealArray + BoundarySpec = tuple[tuple[str, str, str], ...] @@ -227,15 +230,15 @@ def _validate_coordinates(self) -> None: self.coordinates[axis] = coord.astype(coord_dtype, copy=False) - def _get_safe_datatype(self, reference: np.ndarray | None = None) -> np.dtype: + def _get_safe_datatype( + self, reference: NDArray[np.floating] | None = None + ) -> np.dtype[np.floating]: if reference is None: reference = self.coordinates[self.axes[0]] dt = reference.dtype if dt.kind != "f": raise ValueError(f"Invalid data type {dt} (expected a float dtype)") - # return type should already be correct but - # this has the benefit of convincing mypy - return np.dtype(dt) + return dt class Grid(_CoordinateValidatorMixin):