Skip to content

Commit

Permalink
Merge pull request #250 from NREL/gb/solar_experiment
Browse files Browse the repository at this point in the history
Gb/solar experiment
  • Loading branch information
grantbuster authored Dec 26, 2024
2 parents 94eebae + dc0f42c commit 651f396
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 42 deletions.
79 changes: 53 additions & 26 deletions sup3r/models/solar_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,31 @@ class SolarCC(Sup3rGan):
Note
----
*Modifications to standard Sup3rGan*
- Content loss is only on the n_days of the center 8 daylight hours of
the daily true+synthetic high res samples
- Discriminator only sees n_days of the center 8 daylight hours of the
daily true high res sample.
- Discriminator sees random n_days of 8-hour samples of the daily
synthetic high res sample.
- Pointwise content loss (MAE/MSE) is only on the center 2 daylight
hours (POINT_LOSS_HOURS) of the daily true + synthetic days and the
temporal mean of the 24hours of synthetic for n_days
(usually just 1 day)
- Discriminator only sees n_days of the center 8 daylight hours
(DAYLIGHT_HOURS and STARTING_HOUR) of the daily true high res sample.
- Discriminator sees random n_days of 8-hour samples (DAYLIGHT_HOURS)
of the daily synthetic high res sample.
- Includes padding on high resolution output of :meth:`generate` so
that forward pass always outputs a multiple of 24 hours.
"""

# starting hour is the hour that daylight starts at, daylight hours is the
# number of daylight hours to sample, so for example if 8 and 8, the
# daylight slice will be slice(8, 16). The stride length is the step size
# for sampling the temporal axis of the generated data to send to the
# discriminator for the adversarial loss component of the generator. For
# example, if the generator produces 24 timesteps and stride is 4 and the
# daylight hours is 8, slices of (0, 8) (4, 12), (8, 16), (12, 20), and
# (16, 24) will be sent to the disc.
STARTING_HOUR = 8
"""Starting hour is the hour that daylight starts at, typically
zero-indexed and rolled to local time"""

DAYLIGHT_HOURS = 8
STRIDE_LEN = 4
"""Daylight hours is the number of daylight hours to sample, so for example
if STARTING_HOUR is 8 and DAYLIGHT_HOURS is 8, the daylight slice will be
slice(8, 16). """

POINT_LOSS_HOURS = 2
"""Number of hours from the center of the day to calculate pointwise loss
from, e.g., MAE/MSE based on data from the true 4km hourly high res
field."""

def __init__(self, *args, t_enhance=None, **kwargs):
"""Add optional t_enhance adjustment.
Expand Down Expand Up @@ -142,32 +146,55 @@ def calc_loss(

t_len = hi_res_true.shape[3]
n_days = int(t_len // 24)
day_slices = [

# slices for 24-hour full days
day_24h_slices = [slice(x, x + 24) for x in range(0, 24 * n_days, 24)]

# slices for middle-daylight-hours
sub_day_slices = [
slice(
self.STARTING_HOUR + x,
self.STARTING_HOUR + x + self.DAYLIGHT_HOURS,
)
for x in range(0, 24 * n_days, 24)
]

# slices for middle-pointwise-loss-hours
point_loss_slices = [
slice(
(24 - self.POINT_LOSS_HOURS) // 2 + x,
(24 - self.POINT_LOSS_HOURS) // 2 + x + self.POINT_LOSS_HOURS,
)
for x in range(0, 24 * n_days, 24)
]

# sample only daylight hours for disc training and gen content loss
disc_out_true = []
disc_out_gen = []
loss_gen_content = 0.0
for tslice in day_slices:
disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice, :])
gen_c = self.calc_loss_gen_content(
hi_res_true[:, :, :, tslice, :], hi_res_gen[:, :, :, tslice, :]
)
ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices)
for tslice_sub, tslice_ploss, tslice_24h in ziter:
hr_true_sub = hi_res_true[:, :, :, tslice_sub, :]
hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :]
hr_true_ploss = hi_res_true[:, :, :, tslice_ploss, :]
hr_gen_ploss = hi_res_gen[:, :, :, tslice_ploss, :]

hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3)
hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3)

gen_c_sub = self.calc_loss_gen_content(hr_true_ploss, hr_gen_ploss)
gen_c_24h = self.calc_loss_gen_content(hr_true_mean, hr_gen_mean)
loss_gen_content += gen_c_24h + gen_c_sub

disc_t = self._tf_discriminate(hr_true_sub)
disc_out_true.append(disc_t)
loss_gen_content += gen_c

# Randomly sample daylight windows from generated data. Better than
# strided samples covering full day because the random samples will
# provide an evenly balanced training set for the disc
logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS)]
time_samples = tf.random.categorical(logits, len(day_slices))
for i in range(len(day_slices)):
logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS + 1)]
time_samples = tf.random.categorical(logits, n_days)
for i in range(n_days):
t0 = time_samples[0, i]
t1 = t0 + self.DAYLIGHT_HOURS
disc_g = self._tf_discriminate(hi_res_gen[:, :, :, t0:t1, :])
Expand All @@ -177,7 +204,7 @@ def calc_loss(
disc_out_gen = tf.concat([disc_out_gen], axis=0)
loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen)

loss_gen_content /= len(day_slices)
loss_gen_content /= len(sub_day_slices)
loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen)
loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers

Expand Down
5 changes: 3 additions & 2 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def compute(self, **kwargs):
logger.debug(f'Loading dataset into memory: {self._ds}')
logger.debug(f'Pre-loading: {_mem_check()}')

for f in self._ds.data_vars:
self._ds[f] = self._ds[f].compute(**kwargs)
for f in list(self._ds.data_vars) + list(self._ds.coords):
if hasattr(self._ds[f], 'compute'):
self._ds[f] = self._ds[f].compute(**kwargs)
logger.debug(
f'Loaded {f} into memory with shape '
f'{self._ds[f].shape}. {_mem_check()}'
Expand Down
4 changes: 2 additions & 2 deletions sup3r/preprocessing/derivers/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,6 @@ class TasMax(Tas):
class Sza(DerivedFeature):
"""Solar zenith angle derived feature."""

inputs = ()

@classmethod
def compute(cls, data):
"""Compute method for sza."""
Expand All @@ -402,6 +400,8 @@ def compute(cls, data):
'cloud_mask': CloudMask,
'clearsky_ratio': ClearSkyRatio,
'sza': Sza,
'latitude_feature': 'latitude',
'longitude_feature': 'longitude',
}

RegistryH5WindCC = {
Expand Down
5 changes: 3 additions & 2 deletions sup3r/preprocessing/rasterizers/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def get_lat_lon(self):
return self._get_flat_data_lat_lon()

def _get_flat_data_lat_lon(self):
"""Get lat lon for flattened source data."""
"""Get lat lon for flattened source data. Output is shape (y, x, 2)
where 2 is (lat, lon)"""
if hasattr(self.full_lat_lon, 'vindex'):
return self.full_lat_lon.vindex[self.raster_index]
return self.full_lat_lon[self.raster_index.flatten]
return self.full_lat_lon[self.raster_index]
14 changes: 7 additions & 7 deletions sup3r/solar/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
sup3r_fps,
nsrdb_fp,
t_slice=slice(None),
tz=-6,
tz=-7,
agg_factor=1,
nn_threshold=0.5,
cloud_threshold=0.99,
Expand Down Expand Up @@ -64,8 +64,8 @@ def __init__(
tz : int
The timezone offset for the data in sup3r_fps. It is assumed that
the GAN is trained on data in local time and therefore the output
in sup3r_fps should be treated as local time. For example, -6 is
CST which is default for CONUS training data.
in sup3r_fps should be treated as local time. For example, -7 is
MST which is default for CONUS training data.
agg_factor : int
Spatial aggregation factor for nsrdb-to-GAN-meta e.g. the number of
NSRDB spatial pixels to average for a single sup3r GAN output site.
Expand Down Expand Up @@ -585,7 +585,7 @@ def run_temporal_chunks(
fp_pattern,
nsrdb_fp,
fp_out_suffix='irradiance',
tz=-6,
tz=-7,
agg_factor=1,
nn_threshold=0.5,
cloud_threshold=0.99,
Expand All @@ -610,8 +610,8 @@ def run_temporal_chunks(
tz : int
The timezone offset for the data in sup3r_fps. It is assumed that
the GAN is trained on data in local time and therefore the output
in sup3r_fps should be treated as local time. For example, -6 is
CST which is default for CONUS training data.
in sup3r_fps should be treated as local time. For example, -7 is
MST which is default for CONUS training data.
agg_factor : int
Spatial aggregation factor for nsrdb-to-GAN-meta e.g. the number of
NSRDB spatial pixels to average for a single sup3r GAN output site.
Expand Down Expand Up @@ -663,7 +663,7 @@ def _run_temporal_chunk(
fp_pattern,
nsrdb_fp,
fp_out_suffix='irradiance',
tz=-6,
tz=-7,
agg_factor=1,
nn_threshold=0.5,
cloud_threshold=0.99,
Expand Down
4 changes: 2 additions & 2 deletions tests/data_handlers/test_dh_nc_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ def test_data_handling_nc_cc():

handler = DataHandlerNCforCC(
pytest.FPS_GCM,
features=['u_100m', 'v_100m'],
features=['u_100m', 'v_100m', 'latitude_feature', 'longitude_feature'],
target=target,
shape=(20, 20),
)
assert handler.data.shape == (20, 20, 20, 2)
assert handler.data.shape == (20, 20, 20, 4)

# upper case features warning
features = [f'U_{int(plevel)}pa', f'V_{int(plevel)}pa']
Expand Down
12 changes: 12 additions & 0 deletions tests/rasterizers/test_rasterizer_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,15 @@ def test_topography_h5():
topo = res.get_meta_arr('elevation')[ri.flatten(),]
topo = topo.reshape((ri.shape[0], ri.shape[1]))
assert np.allclose(topo, rasterizer['topography'][..., 0])


def test_preloaded_h5():
"""Test preload of h5 file"""
rasterizer = Rasterizer(
file_paths=pytest.FP_WTK,
target=(39.01, -105.15),
shape=(20, 20),
chunks=None,
)
for f in list(rasterizer.data.data_vars) + list(Dimension.coords_2d()):
assert isinstance(rasterizer[f].data, np.ndarray)
1 change: 0 additions & 1 deletion tests/training/test_train_solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,4 +246,3 @@ def test_solar_custom_loss():
)

assert loss1 > loss2
assert loss2 == 0

0 comments on commit 651f396

Please sign in to comment.