From aff87b6801e3c46ab9cfb57681a311c6496a935e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 4 Jan 2025 08:05:17 -0800 Subject: [PATCH 1/4] Refactor parallelization of bias calculations Fixes #227 --- sup3r/bias/bias_calc.py | 107 +++++------ sup3r/bias/presrat.py | 309 +++++++++++++++---------------- sup3r/bias/qdm.py | 399 +++++++++++++++++++--------------------- sup3r/bias/utilities.py | 32 ++++ 4 files changed, 417 insertions(+), 430 deletions(-) diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 29c3c7277..9c4de60bd 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -1,16 +1,11 @@ """Utilities to calculate the bias correction factors for biased data that is going to be fed into the sup3r downscaling models. This is typically used to -bias correct GCM data vs. some historical record like the WTK or NSRDB. - -TODO: Generalize the ``with ProcessPoolExecutor() as exe: ...`` so we don't -need to duplicate this wherever we kickoff a process or thread pool -""" +bias correct GCM data vs. some historical record like the WTK or NSRDB.""" import copy import json import logging import os -from concurrent.futures import ProcessPoolExecutor, as_completed import h5py import numpy as np @@ -20,6 +15,7 @@ from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin +from .utilities import run_in_parallel logger = logging.getLogger(__name__) @@ -226,17 +222,17 @@ def run( if isinstance(self.base_dh, DataHandler): max_workers = 1 - if max_workers == 1: - logger.debug('Running serial calculation.') - for i, bias_gid in enumerate(self.bias_meta.index): - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - single_out = self._run_single( + task_args_list = [] + for bias_gid in self.bias_meta.index: + raster_loc = np.where(self.bias_gid_raster == bias_gid) + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + task_args_list.append( + ( bias_data, self.base_fps, self.bias_feature, @@ -246,66 +242,47 @@ def run( daily_reduction, self.bias_ti, self.decimals, - base_dh_inst=self.base_dh, - match_zero_rate=self.match_zero_rate, + self.base_dh, + self.match_zero_rate, ) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr + ) + if max_workers == 1: + logger.debug('Running serial calculation.') + for i, args in enumerate(task_args_list): + single_out = self._run_single(*args) + raster_loc = np.where(self.bias_gid_raster == args[0]) + for key, arr in single_out.items(): + self.out[key][raster_loc] = arr logger.info( - 'Completed bias calculations for {} out of {} ' - 'sites'.format(i + 1, len(self.bias_meta)) + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(task_args_list), ) - else: - logger.debug( - 'Running parallel calculation with {} workers.'.format( - max_workers - ) + logger.info( + 'Running parallel calculation with %s workers.', max_workers ) - with ProcessPoolExecutor(max_workers=max_workers) as exe: - futures = {} - for bias_gid in self.bias_meta.index: - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - future = exe.submit( - self._run_single, - bias_data, - self.base_fps, - self.bias_feature, - self.base_dset, - base_gid, - self.base_handler, - daily_reduction, - self.bias_ti, - self.decimals, - match_zero_rate=self.match_zero_rate, - ) - futures[future] = raster_loc - - logger.debug('Finished launching futures.') - for i, future in enumerate(as_completed(futures)): - raster_loc = futures[future] - single_out = future.result() - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info( - 'Completed bias calculations for {} out of {} ' - 'sites'.format(i + 1, len(futures)) - ) + results = run_in_parallel( + self._run_single, task_args_list, max_workers=max_workers + ) + for i, single_out in enumerate(results): + raster_loc = np.where( + self.bias_gid_raster == task_args_list[i][0] + ) + for key, arr in single_out.items(): + self.out[key][raster_loc] = arr + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), + ) logger.info('Finished calculating bias correction factors.') self.out = self.fill_and_smooth( self.out, fill_extend, smooth_extend, smooth_interior ) - self.write_outputs(fp_out, self.out) return copy.deepcopy(self.out) diff --git a/sup3r/bias/presrat.py b/sup3r/bias/presrat.py index dd20b94d1..bd2da642f 100644 --- a/sup3r/bias/presrat.py +++ b/sup3r/bias/presrat.py @@ -8,7 +8,6 @@ import json import logging import os -from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Optional import h5py @@ -21,6 +20,7 @@ from .mixins import ZeroRateMixin from .qdm import QuantileDeltaMappingCorrection +from .utilities import run_in_parallel logger = logging.getLogger(__name__) @@ -84,19 +84,26 @@ def _init_out(self): super()._init_out() shape = (*self.bias_gid_raster.shape, 1) - self.out[f'{self.base_dset}_zero_rate'] = np.full(shape, - np.nan, - np.float32) - self.out[f'{self.bias_feature}_tau_fut'] = np.full(shape, - np.nan, - np.float32) + self.out[f'{self.base_dset}_zero_rate'] = np.full( + shape, np.nan, np.float32 + ) + self.out[f'{self.bias_feature}_tau_fut'] = np.full( + shape, np.nan, np.float32 + ) shape = (*self.bias_gid_raster.shape, self.n_time_steps) self.out[f'{self.bias_feature}_k_factor'] = np.full( - shape, np.nan, np.float32) + shape, np.nan, np.float32 + ) @classmethod - def calc_tau_fut(cls, base_data, bias_data, bias_fut_data, - corrected_fut_data, zero_rate_threshold=1.157e-7): + def calc_tau_fut( + cls, + base_data, + bias_data, + bias_fut_data, + corrected_fut_data, + zero_rate_threshold=1.157e-7, + ): """Calculate a precipitation threshold (tau) that preserves the model-predicted changes in fraction of dry days at a single spatial location. @@ -134,12 +141,13 @@ def calc_tau_fut(cls, base_data, bias_data, bias_fut_data, # Step 1: Define zero rate from observations assert base_data.ndim == 1 obs_zero_rate = cls.zero_precipitation_rate( - base_data, zero_rate_threshold) + base_data, zero_rate_threshold + ) # Step 2: Find tau for each grid point # Removed NaN handling, thus reinforce finite-only data. - assert np.isfinite(bias_data).all(), "Unexpected invalid values" - assert bias_data.ndim == 1, "Assumed bias_data to be 1D" + assert np.isfinite(bias_data).all(), 'Unexpected invalid values' + assert bias_data.ndim == 1, 'Assumed bias_data to be 1D' n_threshold = round(obs_zero_rate * bias_data.size) n_threshold = min(n_threshold, bias_data.size - 1) tau = np.sort(bias_data)[n_threshold] @@ -147,20 +155,31 @@ def calc_tau_fut(cls, base_data, bias_data, bias_fut_data, # tau = max(tau, 0.01) # Step 3: Find Z_gf as the zero rate in mf - assert np.isfinite(bias_fut_data).all(), "Unexpected invalid values" + assert np.isfinite(bias_fut_data).all(), 'Unexpected invalid values' z_fg = (bias_fut_data < tau).astype('i').sum() / bias_fut_data.size # Step 4: Estimate tau_fut with corrected mf - tau_fut = np.sort(corrected_fut_data)[round( - z_fg * corrected_fut_data.size)] + tau_fut = np.sort(corrected_fut_data)[ + round(z_fg * corrected_fut_data.size) + ] return tau_fut, obs_zero_rate @classmethod - def calc_k_factor(cls, base_data, bias_data, bias_fut_data, - corrected_fut_data, base_ti, bias_ti, bias_fut_ti, - window_center, window_size, n_time_steps, - zero_rate_threshold): + def calc_k_factor( + cls, + base_data, + bias_data, + bias_fut_data, + corrected_fut_data, + base_ti, + bias_ti, + bias_fut_ti, + window_center, + window_size, + n_time_steps, + zero_rate_threshold, + ): """Calculate the K factor at a single spatial location that will preserve the original model-predicted mean change in precipitation @@ -213,8 +232,9 @@ def calc_k_factor(cls, base_data, bias_data, bias_fut_data, for nt, t in enumerate(window_center): base_idt = cls.window_mask(base_ti.day_of_year, t, window_size) bias_idt = cls.window_mask(bias_ti.day_of_year, t, window_size) - bias_fut_idt = cls.window_mask(bias_fut_ti.day_of_year, t, - window_size) + bias_fut_idt = cls.window_mask( + bias_fut_ti.day_of_year, t, window_size + ) oh = base_data[base_idt].mean() mh = bias_data[bias_idt].mean() @@ -233,29 +253,30 @@ def calc_k_factor(cls, base_data, bias_data, bias_fut_data, # pylint: disable=W0613 @classmethod - def _run_single(cls, - bias_data, - bias_fut_data, - base_fps, - bias_feature, - base_dset, - base_gid, - base_handler, - daily_reduction, - *, - bias_ti, - bias_fut_ti, - decimals, - dist, - relative, - sampling, - n_samples, - log_base, - n_time_steps, - window_size, - zero_rate_threshold, - base_dh_inst=None, - ): + def _run_single( + cls, + bias_data, + bias_fut_data, + base_fps, + bias_feature, + base_dset, + base_gid, + base_handler, + daily_reduction, + *, + bias_ti, + bias_fut_ti, + decimals, + dist, + relative, + sampling, + n_samples, + log_base, + n_time_steps, + window_size, + zero_rate_threshold, + base_dh_inst=None, + ): """Estimate probability distributions at a single site TODO! This should be refactored. There is too much redundancy in @@ -281,19 +302,21 @@ def _run_single(cls, # Define indices for which data goes in the current time window base_idx = cls.window_mask(base_ti.day_of_year, t, window_size) bias_idx = cls.window_mask(bias_ti.day_of_year, t, window_size) - bias_fut_idx = cls.window_mask(bias_fut_ti.day_of_year, - t, - window_size) + bias_fut_idx = cls.window_mask( + bias_fut_ti.day_of_year, t, window_size + ) if any(base_idx) and any(bias_idx) and any(bias_fut_idx): - tmp = cls.get_qdm_params(bias_data[bias_idx], - bias_fut_data[bias_fut_idx], - base_data[base_idx], - bias_feature, - base_dset, - sampling, - n_samples, - log_base) + tmp = cls.get_qdm_params( + bias_data[bias_idx], + bias_fut_data[bias_fut_idx], + base_data[base_idx], + bias_feature, + base_dset, + sampling, + n_samples, + log_base, + ) for k, v in tmp.items(): if k not in out: out[k] = template.copy() @@ -312,14 +335,26 @@ def _run_single(cls, subset = bias_fut_data[bias_fut_idx] corrected_fut_data[bias_fut_idx] = QDM(subset).squeeze() - tau_fut, obs_zero_rate = cls.calc_tau_fut(base_data, bias_data, - bias_fut_data, - corrected_fut_data, - zero_rate_threshold) - k = cls.calc_k_factor(base_data, bias_data, bias_fut_data, - corrected_fut_data, base_ti, bias_ti, - bias_fut_ti, window_center, window_size, - n_time_steps, zero_rate_threshold) + tau_fut, obs_zero_rate = cls.calc_tau_fut( + base_data, + bias_data, + bias_fut_data, + corrected_fut_data, + zero_rate_threshold, + ) + k = cls.calc_k_factor( + base_data, + bias_data, + bias_fut_data, + corrected_fut_data, + base_ti, + bias_ti, + bias_fut_ti, + window_center, + window_size, + n_time_steps, + zero_rate_threshold, + ) out[f'{bias_feature}_k_factor'] = k out[f'{base_dset}_zero_rate'] = obs_zero_rate @@ -393,24 +428,18 @@ def run( if isinstance(self.base_dh, DataHandler): max_workers = 1 - if max_workers == 1: - logger.debug('Running serial calculation.') - for i, bias_gid in enumerate(self.bias_meta.index): - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - logger.debug( - f'No base data for bias_gid: {bias_gid}. ' - 'Adding it to bad_bias_gids' - ) - else: - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data( - bias_gid, self.bias_fut_dh - ) - single_out = self._run_single( + task_args_list = [] + for bias_gid in self.bias_meta.index: + raster_loc = np.where(self.bias_gid_raster == bias_gid) + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) + task_args_list.append( + ( bias_data, bias_fut_data, self.base_fps, @@ -419,81 +448,41 @@ def run( base_gid, self.base_handler, daily_reduction, - bias_ti=self.bias_dh.time_index, - bias_fut_ti=self.bias_fut_dh.time_index, - decimals=self.decimals, - dist=self.dist, - relative=self.relative, - sampling=self.sampling, - n_samples=self.n_quantiles, - log_base=self.log_base, - n_time_steps=self.n_time_steps, - window_size=self.window_size, - base_dh_inst=self.base_dh, - zero_rate_threshold=zero_rate_threshold, + self.bias_dh.time_index, + self.bias_fut_dh.time_index, + self.decimals, + self.dist, + self.relative, + self.sampling, + self.n_quantiles, + self.log_base, + self.n_time_steps, + self.window_size, + zero_rate_threshold, ) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info( - 'Completed bias calculations for {} out of {} ' - 'sites'.format(i + 1, len(self.bias_meta)) ) + if max_workers == 1: + logger.debug('Running serial calculation.') + results = [self._run_single(*args) for args in task_args_list] else: logger.debug( - 'Running parallel calculation with {} workers.'.format( - max_workers - ) + 'Running parallel calculation with %s workers.', max_workers + ) + results = run_in_parallel( + self._run_single, task_args_list, max_workers=max_workers + ) + + for i, single_out in enumerate(results): + raster_loc = np.where(self.bias_gid_raster == task_args_list[i][0]) + for key, arr in single_out.items(): + self.out[key][raster_loc] = arr + + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), ) - with ProcessPoolExecutor(max_workers=max_workers) as exe: - futures = {} - for bias_gid in self.bias_meta.index: - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data( - bias_gid, self.bias_fut_dh - ) - future = exe.submit( - self._run_single, - bias_data, - bias_fut_data, - self.base_fps, - self.bias_feature, - self.base_dset, - base_gid, - self.base_handler, - daily_reduction, - bias_ti=self.bias_dh.time_index, - bias_fut_ti=self.bias_fut_dh.time_index, - decimals=self.decimals, - dist=self.dist, - relative=self.relative, - sampling=self.sampling, - n_samples=self.n_quantiles, - log_base=self.log_base, - n_time_steps=self.n_time_steps, - window_size=self.window_size, - zero_rate_threshold=zero_rate_threshold, - ) - futures[future] = raster_loc - - logger.debug('Finished launching futures.') - for i, future in enumerate(as_completed(futures)): - raster_loc = futures[future] - single_out = future.result() - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info( - 'Completed bias calculations for {} out of {} ' - 'sites'.format(i + 1, len(futures)) - ) logger.info('Finished calculating bias correction factors.') @@ -505,16 +494,20 @@ def run( 'zero_rate_threshold': zero_rate_threshold, 'time_window_center': self.time_window_center, } - self.write_outputs(fp_out, - self.out, - extra_attrs=extra_attrs, - ) + self.write_outputs( + fp_out, + self.out, + extra_attrs=extra_attrs, + ) return copy.deepcopy(self.out) - def write_outputs(self, fp_out: str, - out: Optional[dict] = None, - extra_attrs: Optional[dict] = None): + def write_outputs( + self, + fp_out: str, + out: Optional[dict] = None, + extra_attrs: Optional[dict] = None, + ): """Write outputs to an .h5 file. Parameters diff --git a/sup3r/bias/qdm.py b/sup3r/bias/qdm.py index 11b3c72f3..9546944d5 100644 --- a/sup3r/bias/qdm.py +++ b/sup3r/bias/qdm.py @@ -8,7 +8,6 @@ import json import logging import os -from concurrent.futures import ProcessPoolExecutor, as_completed import h5py import numpy as np @@ -23,6 +22,7 @@ from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin +from .utilities import run_in_parallel logger = logging.getLogger(__name__) @@ -46,31 +46,32 @@ class QuantileDeltaMappingCorrection(FillAndSmoothMixin, DataRetrievalBase): a dataset. """ - def __init__(self, - base_fps, - bias_fps, - bias_fut_fps, - base_dset, - bias_feature, - distance_upper_bound=None, - target=None, - shape=None, - base_handler='Resource', - bias_handler='DataHandlerNCforCC', - base_handler_kwargs=None, - bias_handler_kwargs=None, - bias_fut_handler_kwargs=None, - decimals=None, - match_zero_rate=False, - n_quantiles=101, - dist='empirical', - relative=True, - sampling='linear', - log_base=10, - n_time_steps=24, - window_size=120, - pre_load=True, - ): + def __init__( + self, + base_fps, + bias_fps, + bias_fut_fps, + base_dset, + bias_feature, + distance_upper_bound=None, + target=None, + shape=None, + base_handler='Resource', + bias_handler='DataHandlerNCforCC', + base_handler_kwargs=None, + bias_handler_kwargs=None, + bias_fut_handler_kwargs=None, + decimals=None, + match_zero_rate=False, + n_quantiles=101, + dist='empirical', + relative=True, + sampling='linear', + log_base=10, + n_time_steps=24, + window_size=120, + pre_load=True, + ): """ Parameters ---------- @@ -203,31 +204,34 @@ class with the bias_fps self.n_time_steps = n_time_steps self.window_size = window_size - super().__init__(base_fps=base_fps, - bias_fps=bias_fps, - base_dset=base_dset, - bias_feature=bias_feature, - distance_upper_bound=distance_upper_bound, - target=target, - shape=shape, - base_handler=base_handler, - bias_handler=bias_handler, - base_handler_kwargs=base_handler_kwargs, - bias_handler_kwargs=bias_handler_kwargs, - decimals=decimals, - match_zero_rate=match_zero_rate, - pre_load=False, - ) + super().__init__( + base_fps=base_fps, + bias_fps=bias_fps, + base_dset=base_dset, + bias_feature=bias_feature, + distance_upper_bound=distance_upper_bound, + target=target, + shape=shape, + base_handler=base_handler, + bias_handler=bias_handler, + base_handler_kwargs=base_handler_kwargs, + bias_handler_kwargs=bias_handler_kwargs, + decimals=decimals, + match_zero_rate=match_zero_rate, + pre_load=False, + ) self.bias_fut_fps = bias_fut_fps self.bias_fut_fps = expand_paths(self.bias_fut_fps) self.bias_fut_handler_kwargs = bias_fut_handler_kwargs or {} - self.bias_fut_dh = self.bias_handler(self.bias_fut_fps, - [self.bias_feature], - target=self.target, - shape=self.shape, - **self.bias_fut_handler_kwargs) + self.bias_fut_dh = self.bias_handler( + self.bias_fut_fps, + [self.bias_feature], + target=self.target, + shape=self.shape, + **self.bias_fut_handler_kwargs, + ) if pre_load: self.pre_load() @@ -249,12 +253,16 @@ def _init_out(self): probability distributions for the three datasets (see class documentation). """ - keys = [f'bias_{self.bias_feature}_params', - f'bias_fut_{self.bias_feature}_params', - f'base_{self.base_dset}_params', - ] - shape = (*self.bias_gid_raster.shape, self.n_time_steps, - self.n_quantiles) + keys = [ + f'bias_{self.bias_feature}_params', + f'bias_fut_{self.bias_feature}_params', + f'base_{self.base_dset}_params', + ] + shape = ( + *self.bias_gid_raster.shape, + self.n_time_steps, + self.n_quantiles, + ) arr = np.full(shape, np.nan, np.float32) self.out = {k: arr.copy() for k in keys} @@ -290,44 +298,47 @@ def _window_center(ntimes: int): ...]. It includes the fraction of a day, thus 15.5 is equivalent to January 15th, 12:00h. """ - assert ntimes > 0, "Requires a positive number of intervals" + assert ntimes > 0, 'Requires a positive number of intervals' dt = 365 / ntimes return np.arange(dt / 2, 366, dt) # pylint: disable=W0613 @classmethod - def _run_single(cls, - bias_data, - bias_fut_data, - base_fps, - bias_feature, - base_dset, - base_gid, - base_handler, - daily_reduction, - *, - bias_ti, - bias_fut_ti, - decimals, - dist, - relative, - sampling, - n_samples, - log_base, - n_time_steps, - window_size, - base_dh_inst=None, - ): + def _run_single( + cls, + bias_data, + bias_fut_data, + base_fps, + bias_feature, + base_dset, + base_gid, + base_handler, + daily_reduction, + *, + bias_ti, + bias_fut_ti, + decimals, + dist, + relative, + sampling, + n_samples, + log_base, + n_time_steps, + window_size, + base_dh_inst=None, + ): """Estimate probability distributions at a single site""" - base_data, base_ti = cls.get_base_data(base_fps, - base_dset, - base_gid, - base_handler, - daily_reduction=daily_reduction, - decimals=decimals, - base_dh_inst=base_dh_inst) + base_data, base_ti = cls.get_base_data( + base_fps, + base_dset, + base_gid, + base_handler, + daily_reduction=daily_reduction, + decimals=decimals, + base_dh_inst=base_dh_inst, + ) window_size = window_size or 365 / n_time_steps window_center = cls._window_center(n_time_steps) @@ -338,19 +349,21 @@ def _run_single(cls, for nt, idt in enumerate(window_center): base_idx = cls.window_mask(base_ti.day_of_year, idt, window_size) bias_idx = cls.window_mask(bias_ti.day_of_year, idt, window_size) - bias_fut_idx = cls.window_mask(bias_fut_ti.day_of_year, - idt, - window_size) + bias_fut_idx = cls.window_mask( + bias_fut_ti.day_of_year, idt, window_size + ) if any(base_idx) and any(bias_idx) and any(bias_fut_idx): - tmp = cls.get_qdm_params(bias_data[bias_idx], - bias_fut_data[bias_fut_idx], - base_data[base_idx], - bias_feature, - base_dset, - sampling, - n_samples, - log_base) + tmp = cls.get_qdm_params( + bias_data[bias_idx], + bias_fut_data[bias_fut_idx], + base_data[base_idx], + bias_feature, + base_dset, + sampling, + n_samples, + log_base, + ) for k, v in tmp.items(): if k not in out: out[k] = template.copy() @@ -359,15 +372,16 @@ def _run_single(cls, return out @staticmethod - def get_qdm_params(bias_data, - bias_fut_data, - base_data, - bias_feature, - base_dset, - sampling, - n_samples, - log_base, - ): + def get_qdm_params( + bias_data, + bias_fut_data, + base_data, + bias_feature, + base_dset, + sampling, + n_samples, + log_base, + ): """Get quantiles' cut point for given datasets Estimate the quantiles' cut points for each of the three given @@ -422,16 +436,18 @@ def get_qdm_params(bias_data, elif sampling == 'invlog': quantiles = sample_q_invlog(n_samples, log_base) else: - msg = ('sampling option must be linear, log, or invlog, but ' - 'received: {}'.format(sampling) - ) + msg = ( + 'sampling option must be linear, log, or invlog, but ' + 'received: {}'.format(sampling) + ) logger.error(msg) raise KeyError(msg) out = { f'bias_{bias_feature}_params': np.quantile(bias_data, quantiles), - f'bias_fut_{bias_feature}_params': np.quantile(bias_fut_data, - quantiles), + f'bias_fut_{bias_feature}_params': np.quantile( + bias_fut_data, quantiles + ), f'base_{base_dset}_params': np.quantile(base_data, quantiles), } @@ -466,24 +482,24 @@ def write_outputs(self, fp_out, out=None): for k, v in self.meta.items(): f.attrs[k] = json.dumps(v) - f.attrs["dist"] = self.dist - f.attrs["sampling"] = self.sampling - f.attrs["log_base"] = self.log_base - f.attrs["base_fps"] = self.base_fps - f.attrs["bias_fps"] = self.bias_fps - f.attrs["bias_fut_fps"] = self.bias_fut_fps - f.attrs["time_window_center"] = self.time_window_center - logger.info( - 'Wrote quantiles to file: {}'.format(fp_out)) - - def run(self, - fp_out=None, - max_workers=None, - daily_reduction='avg', - fill_extend=True, - smooth_extend=0, - smooth_interior=0, - ): + f.attrs['dist'] = self.dist + f.attrs['sampling'] = self.sampling + f.attrs['log_base'] = self.log_base + f.attrs['base_fps'] = self.base_fps + f.attrs['bias_fps'] = self.bias_fps + f.attrs['bias_fut_fps'] = self.bias_fut_fps + f.attrs['time_window_center'] = self.time_window_center + logger.info('Wrote quantiles to file: {}'.format(fp_out)) + + def run( + self, + fp_out=None, + max_workers=None, + daily_reduction='avg', + fill_extend=True, + smooth_extend=0, + smooth_interior=0, + ): """Estimate the statistical distributions for each location Parameters @@ -509,8 +525,11 @@ def run(self, logger.debug('Calculating CDF parameters for QDM') - logger.info('Initialized params with shape: {}' - .format(self.bias_gid_raster.shape)) + logger.info( + 'Initialized params with shape: {}'.format( + self.bias_gid_raster.shape + ) + ) self.bad_bias_gids = [] # sup3r DataHandler opening base files will load all data in parallel @@ -518,21 +537,18 @@ def run(self, if isinstance(self.base_dh, DataHandler): max_workers = 1 - if max_workers == 1: - logger.debug('Running serial calculation.') - for i, bias_gid in enumerate(self.bias_meta.index): - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - logger.debug(f'No base data for bias_gid: {bias_gid}. ' - 'Adding it to bad_bias_gids') - else: - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data(bias_gid, - self.bias_fut_dh) - single_out = self._run_single( + task_args_list = [] + for bias_gid in self.bias_meta.index: + raster_loc = np.where(self.bias_gid_raster == bias_gid) + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) + task_args_list.append( + ( bias_data, bias_fut_data, self.base_fps, @@ -541,80 +557,49 @@ def run(self, base_gid, self.base_handler, daily_reduction, - bias_ti=self.bias_dh.time_index, - bias_fut_ti=self.bias_fut_dh.time_index, - decimals=self.decimals, - dist=self.dist, - relative=self.relative, - sampling=self.sampling, - n_samples=self.n_quantiles, - log_base=self.log_base, - n_time_steps=self.n_time_steps, - window_size=self.window_size, - base_dh_inst=self.base_dh, + self.bias_dh.time_index, + self.bias_fut_dh.time_index, + self.decimals, + self.dist, + self.relative, + self.sampling, + self.n_quantiles, + self.log_base, + self.n_time_steps, + self.window_size, + self.base_dh, ) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info('Completed bias calculations for {} out of {} ' - 'sites'.format(i + 1, len(self.bias_meta))) + ) + if max_workers == 1: + logger.debug('Running serial calculation.') + results = [self._run_single(*args) for args in task_args_list] else: logger.debug( - 'Running parallel calculation with {} workers.'.format( - max_workers)) - with ProcessPoolExecutor(max_workers=max_workers) as exe: - futures = {} - for bias_gid in self.bias_meta.index: - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data(bias_gid, - self.bias_fut_dh) - future = exe.submit( - self._run_single, - bias_data, - bias_fut_data, - self.base_fps, - self.bias_feature, - self.base_dset, - base_gid, - self.base_handler, - daily_reduction, - bias_ti=self.bias_dh.time_index, - bias_fut_ti=self.bias_fut_dh.time_index, - decimals=self.decimals, - dist=self.dist, - relative=self.relative, - sampling=self.sampling, - n_samples=self.n_quantiles, - log_base=self.log_base, - n_time_steps=self.n_time_steps, - window_size=self.window_size, - ) - futures[future] = raster_loc - - logger.debug('Finished launching futures.') - for i, future in enumerate(as_completed(futures)): - raster_loc = futures[future] - single_out = future.result() - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info('Completed bias calculations for {} out of {} ' - 'sites'.format(i + 1, len(futures))) + 'Running parallel calculation with %s workers.', max_workers + ) + results = run_in_parallel( + self._run_single, task_args_list, max_workers=max_workers + ) + + for i, single_out in enumerate(results): + raster_loc = np.where(self.bias_gid_raster == task_args_list[i][0]) + for key, arr in single_out.items(): + self.out[key][raster_loc] = arr + + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), + ) logger.info('Finished calculating bias correction factors.') - self.out = self.fill_and_smooth(self.out, fill_extend, smooth_extend, - smooth_interior) + self.out = self.fill_and_smooth( + self.out, fill_extend, smooth_extend, smooth_interior + ) self.write_outputs(fp_out, self.out) - return copy.deepcopy(self.out) @staticmethod diff --git a/sup3r/bias/utilities.py b/sup3r/bias/utilities.py index 817620dde..e650163fa 100644 --- a/sup3r/bias/utilities.py +++ b/sup3r/bias/utilities.py @@ -2,6 +2,7 @@ import logging import os +from concurrent.futures import ProcessPoolExecutor, as_completed from inspect import signature from warnings import warn @@ -18,6 +19,37 @@ logger = logging.getLogger(__name__) +def run_in_parallel(task_function, task_args_list, max_workers=None): + """ + Execute a list of tasks in parallel using ``ProcessPoolExecutor``. + + Parameters + ---------- + task_function : callable + The function to execute in parallel. + task_args_list : list + A list of argument tuples, where each tuple contains the arguments + for a single call to ``task_function``. + max_workers : int, optional + The maximum number of workers to use. If None, it uses all available. + + Returns + ------- + results : list + A list of results from the executed tasks. + """ + results = [] + with ProcessPoolExecutor(max_workers=max_workers) as exe: + futures = { + exe.submit(task_function, *args): args + for args in task_args_list + } + for future in as_completed(futures): + result = future.result() + results.append(result) + return results + + def lin_bc(handler, bc_files, bias_feature=None, threshold=0.1): """Bias correct the data in this DataHandler in place using linear bias correction factors from files output by MonthlyLinearCorrection or From 585a32189fb6c9c71a3cc0c3fdeb3123de615e99 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 4 Jan 2025 11:27:09 -0700 Subject: [PATCH 2/4] bias test fixes --- sup3r/bias/bias_calc.py | 67 ++++++++++++++++++--------------------- sup3r/bias/presrat.py | 56 +++++++++++++++++--------------- sup3r/bias/qdm.py | 55 +++++++++++++++++--------------- sup3r/bias/utilities.py | 10 +++--- sup3r/models/utilities.py | 7 ++-- 5 files changed, 98 insertions(+), 97 deletions(-) diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 9c4de60bd..d60776349 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -222,7 +222,8 @@ def run( if isinstance(self.base_dh, DataHandler): max_workers = 1 - task_args_list = [] + task_kwargs_list = [] + bias_gids = [] for bias_gid in self.bias_meta.index: raster_loc = np.where(self.bias_gid_raster == bias_gid) _, base_gid = self.get_base_gid(bias_gid) @@ -231,52 +232,44 @@ def run( self.bad_bias_gids.append(bias_gid) else: bias_data = self.get_bias_data(bias_gid) - task_args_list.append( - ( - bias_data, - self.base_fps, - self.bias_feature, - self.base_dset, - base_gid, - self.base_handler, - daily_reduction, - self.bias_ti, - self.decimals, - self.base_dh, - self.match_zero_rate, - ) + bias_gids.append(bias_gid) + task_kwargs_list.append( + { + 'bias_data': bias_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'daily_reduction': daily_reduction, + 'bias_ti': self.bias_ti, + 'decimals': self.decimals, + 'match_zero_rate': self.match_zero_rate, + } ) if max_workers == 1: logger.debug('Running serial calculation.') - for i, args in enumerate(task_args_list): - single_out = self._run_single(*args) - raster_loc = np.where(self.bias_gid_raster == args[0]) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(task_args_list), - ) + results = [ + self._run_single(**kwargs, base_dh_inst=self.base_dh) + for kwargs in task_kwargs_list + ] else: logger.info( 'Running parallel calculation with %s workers.', max_workers ) results = run_in_parallel( - self._run_single, task_args_list, max_workers=max_workers + self._run_single, task_kwargs_list, max_workers=max_workers + ) + for i, single_out in enumerate(results): + raster_loc = np.where(self.bias_gid_raster == bias_gids[i]) + for key, arr in single_out.items(): + self.out[key][raster_loc] = arr + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), ) - for i, single_out in enumerate(results): - raster_loc = np.where( - self.bias_gid_raster == task_args_list[i][0] - ) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(results), - ) logger.info('Finished calculating bias correction factors.') diff --git a/sup3r/bias/presrat.py b/sup3r/bias/presrat.py index bd2da642f..9f0876dee 100644 --- a/sup3r/bias/presrat.py +++ b/sup3r/bias/presrat.py @@ -428,7 +428,8 @@ def run( if isinstance(self.base_dh, DataHandler): max_workers = 1 - task_args_list = [] + task_kwargs_list = [] + bias_gids = [] for bias_gid in self.bias_meta.index: raster_loc = np.where(self.bias_gid_raster == bias_gid) _, base_gid = self.get_base_gid(bias_gid) @@ -436,45 +437,48 @@ def run( if not base_gid.any(): self.bad_bias_gids.append(bias_gid) else: + bias_gids.append(bias_gid) bias_data = self.get_bias_data(bias_gid) bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) - task_args_list.append( - ( - bias_data, - bias_fut_data, - self.base_fps, - self.bias_feature, - self.base_dset, - base_gid, - self.base_handler, - daily_reduction, - self.bias_dh.time_index, - self.bias_fut_dh.time_index, - self.decimals, - self.dist, - self.relative, - self.sampling, - self.n_quantiles, - self.log_base, - self.n_time_steps, - self.window_size, - zero_rate_threshold, - ) + task_kwargs_list.append( + { + 'bias_data': bias_data, + 'bias_fut_data': bias_fut_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'daily_reduction': daily_reduction, + 'bias_ti': self.bias_dh.time_index, + 'bias_fut_ti': self.bias_fut_dh.time_index, + 'decimals': self.decimals, + 'dist': self.dist, + 'relative': self.relative, + 'sampling': self.sampling, + 'n_samples': self.n_quantiles, + 'log_base': self.log_base, + 'n_time_steps': self.n_time_steps, + 'window_size': self.window_size, + 'zero_rate_threshold': zero_rate_threshold, + } ) if max_workers == 1: logger.debug('Running serial calculation.') - results = [self._run_single(*args) for args in task_args_list] + results = [ + self._run_single(**kwargs) for kwargs in task_kwargs_list + ] else: logger.debug( 'Running parallel calculation with %s workers.', max_workers ) results = run_in_parallel( - self._run_single, task_args_list, max_workers=max_workers + self._run_single, task_kwargs_list, max_workers=max_workers ) for i, single_out in enumerate(results): - raster_loc = np.where(self.bias_gid_raster == task_args_list[i][0]) + raster_loc = np.where(self.bias_gid_raster == bias_gids[i]) for key, arr in single_out.items(): self.out[key][raster_loc] = arr diff --git a/sup3r/bias/qdm.py b/sup3r/bias/qdm.py index 9546944d5..abb37e9f8 100644 --- a/sup3r/bias/qdm.py +++ b/sup3r/bias/qdm.py @@ -537,7 +537,8 @@ def run( if isinstance(self.base_dh, DataHandler): max_workers = 1 - task_args_list = [] + task_kwargs_list = [] + bias_gids = [] for bias_gid in self.bias_meta.index: raster_loc = np.where(self.bias_gid_raster == bias_gid) _, base_gid = self.get_base_gid(bias_gid) @@ -545,45 +546,47 @@ def run( if not base_gid.any(): self.bad_bias_gids.append(bias_gid) else: + bias_gids.append(bias_gid) bias_data = self.get_bias_data(bias_gid) bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) - task_args_list.append( - ( - bias_data, - bias_fut_data, - self.base_fps, - self.bias_feature, - self.base_dset, - base_gid, - self.base_handler, - daily_reduction, - self.bias_dh.time_index, - self.bias_fut_dh.time_index, - self.decimals, - self.dist, - self.relative, - self.sampling, - self.n_quantiles, - self.log_base, - self.n_time_steps, - self.window_size, - self.base_dh, - ) + task_kwargs_list.append( + { + 'bias_data': bias_data, + 'bias_fut_data': bias_fut_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'daily_reduction': daily_reduction, + 'bias_ti': self.bias_dh.time_index, + 'bias_fut_ti': self.bias_fut_dh.time_index, + 'decimals': self.decimals, + 'dist': self.dist, + 'relative': self.relative, + 'sampling': self.sampling, + 'n_samples': self.n_quantiles, + 'log_base': self.log_base, + 'n_time_steps': self.n_time_steps, + 'window_size': self.window_size, + } ) if max_workers == 1: logger.debug('Running serial calculation.') - results = [self._run_single(*args) for args in task_args_list] + results = [ + self._run_single(**kwargs) for kwargs in task_kwargs_list + ] else: logger.debug( 'Running parallel calculation with %s workers.', max_workers ) results = run_in_parallel( - self._run_single, task_args_list, max_workers=max_workers + self._run_single, task_kwargs_list, max_workers=max_workers ) for i, single_out in enumerate(results): - raster_loc = np.where(self.bias_gid_raster == task_args_list[i][0]) + raster_loc = np.where(self.bias_gid_raster == bias_gids[i]) for key, arr in single_out.items(): self.out[key][raster_loc] = arr diff --git a/sup3r/bias/utilities.py b/sup3r/bias/utilities.py index e650163fa..e417f2736 100644 --- a/sup3r/bias/utilities.py +++ b/sup3r/bias/utilities.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def run_in_parallel(task_function, task_args_list, max_workers=None): +def run_in_parallel(task_function, task_kwargs_list, max_workers=None): """ Execute a list of tasks in parallel using ``ProcessPoolExecutor``. @@ -28,8 +28,8 @@ def run_in_parallel(task_function, task_args_list, max_workers=None): task_function : callable The function to execute in parallel. task_args_list : list - A list of argument tuples, where each tuple contains the arguments - for a single call to ``task_function``. + A list of keyword argument dictionaries for a single call to + ``task_function``. max_workers : int, optional The maximum number of workers to use. If None, it uses all available. @@ -41,8 +41,8 @@ def run_in_parallel(task_function, task_args_list, max_workers=None): results = [] with ProcessPoolExecutor(max_workers=max_workers) as exe: futures = { - exe.submit(task_function, *args): args - for args in task_args_list + exe.submit(task_function, **kwargs): kwargs + for kwargs in task_kwargs_list } for future in as_completed(futures): result = future.result() diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index b3b4037c1..8e825f124 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -38,8 +38,10 @@ def run(self): kwargs=self.kwargs, ) try: - logger.info('Starting training session.') - self.batch_handler.start() + logger.info( + 'Starting training session. Training for %s epochs', + self.kwargs['n_epoch'], + ) model_thread.start() except KeyboardInterrupt: logger.info('Ending training session.') @@ -53,7 +55,6 @@ def run(self): sys.exit() logger.info('Finished training') - self.batch_handler.stop() model_thread.join() From 7f9204924d982410d5fcae8fefdfbd02fc9ca998 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 4 Jan 2025 15:22:12 -0700 Subject: [PATCH 3/4] additional bias refact: ``_run`` base method and ``_get_run_kwargs`` method. --- sup3r/bias/base.py | 118 +++++++++++++++++++++++++++++++++++++++- sup3r/bias/bias_calc.py | 70 ++---------------------- sup3r/bias/presrat.py | 116 ++++++++++++++------------------------- sup3r/bias/qdm.py | 112 ++++++++++++++------------------------ sup3r/bias/utilities.py | 21 +++---- 5 files changed, 216 insertions(+), 221 deletions(-) diff --git a/sup3r/bias/base.py b/sup3r/bias/base.py index c264ae7bd..a184f37ec 100644 --- a/sup3r/bias/base.py +++ b/sup3r/bias/base.py @@ -20,6 +20,8 @@ from sup3r.utilities import VERSION_RECORD, ModuleName from sup3r.utilities.cli import BaseCLI +from .utilities import run_in_parallel + logger = logging.getLogger(__name__) @@ -43,7 +45,7 @@ def __init__( bias_handler_kwargs=None, decimals=None, match_zero_rate=False, - pre_load=True + pre_load=True, ): """ Parameters @@ -178,7 +180,7 @@ class is used, all data will be loaded in this class' self.nn_dist, self.nn_ind = self.bias_tree.query( self.base_meta[['latitude', 'longitude']], - distance_upper_bound=self.distance_upper_bound + distance_upper_bound=self.distance_upper_bound, ) if pre_load: @@ -777,3 +779,115 @@ def _reduce_base_data( assert base_data.shape == daily_ti.shape, msg return base_data, daily_ti + + def _get_run_kwargs(self, **kwargs_extras): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` arguments for that gid""" + task_kwargs = {} + for bias_gid in self.bias_meta.index: + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + task_kwargs[bias_gid] = { + 'bias_data': bias_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'bias_ti': self.bias_ti, + 'decimals': self.decimals, + 'match_zero_rate': self.match_zero_rate, + **kwargs_extras + } + return task_kwargs + + def _run( + self, + max_workers=None, + fill_extend=True, + smooth_extend=0, + smooth_interior=0, + **kwargs_extras + ): + """Run correction factor calculations for every site in the bias + dataset + + Parameters + ---------- + fp_out : str | None + Optional .h5 output file to write scalar and adder arrays. + max_workers : int + Number of workers to run in parallel. 1 is serial and None is all + available. + daily_reduction : None | str + Option to do a reduction of the hourly+ source base data to daily + data. Can be None (no reduction, keep source time frequency), "avg" + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) + fill_extend : bool + Flag to fill data past distance_upper_bound using spatial nearest + neighbor. If False, the extended domain will be left as NaN. + smooth_extend : float + Option to smooth the scalar/adder data outside of the spatial + domain set by the distance_upper_bound input. This alleviates the + weird seams far from the domain of interest. This value is the + standard deviation for the gaussian_filter kernel + smooth_interior : float + Option to smooth the scalar/adder data within the valid spatial + domain. This can reduce the affect of extreme values within + aggregations over large number of pixels. + kwargs_extras: dict + Additional kwargs that get sent to ``_run_single`` e.g. + daily_reduction='avg', zero_rate_threshold=1.157e-7 + + Returns + ------- + out : dict + Dictionary of values defining the mean/std of the bias + base + data and the scalar + adder factors to correct the biased data + like: bias_data * scalar + adder. Each value is of shape + (lat, lon, time). + """ + self.bad_bias_gids = [] + + task_kwargs = self._get_run_kwargs(**kwargs_extras) + # sup3r DataHandler opening base files will load all data in parallel + # during the init and should not be passed in parallel to workers + if isinstance(self.base_dh, DataHandler): + max_workers = 1 + + if max_workers == 1: + logger.debug('Running serial calculation.') + results = { + bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh) + for bias_gid, kwargs in task_kwargs.items() + } + else: + logger.info( + 'Running parallel calculation with %s workers.', max_workers + ) + results = run_in_parallel( + self._run_single, task_kwargs, max_workers=max_workers + ) + for i, (bias_gid, single_out) in enumerate(results.items()): + raster_loc = np.where(self.bias_gid_raster == bias_gid) + for key, arr in single_out.items(): + self.out[key][raster_loc] = arr + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), + ) + + logger.info('Finished calculating bias correction factors.') + + self.out = self.fill_and_smooth( + self.out, fill_extend, smooth_extend, smooth_interior + ) + + return self.out diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index d60776349..13bf18dcc 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -11,11 +11,8 @@ import numpy as np from scipy import stats -from sup3r.preprocessing import DataHandler - from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin -from .utilities import run_in_parallel logger = logging.getLogger(__name__) @@ -214,67 +211,12 @@ def run( self.bias_gid_raster.shape ) ) - - self.bad_bias_gids = [] - - # sup3r DataHandler opening base files will load all data in parallel - # during the init and should not be passed in parallel to workers - if isinstance(self.base_dh, DataHandler): - max_workers = 1 - - task_kwargs_list = [] - bias_gids = [] - for bias_gid in self.bias_meta.index: - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - bias_gids.append(bias_gid) - task_kwargs_list.append( - { - 'bias_data': bias_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'daily_reduction': daily_reduction, - 'bias_ti': self.bias_ti, - 'decimals': self.decimals, - 'match_zero_rate': self.match_zero_rate, - } - ) - - if max_workers == 1: - logger.debug('Running serial calculation.') - results = [ - self._run_single(**kwargs, base_dh_inst=self.base_dh) - for kwargs in task_kwargs_list - ] - else: - logger.info( - 'Running parallel calculation with %s workers.', max_workers - ) - results = run_in_parallel( - self._run_single, task_kwargs_list, max_workers=max_workers - ) - for i, single_out in enumerate(results): - raster_loc = np.where(self.bias_gid_raster == bias_gids[i]) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(results), - ) - - logger.info('Finished calculating bias correction factors.') - - self.out = self.fill_and_smooth( - self.out, fill_extend, smooth_extend, smooth_interior + self.out = self._run( + max_workers=max_workers, + daily_reduction=daily_reduction, + fill_extend=fill_extend, + smooth_extend=smooth_extend, + smooth_interior=smooth_interior, ) self.write_outputs(fp_out, self.out) diff --git a/sup3r/bias/presrat.py b/sup3r/bias/presrat.py index 9f0876dee..b9694ac97 100644 --- a/sup3r/bias/presrat.py +++ b/sup3r/bias/presrat.py @@ -16,11 +16,8 @@ QuantileDeltaMapping, ) -from sup3r.preprocessing import DataHandler - from .mixins import ZeroRateMixin from .qdm import QuantileDeltaMappingCorrection -from .utilities import run_in_parallel logger = logging.getLogger(__name__) @@ -362,6 +359,41 @@ def _run_single( return out + def _get_run_kwargs(self, **kwargs_extras): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` arguments for that gid""" + task_kwargs = {} + for bias_gid in self.bias_meta.index: + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) + task_kwargs[bias_gid] = { + 'bias_data': bias_data, + 'bias_fut_data': bias_fut_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'bias_ti': self.bias_dh.time_index, + 'bias_fut_ti': self.bias_fut_dh.time_index, + 'decimals': self.decimals, + 'dist': self.dist, + 'relative': self.relative, + 'sampling': self.sampling, + 'n_samples': self.n_quantiles, + 'log_base': self.log_base, + 'n_time_steps': self.n_time_steps, + 'window_size': self.window_size, + **kwargs_extras, + } + return task_kwargs + def run( self, fp_out=None, @@ -421,77 +453,13 @@ def run( self.bias_gid_raster.shape ) ) - self.bad_bias_gids = [] - - # sup3r DataHandler opening base files will load all data in parallel - # during the init and should not be passed in parallel to workers - if isinstance(self.base_dh, DataHandler): - max_workers = 1 - - task_kwargs_list = [] - bias_gids = [] - for bias_gid in self.bias_meta.index: - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_gids.append(bias_gid) - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) - task_kwargs_list.append( - { - 'bias_data': bias_data, - 'bias_fut_data': bias_fut_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'daily_reduction': daily_reduction, - 'bias_ti': self.bias_dh.time_index, - 'bias_fut_ti': self.bias_fut_dh.time_index, - 'decimals': self.decimals, - 'dist': self.dist, - 'relative': self.relative, - 'sampling': self.sampling, - 'n_samples': self.n_quantiles, - 'log_base': self.log_base, - 'n_time_steps': self.n_time_steps, - 'window_size': self.window_size, - 'zero_rate_threshold': zero_rate_threshold, - } - ) - - if max_workers == 1: - logger.debug('Running serial calculation.') - results = [ - self._run_single(**kwargs) for kwargs in task_kwargs_list - ] - else: - logger.debug( - 'Running parallel calculation with %s workers.', max_workers - ) - results = run_in_parallel( - self._run_single, task_kwargs_list, max_workers=max_workers - ) - - for i, single_out in enumerate(results): - raster_loc = np.where(self.bias_gid_raster == bias_gids[i]) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(results), - ) - - logger.info('Finished calculating bias correction factors.') - - self.out = self.fill_and_smooth( - self.out, fill_extend, smooth_extend, smooth_interior + self.out = self._run( + max_workers=max_workers, + daily_reduction=daily_reduction, + fill_extend=fill_extend, + smooth_extend=smooth_extend, + smooth_interior=smooth_interior, + zero_rate_threshold=zero_rate_threshold, ) extra_attrs = { diff --git a/sup3r/bias/qdm.py b/sup3r/bias/qdm.py index abb37e9f8..a1d2da23d 100644 --- a/sup3r/bias/qdm.py +++ b/sup3r/bias/qdm.py @@ -17,12 +17,10 @@ sample_q_log, ) -from sup3r.preprocessing import DataHandler from sup3r.preprocessing.utilities import expand_paths from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin -from .utilities import run_in_parallel logger = logging.getLogger(__name__) @@ -491,6 +489,41 @@ def write_outputs(self, fp_out, out=None): f.attrs['time_window_center'] = self.time_window_center logger.info('Wrote quantiles to file: {}'.format(fp_out)) + def _get_run_kwargs(self, **kwargs_extras): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` arguments for that gid""" + task_kwargs = {} + for bias_gid in self.bias_meta.index: + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) + task_kwargs[bias_gid] = { + 'bias_data': bias_data, + 'bias_fut_data': bias_fut_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'bias_ti': self.bias_dh.time_index, + 'bias_fut_ti': self.bias_fut_dh.time_index, + 'decimals': self.decimals, + 'dist': self.dist, + 'relative': self.relative, + 'sampling': self.sampling, + 'n_samples': self.n_quantiles, + 'log_base': self.log_base, + 'n_time_steps': self.n_time_steps, + 'window_size': self.window_size, + **kwargs_extras, + } + return task_kwargs + def run( self, fp_out=None, @@ -530,76 +563,13 @@ def run( self.bias_gid_raster.shape ) ) - self.bad_bias_gids = [] - - # sup3r DataHandler opening base files will load all data in parallel - # during the init and should not be passed in parallel to workers - if isinstance(self.base_dh, DataHandler): - max_workers = 1 - task_kwargs_list = [] - bias_gids = [] - for bias_gid in self.bias_meta.index: - raster_loc = np.where(self.bias_gid_raster == bias_gid) - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_gids.append(bias_gid) - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) - task_kwargs_list.append( - { - 'bias_data': bias_data, - 'bias_fut_data': bias_fut_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'daily_reduction': daily_reduction, - 'bias_ti': self.bias_dh.time_index, - 'bias_fut_ti': self.bias_fut_dh.time_index, - 'decimals': self.decimals, - 'dist': self.dist, - 'relative': self.relative, - 'sampling': self.sampling, - 'n_samples': self.n_quantiles, - 'log_base': self.log_base, - 'n_time_steps': self.n_time_steps, - 'window_size': self.window_size, - } - ) - - if max_workers == 1: - logger.debug('Running serial calculation.') - results = [ - self._run_single(**kwargs) for kwargs in task_kwargs_list - ] - else: - logger.debug( - 'Running parallel calculation with %s workers.', max_workers - ) - results = run_in_parallel( - self._run_single, task_kwargs_list, max_workers=max_workers - ) - - for i, single_out in enumerate(results): - raster_loc = np.where(self.bias_gid_raster == bias_gids[i]) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(results), - ) - - logger.info('Finished calculating bias correction factors.') - - self.out = self.fill_and_smooth( - self.out, fill_extend, smooth_extend, smooth_interior + self.out = self._run( + max_workers=max_workers, + daily_reduction=daily_reduction, + fill_extend=fill_extend, + smooth_extend=smooth_extend, + smooth_interior=smooth_interior, ) self.write_outputs(fp_out, self.out) diff --git a/sup3r/bias/utilities.py b/sup3r/bias/utilities.py index e417f2736..4c2e759e0 100644 --- a/sup3r/bias/utilities.py +++ b/sup3r/bias/utilities.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def run_in_parallel(task_function, task_kwargs_list, max_workers=None): +def run_in_parallel(task_function, task_kwargs, max_workers=None): """ Execute a list of tasks in parallel using ``ProcessPoolExecutor``. @@ -27,26 +27,27 @@ def run_in_parallel(task_function, task_kwargs_list, max_workers=None): ---------- task_function : callable The function to execute in parallel. - task_args_list : list - A list of keyword argument dictionaries for a single call to + task_kwargs : dictionary + A dictionary of keyword argument dictionaries for a single call to ``task_function``. max_workers : int, optional The maximum number of workers to use. If None, it uses all available. Returns ------- - results : list - A list of results from the executed tasks. + results : dictionary + A dictionary of results from the executed tasks with the same keys as + ``task_kwargs``. """ - results = [] + results = {} with ProcessPoolExecutor(max_workers=max_workers) as exe: futures = { - exe.submit(task_function, **kwargs): kwargs - for kwargs in task_kwargs_list + exe.submit(task_function, **kwargs): bias_gid + for bias_gid, kwargs in task_kwargs.items() } for future in as_completed(futures): - result = future.result() - results.append(result) + bias_gid = futures[future] + results[bias_gid] = future.result() return results From f6d25eeb10b2dc8f9ac93f1d7cdfbf26a2e8c6a4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Jan 2025 16:28:37 -0700 Subject: [PATCH 4/4] moved ``_run`` method to bias correction interface ``AbstractBiasCorrection`` --- sup3r/bias/abstract.py | 144 +++++++++++++++++++++ sup3r/bias/base.py | 114 ---------------- sup3r/bias/bias_calc.py | 32 ++++- sup3r/bias/presrat.py | 36 +----- sup3r/bias/qdm.py | 6 +- sup3r/bias/utilities.py | 33 ----- tests/bias/test_presrat_bias_correction.py | 48 +------ 7 files changed, 188 insertions(+), 225 deletions(-) create mode 100644 sup3r/bias/abstract.py diff --git a/sup3r/bias/abstract.py b/sup3r/bias/abstract.py new file mode 100644 index 000000000..8fc5487fd --- /dev/null +++ b/sup3r/bias/abstract.py @@ -0,0 +1,144 @@ +"""Bias correction class interface.""" + +import logging +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, as_completed + +import numpy as np + +from sup3r.preprocessing import DataHandler + +logger = logging.getLogger(__name__) + + +class AbstractBiasCorrection(ABC): + """Minimal interface for bias correction classes""" + + @abstractmethod + def _get_run_kwargs(self, **run_single_kwargs): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` kwargs dict for that gid""" + + def _run( + self, + out, + max_workers=None, + fill_extend=True, + smooth_extend=0, + smooth_interior=0, + **run_single_kwargs, + ): + """Run correction factor calculations for every site in the bias + dataset + + Parameters + ---------- + out : dict + Dictionary of arrays to fill with bias correction factors. + max_workers : int + Number of workers to run in parallel. 1 is serial and None is all + available. + daily_reduction : None | str + Option to do a reduction of the hourly+ source base data to daily + data. Can be None (no reduction, keep source time frequency), "avg" + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) + fill_extend : bool + Flag to fill data past distance_upper_bound using spatial nearest + neighbor. If False, the extended domain will be left as NaN. + smooth_extend : float + Option to smooth the scalar/adder data outside of the spatial + domain set by the distance_upper_bound input. This alleviates the + weird seams far from the domain of interest. This value is the + standard deviation for the gaussian_filter kernel + smooth_interior : float + Option to smooth the scalar/adder data within the valid spatial + domain. This can reduce the affect of extreme values within + aggregations over large number of pixels. + run_single_kwargs: dict + Additional kwargs that get sent to ``_run_single`` e.g. + daily_reduction='avg', zero_rate_threshold=1.157e-7 + + Returns + ------- + out : dict + Dictionary of values defining the mean/std of the bias + base data + and correction factors to correct the biased data like: bias_data * + scalar + adder. Each value is of shape (lat, lon, time). + """ + self.bad_bias_gids = [] + + task_kwargs = self._get_run_kwargs(**run_single_kwargs) + + # sup3r DataHandler opening base files will load all data in parallel + # during the init and should not be passed in parallel to workers + if isinstance(self.base_dh, DataHandler): + max_workers = 1 + + if max_workers == 1: + logger.debug('Running serial calculation.') + results = { + bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh) + for bias_gid, kwargs in task_kwargs.items() + } + else: + logger.info( + 'Running parallel calculation with %s workers.', max_workers + ) + results = {} + with ProcessPoolExecutor(max_workers=max_workers) as exe: + futures = { + exe.submit(self._run_single, **kwargs): bias_gid + for bias_gid, kwargs in task_kwargs.items() + } + for future in as_completed(futures): + bias_gid = futures[future] + results[bias_gid] = future.result() + + for i, (bias_gid, single_out) in enumerate(results.items()): + raster_loc = np.where(self.bias_gid_raster == bias_gid) + for key, arr in single_out.items(): + out[key][raster_loc] = arr + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), + ) + + logger.info('Finished calculating bias correction factors.') + + return self.fill_and_smooth( + out, fill_extend, smooth_extend, smooth_interior + ) + + @abstractmethod + def run( + self, + fp_out=None, + max_workers=None, + daily_reduction='avg', + fill_extend=True, + smooth_extend=0, + smooth_interior=0, + ): + """Run correction factor calculations for every site in the bias + dataset""" + + @classmethod + @abstractmethod + def _run_single( + cls, + bias_data, + base_fps, + bias_feature, + base_dset, + base_gid, + base_handler, + daily_reduction, + bias_ti, + decimals, + base_dh_inst=None, + match_zero_rate=False, + ): + """Run correction factor calculations for a single site""" diff --git a/sup3r/bias/base.py b/sup3r/bias/base.py index a184f37ec..d7761a876 100644 --- a/sup3r/bias/base.py +++ b/sup3r/bias/base.py @@ -20,8 +20,6 @@ from sup3r.utilities import VERSION_RECORD, ModuleName from sup3r.utilities.cli import BaseCLI -from .utilities import run_in_parallel - logger = logging.getLogger(__name__) @@ -779,115 +777,3 @@ def _reduce_base_data( assert base_data.shape == daily_ti.shape, msg return base_data, daily_ti - - def _get_run_kwargs(self, **kwargs_extras): - """Get dictionary of kwarg dictionaries to use for calls to - ``_run_single``. Each key-value pair is a bias_gid with the associated - ``_run_single`` arguments for that gid""" - task_kwargs = {} - for bias_gid in self.bias_meta.index: - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - task_kwargs[bias_gid] = { - 'bias_data': bias_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'bias_ti': self.bias_ti, - 'decimals': self.decimals, - 'match_zero_rate': self.match_zero_rate, - **kwargs_extras - } - return task_kwargs - - def _run( - self, - max_workers=None, - fill_extend=True, - smooth_extend=0, - smooth_interior=0, - **kwargs_extras - ): - """Run correction factor calculations for every site in the bias - dataset - - Parameters - ---------- - fp_out : str | None - Optional .h5 output file to write scalar and adder arrays. - max_workers : int - Number of workers to run in parallel. 1 is serial and None is all - available. - daily_reduction : None | str - Option to do a reduction of the hourly+ source base data to daily - data. Can be None (no reduction, keep source time frequency), "avg" - (daily average), "max" (daily max), "min" (daily min), - "sum" (daily sum/total) - fill_extend : bool - Flag to fill data past distance_upper_bound using spatial nearest - neighbor. If False, the extended domain will be left as NaN. - smooth_extend : float - Option to smooth the scalar/adder data outside of the spatial - domain set by the distance_upper_bound input. This alleviates the - weird seams far from the domain of interest. This value is the - standard deviation for the gaussian_filter kernel - smooth_interior : float - Option to smooth the scalar/adder data within the valid spatial - domain. This can reduce the affect of extreme values within - aggregations over large number of pixels. - kwargs_extras: dict - Additional kwargs that get sent to ``_run_single`` e.g. - daily_reduction='avg', zero_rate_threshold=1.157e-7 - - Returns - ------- - out : dict - Dictionary of values defining the mean/std of the bias + base - data and the scalar + adder factors to correct the biased data - like: bias_data * scalar + adder. Each value is of shape - (lat, lon, time). - """ - self.bad_bias_gids = [] - - task_kwargs = self._get_run_kwargs(**kwargs_extras) - # sup3r DataHandler opening base files will load all data in parallel - # during the init and should not be passed in parallel to workers - if isinstance(self.base_dh, DataHandler): - max_workers = 1 - - if max_workers == 1: - logger.debug('Running serial calculation.') - results = { - bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh) - for bias_gid, kwargs in task_kwargs.items() - } - else: - logger.info( - 'Running parallel calculation with %s workers.', max_workers - ) - results = run_in_parallel( - self._run_single, task_kwargs, max_workers=max_workers - ) - for i, (bias_gid, single_out) in enumerate(results.items()): - raster_loc = np.where(self.bias_gid_raster == bias_gid) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(results), - ) - - logger.info('Finished calculating bias correction factors.') - - self.out = self.fill_and_smooth( - self.out, fill_extend, smooth_extend, smooth_interior - ) - - return self.out diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 13bf18dcc..4e3f40f72 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -11,13 +11,16 @@ import numpy as np from scipy import stats +from .abstract import AbstractBiasCorrection from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin logger = logging.getLogger(__name__) -class LinearCorrection(FillAndSmoothMixin, DataRetrievalBase): +class LinearCorrection( + AbstractBiasCorrection, FillAndSmoothMixin, DataRetrievalBase +): """Calculate linear correction *scalar +adder factors to bias correct data This calculation operates on single bias sites for the full time series of @@ -159,6 +162,32 @@ def write_outputs(self, fp_out, out): 'Wrote scalar adder factors to file: {}'.format(fp_out) ) + def _get_run_kwargs(self, **kwargs_extras): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` arguments for that gid""" + task_kwargs = {} + for bias_gid in self.bias_meta.index: + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + task_kwargs[bias_gid] = { + 'bias_data': bias_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'bias_ti': self.bias_ti, + 'decimals': self.decimals, + 'match_zero_rate': self.match_zero_rate, + **kwargs_extras, + } + return task_kwargs + def run( self, fp_out=None, @@ -212,6 +241,7 @@ def run( ) ) self.out = self._run( + out=self.out, max_workers=max_workers, daily_reduction=daily_reduction, fill_extend=fill_extend, diff --git a/sup3r/bias/presrat.py b/sup3r/bias/presrat.py index b9694ac97..da09ad478 100644 --- a/sup3r/bias/presrat.py +++ b/sup3r/bias/presrat.py @@ -359,41 +359,6 @@ def _run_single( return out - def _get_run_kwargs(self, **kwargs_extras): - """Get dictionary of kwarg dictionaries to use for calls to - ``_run_single``. Each key-value pair is a bias_gid with the associated - ``_run_single`` arguments for that gid""" - task_kwargs = {} - for bias_gid in self.bias_meta.index: - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) - task_kwargs[bias_gid] = { - 'bias_data': bias_data, - 'bias_fut_data': bias_fut_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'bias_ti': self.bias_dh.time_index, - 'bias_fut_ti': self.bias_fut_dh.time_index, - 'decimals': self.decimals, - 'dist': self.dist, - 'relative': self.relative, - 'sampling': self.sampling, - 'n_samples': self.n_quantiles, - 'log_base': self.log_base, - 'n_time_steps': self.n_time_steps, - 'window_size': self.window_size, - **kwargs_extras, - } - return task_kwargs - def run( self, fp_out=None, @@ -454,6 +419,7 @@ def run( ) ) self.out = self._run( + out=self.out, max_workers=max_workers, daily_reduction=daily_reduction, fill_extend=fill_extend, diff --git a/sup3r/bias/qdm.py b/sup3r/bias/qdm.py index a1d2da23d..a635ee811 100644 --- a/sup3r/bias/qdm.py +++ b/sup3r/bias/qdm.py @@ -19,13 +19,16 @@ from sup3r.preprocessing.utilities import expand_paths +from .abstract import AbstractBiasCorrection from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin logger = logging.getLogger(__name__) -class QuantileDeltaMappingCorrection(FillAndSmoothMixin, DataRetrievalBase): +class QuantileDeltaMappingCorrection( + AbstractBiasCorrection, FillAndSmoothMixin, DataRetrievalBase +): """Estimate probability distributions required by Quantile Delta Mapping The main purpose of this class is to estimate the probability @@ -565,6 +568,7 @@ def run( ) self.out = self._run( + out=self.out, max_workers=max_workers, daily_reduction=daily_reduction, fill_extend=fill_extend, diff --git a/sup3r/bias/utilities.py b/sup3r/bias/utilities.py index 4c2e759e0..817620dde 100644 --- a/sup3r/bias/utilities.py +++ b/sup3r/bias/utilities.py @@ -2,7 +2,6 @@ import logging import os -from concurrent.futures import ProcessPoolExecutor, as_completed from inspect import signature from warnings import warn @@ -19,38 +18,6 @@ logger = logging.getLogger(__name__) -def run_in_parallel(task_function, task_kwargs, max_workers=None): - """ - Execute a list of tasks in parallel using ``ProcessPoolExecutor``. - - Parameters - ---------- - task_function : callable - The function to execute in parallel. - task_kwargs : dictionary - A dictionary of keyword argument dictionaries for a single call to - ``task_function``. - max_workers : int, optional - The maximum number of workers to use. If None, it uses all available. - - Returns - ------- - results : dictionary - A dictionary of results from the executed tasks with the same keys as - ``task_kwargs``. - """ - results = {} - with ProcessPoolExecutor(max_workers=max_workers) as exe: - futures = { - exe.submit(task_function, **kwargs): bias_gid - for bias_gid, kwargs in task_kwargs.items() - } - for future in as_completed(futures): - bias_gid = futures[future] - results[bias_gid] = future.result() - return results - - def lin_bc(handler, bc_files, bias_feature=None, threshold=0.1): """Bias correct the data in this DataHandler in place using linear bias correction factors from files output by MonthlyLinearCorrection or diff --git a/tests/bias/test_presrat_bias_correction.py b/tests/bias/test_presrat_bias_correction.py index 7137965ef..a9af3e8ff 100644 --- a/tests/bias/test_presrat_bias_correction.py +++ b/tests/bias/test_presrat_bias_correction.py @@ -82,27 +82,17 @@ def fp_resource(tmpdir_factory): """ fn = tmpdir_factory.mktemp('data').join('precip_oh.h5') - # Reproducing FP_NSRDB before I can change it. time = pd.date_range( - '2018-01-01 00:00:00+0000', '2018-03-26 23:30:00+0000', freq='30m' - ) - time = pd.DatetimeIndex( - np.arange( - np.datetime64('2018-01-01 00:00:00+00:00'), - np.datetime64('2019-01-01 00:00:00+00:00'), - np.timedelta64(6, 'h'), - ) + '2018-01-01 00:00:00', '2019-01-01 00:00:00', freq='6h' ) lat = np.arange(39.77, 39.00, -0.04) lon = np.arange(-105.14, -104.37, 0.04) - rng = np.random.default_rng() - ghi = rng.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) + ghi = RANDOM_GENERATOR.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) ds = xr.Dataset( data_vars={'ghi': (['time', 'lat', 'lon'], ghi)}, coords={ 'time': ('time', time), - # "time_bnds": (["time", "bnds"], time_bnds), 'lat': ('lat', lat), 'lon': ('lon', lon), }, @@ -142,40 +132,16 @@ def fp_resource(tmpdir_factory): @pytest.fixture(scope='module') def precip(): - """Synthetic historical modeled dataset - - Note - ---- - There are different expected patterns in different components of the - processing. For instance, lon might be expected as 0-360 in some places - but -180 to 180 in others, and expect a certain order that does not - necessarily match latitutde. So changes in the coordinates shall be - done carefullly. - """ - # first value must conform with TARGET[0] - # n values must conform with SHAPE[0] - # dlat = -0.70175216 + """Synthetic historical modeled dataset""" lat = np.array( [40.3507847105177, 39.649032596592, 38.9472804370071, 38.2455282337738] ) - # assert np.allclose(lat[0], TARGET[0]) - # assert lat.size == SHAPE[0] - - # lon = np.linspace(254.4, 255.1, 10) - # first value must conform with TARGET[1] - # n values must conform with SHAPE[1] lon = np.array([254.53125, 255.234375, 255.9375, 256.640625]) - # assert np.allclose(lat[1], 360 + TARGET[0]) - # assert lon.size == SHAPE[0] - t0 = np.datetime64('2015-01-01T12:00:00') - time = t0 + np.arange( - 0, SAMPLE_TIME_DURATION, SAMPLE_TIME_RESOLUTION, dtype='timedelta64[D]' + time = pd.date_range( + '2015-01-01T12:00:00', '2016-12-31T12:00:00', freq='D' ) - # bnds = (-np.timedelta64(12, 'h'), np.timedelta64(12, 'h')) - # time_bnds = time[:, np.newaxis] + bnds - rng = np.random.default_rng() - pr = rng.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) + pr = RANDOM_GENERATOR.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) # Transform the upper tail into negligible to guarantee some 'zero # precipiation days'. @@ -190,7 +156,7 @@ def precip(): data=pr, dims=['time', 'lat', 'lon'], coords={ - 'time': ('time', time), + 'time': ('time', pd.DatetimeIndex(time)), 'lat': ('lat', lat), 'lon': ('lon', lon), },