Skip to content

Commit

Permalink
Merge pull request #1358 from metno/fix-pydantic-model-validator
Browse files Browse the repository at this point in the history
Fix model validator for pyaro config and ColocatedData
  • Loading branch information
lewisblake authored Oct 1, 2024
2 parents c4ddcfc + af3a797 commit dcdcc3f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 32 deletions.
34 changes: 18 additions & 16 deletions pyaerocom/colocation/colocated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,18 @@
logger = logging.getLogger(__name__)


def ensure_correct_dimensions(data: np.ndarray | xr.DataArray):
def ensure_correct_dimensions(data: xr.DataArray):
"""
Ensure the dimensions on either a numpy aray or xarray passed to ColocatedData.
Ensure the dimensions on an xarray.DataArray passed to ColocatedData.
If a ColocatedData object is created outside of pyaerocom, this checking is needed.
This function is used as part of the model validator.
"""
shape = data.shape[0]
if isinstance(data, np.ndarray):
num_dims = data.ndim
elif isinstance(data, xr.DataArray):
num_dims = len(data.dims)
else:
if not isinstance(data, xr.DataArray):
raise ValueError("Could not interpret data")

shape = data.shape[0]
num_dims = len(data.dims)

if num_dims not in (2, 3, 4):
raise DataDimensionError("invalid input, need 2D, 3D or 4D numpy array")
elif not shape == 2:
Expand Down Expand Up @@ -124,19 +123,19 @@ class ColocatedData(BaseModel):

@model_validator(mode="after")
def validate_data(self):
if self.data is None:
return self
if isinstance(self.data, Path):
# make sure path is str instance
self.data = str(self.data)
if isinstance(self.data, str):
assert self.data.endswith("nc"), ValueError(
"Invalid data filepath str, must point to a .nc file"
)
if not self.data.endswith("nc"):
raise ValueError(
f"Invalid data filepath str, must point to a .nc file. Got {self.data}"
)
self.open(self.data)
elif isinstance(self.data, xr.DataArray):
ensure_correct_dimensions(self.data)
return self.data
elif isinstance(self.data, np.ndarray):
ensure_correct_dimensions(self.data)
return self
if isinstance(self.data, np.ndarray):
if hasattr(self, "model_extra"):
da_keys = dir(xr.DataArray)
extra_args_from_class_initialization = {
Expand All @@ -146,6 +145,9 @@ def validate_data(self):
extra_args_from_class_initialization = {}
data = xr.DataArray(self.data, **extra_args_from_class_initialization)
self.data = data
# self.data should be xr.DataArray at this stage
ensure_correct_dimensions(self.data)
return self

# Override __init__ to allow for positional arguments
def __init__(
Expand Down
30 changes: 16 additions & 14 deletions pyaerocom/colocation/colocation_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def validate_no_forbidden_keys(self):
for key in self.FORBIDDEN_KEYS:
if key in self.model_fields:
raise ValidationError
return self

@cached_property
def basedir_logfiles(self):
Expand All @@ -488,25 +489,26 @@ def basedir_logfiles(self):
return str(p)

@model_validator(mode="after")
@classmethod
def validate_obs_config(cls, v: PyaroConfig):
if v is not None and cls.obs.config.name != cls.obs_id:
def validate_obs_config(self):
if self.obs_config is None:
return self
if self.obs_config.name != self.obs_id:
logger.info(
f"Data ID in Pyaro config {v.name} does not match obs_id {cls.obs_id}. Setting Pyaro config to None!"
f"Data ID in Pyaro config {self.obs_config.name} does not match obs_id {self.obs_id}. Setting Pyaro config to None!"
)
v = None
if v is not None:
if isinstance(v, dict):
self.obs_config = None
if self.obs_config is not None:
if isinstance(self.obs_config, dict):
logger.info("Obs config was given as dict. Will try to convert to PyaroConfig")
v = PyaroConfig(**v)
if v.name != cls.obs_id:
self.obs_config = PyaroConfig(**self.obs_config)
if self.obs_config.name != self.obs_id:
logger.info(
f"Data ID in Pyaro config {v.name} does not match obs_id {cls.obs_id}. Setting Obs ID to match Pyaro Config!"
f"Data ID in Pyaro config {self.obs_config.name} does not match obs_id {self.obs_id}. Setting Obs ID to match Pyaro Config!"
)
cls.obs_id = v.name
if cls.obs_id is None:
cls.obs_id = v.name
return v
self.obs_id = self.obs_config.name
if self.obs_id is None:
self.obs_id = self.obs_config.name
return self

def add_glob_meta(self, **kwargs):
"""
Expand Down
12 changes: 10 additions & 2 deletions tests/io/test_readungridded.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,17 @@ def test_ReadUngridded___init__(data_ids, ignore_cache):
(dict(station_name="La_Paz"), 1, 1),
(dict(station_name=["La_Paz", "AAO*"]), 2, 2),
(dict(altitude=[1000, 10000]), 3, 3),
(dict(altitude=[1000, 10000], ignore_station_names=dict(od550aer="La_Paz")), 2, 2),
(
dict(altitude=[1000, 10000], ignore_station_names=dict(od550aer="La_Paz")),
2,
2,
),
(dict(altitude=[1000, 10000], ignore_station_names="La_*"), 2, 2),
(dict(altitude=[1000, 10000], ignore_station_names=["La_*", "Mauna_Loa"]), 1, 1),
(
dict(altitude=[1000, 10000], ignore_station_names=["La_*", "Mauna_Loa"]),
1,
1,
),
],
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit dcdcc3f

Please sign in to comment.