Skip to content
This repository has been archived by the owner on May 20, 2024. It is now read-only.

Commit

Permalink
fix broken tests (#8)
Browse files Browse the repository at this point in the history
* fix test for min_num_obs

* debug network_lists test
  • Loading branch information
amy-defnet authored Oct 18, 2023
1 parent 837ec1e commit 2c1f92b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/hf_point_data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
import xarray as xr

NETWORK_LISTS_PATH = 'network_lists'


def check_inputs(data_source, variable, temporal_resolution, aggregation, **kwargs):
"""
Expand Down Expand Up @@ -320,7 +322,7 @@ def get_network_site_list(data_source, variable, site_networks):
for network in site_networks:
try:
assert network in network_options[data_source][variable]
df = pd.read_csv(f'network_lists/{data_source}/{variable}/{network}.csv',
df = pd.read_csv(f'{NETWORK_LISTS_PATH}/{data_source}/{variable}/{network}.csv',
dtype=str, header=None, names=['site_id'])
site_list += list(df['site_id'])
except:
Expand Down
27 changes: 17 additions & 10 deletions tests/test_hf_point_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,20 @@ def test_filter_min_num_obs():
"""Test functionality for filtering DataFrame on minimum non-NaN values."""
df = pd.DataFrame(
{
"site_id": ["101", "102", "103", "104", "105"],
"date1": [1, 5, 3, 4, 8],
"date2": [np.nan, 4, 2, 9, 4],
"date3": [np.nan, 9, 2, np.nan, 9],
"site1": [1, 5, 3, 4],
"site2": [np.nan, 4, 2, 9],
"site3": [np.nan, 9, 2, np.nan],
}
)

assert len(utils.filter_min_num_obs(df, 1)) == 5
assert len(utils.filter_min_num_obs(df, 2)) == 4
assert len(utils.filter_min_num_obs(df, 3)) == 3
df1 = utils.filter_min_num_obs(df, 1)
assert list(df1.columns) == ['site1', 'site2', 'site3']
df2 = utils.filter_min_num_obs(df, 2)
assert list(df2.columns) == ['site1', 'site2', 'site3']
df3 = utils.filter_min_num_obs(df, 3)
assert list(df3.columns) == ['site1', 'site2']
df4 = utils.filter_min_num_obs(df, 4)
assert list(df4.columns) == ['site1']


def test_no_sites_error_message():
Expand Down Expand Up @@ -698,6 +702,9 @@ def test_get_data_site_filter():

def test_site_networks_filter():
"""Test for using site_networks filter"""
utils.NETWORK_LISTS_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../src/hf_point_data/network_lists")
)
data_df = hf_point_data.get_data(
"usgs_nwis",
"streamflow",
Expand Down Expand Up @@ -740,7 +747,7 @@ def test_get_data_min_num_obs_filter():
site_ids=['01377500', '01378500', '01445000'],
min_num_obs=5
)
assert list(df.columns == ['date', '01377500', '01378500'])
assert list(df.columns) == ['date', '01377500', '01378500']

df = hf_point_data.get_data(
"usgs_nwis",
Expand All @@ -752,7 +759,7 @@ def test_get_data_min_num_obs_filter():
site_ids=['01377500', '01378500', '01445000'],
min_num_obs=1
)
assert list(df.columns == ['date', '01377500', '01378500'])
assert list(df.columns) == ['date', '01377500', '01378500']

# If no min_num_obs filter supplied, all three sites returned
df = hf_point_data.get_data(
Expand All @@ -764,7 +771,7 @@ def test_get_data_min_num_obs_filter():
date_end="2002-01-05",
site_ids=['01377500', '01378500', '01445000']
)
assert list(df.columns == ['date', '01377500', '01378500', '01445000'])
assert list(df.columns) == ['date', '01377500', '01378500', '01445000']


if __name__ == "__main__":
Expand Down

0 comments on commit 2c1f92b

Please sign in to comment.