From 14f65619c9297792a7963cf9a273a7ddd2cd8f11 Mon Sep 17 00:00:00 2001 From: Steph Merritt Date: Wed, 24 Apr 2024 17:34:10 +0100 Subject: [PATCH] Unit tests and docstrings. --- src/adler/adler.py | 9 ++- src/adler/dataclasses/dataclass_utilities.py | 54 ++++++++++---- src/adler/utilities/AdlerCLIArguments.py | 28 ++++++- tests/adler/dataclasses/test_MPCORB.py | 2 +- .../dataclasses/test_dataclass_utilities.py | 14 ++++ .../adler/utilities/test_AdlerCLIArguments.py | 73 +++++++++++++++++++ 6 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 tests/adler/utilities/test_AdlerCLIArguments.py diff --git a/src/adler/adler.py b/src/adler/adler.py index f2b797c..7d29f86 100644 --- a/src/adler/adler.py +++ b/src/adler/adler.py @@ -39,9 +39,14 @@ def runAdler(cli_args): def main(): parser = argparse.ArgumentParser(description="Runs Adler for a select planetoid and given user input.") - parser.add_argument("-s", "--ssoid", help="SSObject ID of planetoid.", type=str, required=True) + parser.add_argument("-s", "--ssObjectId", help="SSObject ID of planetoid.", type=str, required=True) parser.add_argument( - "-f", "--filter_list", help="Filters required.", nargs="*", type=str, default=["u", "g", "r", "i", "z", "y"] + "-f", + "--filter_list", + help="Filters required.", + nargs="*", + type=str, + default=["u", "g", "r", "i", "z", "y"], ) parser.add_argument( "-d", diff --git a/src/adler/dataclasses/dataclass_utilities.py b/src/adler/dataclasses/dataclass_utilities.py index 20fb1d6..172a264 100644 --- a/src/adler/dataclasses/dataclass_utilities.py +++ b/src/adler/dataclasses/dataclass_utilities.py @@ -53,14 +53,21 @@ def get_from_table(data_table, column_name, data_type, table_name="default"): Parameters ----------- + data_table : DALResultsTable or Pandas dataframe + Data table containing columns of interest. + column_name : str Column name under which the data of interest is stored. - type : str - String delineating data type. Should be "str", "float", "int" or "array". + + data_type : type + Data type. Should be int, float, str or np.ndarray. + + table_name : str + Name of the table. This is mostly for more informative error messages. Default="default". Returns ----------- - data : any type + data_val : str, float, int or nd.array The data requested from the table cast to the type required. """ @@ -87,20 +94,41 @@ def get_from_table(data_table, column_name, data_type, table_name="default"): raise ValueError("Could not cast column name to type.") # here we alert the user if one of the values is unpopulated and change it to a NaN - data_val = check_value_for_nan(column_name, data_val, data_type, table_name) + data_val = check_value_populated(data_val, data_type, column_name, table_name) return data_val -def check_value_for_nan(column_name, data_val, data_type, table_name): - if data_type == np.ndarray and len(data_val) == 0: - print( - "WARNING: {} unpopulated in {} table for this object. Storing NaN instead.".format( - column_name, table_name - ) - ) - data_val = np.nan - elif data_type in [float, int] and np.isnan(data_val): +def check_value_populated(data_val, data_type, column_name, table_name): + """Checks to see if data_val populated properly and prints a helpful warning if it didn't. + Usually this will trigger because the RSP hasn't populated that field for this particular object. + + Parameters + ----------- + data_val : str, float, int or nd.array + The value to check. + + data_type: type + Data type. Should be int, float, str or np.ndarray. + + column_name: str + Column name under which the data of interest is stored. + + table_name : str + Name of the table. This is mostly for more informative error messages. Default="default". + + Returns + ----------- + data_val : str, float, int, nd.array or np.nan + Either returns the original data_val or an np.nan if it detected that the value was not populated. + + """ + + array_length_zero = data_type == np.ndarray and len(data_val) == 0 + number_is_nan = data_type in [float, int] and np.isnan(data_val) + str_is_empty = data_type == str and len(data_val) == 0 + + if array_length_zero or number_is_nan or str_is_empty: print( "WARNING: {} unpopulated in {} table for this object. Storing NaN instead.".format( column_name, table_name diff --git a/src/adler/utilities/AdlerCLIArguments.py b/src/adler/utilities/AdlerCLIArguments.py index 0f65a8b..7a30066 100644 --- a/src/adler/utilities/AdlerCLIArguments.py +++ b/src/adler/utilities/AdlerCLIArguments.py @@ -1,6 +1,16 @@ class AdlerCLIArguments: + """ + Class for storing abd validating Adler command-line arguments. + + Attributes: + ----------- + args : argparse.Namespace object + argparse.Namespace object created by calling parse_args(). + + """ + def __init__(self, args): - self.ssObjectId = args.ssoid + self.ssObjectId = args.ssObjectId self.filter_list = args.filter_list self.date_range = args.date_range @@ -16,7 +26,7 @@ def _validate_filter_list(self): if not set(self.filter_list).issubset(expected_filters): raise ValueError( - "Unexpected filters found in filter_list command-line argument. filter_list must be a comma-separated list of LSST filters." + "Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters." ) def _validate_ssObjectId(self): @@ -26,5 +36,15 @@ def _validate_ssObjectId(self): raise ValueError("ssoid command-line argument does not appear to be a valid ssObjectId.") def _validate_date_range(self): - if len(self.date_range) != 2: - raise ValueError("date_range command-line argument must be of length 2.") + for d in self.date_range: + try: + float(d) + except ValueError: + raise ValueError( + "One or both of the values for the date_range command-line argument do not seem to be valid numbers." + ) + + if any(d > 250000 for d in self.date_range): + raise ValueError( + "Dates for date_range command-line argument seem rather large. Did you input JD instead of MJD?" + ) diff --git a/tests/adler/dataclasses/test_MPCORB.py b/tests/adler/dataclasses/test_MPCORB.py index f94db75..d139d00 100644 --- a/tests/adler/dataclasses/test_MPCORB.py +++ b/tests/adler/dataclasses/test_MPCORB.py @@ -36,5 +36,5 @@ def test_construct_MPCORB_from_data_table(): assert_almost_equal(test_MPCORB.e, 0.7168805704972735, decimal=6) assert np.isnan(test_MPCORB.n) assert_almost_equal(test_MPCORB.q, 0.5898291078470536, decimal=6) - assert test_MPCORB.uncertaintyParameter == "" + assert np.isnan(test_MPCORB.uncertaintyParameter) assert test_MPCORB.flags == "0" diff --git a/tests/adler/dataclasses/test_dataclass_utilities.py b/tests/adler/dataclasses/test_dataclass_utilities.py index 6fca494..b087e0e 100644 --- a/tests/adler/dataclasses/test_dataclass_utilities.py +++ b/tests/adler/dataclasses/test_dataclass_utilities.py @@ -6,6 +6,7 @@ from adler.dataclasses.dataclass_utilities import get_data_table from adler.dataclasses.dataclass_utilities import get_from_table +from adler.dataclasses.dataclass_utilities import check_value_populated from adler.utilities.tests_utilities import get_test_data_filepath @@ -55,3 +56,16 @@ def test_get_from_table(): error_info_2.value.args[0] == "Type for argument data_type not recognised for column string_col in table default: must be str, float, int or np.ndarray." ) + + +def test_check_value_populated(): + populated_value = check_value_populated(3, int, "column", "table") + assert populated_value == 3 + + array_length_zero = check_value_populated(np.array([]), np.ndarray, "column", "table") + number_is_nan = check_value_populated(np.nan, float, "column", "table") + str_is_empty = check_value_populated("", str, "column", "table") + + assert np.isnan(array_length_zero) + assert np.isnan(number_is_nan) + assert np.isnan(str_is_empty) diff --git a/tests/adler/utilities/test_AdlerCLIArguments.py b/tests/adler/utilities/test_AdlerCLIArguments.py new file mode 100644 index 0000000..90c4e5f --- /dev/null +++ b/tests/adler/utilities/test_AdlerCLIArguments.py @@ -0,0 +1,73 @@ +import pytest +from adler.utilities.AdlerCLIArguments import AdlerCLIArguments + + +# AdlerCLIArguments object takes an object as input, so we define a quick one here +class args: + def __init__(self, ssObjectId, filter_list, date_range): + self.ssObjectId = ssObjectId + self.filter_list = filter_list + self.date_range = date_range + + +def test_AdlerCLIArguments(): + # test correct population + good_input_dict = {"ssObjectId": "666", "filter_list": ["g", "r", "i"], "date_range": [60000.0, 67300.0]} + good_arguments = args(**good_input_dict) + good_arguments_object = AdlerCLIArguments(good_arguments) + + assert good_arguments_object.__dict__ == good_input_dict + + # test that a bad ssObjectId triggers the right error + bad_ssoid_arguments = args("hello!", ["g", "r", "i"], [60000.0, 67300.0]) + + with pytest.raises(ValueError) as bad_ssoid_error: + bad_ssoid_object = AdlerCLIArguments(bad_ssoid_arguments) + + assert ( + bad_ssoid_error.value.args[0] + == "ssoid command-line argument does not appear to be a valid ssObjectId." + ) + + # test that non-LSST or unexpected filters trigger the right error + bad_filter_arguments = args("666", ["g", "r", "i", "m"], [60000.0, 67300.0]) + + with pytest.raises(ValueError) as bad_filter_error: + bad_filter_object = AdlerCLIArguments(bad_filter_arguments) + + assert ( + bad_filter_error.value.args[0] + == "Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters." + ) + + bad_filter_arguments_2 = args("666", ["pony"], [60000.0, 67300.0]) + + with pytest.raises(ValueError) as bad_filter_error_2: + bad_filter_object = AdlerCLIArguments(bad_filter_arguments_2) + + assert ( + bad_filter_error_2.value.args[0] + == "Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters." + ) + + # test that overly-large dates trigger the right error + big_date_arguments = args("666", ["g", "r", "i"], [260000.0, 267300.0]) + + with pytest.raises(ValueError) as big_date_error: + big_date_object = AdlerCLIArguments(big_date_arguments) + + assert ( + big_date_error.value.args[0] + == "Dates for date_range command-line argument seem rather large. Did you input JD instead of MJD?" + ) + + # test that unexpected date values trigger the right error + bad_date_arguments = args("666", ["g", "r", "i"], [260000.0, "cheese"]) + + with pytest.raises(ValueError) as bad_date_error: + bad_date_object = AdlerCLIArguments(bad_date_arguments) + + assert ( + bad_date_error.value.args[0] + == "One or both of the values for the date_range command-line argument do not seem to be valid numbers." + )