Skip to content

Commit

Permalink
Merge pull request #43 from IMMM-SFA/feature/tests
Browse files Browse the repository at this point in the history
Tests to increase coverage of core functionality
  • Loading branch information
mcgrathc authored Oct 3, 2022
2 parents cc2c275 + 50b7fa8 commit 0842c90
Show file tree
Hide file tree
Showing 20 changed files with 48,438 additions and 10 deletions.
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ include tell/data/balancing_authority_modeled.yml
include tell/data/mlp_settings.yml
include tell/data/hyperparameters.csv
include tell/data/models/*
include tell/tests/data/*
include tell/tests/data/train_data/*
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
copyright = '2022, Batelle Memorial Institute'
author = 'Casey Burleyson, Casey McGrath'

version = 'v0.1.3'
version = 'v0.1.4'


# -- General configuration ---------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ sklearn>=0.0
threadpoolctl>=3.1.0
urllib3>=1.26.8
tqdm>=4.63.0
fastparquet>=0.8.3
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def readme():
'six>=1.16.0',
'sklearn>=0.0',
'threadpoolctl>=3.1.0',
'urllib3>=1.26.8'
'urllib3>=1.26.8',
'fastparquet>=0.8.3'
],
extras_require={
'dev': [
Expand Down
2 changes: 1 addition & 1 deletion tell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
from .visualization import *

# Set the current version of TELL:
__version__ = '0.1.3'
__version__ = '0.1.4'
3 changes: 2 additions & 1 deletion tell/install_forcing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class InstallForcingSample:
'0.1.0': 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1',
'0.1.1': 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1',
'0.1.2': 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1',
'0.1.3': 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1'}
'0.1.3': 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1',
'0.1.4': 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1'}

DEFAULT_VERSION = 'https://zenodo.org/record/6354665/files/sample_forcing_data.zip?download=1'

Expand Down
3 changes: 2 additions & 1 deletion tell/install_quickstarter_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class InstallQuickstarterData:
'0.1.0': 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1',
'0.1.1': 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1',
'0.1.2': 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1',
'0.1.3': 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1'}
'0.1.3': 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1',
'0.1.4': 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1'}

DEFAULT_VERSION = 'https://zenodo.org/record/6804242/files/tell_quickstarter_data.zip?download=1'

Expand Down
3 changes: 2 additions & 1 deletion tell/install_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class InstallRawData:
'0.1.0': 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1',
'0.1.1': 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1',
'0.1.2': 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1',
'0.1.3': 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1'}
'0.1.3': 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1',
'0.1.4': 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1'}

DEFAULT_VERSION = 'https://zenodo.org/record/6378036/files/tell_raw_data.zip?download=1'

Expand Down
8,761 changes: 8,761 additions & 0 deletions tell/tests/data/ERCO_WRF_Hourly_Mean_Meteorology_2039.csv

Large diffs are not rendered by default.

Binary file added tell/tests/data/comp_predict.parquet
Binary file not shown.
Binary file added tell/tests/data/comp_train.parquet
Binary file not shown.
39,475 changes: 39,475 additions & 0 deletions tell/tests/data/train_data/ERCO_historical_data.csv

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
21 changes: 21 additions & 0 deletions tell/tests/test_install_forcing_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

import tell
import tell.install_forcing_data as td


class TestInstallForcingData(unittest.TestCase):

def test_instantiate(self):

zen = td.InstallForcingSample(data_dir="fake")

# ensure default version is set
self.assertEqual(str, type(zen.DEFAULT_VERSION))

# ensure urls present for current version
self.assertTrue(tell.__version__ in zen.DATA_VERSION_URLS)


if __name__ == '__main__':
unittest.main()
21 changes: 21 additions & 0 deletions tell/tests/test_install_quickstarter_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

import tell
from tell.install_quickstarter_data import InstallQuickstarterData


class TestInstallQuickstarterData(unittest.TestCase):

def test_instantiate(self):

zen = InstallQuickstarterData(data_dir="fake")

# ensure default version is set
self.assertEqual(str, type(zen.DEFAULT_VERSION))

# ensure urls present for current version
self.assertTrue(tell.__version__ in zen.DATA_VERSION_URLS)


if __name__ == '__main__':
unittest.main()
21 changes: 21 additions & 0 deletions tell/tests/test_install_raw_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

import tell
import tell.install_raw_data as td


class TestInstallRawData(unittest.TestCase):

def test_instantiate(self):

zen = td.InstallRawData(data_dir="fake")

# ensure default version is set
self.assertEqual(str, type(zen.DEFAULT_VERSION))

# ensure urls present for current version
self.assertTrue(tell.__version__ in zen.DATA_VERSION_URLS)


if __name__ == '__main__':
unittest.main()
29 changes: 25 additions & 4 deletions tell/tests/test_mlp_predict.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
import unittest
import pkg_resources

import pandas as pd

class TestPackageData(unittest.TestCase):
import tell.mlp_predict as mp


class TestMlpPredict(unittest.TestCase):
"""Tests for functionality within mlp_predict.py"""

def test_mlp_predict(self):
"""Test to ensure high level functionality of mlp_predict.py()"""
pass
COMP_PREDICT_DF = pd.read_parquet(pkg_resources.resource_filename("tell", "tests/data/comp_predict.parquet"))

def test_predict(self):
"""Test to ensure high level functionality of predict()"""

df = mp.predict(region="ERCO",
year=2039,
data_dir=pkg_resources.resource_filename("tell", "tests/data"))

pd.testing.assert_frame_equal(TestMlpPredict.COMP_PREDICT_DF, df)

def test_predict_batch(self):
"""Test to ensure high level functionality of predict_batch()"""

df = mp.predict_batch(target_region_list=["ERCO"],
year=2039,
data_dir=pkg_resources.resource_filename("tell", "tests/data"))

pd.testing.assert_frame_equal(TestMlpPredict.COMP_PREDICT_DF, df)


if __name__ == '__main__':
Expand Down
56 changes: 56 additions & 0 deletions tell/tests/test_mlp_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
import pkg_resources

import numpy as np
import pandas as pd

import tell.mlp_train as mp


class TestMlpTrain(unittest.TestCase):
"""Tests for functionality within mlp_train.py"""

COMP_TRAIN_ARR = pd.read_parquet(pkg_resources.resource_filename("tell", "tests/data/comp_train.parquet"))["a"].values
COMP_TRAIN_PRED_DF = pd.read_parquet(pkg_resources.resource_filename("tell", "tests/data/train_data/comp_train_pred.parquet"))
COMP_TRAIN_VALID_DF = pd.read_parquet(pkg_resources.resource_filename("tell", "tests/data/train_data/comp_train_valid.parquet"))

def test_train_mlp_model(self):
"""Test to ensure high level functionality of predict()"""

np.random.seed(123)

arr = mp.train_mlp_model(region="ERCO",
x_train=np.arange(0.1, 100.0, 1.1).reshape(-1, 1),
y_train=np.arange(0.1, 100.0, 1.1).reshape(-1, 1),
x_test=np.arange(0.1, 100.0, 1.1).reshape(-1, 1),
mlp_hidden_layer_sizes=10,
mlp_max_iter=2,
mlp_validation_fraction=0.4)

np.testing.assert_array_equal(TestMlpTrain.COMP_TRAIN_ARR, arr)

def test_train(self):
"""Test to ensure high level functionality of train()"""

np.random.seed(123)

prediction_df, validation_df = mp.train(region="ERCO",
data_dir=pkg_resources.resource_filename("tell", "tests/data/train_data"))

pd.testing.assert_frame_equal(TestMlpTrain.COMP_TRAIN_PRED_DF, prediction_df)
pd.testing.assert_frame_equal(TestMlpTrain.COMP_TRAIN_VALID_DF, validation_df)

def test_train_batch(self):
"""Test to ensure high level functionality of train()"""

np.random.seed(123)

prediction_df, validation_df = mp.train_batch(target_region_list=["ERCO"],
data_dir=pkg_resources.resource_filename("tell", "tests/data/train_data"))

pd.testing.assert_frame_equal(TestMlpTrain.COMP_TRAIN_PRED_DF, prediction_df)
pd.testing.assert_frame_equal(TestMlpTrain.COMP_TRAIN_VALID_DF, validation_df)


if __name__ == '__main__':
unittest.main()
45 changes: 45 additions & 0 deletions tell/tests/test_mlp_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import numpy as np
import pandas as pd

import tell.mlp_utils as mpu

Expand Down Expand Up @@ -28,6 +29,17 @@ class TestMlpUtils(unittest.TestCase):
[1.61, 0.09, -0.36]]),
'y_test_norm': np.array([1.36, 0.5, 0.09])}

COMP_DENORM_DF = pd.DataFrame({"datetime": [1.1, 2.2],
"predictions": [3.4, 7.8],
"ground_truth": [1.1, 2.2],
"region": ["alpha"]*2})

COMP_EVAL_DF = pd.DataFrame({"BA": "alpha",
"RMS_ABS": [0.070711],
"RMS_NORM": [0.041595],
"MAPE": [0.022727],
"R2": [0.983471]}).round(4)

def test_normalize_prediction_data(self):
"""Test to ensure high level functionality of normalize_prediction_data()"""

Expand Down Expand Up @@ -55,6 +67,39 @@ def test_normalize_features(self):
else:
np.testing.assert_array_equal(target, res[k].round(2))

def test_get_balancing_authority_to_model_dict(self):
"""Test to ensure high level functionality of get_balancing_authority_to_model_dict()"""

d = mpu.get_balancing_authority_to_model_dict()

self.assertEqual(54, len(d))

def test_denormalize_features(self):
"""Test to ensure high level functionality of denormalize_features()"""

norm_dict = {"max_y_train": np.array([3.1, 4.2]),
"min_y_train": np.array([0.1, 1.2])}

df = mpu.denormalize_features(region="alpha",
normalized_dict=norm_dict,
y_predicted_normalized=np.array([1.1, 2.2]),
y_comparison=np.array([1.1, 2.2]),
datetime_arr=np.array([1.1, 2.2]))

pd.testing.assert_frame_equal(TestMlpUtils.COMP_DENORM_DF, df)

def test_evaluate(self):
"""Test to ensure high level functionality of evaluate()"""

df = mpu.evaluate(region="alpha",
y_predicted=np.array([1.1, 2.2]),
y_comparison=np.array([1.1, 2.3])).round(4)

print(TestMlpUtils.COMP_EVAL_DF)
print(df)

pd.testing.assert_frame_equal(TestMlpUtils.COMP_EVAL_DF, df)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0842c90

Please sign in to comment.