Skip to content

Commit

Permalink
RFC: spread field maps validation logic into smaller functions
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Dec 10, 2024
1 parent b39374c commit 7ef95cf
Showing 1 changed file with 54 additions and 35 deletions.
89 changes: 54 additions & 35 deletions src/gpgi/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property, partial, reduce
from itertools import chain
from textwrap import indent
Expand Down Expand Up @@ -44,7 +45,7 @@

from numpy.typing import NDArray

from gpgi._typing import FieldMap, HCIArray, Name, RealArray
from gpgi._typing import FieldMap, HCIArray, Name, Real, RealArray


BoundarySpec = tuple[tuple[str, str, str], ...]
Expand Down Expand Up @@ -101,6 +102,12 @@ class CoordinateData(ABC):
fields: FieldMap


@dataclass
class ReferenceArray:
name: Name
data: NDArray


class ValidatorMixin(GeometricData, ABC):
def __init__(self) -> None:
self._validate()
Expand All @@ -115,40 +122,52 @@ def _validate_FieldMaps(
require_sorted: bool = False,
**required_attrs: Any,
) -> None:
_reference_shape: tuple[int, ...] | None = None
_reference_field_name: str
for fmap in fmaps:
if fmap is None:
continue # pragma: no cover
for name, data in fmap.items():
if require_shape_equality:
if _reference_shape is None:
_reference_shape = data.shape
_reference_field_name = name
elif data.shape != _reference_shape:
raise ValueError(
f"Fields {name!r} and {_reference_field_name!r} "
f"have mismatching shapes {data.shape} and {_reference_shape}"
)

if require_sorted:
_a = data[0]
for i, _b in enumerate(data[1:], start=1):
if _a > _b:
raise ValueError(
f"Field {name!r} is not properly sorted by ascending order. "
f"Got {_a} (index {i-1}) > {_b} (index {i})"
)
_a = _b

if not required_attrs:
continue # pragma: no cover
for attr, expected in required_attrs.items():
if (actual := getattr(data, attr)) != expected:
raise ValueError(
f"Field {name!r} has incorrect {attr} {actual} "
f"(expected {expected})"
)
ref_arr: ReferenceArray | None = None
for name, data in chain.from_iterable(
fm.items() for fm in fmaps if fm is not None
):
if require_shape_equality:
ref_arr = self._validate_shape_equality(name, data, ref_arr)
if require_sorted:
self._validate_sorted_state(name, data)
if required_attrs:
self._validate_required_attributes(name, data, **required_attrs)

def _validate_shape_equality(
self,
name: str,
data: NDArray[Real],
ref_arr: ReferenceArray | None,
) -> ReferenceArray:
if ref_arr is not None and data.shape != ref_arr.data.shape:
raise ValueError(
f"Fields {name!r} and {ref_arr.name!r} "
f"have mismatching shapes {data.shape} and {ref_arr.data.shape}"
)
return ReferenceArray(name, data)

def _validate_sorted_state(self, name: str, data: NDArray[Real]) -> None:
a = data[0]
for i, b in enumerate(data[1:], start=1):
if a > b:
raise ValueError(
f"Field {name!r} is not properly sorted by ascending order. "
f"Got {a} (index {i-1}) > {b} (index {i})"
)
a = b

def _validate_required_attributes(
self,
name: str,
data: NDArray[Real],
**required_attrs: Any,
) -> None:
for attr, expected in required_attrs.items():
if (actual := getattr(data, attr)) != expected:
raise ValueError(
f"Field {name!r} has incorrect {attr} {actual} "
f"(expected {expected})"
)

def _validate_geometry(self) -> None:
match self.geometry:
Expand Down

0 comments on commit 7ef95cf

Please sign in to comment.