diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 9c10b85f5..8dd06ad0b 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -618,7 +618,8 @@ def _parallel_normalization(self): max_workers = self.norm_workers if max_workers == 1: for dh in self.data_handlers: - dh.normalize(self.means, self.stds) + dh.normalize(self.means, self.stds, + max_workers=dh.norm_workers) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} @@ -691,7 +692,8 @@ def _get_stats(self): future = exe.submit(dh._get_stats) futures[future] = idh - for i, _ in enumerate(as_completed(futures)): + for i, future in enumerate(as_completed(futures)): + _ = future.result() logger.debug(f'{i+1} out of {len(self.data_handlers)} ' 'means calculated.') @@ -731,10 +733,10 @@ def check_cached_stats(self): means_check = means_check and os.path.exists(self.means_file) if stdevs_check and means_check: logger.info(f'Loading stdevs from {self.stdevs_file}') - with open(self.stdevs_file, 'r') as fh: + with open(self.stdevs_file) as fh: self.stds = json.load(fh) logger.info(f'Loading means from {self.means_file}') - with open(self.means_file, 'r') as fh: + with open(self.means_file) as fh: self.means = json.load(fh) msg = ('The training features and cached statistics are ' @@ -777,8 +779,7 @@ def _get_feature_means(self, feature): feature : str Feature to get mean for """ - - logger.debug(f'Calculating mean for {feature}') + logger.debug(f'Calculating multi-handler mean for {feature}') for idh, dh in enumerate(self.data_handlers): self.means[feature] += (self.handler_weights[idh] * dh.means[feature]) @@ -798,7 +799,7 @@ def _get_feature_stdev(self, feature): Feature to get stdev for """ - logger.debug(f'Calculating stdev for {feature}') + logger.debug(f'Calculating multi-handler stdev for {feature}') for idh, dh in enumerate(self.data_handlers): variance = dh.stds[feature]**2 self.stds[feature] += (variance * self.handler_weights[idh]) @@ -823,6 +824,9 @@ def normalize(self, means=None, stds=None): feature names and values: standard deviations. if None, this will be calculated. if norm is true these will be used for data normalization + features : list | None + Optional list of features used to index data array during + normalization. If this is None self.features will be used. """ if means is None or stds is None: self.get_stats() diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 414be799d..acbbb05c2 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -47,6 +47,7 @@ estimate_max_workers, get_chunk_slices, get_raster_shape, + nn_fill_array, np_to_pd_times, spatial_coarsening, uniform_box_sampler, @@ -449,17 +450,6 @@ def load_workers(self): n_procs) return load_workers - @property - def norm_workers(self): - """Get upper bound on workers used for normalization.""" - if self.data is not None: - norm_workers = estimate_max_workers(self._norm_workers, - 2 * self.feature_mem, - self.shape[-1]) - else: - norm_workers = self._norm_workers - return norm_workers - @property def time_chunks(self): """Get time chunks which will be extracted from source data @@ -543,8 +533,7 @@ def get_handle_features(cls, file_paths): handle_features = [] for f in file_paths: handle = cls.source_handler([f]) - for r in handle: - handle_features.append(Feature.get_basename(r)) + handle_features += [Feature.get_basename(r) for r in handle] return list(set(handle_features)) @property @@ -921,72 +910,6 @@ def get_cache_file_names(self, target, features) - @property - def means(self): - """Get the mean values for each feature. - - Returns - ------- - dict - """ - self._get_stats() - return self._means - - @property - def stds(self): - """Get the standard deviation values for each feature. - - Returns - ------- - dict - """ - self._get_stats() - return self._stds - - def _get_stats(self): - if self._means is None or self._stds is None: - msg = (f'DataHandler has {len(self.features)} features ' - f'and mismatched shape of {self.shape}') - assert len(self.features) == self.shape[-1], msg - self._stds = {} - self._means = {} - for idf, fname in enumerate(self.features): - self._means[fname] = np.nanmean(self.data[..., idf]) - self._stds[fname] = np.nanstd(self.data[..., idf]) - - def normalize(self, means=None, stds=None, max_workers=None): - """Normalize all data features. - - Parameters - ---------- - means : dict | none - Dictionary of means for all features with keys: feature names and - values: mean values. If this is None, the self.means attribute will - be used. If this is not None, this DataHandler object means - attribute will be updated. - stds : dict | none - dictionary of standard deviation values for all features with keys: - feature names and values: standard deviations. If this is None, the - self.stds attribute will be used. If this is not None, this - DataHandler object stds attribute will be updated. - max_workers : None | int - Max workers to perform normalization. if None, self.norm_workers - will be used - """ - if means is not None: - self._means = means - if stds is not None: - self._stds = stds - - max_workers = max_workers or self.norm_workers - if self._is_normalized: - logger.info('Skipping DataHandler, already normalized') - else: - self._normalize(self.data, - self.val_data, - max_workers=max_workers) - self._is_normalized = True - def get_next(self): """Get data for observation using random observation index. Loops repeatedly over randomized time index @@ -1159,7 +1082,7 @@ def run_all_data_init(self): self.run_data_compute() logger.info('Building final data array') - self.parallel_data_fill(shifted_time_chunks, self.extract_workers) + self.data_fill(shifted_time_chunks, self.extract_workers) if self.invert_lat: self.data = self.data[::-1] @@ -1182,8 +1105,16 @@ def run_all_data_init(self): logger.info(f'Finished extracting data for {self.input_file_info} in ' f'{dt.now() - now}') + + self.run_nn_fill() return self.data + def run_nn_fill(self): + """Run nn nan fill on full data array.""" + for i in range(self.data.shape[-1]): + if np.isnan(self.data[..., i]).any(): + self.data[..., i] = nn_fill_array(self.data[..., i]) + def run_data_extraction(self): """Run the raw dataset extraction process from disk to raw un-manipulated datasets. @@ -1238,7 +1169,7 @@ def run_data_compute(self): logger.info(f'Finished computing {self.derive_features} for ' f'{self.input_file_info}') - def data_fill(self, t, t_slice, f_index, f): + def _single_data_fill(self, t, t_slice, f_index, f): """Place single extracted / computed chunk in final data array Parameters @@ -1269,14 +1200,12 @@ def serial_data_fill(self, shifted_time_chunks): for t, ts in enumerate(shifted_time_chunks): for _, f in enumerate(self.noncached_features): f_index = self.features.index(f) - self.data_fill(t, ts, f_index, f) - interval = int(np.ceil(len(shifted_time_chunks) / 10)) - if t % interval == 0: - logger.info(f'Added {t + 1} of {len(shifted_time_chunks)} ' - 'chunks to final data array') + self._single_data_fill(t, ts, f_index, f) + logger.info(f'Added {t + 1} of {len(shifted_time_chunks)} ' + 'chunks to final data array') self._raw_data.pop(t) - def parallel_data_fill(self, shifted_time_chunks, max_workers=None): + def data_fill(self, shifted_time_chunks, max_workers=None): """Fill final data array with extracted / computed chunks Parameters @@ -1304,13 +1233,13 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): for t, ts in enumerate(shifted_time_chunks): for _, f in enumerate(self.noncached_features): f_index = self.features.index(f) - future = exe.submit(self.data_fill, t, ts, f_index, f) + future = exe.submit(self._single_data_fill, + t, ts, f_index, f) futures[future] = {'t': t, 'fidx': f_index} logger.info(f'Started adding {len(futures)} chunks ' f'to data array in {dt.now() - now}.') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): try: future.result() @@ -1320,9 +1249,8 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): 'final data array.') logger.exception(msg) raise RuntimeError(msg) from e - if i % interval == 0: - logger.debug(f'Added {i+1} out of {len(futures)} ' - 'chunks to final data array') + logger.debug(f'Added {i+1} out of {len(futures)} ' + 'chunks to final data array') logger.info('Finished building data array') @abstractmethod diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index b040e78cb..affcc95dc 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -72,28 +72,30 @@ def __init__(self, self.t_enhance = t_enhance self.lr_dh = lr_handler self.hr_dh = hr_handler - self._cache_pattern = cache_pattern - self._cached_features = None - self._noncached_features = None self.overwrite_cache = overwrite_cache self.val_split = val_split self.current_obs_index = None self.load_cached = load_cached self.regrid_workers = regrid_workers self.shuffle_time = shuffle_time - self._lr_lat_lon = None - self._hr_lat_lon = None - self._lr_input_data = None self.hr_data = None self.lr_val_data = None self.hr_val_data = None - self.lr_time_index = None - self.hr_time_index = None - self.lr_val_time_index = None - self.hr_val_time_index = None - - lr_data_shape = (*self.lr_required_shape, len(self.lr_dh.features)) - self.lr_data = np.zeros(lr_data_shape, dtype=np.float32) + self.lr_data = np.zeros(self.shape, dtype=np.float32) + self.lr_time_index = lr_handler.time_index + self.hr_time_index = hr_handler.time_index + self.lr_val_time_index = lr_handler.val_time_index + self.hr_val_time_index = hr_handler.val_time_index + self._lr_lat_lon = None + self._hr_lat_lon = None + self._lr_input_data = None + self._cache_pattern = cache_pattern + self._cached_features = None + self._noncached_features = None + self._means = None + self._stds = None + self._is_normalized = False + self._norm_workers = self.lr_dh.norm_workers if self.try_load and self.load_cached: self.load_cached_data() @@ -163,7 +165,7 @@ def _val_split_check(self): def _get_stats(self): """Get mean/stdev stats for HR and LR data handlers""" - self.lr_dh._get_stats() + super()._get_stats(features=self.lr_dh.features) self.hr_dh._get_stats() @property @@ -177,7 +179,7 @@ def means(self): dict """ out = copy.deepcopy(self.hr_dh.means) - out.update(self.lr_dh.means) + out.update(super().means) return out @property @@ -191,9 +193,10 @@ def stds(self): dict """ out = copy.deepcopy(self.hr_dh.stds) - out.update(self.lr_dh.stds) + out.update(super().stds) return out + # pylint: disable=unused-argument def normalize(self, means=None, stds=None, max_workers=None): """Normalize low_res and high_res data @@ -210,8 +213,7 @@ def normalize(self, means=None, stds=None, max_workers=None): self.stds attribute will be used. If this is not None, this DataHandler object stds attribute will be updated. max_workers : None | int - Max workers to perform normalization. if None, self.norm_workers - will be used + Has no effect. Used to match MixIn class signature. """ if means is None: means = self.means @@ -219,10 +221,14 @@ def normalize(self, means=None, stds=None, max_workers=None): stds = self.stds logger.info('Normalizing low resolution data features=' f'{self.lr_dh.features}') - self.lr_dh.normalize(means=means, stds=stds, max_workers=max_workers) + super().normalize(means=means, stds=stds, + features=self.lr_dh.features, + max_workers=self.lr_dh.norm_workers) logger.info('Normalizing high resolution data features=' f'{self.hr_dh.features}') - self.hr_dh.normalize(means=means, stds=stds, max_workers=max_workers) + self.hr_dh.normalize(means=means, stds=stds, + features=self.hr_dh.features, + max_workers=self.hr_dh.norm_workers) @property def features(self): @@ -269,8 +275,8 @@ def _shape_check(self): self.hr_dh.load_cached_data(with_split=False) msg = (f'hr_handler.shape {self.hr_dh.shape[:-1]} is not divisible ' - f'by s_enhance. Using shape = {self.hr_required_shape} ' - 'instead.') + f'by s_enhance ({self.s_enhance}). Using shape = ' + f'{self.hr_required_shape} instead.') if self.hr_dh.shape[:-1] != self.hr_required_shape: logger.warning(msg) warn(msg) @@ -364,9 +370,15 @@ def hr_sample_shape(self): @property def data(self): """Get low res data. Same as self.lr_data but used to match property - used by batch handler for computing means and stdevs""" + used for computing means and stdevs""" return self.lr_data + @property + def val_data(self): + """Get low res validation data. Same as self.lr_val_data but used to + match property used by normalization routine.""" + return self.lr_val_data + @property def lr_input_data(self): """Get low res data used as input to regridding routine""" @@ -406,11 +418,6 @@ def lr_grid_shape(self): """Return grid shape for regridded low_res data""" return (self.lr_required_shape[0], self.lr_required_shape[1]) - @property - def lr_requested_shape(self): - """Return requested shape for low_res data""" - return (*self.lr_required_shape, len(self.features)) - @property def lr_lat_lon(self): """Get low_res lat lon array""" @@ -472,10 +479,10 @@ def load_lr_cached_data(self): """Load low_res cache data""" logger.info( - f'Loading cache with requested_shape={self.lr_requested_shape}.') + f'Loading cache with requested_shape={self.shape}.') self._load_cached_data(self.lr_data, self.cache_files, - self.features, + self.lr_dh.features, max_workers=self.hr_dh.load_workers) def load_cached_data(self): diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 0d071a082..2d3037c5d 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -17,6 +17,7 @@ from scipy.stats import mode from sup3r.utilities.utilities import ( + estimate_max_workers, get_source_type, ignore_case_path_fetch, uniform_box_sampler, @@ -887,7 +888,8 @@ def time_index_file(self): return None if self.cache_pattern is not None and self._time_index_file is None: - basename = self.cache_pattern.replace('{times}', '') + basename = self.cache_pattern.replace('_{times}', '') + basename = basename.replace('{times}', '') basename = basename.replace('{shape}', str(len(self.file_paths))) basename = basename.replace('_{target}', '') basename = basename.replace('{feature}', 'time_index') @@ -923,8 +925,14 @@ class TrainingPrepMixIn: def __init__(self): """Initialize common attributes""" self.features = None - self.means = None - self.stds = None + self.data = None + self.val_data = None + self.feature_mem = None + self.shape = None + self._means = None + self._stds = None + self._is_normalized = False + self._norm_workers = None @classmethod def _split_data_indices(cls, @@ -1030,7 +1038,7 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): logger.debug(f'Finished normalizing {self.features[feature_index]} ' f'with mean {mean:.3e} and std {std:.3e}.') - def _normalize(self, data, val_data, max_workers=None): + def _normalize(self, data, val_data, features=None, max_workers=None): """Normalize all data features Parameters @@ -1041,27 +1049,31 @@ def _normalize(self, data, val_data, max_workers=None): val_data : np.ndarray Array of validation data. (spatial_1, spatial_2, temporal, n_features) + features : list | None + List of features used for indexing data array during normalization. max_workers : int | None Number of workers to use in thread pool for nomalization. """ + if features is None: + features = self.features - msg1 = (f'Not all feature names {self.features} were found in ' + msg1 = (f'Not all feature names {features} were found in ' f'self.means: {list(self.means.keys())}') - msg2 = (f'Not all feature names {self.features} were found in ' + msg2 = (f'Not all feature names {features} were found in ' f'self.stds: {list(self.stds.keys())}') - assert all(fn in self.means for fn in self.features), msg1 - assert all(fn in self.stds for fn in self.features), msg2 + assert all(fn in self.means for fn in features), msg1 + assert all(fn in self.stds for fn in features), msg2 - logger.info(f'Normalizing {data.shape[-1]} features: {self.features}') + logger.info(f'Normalizing {data.shape[-1]} features: {features}') if max_workers == 1: - for idf, feature in enumerate(self.features): + for idf, feature in enumerate(features): self._normalize_data(data, val_data, idf, self.means[feature], self.stds[feature]) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = [] - for idf, feature in enumerate(self.features): + for idf, feature in enumerate(features): future = exe.submit(self._normalize_data, data, val_data, idf, self.means[feature], @@ -1076,3 +1088,88 @@ def _normalize(self, data, val_data, max_workers=None): f'{futures[future]}.') logger.exception(msg) raise RuntimeError(msg) from e + + @property + def means(self): + """Get the mean values for each feature. + + Returns + ------- + dict + """ + self._get_stats() + return self._means + + @property + def stds(self): + """Get the standard deviation values for each feature. + + Returns + ------- + dict + """ + self._get_stats() + return self._stds + + def _get_stats(self, features=None): + """Get the mean/stdev for each feature in the data handler.""" + if features is None: + features = self.features + if self._means is None or self._stds is None: + msg = (f'DataHandler has {len(features)} features ' + f'and mismatched shape of {self.shape}') + assert len(features) == self.shape[-1], msg + self._stds = {} + self._means = {} + for idf, fname in enumerate(features): + self._means[fname] = np.nanmean( + self.data[..., idf].astype(np.float64)) + self._stds[fname] = np.nanstd( + self.data[..., idf].astype(np.float64)) + + def normalize(self, means=None, stds=None, features=None, + max_workers=None): + """Normalize all data features. + + Parameters + ---------- + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. If this is None, the self.means attribute will + be used. If this is not None, this DataHandler object means + attribute will be updated. + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. If this is None, the + self.stds attribute will be used. If this is not None, this + DataHandler object stds attribute will be updated. + features : list | None + List of features used for indexing data array during normalization. + max_workers : None | int + Max workers to perform normalization. if None, self.norm_workers + will be used + """ + if means is not None: + self._means = means + if stds is not None: + self._stds = stds + + if self._is_normalized: + logger.info('Skipping DataHandler, already normalized') + else: + self._normalize(self.data, + self.val_data, + features=features, + max_workers=max_workers) + self._is_normalized = True + + @property + def norm_workers(self): + """Get upper bound on workers used for normalization.""" + if self.data is not None: + norm_workers = estimate_max_workers(self._norm_workers, + 2 * self.feature_mem, + self.shape[-1]) + else: + norm_workers = self._norm_workers + return norm_workers diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index fee3b8827..af2b39941 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -54,7 +54,11 @@ class EraDownloader: 'total_precipitation', "convective_available_potential_energy", "2m_dewpoint_temperature", "convective_inhibition", "surface_latent_heat_flux", "instantaneous_moisture_flux", - "mean_total_precipitation_rate" + "mean_total_precipitation_rate", "mean_sea_level_pressure", + "friction_velocity", "lake_cover", "high_vegetation_cover", + "land_sea_mask", "k_index", "forecast_surface_roughness", + "northward_turbulent_surface_stress", + "eastward_turbulent_surface_stress", ] # variables available on multiple pressure levels diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index bb5c53296..21d87d5f0 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -1,5 +1,6 @@ """Loss metrics for Sup3r""" +import numpy as np import tensorflow as tf from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError @@ -144,7 +145,7 @@ def __call__(self, x1, x2, sigma=1.0): """ mmd = self.MMD_LOSS(x1, x2, sigma=sigma) mse = self.MSE_LOSS(x1, x2) - return mmd + mse + return (mmd + mse) / 2 class CoarseMseLoss(tf.keras.losses.Loss): @@ -175,11 +176,47 @@ def __call__(self, x1, x2): return self.MSE_LOSS(x1_coarse, x2_coarse) +class SpatialExtremesOnlyLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the min/max values in the + spatial domain. This does not include an additional MAE term""" + + MAE_LOSS = MeanAbsoluteError() + + def __call__(self, x1, x2): + """Custom content loss that encourages temporal min/max accuracy + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + x1_min = tf.reduce_min(x1, axis=(1, 2)) + x2_min = tf.reduce_min(x2, axis=(1, 2)) + + x1_max = tf.reduce_max(x1, axis=(1, 2)) + x2_max = tf.reduce_max(x2, axis=(1, 2)) + + mae_min = self.MAE_LOSS(x1_min, x2_min) + mae_max = self.MAE_LOSS(x1_max, x2_max) + + return (mae_min + mae_max) / 2 + + class SpatialExtremesLoss(tf.keras.losses.Loss): """Loss class that encourages accuracy of the min/max values in the spatial domain""" MAE_LOSS = MeanAbsoluteError() + EX_LOSS = SpatialExtremesOnlyLoss() def __init__(self, weight=1.0): """Initialize the loss with given weight @@ -210,17 +247,45 @@ def __call__(self, x1, x2): tf.tensor 0D tensor with loss value """ - x1_min = tf.reduce_min(x1, axis=(1, 2)) - x2_min = tf.reduce_min(x2, axis=(1, 2)) + mae = self.MAE_LOSS(x1, x2) + ex_mae = self.EX_LOSS(x1, x2) - x1_max = tf.reduce_max(x1, axis=(1, 2)) - x2_max = tf.reduce_max(x2, axis=(1, 2)) + return (mae + 2 * self._weight * ex_mae) / 3 + + +class TemporalExtremesOnlyLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the min/max values in the + timeseries. This does not include an additional mae term""" + + MAE_LOSS = MeanAbsoluteError() + + def __call__(self, x1, x2): + """Custom content loss that encourages temporal min/max accuracy + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, temporal, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, temporal, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + x1_min = tf.reduce_min(x1, axis=3) + x2_min = tf.reduce_min(x2, axis=3) + + x1_max = tf.reduce_max(x1, axis=3) + x2_max = tf.reduce_max(x2, axis=3) - mae = self.MAE_LOSS(x1, x2) mae_min = self.MAE_LOSS(x1_min, x2_min) mae_max = self.MAE_LOSS(x1_max, x2_max) - return mae + self._weight * (mae_min + mae_max) + return (mae_min + mae_max) / 2 class TemporalExtremesLoss(tf.keras.losses.Loss): @@ -228,6 +293,7 @@ class TemporalExtremesLoss(tf.keras.losses.Loss): timeseries""" MAE_LOSS = MeanAbsoluteError() + EX_LOSS = TemporalExtremesOnlyLoss() def __init__(self, weight=1.0): """Initialize the loss with given weight @@ -258,23 +324,20 @@ def __call__(self, x1, x2): tf.tensor 0D tensor with loss value """ - x1_min = tf.reduce_min(x1, axis=3) - x2_min = tf.reduce_min(x2, axis=3) - - x1_max = tf.reduce_max(x1, axis=3) - x2_max = tf.reduce_max(x2, axis=3) - mae = self.MAE_LOSS(x1, x2) - mae_min = self.MAE_LOSS(x1_min, x2_min) - mae_max = self.MAE_LOSS(x1_max, x2_max) + ex_mae = self.EX_LOSS(x1, x2) - return mae + self._weight * (mae_min + mae_max) + return (mae + 2 * self._weight * ex_mae) / 3 class SpatiotemporalExtremesLoss(tf.keras.losses.Loss): """Loss class that encourages accuracy of the min/max values across both space and time""" + MAE_LOSS = MeanAbsoluteError() + S_EX_LOSS = SpatialExtremesOnlyLoss() + T_EX_LOSS = TemporalExtremesOnlyLoss() + def __init__(self, spatial_weight=1.0, temporal_weight=1.0): """Initialize the loss with given weight @@ -286,13 +349,11 @@ def __init__(self, spatial_weight=1.0, temporal_weight=1.0): Weight for temporal min/max loss terms. """ super().__init__() - self.sp_ex_loss = SpatialExtremesLoss(2 * temporal_weight) - self.temp_ex_loss = TemporalExtremesLoss(2 * spatial_weight) + self.s_weight = spatial_weight + self.t_weight = temporal_weight def __call__(self, x1, x2): """Custom content loss that encourages spatiotemporal min/max accuracy. - This is computed as 1/2 times the sum of spatial and temporal extremes - loss functions with doubled weights. Parameters ---------- @@ -308,4 +369,146 @@ def __call__(self, x1, x2): tf.tensor 0D tensor with loss value """ - return 0.5 * (self.sp_ex_loss(x1, x2) + self.temp_ex_loss(x1, x2)) + mae = self.MAE_LOSS(x1, x2) + s_ex_mae = self.S_EX_LOSS(x1, x2) + t_ex_mae = self.T_EX_LOSS(x1, x2) + return (mae + 2 * self.s_weight * s_ex_mae + + 2 * self.t_weight * t_ex_mae) / 5 + + +class SpatialFftOnlyLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the spatial frequency spectrum""" + + MAE_LOSS = MeanAbsoluteError() + + @staticmethod + def _freq_weights(x): + """Get product of squared frequencies to weight frequency amplitudes""" + k0 = np.array([k**2 for k in range(x.shape[1])]) + k1 = np.array([k**2 for k in range(x.shape[2])]) + freqs = np.multiply.outer(k0, k1) + freqs = tf.convert_to_tensor(freqs[np.newaxis, ..., np.newaxis]) + return tf.cast(freqs, x.dtype) + + def _fft(self, x): + """Apply needed transpositions and fft operation.""" + x_hat = tf.transpose(x, perm=[3, 0, 1, 2]) + x_hat = tf.signal.fft2d(tf.cast(x_hat, tf.complex64)) + x_hat = tf.transpose(x_hat, perm=[1, 2, 3, 0]) + x_hat = tf.cast(tf.abs(x_hat), x.dtype) + x_hat = tf.math.multiply(self._freq_weights(x), x_hat) + return tf.math.log(1 + x_hat) + + def __call__(self, x1, x2): + """Custom content loss that encourages frequency domain accuracy + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + x1_hat = self._fft(x1) + x2_hat = self._fft(x2) + return self.MAE_LOSS(x1_hat, x2_hat) + + +class SpatiotemporalFftOnlyLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the spatiotemporal frequency + spectrum""" + + MAE_LOSS = MeanAbsoluteError() + + @staticmethod + def _freq_weights(x): + """Get product of squared frequencies to weight frequency amplitudes""" + k0 = np.array([k**2 for k in range(x.shape[1])]) + k1 = np.array([k**2 for k in range(x.shape[2])]) + f = np.array([f**2 for f in range(x.shape[3])]) + freqs = np.multiply.outer(k0, k1) + freqs = np.multiply.outer(freqs, f) + freqs = tf.convert_to_tensor(freqs[np.newaxis, ..., np.newaxis]) + return tf.cast(freqs, x.dtype) + + def _fft(self, x): + """Apply needed transpositions and fft operation.""" + x_hat = tf.transpose(x, perm=[4, 0, 1, 2, 3]) + x_hat = tf.signal.fft3d(tf.cast(x_hat, tf.complex64)) + x_hat = tf.transpose(x_hat, perm=[1, 2, 3, 4, 0]) + x_hat = tf.cast(tf.abs(x_hat), x.dtype) + x_hat = tf.math.multiply(self._freq_weights(x), x_hat) + return tf.math.log(1 + x_hat) + + def __call__(self, x1, x2): + """Custom content loss that encourages frequency domain accuracy + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, temporal, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, temporal, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + x1_hat = self._fft(x1) + x2_hat = self._fft(x2) + return self.MAE_LOSS(x1_hat, x2_hat) + + +class StExtremesFftLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the min/max values across both + space and time as well as frequency domain accuracy.""" + + def __init__(self, spatial_weight=1.0, temporal_weight=1.0, + fft_weight=1.0): + """Initialize the loss with given weight + + Parameters + ---------- + spatial_weight : float + Weight for spatial min/max loss terms. + temporal_weight : float + Weight for temporal min/max loss terms. + fft_weight : float + Weight for the fft loss term. + """ + super().__init__() + self.st_ex_loss = SpatiotemporalExtremesLoss(spatial_weight, + temporal_weight) + self.fft_loss = SpatiotemporalFftOnlyLoss() + self.fft_weight = fft_weight + + def __call__(self, x1, x2): + """Custom content loss that encourages spatiotemporal min/max accuracy + and fft accuracy. + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, temporal, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, temporal, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + return (5 * self.st_ex_loss(x1, x2) + + self.fft_weight * self.fft_loss(x1, x2)) / 6