diff --git a/ccc/get_taxcalc_rates.py b/ccc/get_taxcalc_rates.py index c9de2308..f7ce02d3 100644 --- a/ccc/get_taxcalc_rates.py +++ b/ccc/get_taxcalc_rates.py @@ -39,7 +39,7 @@ def get_calculator( """ # create a calculator policy1 = Policy() - if data is not None and "cps" in data: + if data is not None and "cps" in str(data): print("Using CPS") records1 = Records.cps_constructor() # impute short and long term capital gains if using CPS data @@ -48,12 +48,12 @@ def get_calculator( records1.p23250 = (1 - 0.06587) * records1.e01100 # set total capital gains to zero records1.e01100 = np.zeros(records1.e01100.shape[0]) - elif data is None or "puf" in data: # pragma: no cover + elif data is None or "puf" in str(data): # pragma: no cover print("Using PUF") records1 = Records() - elif data is not None and "tmd" in data: # pragma: no cover + elif data is not None and "tmd" in str(data): # pragma: no cover print("Using TMD") - records1 = Records.tmd_constructor("tmd.csv.gz") + records1 = Records.tmd_constructor(data, weights, gfactors) elif data is not None: # pragma: no cover print("Data is ", data) print("Weights are ", weights) diff --git a/ccc/tests/test_get_taxcalc_rates.py b/ccc/tests/test_get_taxcalc_rates.py index 8019ea98..057d2be3 100644 --- a/ccc/tests/test_get_taxcalc_rates.py +++ b/ccc/tests/test_get_taxcalc_rates.py @@ -1,10 +1,15 @@ import numpy as np import pytest +import os +from pathlib import Path from ccc import get_taxcalc_rates as tc from ccc.parameters import Specification from ccc.utils import TC_LAST_YEAR +CUR_DIR = os.path.abspath(os.path.dirname(__file__)) + + @pytest.mark.parametrize( "reform", [(None), ({"FICA_ss_trt_employee": {2018: 0.0625}})], @@ -47,11 +52,17 @@ def test_get_calculator_puf(data): @pytest.mark.needs_tmd @pytest.mark.parametrize( - "data", - [("tmd.csv")], + "data,weights,growfactors", + [ + ( + Path(os.path.join(CUR_DIR, "tmd.csv")), + Path(os.path.join(CUR_DIR, "tmd_weights.csv.gz")), + Path(os.path.join(CUR_DIR, "tmd_growfactors.csv")), + ) + ], ids=["baseline,data=TMD"], ) -def test_get_calculator_tmd(data): +def test_get_calculator_tmd(data, weights, growfactors): """ Test the get_calculator() function """ @@ -60,6 +71,8 @@ def test_get_calculator_tmd(data): baseline_policy={"FICA_ss_trt_employee": {2021: 0.075}}, reform={"FICA_ss_trt_employee": {2022: 0.0625}}, data=data, + weights=weights, + gfactors=growfactors, ) assert calc1.current_year == 2021