diff --git a/tests/test_feature_validation.py b/tests/test_feature_validation.py index 89634fa9f..31700e8be 100644 --- a/tests/test_feature_validation.py +++ b/tests/test_feature_validation.py @@ -238,12 +238,28 @@ def test_validate_top_level(test_df): assert list(test_df.exclude) == [False, True] -def test_category_none(test_df): - test_df.categorize("category", "Testing", {"Primary Energy": {"up": 0.8}}) +# include args for deprecated legacy signature +@pytest.mark.parametrize( + "args", + ( + dict(variable="Primary Energy", upper_bound=0), + dict(criteria={"Primary Energy": {"up": 0}}), + ), +) +def test_category_no_match(test_df, args): + test_df.categorize("category", "foo", **args) assert "category" not in test_df.meta.columns -def test_category_pass(test_df): +# include args for deprecated legacy signature +@pytest.mark.parametrize( + "args", + ( + dict(variable="Primary Energy", upper_bound=6), + dict(criteria={"Primary Energy": {"up": 6}}), + ), +) +def test_category_match(test_df, args): dct = { "model": ["model_a", "model_a"], "scenario": ["scen_a", "scen_b"], @@ -251,7 +267,7 @@ def test_category_pass(test_df): } exp = pd.DataFrame(dct).set_index(["model", "scenario"])["category"] - test_df.categorize("category", "foo", {"Primary Energy": {"up": 6, "year": 2010}}) + test_df.categorize("category", "foo", **args) obs = test_df["category"] pd.testing.assert_series_equal(obs, exp)