Skip to content

Commit

Permalink
make tests consistent with new demog plots
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Jan 30, 2024
1 parent e099063 commit 46fc09a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 64 deletions.
38 changes: 19 additions & 19 deletions ogcore/parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def plot_imm_rates(
if include_title:
plt.title("Immigration Rates")
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "imm_rates")
if path:
output_path = os.path.join(path, "imm_rates")
plt.savefig(output_path, dpi=300)
plt.close()
else:
Expand Down Expand Up @@ -335,8 +335,8 @@ def plot_fert_rates(
)
plt.tight_layout(rect=(0, 0.035, 1, 1))
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "fert_rates")
if path:
output_path = os.path.join(path, "fert_rates")
plt.savefig(output_path, dpi=300)
plt.close()
else:
Expand Down Expand Up @@ -386,8 +386,8 @@ def plot_mort_rates_data(
)
plt.tight_layout(rect=(0, 0.035, 1, 1))
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "mort_rates")
if path:
output_path = os.path.join(path, "mort_rates")
plt.savefig(output_path, dpi=300)
plt.close()
else:
Expand Down Expand Up @@ -431,8 +431,8 @@ def plot_omega_fixed(
plt.xlim((0, E + S + 1))
plt.legend(loc="upper right")
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "OrigVsFixSSpop")
if path:
output_path = os.path.join(path, "OrigVsFixSSpop")
plt.savefig(output_path, dpi=300)
plt.close()
else:
Expand Down Expand Up @@ -472,8 +472,8 @@ def plot_imm_fixed(
plt.xlim((0, E + S + 1))
plt.legend(loc="upper center")
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "OrigVsAdjImm")
if path:
output_path = os.path.join(path, "OrigVsAdjImm")
plt.savefig(output_path, dpi=300)
plt.close()
else:
Expand Down Expand Up @@ -538,8 +538,8 @@ def plot_population_path(
plt.ylabel(r"Pop. dist'n $\omega_{s}$")
plt.legend(loc="lower left")
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "PopDistPath")
if path:
output_path = os.path.join(path, "PopDistPath")
plt.savefig(output_path, dpi=300)
plt.close()
else:
Expand Down Expand Up @@ -861,23 +861,23 @@ def plot_income_data(
J = abil_midp.shape[0]
abil_mesh, age_mesh = np.meshgrid(abil_midp, ages)
cmap1 = matplotlib.cm.get_cmap("summer")
if output_dir:
if path:
# Make sure that directory is created
utils.mkdirs(output_dir)
utils.mkdirs(path)
if J == 1:
# Plot of 2D, J=1 in levels
plt.figure()
plt.plot(ages, emat[t, :, :])
filename = "ability_2D_lev" + filesuffix
fullpath = os.path.join(output_dir, filename)
fullpath = os.path.join(path, filename)
plt.savefig(fullpath, dpi=300)
plt.close()

# Plot of 2D, J=1 in logs
plt.figure()
plt.plot(ages, np.log(emat[t, :, :]))
filename = "ability_2D_log" + filesuffix
fullpath = os.path.join(output_dir, filename)
fullpath = os.path.join(path, filename)
plt.savefig(fullpath, dpi=300)
plt.close()
else:
Expand All @@ -895,7 +895,7 @@ def plot_income_data(
ax10.set_ylabel(r"ability type -$j$")
ax10.set_zlabel(r"ability $e_{j,s}$")
filename = "ability_3D_lev" + filesuffix
fullpath = os.path.join(output_dir, filename)
fullpath = os.path.join(path, filename)
plt.savefig(fullpath, dpi=300)
plt.close()

Expand All @@ -913,7 +913,7 @@ def plot_income_data(
ax11.set_ylabel(r"ability type -$j$")
ax11.set_zlabel(r"log ability $log(e_{j,s})$")
filename = "ability_3D_log" + filesuffix
fullpath = os.path.join(output_dir, filename)
fullpath = os.path.join(path, filename)
plt.savefig(fullpath, dpi=300)
plt.close()

Expand Down Expand Up @@ -961,7 +961,7 @@ def plot_income_data(
ax.set_xlabel(r"age-$s$")
ax.set_ylabel(r"log ability $log(e_{j,s})$")
filename = "ability_2D_log" + filesuffix
fullpath = os.path.join(output_dir, filename)
fullpath = os.path.join(path, filename)
plt.savefig(fullpath, dpi=300)
plt.close()
else:
Expand Down
62 changes: 17 additions & 45 deletions tests/test_parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@


def test_plot_imm_rates():
fig = parameter_plots.plot_imm_rates(base_params, include_title=True)
fig = parameter_plots.plot_imm_rates(base_params.imm_rates, base_params.start_year, [base_params.start_year], include_title=True)
assert fig


def test_plot_imm_rates_save_fig(tmpdir):
parameter_plots.plot_imm_rates(base_params, path=tmpdir)
img = mpimg.imread(os.path.join(tmpdir, "imm_rates_orig.png"))
parameter_plots.plot_imm_rates(base_params.imm_rates, base_params.start_year, [base_params.start_year], path=tmpdir)
img = mpimg.imread(os.path.join(tmpdir, "imm_rates.png"))

assert isinstance(img, np.ndarray)

Expand Down Expand Up @@ -173,9 +173,9 @@ def test_plot_fert_rates():
)
age_midp = np.array([9, 10, 12, 16, 18.5, 22, 27, 32, 37, 42, 47, 55, 56])
fert_func = si.interp1d(age_midp, fert_data, kind="cubic")
fert_rates = np.random.uniform(size=totpers)
fert_rates = np.random.uniform(size=totpers).reshape((1, totpers))
fig = parameter_plots.plot_fert_rates(
fert_func, age_midp, totpers, min_yr, max_yr, fert_data, fert_rates
fert_rates
)
assert fig

Expand Down Expand Up @@ -206,16 +206,10 @@ def test_plot_fert_rates_save_fig(tmpdir):
)
age_midp = np.array([9, 10, 12, 16, 18.5, 22, 27, 32, 37, 42, 47, 55, 56])
fert_func = si.interp1d(age_midp, fert_data, kind="cubic")
fert_rates = np.random.uniform(size=totpers)
fert_rates = np.random.uniform(size=totpers).reshape((1, totpers))
parameter_plots.plot_fert_rates(
fert_func,
age_midp,
totpers,
min_yr,
max_yr,
fert_data,
fert_rates,
output_dir=tmpdir,
path=tmpdir,
)
img = mpimg.imread(os.path.join(tmpdir, "fert_rates.png"))

Expand All @@ -224,42 +218,20 @@ def test_plot_fert_rates_save_fig(tmpdir):

def test_plot_mort_rates_data():
totpers = base_params.S - 1
min_yr = 21
max_yr = 100
age_year_all = np.arange(min_yr, max_yr)
mort_rates = base_params.rho[-1, 1:].squeeze()
mort_rates_all = base_params.rho[-1, 1:].squeeze()
infmort_rate = base_params.rho[0, 0]
mort_rates = base_params.rho[-1, 1:].reshape((1, totpers))
fig = parameter_plots.plot_mort_rates_data(
totpers,
min_yr,
max_yr,
age_year_all,
mort_rates_all,
infmort_rate,
mort_rates,
output_dir=None,
path=None,
)
assert fig


def test_plot_mort_rates_data_save_fig(tmpdir):
totpers = base_params.S - 1
min_yr = 21
max_yr = 100
age_year_all = np.arange(min_yr, max_yr)
mort_rates = base_params.rho[-1, 1:].squeeze()
mort_rates_all = base_params.rho[-1, 1:].squeeze()
infmort_rate = base_params.rho[0, 0]
mort_rates = base_params.rho[-1, 1:].reshape((1, totpers))
parameter_plots.plot_mort_rates_data(
totpers,
min_yr,
max_yr,
age_year_all,
mort_rates_all,
infmort_rate,
mort_rates,
output_dir=tmpdir,
path=tmpdir,
)
img = mpimg.imread(os.path.join(tmpdir, "mort_rates.png"))

Expand All @@ -285,7 +257,7 @@ def test_plot_omega_fixed_save_fig(tmpdir):
omega_SS_orig = base_params.omega_SS
omega_SSfx = base_params.omega_SS
parameter_plots.plot_omega_fixed(
age_per_EpS, omega_SS_orig, omega_SSfx, E, S, output_dir=tmpdir
age_per_EpS, omega_SS_orig, omega_SSfx, E, S, path=tmpdir
)
img = mpimg.imread(os.path.join(tmpdir, "OrigVsFixSSpop.png"))

Expand All @@ -311,7 +283,7 @@ def test_plot_imm_fixed_save_fig(tmpdir):
imm_rates_orig = base_params.imm_rates[0, :]
imm_rates_adj = base_params.imm_rates[-1, :]
parameter_plots.plot_imm_fixed(
age_per_EpS, imm_rates_orig, imm_rates_adj, E, S, output_dir=tmpdir
age_per_EpS, imm_rates_orig, imm_rates_adj, E, S, path=tmpdir
)
img = mpimg.imread(os.path.join(tmpdir, "OrigVsAdjImm.png"))

Expand All @@ -322,7 +294,7 @@ def test_plot_population_path():
S = base_params.S
age_per_EpS = np.arange(21, S + 21)
initial_pop_pct = base_params.omega[0, :]
omega_path_lev = base_params.omega.T
omega_path_lev = base_params.omega
omega_SSfx = base_params.omega_SS
data_year = base_params.start_year
curr_year = base_params.start_year
Expand All @@ -343,7 +315,7 @@ def test_plot_population_path_save_fig(tmpdir):
S = base_params.S
age_per_EpS = np.arange(21, S + 21)
pop_2013_pct = base_params.omega[0, :]
omega_path_lev = base_params.omega.T
omega_path_lev = base_params.omega
omega_SSfx = base_params.omega_SS
curr_year = base_params.start_year
parameter_plots.plot_population_path(
Expand All @@ -354,7 +326,7 @@ def test_plot_population_path_save_fig(tmpdir):
curr_year,
E,
S,
output_dir=tmpdir,
path=tmpdir,
)
img = mpimg.imread(os.path.join(tmpdir, "PopDistPath.png"))

Expand Down Expand Up @@ -385,7 +357,7 @@ def test_plot_income_data_save_fig(tmpdir):
abil_pcts = np.array([0.25, 0.25, 0.2, 0.1, 0.1, 0.09, 0.01])
emat = p.e
parameter_plots.plot_income_data(
ages, abil_midp, abil_pcts, emat, output_dir=tmpdir
ages, abil_midp, abil_pcts, emat, path=tmpdir
)
img1 = mpimg.imread(os.path.join(tmpdir, "ability_3D_lev.png"))
img2 = mpimg.imread(os.path.join(tmpdir, "ability_3D_log.png"))
Expand Down

0 comments on commit 46fc09a

Please sign in to comment.