Skip to content

Commit

Permalink
update tests to use 3.12 pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Aug 16, 2024
1 parent 5c1806d commit 8b3e80c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 26 deletions.
38 changes: 24 additions & 14 deletions tests/test_output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@
base_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
)
else:
elif sys.version_info[1] == 11:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
)
)
else:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
)
)
reform_ss = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "SS_vars_reform.pkl")
)
Expand All @@ -38,10 +44,14 @@
reform_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_reform.pkl")
)
else:
elif sys.version_info[1] == 11:
reform_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_reform_v311.pkl")
)
else:
reform_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl")
)
reform_taxfunctions = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "TxFuncEst_reform.pkl")
)
Expand Down Expand Up @@ -147,7 +157,7 @@ def test_plot_aggregates(
plot_type=plot_type,
stationarized=stationarized,
num_years_to_plot=20,
start_year=2023,
start_year=int(base_params.start_year),
forecast_data=np.ones(20),
forecast_units="ones",
vertical_line_years=vertical_line_years,
Expand Down Expand Up @@ -198,7 +208,7 @@ def test_plot_industry_aggregates(
var_list=["Y_vec"],
plot_type=plot_type,
num_years_to_plot=20,
start_year=2023,
start_year=int(base_params.start_year),
forecast_data=np.ones(20),
forecast_units="ones",
vertical_line_years=vertical_line_years,
Expand All @@ -218,7 +228,7 @@ def test_plot_industry_aggregates(
def test_plot_aggregates_save_fig(tmpdir):
path = os.path.join(tmpdir, "test_plot.png")
output_plots.plot_aggregates(
base_tpi, base_params, start_year=2023, plot_type="levels", path=path
base_tpi, base_params, start_year=int(base_params.start_year), plot_type="levels", path=path
)
img = mpimg.imread(path)

Expand All @@ -228,7 +238,7 @@ def test_plot_aggregates_save_fig(tmpdir):
def test_plot_aggregates_not_a_type(tmpdir):
with pytest.raises(AssertionError):
output_plots.plot_aggregates(
base_tpi, base_params, start_year=2023, plot_type="levels2"
base_tpi, base_params, start_year=int(base_params.start_year), plot_type="levels2"
)


Expand Down Expand Up @@ -275,7 +285,7 @@ def test_plot_gdp_ratio(
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
start_year=2023,
start_year=int(base_params.start_year),
plot_type=plot_type,
vertical_line_years=vertical_line_years,
plot_title=plot_title,
Expand All @@ -289,7 +299,7 @@ def test_plot_gdp_ratio_save_fig(tmpdir):
base_tpi,
base_params,
reform_tpi=reform_tpi,
start_year=2023,
start_year=int(base_params.start_year),
reform_params=reform_params,
path=path,
)
Expand All @@ -304,7 +314,7 @@ def test_ability_bar():
base_params,
reform_tpi,
reform_params,
start_year=2023,
start_year=int(base_params.start_year),
plot_title=" Test Plot Title",
)
assert fig
Expand All @@ -317,7 +327,7 @@ def test_ability_bar_save_fig(tmpdir):
base_params,
reform_tpi,
reform_params,
start_year=2023,
start_year=int(base_params.start_year),
path=path,
)
img = mpimg.imread(path)
Expand Down Expand Up @@ -374,7 +384,7 @@ def test_tpi_profiles(by_j):
base_params,
reform_tpi,
reform_params,
start_year=2023,
start_year=int(base_params.start_year),
by_j=by_j,
plot_title=" Test Plot Title",
)
Expand Down Expand Up @@ -404,7 +414,7 @@ def test_tpi_profiles_save_fig(tmpdir):
base_params,
reform_tpi,
reform_params,
start_year=2023,
start_year=int(base_params.start_year),
path=path,
)
img = mpimg.imread(path)
Expand Down Expand Up @@ -515,7 +525,7 @@ def test_inequality_plot(
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
start_year=2023,
start_year=int(base_params.start_year),
ineq_measure=ineq_measure,
pctiles=pctiles,
plot_type=plot_type,
Expand All @@ -530,7 +540,7 @@ def test_inequality_plot_save_fig(tmpdir):
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
start_year=2023,
start_year=int(base_params.start_year),
path=path,
)
img = mpimg.imread(path)
Expand Down
21 changes: 17 additions & 4 deletions tests/test_output_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from ogcore import utils, output_tables


# TODO: for dynamic_revenue_decomposition, need to add 3.12 pickle of results
# for baseline and reform that have M = 1 as the 3.12 pickle file does

# Load in test results and parameters
CUR_PATH = os.path.abspath(os.path.dirname(__file__))
base_ss = utils.safe_read_pickle(
Expand All @@ -22,12 +25,18 @@
base_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
)
else:
elif sys.version_info[1] == 11:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
)
)
else:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
)
)
reform_ss = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "SS_vars_reform.pkl")
)
Expand All @@ -38,10 +47,14 @@
reform_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_reform.pkl")
)
else:
elif sys.version_info[1] == 11:
reform_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_reform_v311.pkl")
)
else:
reform_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl")
)
# add investment tax credit parameter that not in cached parameters
base_params.inv_tax_credit = np.zeros(
(base_params.T + base_params.S, base_params.M)
Expand Down Expand Up @@ -76,7 +89,7 @@ def test_macro_table(
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
start_year=2023,
start_year=int(base_params.start_year),
output_type=output_type,
stationarized=stationarized,
include_SS=True,
Expand Down Expand Up @@ -176,7 +189,7 @@ def test_dynamic_revenue_decomposition(include_business_tax, full_break_out):
reform_params,
reform_tpi,
reform_ss,
start_year=2023,
start_year=int(base_params.start_year),
include_business_tax=include_business_tax,
full_break_out=full_break_out,
)
Expand Down
21 changes: 14 additions & 7 deletions tests/test_parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
base_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
)
else:
elif sys.version_info[1] == 11:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
)
)
else:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
)
)
base_taxfunctions = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "TxFuncEst_baseline.pkl")
)
Expand All @@ -36,10 +42,11 @@
micro_data = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "micro_data_dict_for_tests.pkl")
)
base_params.rho = np.tile(
base_params.rho.reshape(1, base_params.S),
(base_params.T + base_params.S, 1),
)
if base_params.rho.ndim == 1:
base_params.rho = np.tile(
base_params.rho.reshape(1, base_params.S),
(base_params.T + base_params.S, 1),
)


def test_plot_imm_rates():
Expand Down Expand Up @@ -94,13 +101,13 @@ def test_plot_surv_rates_save_fig(tmpdir):

def test_plot_pop_growth():
fig = parameter_plots.plot_pop_growth(
base_params, start_year=2023, include_title=True
base_params, start_year=int(base_params.start_year), include_title=True
)
assert fig


def test_plot_pop_growth_rates_save_fig(tmpdir):
parameter_plots.plot_pop_growth(base_params, start_year=2023, path=tmpdir)
parameter_plots.plot_pop_growth(base_params, start_year=int(base_params.start_year), path=tmpdir)
img = mpimg.imread(os.path.join(tmpdir, "pop_growth_rates.png"))

assert isinstance(img, np.ndarray)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_parameter_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
base_params = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
)
else:
elif sys.version_info[1] == 11:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
)
)
else:
base_params = utils.safe_read_pickle(
os.path.join(
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
)
)
base_taxfunctions = utils.safe_read_pickle(
os.path.join(CUR_PATH, "test_io_data", "TxFuncEst_baseline.pkl")
)
Expand Down

0 comments on commit 8b3e80c

Please sign in to comment.