Skip to content

Commit

Permalink
Using preprocessing function in call to xr.open_mfdataset
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 9, 2024
1 parent 298e551 commit ece79e7
Showing 1 changed file with 19 additions and 44 deletions.
63 changes: 19 additions & 44 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import random
import string
import time
from warnings import warn

import numpy as np
import pandas as pd
Expand All @@ -14,44 +13,27 @@
from packaging import version
from scipy import ndimage as nd

from sup3r.preprocessing.utilities import get_class_kwargs

logger = logging.getLogger(__name__)

RANDOM_GENERATOR = np.random.default_rng(seed=42)


def merge_datasets(files, **kwargs):
"""Merge xr.Datasets after some standardization. This useful when
xr.open_mfdatasets fails due to different time index formats or coordinate
names, for example."""
dsets = [xr.open_mfdataset(f, **kwargs) for f in files]
time_indices = []
for i, dset in enumerate(dsets):
if 'time' in dset and dset.time.size > 1:
ti = dset.time.astype(int)
dset['time'] = ti
dsets[i] = dset
time_indices.append(ti.to_series())
if 'latitude' in dset.dims:
dset = dset.swap_dims({'latitude': 'south_north'})
dsets[i] = dset
if 'longitude' in dset.dims:
dset = dset.swap_dims({'longitude': 'west_east'})
dsets[i] = dset
# temporary to handle downloaded era files
if 'expver' in dset:
dset.drop_vars('expver')
if 'number' in dset:
dset.drop_vars('number')
out = xr.merge(dsets, **get_class_kwargs(xr.merge, kwargs))
msg = ('Merged time index does not have the same number of time steps '
'(%s) as the sum of the individual time index steps (%s).')
if hasattr(out, 'time'):
merged_size = out.time.size
summed_size = pd.concat(time_indices).drop_duplicates().size
assert merged_size == summed_size, msg % (merged_size, summed_size)
return out
def preprocess_datasets(dset):
"""Standardization preprocessing applied before datasets are concatenated
by ``xr.open_mfdataset``"""
if 'time' in dset and dset.time.size > 1:
ti = dset.time.astype(int)
dset['time'] = ti
if 'latitude' in dset.dims:
dset = dset.swap_dims({'latitude': 'south_north'})
if 'longitude' in dset.dims:
dset = dset.swap_dims({'longitude': 'west_east'})
# temporary to handle downloaded era files
if 'expver' in dset:
dset = dset.drop_vars('expver')
if 'number' in dset:
dset = dset.drop_vars('number')
return dset


def xr_open_mfdataset(files, **kwargs):
Expand All @@ -60,16 +42,9 @@ def xr_open_mfdataset(files, **kwargs):
default_kwargs.update(kwargs)
if isinstance(files, str):
files = [files]
try:
return xr.open_mfdataset(files, **default_kwargs)
except Exception as e:
msg = 'Could not use xr.open_mfdataset to open %s. %s'
if len(files) == 1:
raise RuntimeError(msg % (files, e)) from e
msg += ' Trying to open them separately and merge.'
logger.warning(msg, files, e)
warn(msg % (files, e))
return merge_datasets(files, **default_kwargs)
return xr.open_mfdataset(
files, preprocess=preprocess_datasets, **default_kwargs
)


def safe_cast(o):
Expand Down

0 comments on commit ece79e7

Please sign in to comment.