diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index d2756ba78..c79467fb3 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -5,7 +5,6 @@ import random import string import time -from warnings import warn import numpy as np import pandas as pd @@ -14,44 +13,27 @@ from packaging import version from scipy import ndimage as nd -from sup3r.preprocessing.utilities import get_class_kwargs - logger = logging.getLogger(__name__) RANDOM_GENERATOR = np.random.default_rng(seed=42) -def merge_datasets(files, **kwargs): - """Merge xr.Datasets after some standardization. This useful when - xr.open_mfdatasets fails due to different time index formats or coordinate - names, for example.""" - dsets = [xr.open_mfdataset(f, **kwargs) for f in files] - time_indices = [] - for i, dset in enumerate(dsets): - if 'time' in dset and dset.time.size > 1: - ti = dset.time.astype(int) - dset['time'] = ti - dsets[i] = dset - time_indices.append(ti.to_series()) - if 'latitude' in dset.dims: - dset = dset.swap_dims({'latitude': 'south_north'}) - dsets[i] = dset - if 'longitude' in dset.dims: - dset = dset.swap_dims({'longitude': 'west_east'}) - dsets[i] = dset - # temporary to handle downloaded era files - if 'expver' in dset: - dset.drop_vars('expver') - if 'number' in dset: - dset.drop_vars('number') - out = xr.merge(dsets, **get_class_kwargs(xr.merge, kwargs)) - msg = ('Merged time index does not have the same number of time steps ' - '(%s) as the sum of the individual time index steps (%s).') - if hasattr(out, 'time'): - merged_size = out.time.size - summed_size = pd.concat(time_indices).drop_duplicates().size - assert merged_size == summed_size, msg % (merged_size, summed_size) - return out +def preprocess_datasets(dset): + """Standardization preprocessing applied before datasets are concatenated + by ``xr.open_mfdataset``""" + if 'time' in dset and dset.time.size > 1: + ti = dset.time.astype(int) + dset['time'] = ti + if 'latitude' in dset.dims: + dset = dset.swap_dims({'latitude': 'south_north'}) + if 'longitude' in dset.dims: + dset = dset.swap_dims({'longitude': 'west_east'}) + # temporary to handle downloaded era files + if 'expver' in dset: + dset = dset.drop_vars('expver') + if 'number' in dset: + dset = dset.drop_vars('number') + return dset def xr_open_mfdataset(files, **kwargs): @@ -60,16 +42,9 @@ def xr_open_mfdataset(files, **kwargs): default_kwargs.update(kwargs) if isinstance(files, str): files = [files] - try: - return xr.open_mfdataset(files, **default_kwargs) - except Exception as e: - msg = 'Could not use xr.open_mfdataset to open %s. %s' - if len(files) == 1: - raise RuntimeError(msg % (files, e)) from e - msg += ' Trying to open them separately and merge.' - logger.warning(msg, files, e) - warn(msg % (files, e)) - return merge_datasets(files, **default_kwargs) + return xr.open_mfdataset( + files, preprocess=preprocess_datasets, **default_kwargs + ) def safe_cast(o):