Skip to content

Commit

Permalink
updates for new tmd constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Dec 22, 2024
1 parent 53d8aee commit 55c867c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
8 changes: 4 additions & 4 deletions ccc/get_taxcalc_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_calculator(
"""
# create a calculator
policy1 = Policy()
if data is not None and "cps" in data:
if data is not None and "cps" in str(data):
print("Using CPS")
records1 = Records.cps_constructor()
# impute short and long term capital gains if using CPS data
Expand All @@ -48,12 +48,12 @@ def get_calculator(
records1.p23250 = (1 - 0.06587) * records1.e01100
# set total capital gains to zero
records1.e01100 = np.zeros(records1.e01100.shape[0])
elif data is None or "puf" in data: # pragma: no cover
elif data is None or "puf" in str(data): # pragma: no cover
print("Using PUF")
records1 = Records()
elif data is not None and "tmd" in data: # pragma: no cover
elif data is not None and "tmd" in str(data): # pragma: no cover
print("Using TMD")
records1 = Records.tmd_constructor("tmd.csv.gz")
records1 = Records.tmd_constructor(data, weights, gfactors)
elif data is not None: # pragma: no cover
print("Data is ", data)
print("Weights are ", weights)
Expand Down
19 changes: 16 additions & 3 deletions ccc/tests/test_get_taxcalc_rates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import numpy as np
import pytest
import os
from pathlib import Path
from ccc import get_taxcalc_rates as tc
from ccc.parameters import Specification
from ccc.utils import TC_LAST_YEAR


CUR_DIR = os.path.abspath(os.path.dirname(__file__))


@pytest.mark.parametrize(
"reform",
[(None), ({"FICA_ss_trt_employee": {2018: 0.0625}})],
Expand Down Expand Up @@ -47,11 +52,17 @@ def test_get_calculator_puf(data):

@pytest.mark.needs_tmd
@pytest.mark.parametrize(
"data",
[("tmd.csv")],
"data,weights,growfactors",
[
(
Path(os.path.join(CUR_DIR, "tmd.csv")),
Path(os.path.join(CUR_DIR, "tmd_weights.csv.gz")),
Path(os.path.join(CUR_DIR, "tmd_growfactors.csv")),
)
],
ids=["baseline,data=TMD"],
)
def test_get_calculator_tmd(data):
def test_get_calculator_tmd(data, weights, growfactors):
"""
Test the get_calculator() function
"""
Expand All @@ -60,6 +71,8 @@ def test_get_calculator_tmd(data):
baseline_policy={"FICA_ss_trt_employee": {2021: 0.075}},
reform={"FICA_ss_trt_employee": {2022: 0.0625}},
data=data,
weights=weights,
gfactors=growfactors,
)
assert calc1.current_year == 2021

Expand Down

0 comments on commit 55c867c

Please sign in to comment.