Skip to content

Commit

Permalink
Merge pull request #142 from NREL/gb/trh_loss
Browse files Browse the repository at this point in the history
Gb/trh loss
  • Loading branch information
grantbuster authored Jan 20, 2023
2 parents 9cfa20c + 2ff434e commit 60fab96
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 87 deletions.
17 changes: 10 additions & 7 deletions sup3r/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Simple models for super resolution such as linear interp models."""
import numpy as np
import logging
from inspect import signature
import os
import json
from sup3r.utilities.utilities import st_interp
Expand Down Expand Up @@ -45,7 +46,9 @@ def load(cls, model_dir, verbose=False):
Parameters
----------
model_dir : str
Directory to load LinearInterp model files from.
Directory to load LinearInterp model files from. Must
have a model_params.json file containing "meta" key with all of the
class init args.
verbose : bool
Flag to log information about the loaded model.
Expand All @@ -59,11 +62,10 @@ def load(cls, model_dir, verbose=False):
with open(fp_params, 'r') as f:
params = json.load(f)

meta = params.get('meta', {'class': 'Sup3rGan'})
model = cls(features=meta['training_features'],
s_enhance=meta['s_enhance'],
t_enhance=meta['t_enhance'],
t_centered=meta['t_centered'])
meta = params['meta']
args = signature(cls.__init__).parameters
kwargs = {k: v for k, v in meta.items() if k in args}
model = cls(**kwargs)

if verbose:
logger.info('Loading LinearInterp with meta data: {}'
Expand All @@ -74,7 +76,8 @@ def load(cls, model_dir, verbose=False):
@property
def meta(self):
"""Get meta data dictionary that defines the model params"""
return {'s_enhance': self._s_enhance,
return {'features': self._features,
's_enhance': self._s_enhance,
't_enhance': self._t_enhance,
't_centered': self._t_centered,
'training_features': self.training_features,
Expand Down
148 changes: 82 additions & 66 deletions sup3r/models/surface.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# -*- coding: utf-8 -*-
"""Special models for surface meteorological data."""
import os
import json
import logging
from inspect import signature
from fnmatch import fnmatch
import numpy as np
from PIL import Image
from sklearn import linear_model
from warnings import warn

from sup3r.models.abstract import AbstractInterface
from sup3r.models.linear import LinearInterp
from sup3r.utilities.utilities import spatial_coarsening

logger = logging.getLogger(__name__)


class SurfaceSpatialMetModel(AbstractInterface):
class SurfaceSpatialMetModel(LinearInterp):
"""Model to spatially downscale daily-average near-surface temperature,
relative humidity, and pressure
Expand Down Expand Up @@ -43,7 +46,8 @@ class SurfaceSpatialMetModel(AbstractInterface):

def __init__(self, features, s_enhance, noise_adders=None,
temp_lapse=None, w_delta_temp=None, w_delta_topo=None,
pres_div=None, pres_exp=None):
pres_div=None, pres_exp=None, interp_method='LANCZOS',
fix_bias=True):
"""
Parameters
----------
Expand Down Expand Up @@ -85,6 +89,15 @@ def __init__(self, features, s_enhance, noise_adders=None,
pres_div : None | float
Exponential factor in the pressure scale height equation. Defaults
to the cls.PRES_EXP attribute.
interp_method : str
Name of the interpolation method to use from PIL.Image.Resampling
(NEAREST, BILINEAR, BICUBIC, LANCZOS)
LANCZOS is default and has been tested to work best for
SurfaceSpatialMetModel.
fix_bias : bool
Some local bias can be introduced by the bilinear interp + lapse
rate, this flag will attempt to correct that bias by using the
low-resolution deviation from the input data
"""

self._features = features
Expand All @@ -95,6 +108,8 @@ def __init__(self, features, s_enhance, noise_adders=None,
self._w_delta_topo = w_delta_topo or self.W_DELTA_TOPO
self._pres_div = pres_div or self.PRES_DIV
self._pres_exp = pres_exp or self.PRES_EXP
self._fix_bias = fix_bias
self._interp_method = getattr(Image.Resampling, interp_method)

if isinstance(self._noise_adders, (int, float)):
self._noise_adders = [self._noise_adders] * len(self._features)
Expand All @@ -103,42 +118,6 @@ def __len__(self):
"""Get number of model steps (match interface of MultiStepGan)"""
return 1

@classmethod
def load(cls, features, s_enhance, verbose=False, **kwargs):
"""Load the GAN with its sub-networks from a previously saved-to output
directory.
Parameters
----------
features : list
List of feature names that this model will operate on for both
input and output. This must match the feature axis ordering in the
array input to generate(). Typically this is a list containing:
temperature_*m, relativehumidity_*m, and pressure_*m. The list can
contain multiple instances of each variable at different heights.
relativehumidity_*m entries must have corresponding temperature_*m
entires at the same hub height.
s_enhance : int
Integer factor by which the spatial axes are to be enhanced.
verbose : bool
Flag to log information about the loaded model.
kwargs : None | dict
Optional kwargs to initialize SurfaceSpatialMetModel
Returns
-------
out : SurfaceSpatialMetModel
Returns an initialized SurfaceSpatialMetModel
"""

model = cls(features, s_enhance, **kwargs)

if verbose:
logger.info('Loading SurfaceSpatialMetModel with meta data: {}'
.format(model.meta))

return model

@staticmethod
def _get_s_enhance(topo_lr, topo_hr):
"""Get the spatial enhancement factor given low-res and high-res
Expand Down Expand Up @@ -227,8 +206,39 @@ def _get_temp_rh_ind(self, idf_rh):

return idf_temp

def _fix_downscaled_bias(self, single_lr, single_hr,
method=Image.Resampling.LANCZOS):
"""Fix any bias introduced by the spatial downscaling with lapse rate.
Parameters
----------
single_lr : np.ndarray
Single timestep raster data with shape
(lat, lon) matching the low-resolution input data.
single_hr : np.ndarray
Single timestep downscaled raster data with shape
(lat, lon) matching the high-resolution input data.
method : Image.Resampling.LANCZOS
An Image.Resampling method (NEAREST, BILINEAR, BICUBIC, LANCZOS).
NEAREST enforces zero bias but makes slightly more spatial seams.
Returns
-------
single_hr : np.ndarray
Single timestep downscaled raster data with shape
(lat, lon) matching the high-resolution input data.
"""

re_coarse = spatial_coarsening(np.expand_dims(single_hr, axis=-1),
s_enhance=self._s_enhance,
obs_axis=False)[..., 0]
bias = re_coarse - single_lr
bc = self.downscale_arr(bias, s_enhance=self._s_enhance, method=method)
single_hr -= bc
return single_hr

@staticmethod
def downscale_arr(arr, s_enhance, method=Image.Resampling.BILINEAR):
def downscale_arr(arr, s_enhance, method=Image.Resampling.LANCZOS):
"""Downscale a 2D array of data Image.resize() method
Parameters
Expand All @@ -238,9 +248,9 @@ def downscale_arr(arr, s_enhance, method=Image.Resampling.BILINEAR):
(lat, lon)
s_enhance : int
Integer factor by which the spatial axes are to be enhanced.
method : Image.Resampling.BILINEAR
method : Image.Resampling.LANCZOS
An Image.Resampling method (NEAREST, BILINEAR, BICUBIC, LANCZOS).
BILINEAR is default and has been tested to work best for
LANCZOS is default and has been tested to work best for
SurfaceSpatialMetModel.
"""
im = Image.fromarray(arr)
Expand Down Expand Up @@ -284,9 +294,15 @@ def downscale_temp(self, single_lr_temp, topo_lr, topo_hr):
assert len(topo_hr.shape) == 2, 'Bad shape for topo_hr'

lower_data = single_lr_temp.copy() + topo_lr * self._temp_lapse
hi_res_temp = self.downscale_arr(lower_data, self._s_enhance)
hi_res_temp = self.downscale_arr(lower_data, self._s_enhance,
method=self._interp_method)
hi_res_temp -= topo_hr * self._temp_lapse

if self._fix_bias:
hi_res_temp = self._fix_downscaled_bias(single_lr_temp,
hi_res_temp,
method=self._interp_method)

return hi_res_temp

def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
Expand Down Expand Up @@ -336,9 +352,12 @@ def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
assert len(topo_lr.shape) == 2, 'Bad shape for topo_lr'
assert len(topo_hr.shape) == 2, 'Bad shape for topo_hr'

interp_rh = self.downscale_arr(single_lr_rh, self._s_enhance)
interp_temp = self.downscale_arr(single_lr_temp, self._s_enhance)
interp_topo = self.downscale_arr(topo_lr, self._s_enhance)
interp_rh = self.downscale_arr(single_lr_rh, self._s_enhance,
method=self._interp_method)
interp_temp = self.downscale_arr(single_lr_temp, self._s_enhance,
method=self._interp_method)
interp_topo = self.downscale_arr(topo_lr, self._s_enhance,
method=self._interp_method)

delta_temp = single_hr_temp - interp_temp
delta_topo = topo_hr - interp_topo
Expand All @@ -347,6 +366,10 @@ def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
+ self._w_delta_temp * delta_temp
+ self._w_delta_topo * delta_topo)

if self._fix_bias:
hi_res_rh = self._fix_downscaled_bias(single_lr_rh, hi_res_rh,
method=self._interp_method)

return hi_res_rh

def downscale_pres(self, single_lr_pres, topo_lr, topo_hr):
Expand Down Expand Up @@ -388,21 +411,28 @@ def downscale_pres(self, single_lr_pres, topo_lr, topo_hr):
warn(msg)

const = 101325 * (1 - (1 - topo_lr / self._pres_div)**self._pres_exp)
single_lr_pres = single_lr_pres.copy() + const
lr_pres_adj = single_lr_pres.copy() + const

if np.min(single_lr_pres) < 0.0:
if np.min(lr_pres_adj) < 0.0:
msg = ('Spatial interpolation of surface pressure '
'resulted in negative values. Incorrectly '
'scaled/unscaled values or incorrect units are '
'the most likely causes.')
'the most likely causes. All pressure data should be '
'in Pascals.')
logger.error(msg)
raise ValueError(msg)

hi_res_pres = self.downscale_arr(single_lr_pres, self._s_enhance)
hi_res_pres = self.downscale_arr(lr_pres_adj, self._s_enhance,
method=self._interp_method)

const = 101325 * (1 - (1 - topo_hr / self._pres_div)**self._pres_exp)
hi_res_pres -= const

if self._fix_bias:
hi_res_pres = self._fix_downscaled_bias(single_lr_pres,
hi_res_pres,
method=self._interp_method)

if np.min(hi_res_pres) < 0.0:
msg = ('Spatial interpolation of surface pressure '
'resulted in negative values. Incorrectly '
Expand Down Expand Up @@ -524,25 +554,11 @@ def meta(self):
'pressure_exponent': self._pres_exp,
'training_features': self.training_features,
'output_features': self.output_features,
'interp_method': str(self._interp_method),
'fix_bias': self._fix_bias,
'class': self.__class__.__name__,
}

@property
def training_features(self):
"""Get the list of input feature names that the generative model was
trained on.
Note that topography needs to be passed into generate() as an exogenous
data input.
"""
return self._features

@property
def output_features(self):
"""Get the list of output feature names that the generative model
outputs"""
return self._features

def train(self, true_hr_temp, true_hr_rh, true_hr_topo):
"""This method trains the relative humidity linear model. The
temperature and surface lapse rate models are parameterizations taken
Expand Down
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def preflight(self):
msg = (f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger '
'than the number of time steps in the raw data '
f'({len(self.raw_time_index)}).')
if len(self.raw_time_index) >= self.sample_shape[2]:
if len(self.raw_time_index) < self.sample_shape[2]:
logger.warning(msg)
warnings.warn(msg)

Expand Down
38 changes: 37 additions & 1 deletion sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Loss metrics for Sup3r"""

from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
import tensorflow as tf


Expand Down Expand Up @@ -171,3 +171,39 @@ def __call__(self, x1, x2):
x1_coarse = tf.reduce_mean(x1, axis=(1, 2))
x2_coarse = tf.reduce_mean(x2, axis=(1, 2))
return self.MSE_LOSS(x1_coarse, x2_coarse)


class TemporalExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
timeseries"""

MAE_LOSS = MeanAbsoluteError()

def __call__(self, x1, x2):
"""Custom content loss that encourages temporal min/max accuracy
Parameters
----------
x1 : tf.tensor
synthetic generator output
(n_observations, spatial_1, spatial_2, temporal, features)
x2 : tf.tensor
high resolution data
(n_observations, spatial_1, spatial_2, temporal, features)
Returns
-------
tf.tensor
0D tensor with loss value
"""
x1_min = tf.reduce_min(x1, axis=3)
x2_min = tf.reduce_min(x2, axis=3)

x1_max = tf.reduce_max(x1, axis=3)
x2_max = tf.reduce_max(x2, axis=3)

mae = self.MAE_LOSS(x1, x2)
mae_min = self.MAE_LOSS(x1_min, x2_min)
mae_max = self.MAE_LOSS(x1_max, x2_max)

return mae + mae_min + mae_max
2 changes: 1 addition & 1 deletion sup3r/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""SUP3R Version"""

__version__ = '0.0.8'
__version__ = '0.0.9'
Loading

0 comments on commit 60fab96

Please sign in to comment.