Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bias refact: run_in_parallel function to remove duplicate calls t… #253

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions sup3r/bias/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Bias correction class interface."""

import logging
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np

from sup3r.preprocessing import DataHandler

logger = logging.getLogger(__name__)


class AbstractBiasCorrection(ABC):
"""Minimal interface for bias correction classes"""

@abstractmethod
def _get_run_kwargs(self, **run_single_kwargs):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` kwargs dict for that gid"""

def _run(
self,
out,
max_workers=None,
fill_extend=True,
smooth_extend=0,
smooth_interior=0,
**run_single_kwargs,
):
"""Run correction factor calculations for every site in the bias
dataset

Parameters
----------
out : dict
Dictionary of arrays to fill with bias correction factors.
max_workers : int
Number of workers to run in parallel. 1 is serial and None is all
available.
daily_reduction : None | str
Option to do a reduction of the hourly+ source base data to daily
data. Can be None (no reduction, keep source time frequency), "avg"
(daily average), "max" (daily max), "min" (daily min),
"sum" (daily sum/total)
fill_extend : bool
Flag to fill data past distance_upper_bound using spatial nearest
neighbor. If False, the extended domain will be left as NaN.
smooth_extend : float
Option to smooth the scalar/adder data outside of the spatial
domain set by the distance_upper_bound input. This alleviates the
weird seams far from the domain of interest. This value is the
standard deviation for the gaussian_filter kernel
smooth_interior : float
Option to smooth the scalar/adder data within the valid spatial
domain. This can reduce the affect of extreme values within
aggregations over large number of pixels.
run_single_kwargs: dict
Additional kwargs that get sent to ``_run_single`` e.g.
daily_reduction='avg', zero_rate_threshold=1.157e-7

Returns
-------
out : dict
Dictionary of values defining the mean/std of the bias + base data
and correction factors to correct the biased data like: bias_data *
scalar + adder. Each value is of shape (lat, lon, time).
"""
self.bad_bias_gids = []

task_kwargs = self._get_run_kwargs(**run_single_kwargs)

# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

if max_workers == 1:
logger.debug('Running serial calculation.')
results = {
bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh)
for bias_gid, kwargs in task_kwargs.items()
}
else:
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
results = {}
with ProcessPoolExecutor(max_workers=max_workers) as exe:
futures = {
exe.submit(self._run_single, **kwargs): bias_gid
for bias_gid, kwargs in task_kwargs.items()
}
for future in as_completed(futures):
bias_gid = futures[future]
results[bias_gid] = future.result()

for i, (bias_gid, single_out) in enumerate(results.items()):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
for key, arr in single_out.items():
out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

return self.fill_and_smooth(
out, fill_extend, smooth_extend, smooth_interior
)

@abstractmethod
def run(
self,
fp_out=None,
max_workers=None,
daily_reduction='avg',
fill_extend=True,
smooth_extend=0,
smooth_interior=0,
):
"""Run correction factor calculations for every site in the bias
dataset"""

@classmethod
@abstractmethod
def _run_single(
cls,
bias_data,
base_fps,
bias_feature,
base_dset,
base_gid,
base_handler,
daily_reduction,
bias_ti,
decimals,
base_dh_inst=None,
match_zero_rate=False,
):
"""Run correction factor calculations for a single site"""
4 changes: 2 additions & 2 deletions sup3r/bias/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
bias_handler_kwargs=None,
decimals=None,
match_zero_rate=False,
pre_load=True
pre_load=True,
):
"""
Parameters
Expand Down Expand Up @@ -178,7 +178,7 @@ class is used, all data will be loaded in this class'

self.nn_dist, self.nn_ind = self.bias_tree.query(
self.base_meta[['latitude', 'longitude']],
distance_upper_bound=self.distance_upper_bound
distance_upper_bound=self.distance_upper_bound,
)

if pre_load:
Expand Down
134 changes: 38 additions & 96 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
"""Utilities to calculate the bias correction factors for biased data that is
going to be fed into the sup3r downscaling models. This is typically used to
bias correct GCM data vs. some historical record like the WTK or NSRDB.

TODO: Generalize the ``with ProcessPoolExecutor() as exe: ...`` so we don't
need to duplicate this wherever we kickoff a process or thread pool
"""
bias correct GCM data vs. some historical record like the WTK or NSRDB."""

import copy
import json
import logging
import os
from concurrent.futures import ProcessPoolExecutor, as_completed

import h5py
import numpy as np
from scipy import stats

from sup3r.preprocessing import DataHandler

from .abstract import AbstractBiasCorrection
from .base import DataRetrievalBase
from .mixins import FillAndSmoothMixin

logger = logging.getLogger(__name__)


class LinearCorrection(FillAndSmoothMixin, DataRetrievalBase):
class LinearCorrection(
AbstractBiasCorrection, FillAndSmoothMixin, DataRetrievalBase
):
"""Calculate linear correction *scalar +adder factors to bias correct data

This calculation operates on single bias sites for the full time series of
Expand Down Expand Up @@ -166,6 +162,32 @@ def write_outputs(self, fp_out, out):
'Wrote scalar adder factors to file: {}'.format(fp_out)
)

def _get_run_kwargs(self, **kwargs_extras):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` arguments for that gid"""
task_kwargs = {}
for bias_gid in self.bias_meta.index:
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
task_kwargs[bias_gid] = {
'bias_data': bias_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'bias_ti': self.bias_ti,
'decimals': self.decimals,
'match_zero_rate': self.match_zero_rate,
**kwargs_extras,
}
return task_kwargs

def run(
self,
fp_out=None,
Expand Down Expand Up @@ -218,94 +240,14 @@ def run(
self.bias_gid_raster.shape
)
)

self.bad_bias_gids = []

# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

if max_workers == 1:
logger.debug('Running serial calculation.')
for i, bias_gid in enumerate(self.bias_meta.index):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
single_out = self._run_single(
bias_data,
self.base_fps,
self.bias_feature,
self.base_dset,
base_gid,
self.base_handler,
daily_reduction,
self.bias_ti,
self.decimals,
base_dh_inst=self.base_dh,
match_zero_rate=self.match_zero_rate,
)
for key, arr in single_out.items():
self.out[key][raster_loc] = arr

logger.info(
'Completed bias calculations for {} out of {} '
'sites'.format(i + 1, len(self.bias_meta))
)

else:
logger.debug(
'Running parallel calculation with {} workers.'.format(
max_workers
)
)
with ProcessPoolExecutor(max_workers=max_workers) as exe:
futures = {}
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
future = exe.submit(
self._run_single,
bias_data,
self.base_fps,
self.bias_feature,
self.base_dset,
base_gid,
self.base_handler,
daily_reduction,
self.bias_ti,
self.decimals,
match_zero_rate=self.match_zero_rate,
)
futures[future] = raster_loc

logger.debug('Finished launching futures.')
for i, future in enumerate(as_completed(futures)):
raster_loc = futures[future]
single_out = future.result()
for key, arr in single_out.items():
self.out[key][raster_loc] = arr

logger.info(
'Completed bias calculations for {} out of {} '
'sites'.format(i + 1, len(futures))
)

logger.info('Finished calculating bias correction factors.')

self.out = self.fill_and_smooth(
self.out, fill_extend, smooth_extend, smooth_interior
self.out = self._run(
out=self.out,
max_workers=max_workers,
daily_reduction=daily_reduction,
fill_extend=fill_extend,
smooth_extend=smooth_extend,
smooth_interior=smooth_interior,
)

self.write_outputs(fp_out, self.out)

return copy.deepcopy(self.out)
Expand Down
Loading
Loading