Skip to content

Commit

Permalink
blockwise operation causing issues with compatible shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 19, 2024
1 parent dd69f42 commit 9e40015
Showing 1 changed file with 31 additions and 40 deletions.
71 changes: 31 additions & 40 deletions sup3r/bias/bias_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,6 @@ def _apply_qdm(
base_params,
bias_params,
bias_fut_params,
current_time_idx,
dist='empirical',
sampling='linear',
log_base=10,
Expand Down Expand Up @@ -514,8 +513,6 @@ def _apply_qdm(
Same requirements as params_oh. This input arg is for the **modeled
future distribution**. If this is None, this defaults to params_mh
(no future data, just corrected to modeled historical distribution)
current_time_idx : int
Time index for the current qdm application.
dist : str
Probability distribution name to use to model the data which
determines how the param args are used. This can "empirical" or any
Expand Down Expand Up @@ -572,22 +569,20 @@ def _apply_qdm(
Max number of workers to use for QDM process pool
"""

# Naming following the paper: observed historical
oh = base_params[:, :, current_time_idx]
# Modeled historical
mh = bias_params[:, :, current_time_idx]
# Modeled future
mf = bias_fut_params[:, :, current_time_idx]

# This satisfies the rex's QDM design
mf = None if no_trend else np.reshape(mf, (-1, mf.shape[-1]))
bias_fut_params = (
None
if no_trend
else np.reshape(bias_fut_params, (-1, bias_fut_params.shape[-1]))
)
# The distributions at this point, after selected the respective
# time window with `window_idx`, are 3D (space, space, N-params)
# Collapse 3D (space, space, N) into 2D (space**2, N)

QDM = QuantileDeltaMapping(
params_oh=np.reshape(oh, (-1, oh.shape[-1])),
params_mh=np.reshape(mh, (-1, mh.shape[-1])),
params_mf=mf,
params_oh=np.reshape(base_params, (-1, base_params.shape[-1])),
params_mh=np.reshape(bias_params, (-1, bias_params.shape[-1])),
params_mf=bias_fut_params,
dist=dist,
relative=relative,
sampling=sampling,
Expand All @@ -596,15 +591,11 @@ def _apply_qdm(
delta_denom_zero=delta_denom_zero,
delta_range=delta_range,
)

# input 3D shape (spatial, spatial, temporal)
# QDM expects input arr with shape (time, space)
tmp = np.reshape(subset.data, (-1, subset.shape[-1])).T
# Apply QDM correction
tmp = da.blockwise(
QDM, 'ij', tmp, 'ij', dtype=np.float32, max_workers=max_workers
)

tmp = QDM(tmp, max_workers=max_workers)
# Reorgnize array back from (time, space)
# to (spatial, spatial, temporal)
return da.reshape(tmp.T, subset.shape)
Expand Down Expand Up @@ -762,29 +753,29 @@ def local_qdm_bc(
cfg = params['cfg']
base_params = params['base']
bias_params = params['bias']
bias_fut_params = params['bias_fut']
bias_fut_params = params.get('bias_fut', None)

if lr_padded_slice is not None:
spatial_slice = (lr_padded_slice[0], lr_padded_slice[1])
base_params = base_params[spatial_slice]
bias_params = bias_params[spatial_slice]
bias_fut_params = bias_fut_params[spatial_slice]
if bias_fut_params is not None:
bias_fut_params = bias_fut_params[spatial_slice]

data_unbiased = da.full_like(data, np.nan)
closest_time_idx = abs(
cfg['time_window_center'][:, np.newaxis]
- np.array(time_index.day_of_year)
)
closest_time_idx = closest_time_idx.argmin(axis=0)
closest_time_idx = [
np.argmin(abs(d - cfg['time_window_center']))
for d in time_index.day_of_year
]

for nt in set(closest_time_idx):
subset_idx = closest_time_idx == nt
mf = None if bias_fut_params is None else bias_fut_params[:, :, nt]
subset = _apply_qdm(
subset=data[:, :, subset_idx],
base_params=base_params,
bias_params=bias_params,
bias_fut_params=bias_fut_params,
current_time_idx=nt,
base_params=base_params[:, :, nt],
bias_params=bias_params[:, :, nt],
bias_fut_params=mf,
dist=cfg.get('dist', 'empirical'),
sampling=cfg.get('sampling', 'linear'),
log_base=cfg.get('log_base', 10),
Expand Down Expand Up @@ -1073,23 +1064,23 @@ def local_presrat_bc(
spatial_slice = (lr_padded_slice[0], lr_padded_slice[1])
base_params = base_params[spatial_slice]
bias_params = bias_params[spatial_slice]
bias_fut_params = bias_fut_params[spatial_slice]
if bias_fut_params is not None:
bias_fut_params = bias_fut_params[spatial_slice]

data_unbiased = da.full_like(data, np.nan)
closest_time_idx = abs(
cfg['time_window_center'][:, np.newaxis]
- np.array(time_index.day_of_year)
)
closest_time_idx = closest_time_idx.argmin(axis=0)
closest_time_idx = [
np.argmin(abs(d - cfg['time_window_center']))
for d in time_index.day_of_year
]

for nt in set(closest_time_idx):
subset_idx = closest_time_idx == nt
mf = None if bias_fut_params is None else bias_fut_params[:, :, nt]
subset = _apply_qdm(
subset=data[:, :, subset_idx],
base_params=base_params,
bias_params=bias_params,
bias_fut_params=bias_fut_params,
current_time_idx=nt,
base_params=base_params[:, :, nt],
bias_params=bias_params[:, :, nt],
bias_fut_params=mf,
dist=cfg.get('dist', 'empirical'),
sampling=cfg.get('sampling', 'linear'),
log_base=cfg.get('log_base', 10),
Expand Down

0 comments on commit 9e40015

Please sign in to comment.