Skip to content

Commit

Permalink
some cleanup in bias_transforms.py. Enabled use of dask for these met…
Browse files Browse the repository at this point in the history
…hods.
  • Loading branch information
bnb32 committed Nov 19, 2024
1 parent 402b530 commit dd69f42
Show file tree
Hide file tree
Showing 8 changed files with 355 additions and 278 deletions.
522 changes: 266 additions & 256 deletions sup3r/bias/bias_transforms.py

Large diffs are not rendered by default.

34 changes: 33 additions & 1 deletion sup3r/bias/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def qdm_bc(
relative=True,
threshold=0.1,
no_trend=False,
delta_denom_min=None,
delta_denom_zero=None,
delta_range=None,
out_range=None,
max_workers=1
):
"""Bias Correction using Quantile Delta Mapping
Expand Down Expand Up @@ -149,6 +154,27 @@ def qdm_bc(
Note that this assumes that "bias_{feature}_params"
(``params_mh``) is the data distribution representative for the
target data.
delta_denom_min : float | None
Option to specify a minimum value for the denominator term in the
calculation of a relative delta value. This prevents division by a
very small number making delta blow up and resulting in very large
output bias corrected values. See equation 4 of Cannon et al., 2015
for the delta term.
delta_denom_zero : float | None
Option to specify a value to replace zeros in the denominator term
in the calculation of a relative delta value. This prevents
division by a very small number making delta blow up and resulting
in very large output bias corrected values. See equation 4 of
Cannon et al., 2015 for the delta term.
delta_range : tuple | None
Option to set a (min, max) on the delta term in QDM. This can help
prevent QDM from making non-realistic increases/decreases in
otherwise physical values. See equation 4 of Cannon et al., 2015 for
the delta term.
out_range : None | tuple
Option to set floor/ceiling values on the output data.
max_workers: int | None
Max number of workers to use for QDM process pool
"""

if isinstance(bc_files, str):
Expand All @@ -172,7 +198,7 @@ def qdm_bc(
)
)
handler.data[feature] = local_qdm_bc(
handler.data[feature][...],
handler.data[feature],
handler.lat_lon,
bias_feature,
feature,
Expand All @@ -181,6 +207,12 @@ def qdm_bc(
threshold=threshold,
relative=relative,
no_trend=no_trend,
delta_denom_min=delta_denom_min,
delta_denom_zero=delta_denom_zero,
delta_range=delta_range,
out_range=out_range,
max_workers=max_workers

)
completed.append(feature)

Expand Down
46 changes: 37 additions & 9 deletions sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,19 @@ def parse_chunks(feature, chunks, dims):
@classmethod
def get_chunksizes(cls, dset, data, chunks):
"""Get chunksizes after rechunking (could be undetermined beforehand
if ``chunks == 'auto'``) and return rechunked data."""
if ``chunks == 'auto'``) and return rechunked data.
Parameters
----------
dset : str
Name of feature to get chunksizes for.
data : Sup3rX | xr.Dataset
``Sup3rX`` or ``xr.Dataset`` containing data to be cached.
chunks : dict | None | 'auto'
Dictionary of chunksizes either to use for all features or, if the
dictionary includes feature keys, feature specific chunksizes. Can
also be None or 'auto'.
"""
data_var = data.coords[dset] if dset in data.coords else data[dset]
fchunk = cls.parse_chunks(dset, chunks, data_var.dims)
if isinstance(fchunk, dict):
Expand All @@ -233,10 +245,23 @@ def get_chunksizes(cls, dset, data, chunks):
return data_var, chunksizes

@classmethod
def add_coord_meta(cls, out_file, data):
def add_coord_meta(cls, out_file, data, meta=None):
"""Add flattened coordinate meta to out_file. This is used for h5
caching."""
meta = pd.DataFrame()
caching.
Parameters
----------
out_file : str
Name of output file.
data : Sup3rX | xr.Dataset
Data being written to the given ``out_file``.
meta : pd.DataFrame | None
Optional additional meta information to be written to the given
``out_file``. If this is None then only coordinate info will be
included in the meta written to the ``out_file``
"""
if meta is None or (isinstance(meta, dict) and not meta):
meta = pd.DataFrame()
for coord in Dimension.coords_2d():
if coord in data:
meta[coord] = data[coord].data.flatten()
Expand Down Expand Up @@ -280,18 +305,21 @@ def write_h5(
attrs : dict | None
Optional attributes to write to file. Can specify dataset specific
attributes by adding a dictionary with the dataset name as a key.
e.g. {**global_attrs, dset: {...}}
e.g. {**global_attrs, dset: {...}}. Can also include a global meta
dataframe that will then be added to the coordinate meta.
verbose : bool
Dummy arg to match ``write_netcdf`` signature
"""
if len(data.dims) == 3:
if len(data.dims) == 3 and Dimension.TIME in data.dims:
data = data.transpose(Dimension.TIME, *Dimension.dims_2d())
if features == 'all':
features = list(data.data_vars)
features = features if isinstance(features, list) else [features]
chunks = chunks or 'auto'
global_attrs = data.attrs.copy()
global_attrs.update(attrs or {})
attrs = attrs or {}
meta = attrs.pop('meta', {})
global_attrs.update(attrs)
attrs = {k: safe_cast(v) for k, v in global_attrs.items()}
with h5py.File(out_file, mode) as f:
for k, v in attrs.items():
Expand Down Expand Up @@ -335,7 +363,7 @@ def write_h5(
scheduler='threads',
num_workers=max_workers,
)
cls.add_coord_meta(out_file=out_file, data=data)
cls.add_coord_meta(out_file=out_file, data=data, meta=meta)

@staticmethod
def get_chunk_slices(chunks, shape):
Expand Down Expand Up @@ -383,7 +411,7 @@ def write_netcdf_chunks(
_mem_check(),
)
for i, chunk_slice in enumerate(chunk_slices):
msg = f'Writing chunk {i} / {len(chunk_slices)} to {out_file}'
msg = f'Writing chunk {i + 1} / {len(chunk_slices)} to {out_file}'
msg = None if not verbose else msg
chunk = data_var.data[chunk_slice]
task = dask.delayed(cls.write_chunk)(
Expand Down
3 changes: 2 additions & 1 deletion sup3r/preprocessing/data_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def __init__(
Keyword arguments for nan handling. If 'mask', time steps with nans
will be dropped. Otherwise this should be a dict of kwargs which
will be passed to
:py:meth:`sup3r.preprocessing.accessor.Sup3rX.interpolate_na`.
:py:meth:`sup3r.preprocessing.accessor.Sup3rX.interpolate_na`. e.g.
{'method': 'linear', 'dim': 'time'}
BaseLoader : Callable
Base level file loader wrapped by
:class:`~sup3r.preprocessing.loaders.Loader`. This is usually
Expand Down
5 changes: 4 additions & 1 deletion sup3r/preprocessing/loaders/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _rechunk_dsets(self, res):
res.data_vars."""
for dset in [*list(res.coords), *list(res.data_vars)]:
chunks = self._parse_chunks(dims=res[dset].dims, feature=dset)
if chunks != 'auto':

# specifying chunks to xarray.open_mfdataset doesn't automatically
# apply to coordinates so we do that here
if chunks != 'auto' or dset in Dimension.coords_2d():
res[dset] = res[dset].chunk(chunks)
return res

Expand Down
7 changes: 6 additions & 1 deletion sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def download_file(
Whether to overwrite existing file
"""
if os.path.exists(out_file) and not cls._can_skip_file(out_file):
logger.info(
'Previous download of %s failed. Removing %s.',
out_file,
out_file,
)
os.remove(out_file)

if not cls._can_skip_file(out_file) or overwrite:
Expand Down Expand Up @@ -807,7 +812,7 @@ def _can_skip_file(cls, file):
try:
_ = Loader(file)
except Exception as e:
msg = 'Could not open %s. %s Will redownload.'
msg = 'Could not open %s. %s. Will redownload.'
logger.warning(msg, file, e)
warn(msg % (file, e))
openable = False
Expand Down
3 changes: 2 additions & 1 deletion tests/bias/test_presrat_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ def test_presrat_transform(presrat_params, precip_fut):
- unbiased zero rate is not smaller the input zero rate
"""
# local_presrat_bc expects time in the last dimension.
data = precip_fut.transpose('lat', 'lon', 'time').values

data = precip_fut.transpose('lat', 'lon', 'time')
time = pd.to_datetime(precip_fut.time)
latlon = np.stack(
xr.broadcast(precip_fut['lat'], precip_fut['lon'] - 360), axis=-1
Expand Down
13 changes: 5 additions & 8 deletions tests/bias/test_qdm_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,20 @@ def test_qdm_transform_notrend(tmp_path, dist_params):
assert np.allclose(corrected, unbiased, equal_nan=True)


def test_handler_qdm_bc(fp_fut_cc, dist_params):
"""qdm_bc() method from DataHandler
WIP: Confirm it runs, but don't verify much yet.
"""
def test_qdm_bc_method(fp_fut_cc, dist_params):
"""Tesat qdm_bc standalone method"""
Handler = DataHandler(fp_fut_cc, 'rsds')
original = Handler.data.as_array().copy()
qdm_bc(Handler, dist_params, 'ghi')
corrected = Handler.data.as_array()

original = compute_if_dask(original)
corrected = compute_if_dask(corrected)
assert not np.isnan(corrected).all(), "Can't compare if only NaN"

idx = ~(np.isnan(original) | np.isnan(corrected))
# Where it is not NaN, it must have differences.
assert not np.allclose(
compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]
)
assert not np.allclose(original[idx], corrected[idx])


def test_bc_identity(tmp_path, fp_fut_cc, dist_params):
Expand Down

0 comments on commit dd69f42

Please sign in to comment.