Skip to content

Commit

Permalink
extra protection and tests against accidentally casting to float64
Browse files Browse the repository at this point in the history
grantbuster committed Nov 23, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 5fbc3e1 commit c1812fe
Showing 5 changed files with 59 additions and 14 deletions.
12 changes: 10 additions & 2 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
@@ -578,8 +578,8 @@ def set_norm_stats(self, new_means, new_stdevs):
logger.info("Model's previous data stdev values: {}".format(
self._stdevs))

self._means = new_means
self._stdevs = new_stdevs
self._means = {k: np.float32(v) for k, v in new_means.items()}
self._stdevs = {k: np.float32(v) for k, v in new_stdevs.items()}

if (not isinstance(self._means, dict)
or not isinstance(self._stdevs, dict)):
@@ -794,6 +794,14 @@ def load_saved_params(out_dir, verbose=True):
'following package versions: \n{}'.format(
pprint.pformat(version_record, indent=2)))

means = params.get('means', None)
stdevs = params.get('stdevs', None)
if means is not None and stdevs is not None:
means = {k: np.float32(v) for k, v in means.items()}
stdevs = {k: np.float32(v) for k, v in stdevs.items()}
params['means'] = means
params['stdevs'] = stdevs

return params

def get_high_res_exo_input(self, high_res):
10 changes: 8 additions & 2 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
@@ -328,6 +328,12 @@ def model_params(self):
dict
"""

means = self._means
stdevs = self._stdevs
if means is not None and stdevs is not None:
means = {k: float(v) for k, v in means.items()}
stdevs = {k: float(v) for k, v in stdevs.items()}

config_optm_g = self.get_optimizer_config(self.optimizer)
config_optm_d = self.get_optimizer_config(self.optimizer_disc)

@@ -337,8 +343,8 @@ def model_params(self):
'version_record': self.version_record,
'optimizer': config_optm_g,
'optimizer_disc': config_optm_d,
'means': self._means,
'stdevs': self._stdevs,
'means': means,
'stdevs': stdevs,
'meta': self.meta,
}

11 changes: 7 additions & 4 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
@@ -515,6 +515,7 @@ def handler_weights(self):
relative sizes"""
sizes = [dh.size for dh in self.data_handlers]
weights = sizes / np.sum(sizes)
weights = weights.astype(np.float32)
return weights

def get_handler_index(self):
@@ -680,8 +681,8 @@ def _get_stats(self):
'features.')
for feature in self.features:
logger.debug(f'Calculating mean/stdev for "{feature}"')
self.means[feature] = 0
self.stds[feature] = 0
self.means[feature] = np.float32(0)
self.stds[feature] = np.float32(0)
max_workers = self.stats_workers

if max_workers is None or max_workers >= 1:
@@ -755,7 +756,9 @@ def cache_stats(self):
logger.info(f'Saving stats to {fp}')
os.makedirs(os.path.dirname(fp), exist_ok=True)
with open(fp, 'w') as fh:
json.dump(data, fh)
# need to convert numpy float32 type to python float to be
# serializable in json
json.dump({k: float(v) for k, v in data.items()}, fh)

def get_stats(self):
"""Get standard deviations and means for all data features"""
@@ -803,7 +806,7 @@ def _get_feature_stdev(self, feature):
variance = dh.stds[feature]**2
self.stds[feature] += (variance * self.handler_weights[idh])

self.stds[feature] = np.sqrt(self.stds[feature])
self.stds[feature] = np.sqrt(self.stds[feature]).astype(np.float32)

return self.stds[feature]

36 changes: 30 additions & 6 deletions tests/training/test_train_gan.py
Original file line number Diff line number Diff line change
@@ -70,6 +70,8 @@ def test_train_spatial(log=False, full_shape=(20, 20),
assert 'test_1' in os.listdir(td)
assert 'model_gen.pkl' in os.listdir(td + '/test_1')
assert 'model_disc.pkl' in os.listdir(td + '/test_1')
assert model.means is not None
assert model.stdevs is not None

# make an un-trained dummy model
dummy = Sup3rGan(fp_gen, fp_disc, learning_rate=2e-5,
@@ -85,9 +87,15 @@ def test_train_spatial(log=False, full_shape=(20, 20),
assert isinstance(loaded.loss_fun, tf.keras.losses.MeanAbsoluteError)

for batch in batch_handler:
out_og = model._tf_generate(batch.low_res)
out_dummy = dummy._tf_generate(batch.low_res)
out_loaded = loaded._tf_generate(batch.low_res)
out_og = model.generate(batch.low_res, norm_in=True,
un_norm_out=True)
out_dummy = dummy.generate(batch.low_res, norm_in=True,
un_norm_out=True)
out_loaded = loaded.generate(batch.low_res, norm_in=True,
un_norm_out=True)
assert out_og.dtype == np.float32
assert out_dummy.dtype == np.float32
assert out_loaded.dtype == np.float32

# make sure the loaded model generates the same data as the saved
# model but different than the dummy
@@ -96,6 +104,10 @@ def test_train_spatial(log=False, full_shape=(20, 20),
tf.assert_equal(out_og, out_dummy)

# make sure the trained model has less loss than dummy
out_og = model.generate(batch.low_res, norm_in=False,
un_norm_out=False)
out_dummy = dummy.generate(batch.low_res, norm_in=False,
un_norm_out=False)
loss_og = model.calc_loss(batch.high_res, out_og)[0]
loss_dummy = dummy.calc_loss(batch.high_res, out_dummy)[0]
assert loss_og.numpy() < loss_dummy.numpy()
@@ -306,6 +318,8 @@ def test_train_st(n_epoch=2, log=False):
assert 'test_1' in os.listdir(td)
assert 'model_gen.pkl' in os.listdir(td + '/test_1')
assert 'model_disc.pkl' in os.listdir(td + '/test_1')
assert model.means is not None
assert model.stdevs is not None

# test save/load functionality
out_dir = os.path.join(td, 'st_gan')
@@ -330,9 +344,15 @@ def test_train_st(n_epoch=2, log=False):
learning_rate_disc=2e-5)

for batch in batch_handler:
out_og = model._tf_generate(batch.low_res)
out_dummy = dummy._tf_generate(batch.low_res)
out_loaded = loaded._tf_generate(batch.low_res)
out_og = model.generate(batch.low_res, norm_in=True,
un_norm_out=True)
out_dummy = dummy.generate(batch.low_res, norm_in=True,
un_norm_out=True)
out_loaded = loaded.generate(batch.low_res, norm_in=True,
un_norm_out=True)
assert out_og.dtype == np.float32
assert out_dummy.dtype == np.float32
assert out_loaded.dtype == np.float32

# make sure the loaded model generates the same data as the saved
# model but different than the dummy
@@ -341,6 +361,10 @@ def test_train_st(n_epoch=2, log=False):
tf.assert_equal(out_og, out_dummy)

# make sure the trained model has less loss than dummy
out_og = model.generate(batch.low_res, norm_in=False,
un_norm_out=False)
out_dummy = dummy.generate(batch.low_res, norm_in=False,
un_norm_out=False)
loss_og = model.calc_loss(batch.high_res, out_og)[0]
loss_dummy = dummy.calc_loss(batch.high_res, out_dummy)[0]
assert loss_og.numpy() < loss_dummy.numpy()
4 changes: 4 additions & 0 deletions tests/training/test_train_gan_exo.py
Original file line number Diff line number Diff line change
@@ -125,6 +125,7 @@ def test_wind_hi_res_topo_with_train_only(CustomLayer, log=False):
{'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}}
y = model.generate(x, exogenous_data=exo_tmp)

assert y.dtype == np.float32
assert y.shape[0] == x.shape[0]
assert y.shape[1] == x.shape[1] * 2
assert y.shape[2] == x.shape[2] * 2
@@ -219,6 +220,7 @@ def test_wind_hi_res_topo(CustomLayer, log=False):
{'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}}
y = model.generate(x, exogenous_data=exo_tmp)

assert y.dtype == np.float32
assert y.shape[0] == x.shape[0]
assert y.shape[1] == x.shape[1] * 2
assert y.shape[2] == x.shape[2] * 2
@@ -312,6 +314,7 @@ def test_wind_non_cc_hi_res_topo(CustomLayer, log=False):
{'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}}
y = model.generate(x, exogenous_data=exo_tmp)

assert y.dtype == np.float32
assert y.shape[0] == x.shape[0]
assert y.shape[1] == x.shape[1] * 2
assert y.shape[2] == x.shape[2] * 2
@@ -405,6 +408,7 @@ def test_wind_dc_hi_res_topo(CustomLayer, log=False):
{'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}}
y = model.generate(x, exogenous_data=exo_tmp)

assert y.dtype == np.float32
assert y.shape[0] == x.shape[0]
assert y.shape[1] == x.shape[1] * 2
assert y.shape[2] == x.shape[2] * 2

0 comments on commit c1812fe

Please sign in to comment.