diff --git a/mdtk/degradations.py b/mdtk/degradations.py index 0e6b279..cf306dc 100644 --- a/mdtk/degradations.py +++ b/mdtk/degradations.py @@ -104,8 +104,18 @@ def pre_process(df, sort=False): ------- df : pd.DataFrame The postprocessed dataframe. + + Raises + ------ + ValueError + If the given df does not have all of the necessary columns. """ - df = df.loc[:, NOTE_DF_SORT_ORDER] + try: + df = df.loc[:, NOTE_DF_SORT_ORDER] + except KeyError: # df has incorrect columns + raise ValueError( + f"Input note_df must have all of the columns: {NOTE_DF_SORT_ORDER}" + ) if sort: df = df.sort_values(NOTE_DF_SORT_ORDER) df = df.reset_index(drop=True) diff --git a/mdtk/tests/test_degradations.py b/mdtk/tests/test_degradations.py index ca0273c..0043d3b 100644 --- a/mdtk/tests/test_degradations.py +++ b/mdtk/tests/test_degradations.py @@ -90,6 +90,11 @@ def test_pre_process(): f"instead of \n{float_res}" ) + # Check not correct columns raises ValueError + invalid_df = pd.DataFrame({"track": [0, 1], "onset": [0, 50], "pitch": [10, 20]}) + with pytest.raises(ValueError): + deg.pre_process(invalid_df) + def test_post_process(): basic_res = pd.DataFrame(