diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 19f99d3cb..028cddf27 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -427,6 +427,11 @@ def test_normalization(log=False, t_enhance=t_enhance, val_split=0.1) + hr_means0 = np.mean(hr_handler.data, axis=(0, 1, 2)) + lr_means0 = np.mean(lr_handler.data, axis=(0, 1, 2)) + ddh_hr_means0 = np.mean(dual_handler.hr_data, axis=(0, 1, 2)) + ddh_lr_means0 = np.mean(dual_handler.lr_data, axis=(0, 1, 2)) + means = copy.deepcopy(lr_handler.means) stdevs = copy.deepcopy(lr_handler.stds) @@ -436,6 +441,11 @@ def test_normalization(log=False, t_enhance=t_enhance, n_batches=10) + hr_means1 = np.mean(hr_handler.data, axis=(0, 1, 2)) + lr_means1 = np.mean(lr_handler.data, axis=(0, 1, 2)) + ddh_hr_means1 = np.mean(dual_handler.hr_data, axis=(0, 1, 2)) + ddh_lr_means1 = np.mean(dual_handler.lr_data, axis=(0, 1, 2)) + assert all(means[k] == v for k, v in batch_handler.means.items()) assert all(stdevs[k] == v for k, v in batch_handler.stds.items()) @@ -446,6 +456,12 @@ def test_normalization(log=False, assert np.allclose(std, 1, atol=1e-3), str(std) assert np.allclose(mean, 0, atol=1e-3), str(mean) + fn = FEATURES[idf] + assert np.allclose(hr_means0[idf] - means[fn], hr_means1[idf]) + assert np.allclose(lr_means0[idf] - means[fn], lr_means1[idf]) + assert np.allclose(ddh_hr_means0[idf] - means[fn], ddh_hr_means1[idf]) + assert np.allclose(ddh_lr_means0[idf] - means[fn], ddh_lr_means1[idf]) + @pytest.mark.parametrize(['lr_features', 'hr_features', 'hr_exo_features'], [(['U_100m'], ['U_100m', 'V_100m'], ['V_100m']),