diff --git a/ogcore/parameter_plots.py b/ogcore/parameter_plots.py index 63f6b236c..e295bf32f 100644 --- a/ogcore/parameter_plots.py +++ b/ogcore/parameter_plots.py @@ -54,8 +54,8 @@ def plot_imm_rates( if include_title: plt.title("Immigration Rates") # Save or return figure - if output_dir: - output_path = os.path.join(output_dir, "imm_rates") + if path: + output_path = os.path.join(path, "imm_rates") plt.savefig(output_path, dpi=300) plt.close() else: @@ -335,8 +335,8 @@ def plot_fert_rates( ) plt.tight_layout(rect=(0, 0.035, 1, 1)) # Save or return figure - if output_dir: - output_path = os.path.join(output_dir, "fert_rates") + if path: + output_path = os.path.join(path, "fert_rates") plt.savefig(output_path, dpi=300) plt.close() else: @@ -386,8 +386,8 @@ def plot_mort_rates_data( ) plt.tight_layout(rect=(0, 0.035, 1, 1)) # Save or return figure - if output_dir: - output_path = os.path.join(output_dir, "mort_rates") + if path: + output_path = os.path.join(path, "mort_rates") plt.savefig(output_path, dpi=300) plt.close() else: @@ -431,8 +431,8 @@ def plot_omega_fixed( plt.xlim((0, E + S + 1)) plt.legend(loc="upper right") # Save or return figure - if output_dir: - output_path = os.path.join(output_dir, "OrigVsFixSSpop") + if path: + output_path = os.path.join(path, "OrigVsFixSSpop") plt.savefig(output_path, dpi=300) plt.close() else: @@ -472,8 +472,8 @@ def plot_imm_fixed( plt.xlim((0, E + S + 1)) plt.legend(loc="upper center") # Save or return figure - if output_dir: - output_path = os.path.join(output_dir, "OrigVsAdjImm") + if path: + output_path = os.path.join(path, "OrigVsAdjImm") plt.savefig(output_path, dpi=300) plt.close() else: @@ -538,8 +538,8 @@ def plot_population_path( plt.ylabel(r"Pop. dist'n $\omega_{s}$") plt.legend(loc="lower left") # Save or return figure - if output_dir: - output_path = os.path.join(output_dir, "PopDistPath") + if path: + output_path = os.path.join(path, "PopDistPath") plt.savefig(output_path, dpi=300) plt.close() else: @@ -861,15 +861,15 @@ def plot_income_data( J = abil_midp.shape[0] abil_mesh, age_mesh = np.meshgrid(abil_midp, ages) cmap1 = matplotlib.cm.get_cmap("summer") - if output_dir: + if path: # Make sure that directory is created - utils.mkdirs(output_dir) + utils.mkdirs(path) if J == 1: # Plot of 2D, J=1 in levels plt.figure() plt.plot(ages, emat[t, :, :]) filename = "ability_2D_lev" + filesuffix - fullpath = os.path.join(output_dir, filename) + fullpath = os.path.join(path, filename) plt.savefig(fullpath, dpi=300) plt.close() @@ -877,7 +877,7 @@ def plot_income_data( plt.figure() plt.plot(ages, np.log(emat[t, :, :])) filename = "ability_2D_log" + filesuffix - fullpath = os.path.join(output_dir, filename) + fullpath = os.path.join(path, filename) plt.savefig(fullpath, dpi=300) plt.close() else: @@ -895,7 +895,7 @@ def plot_income_data( ax10.set_ylabel(r"ability type -$j$") ax10.set_zlabel(r"ability $e_{j,s}$") filename = "ability_3D_lev" + filesuffix - fullpath = os.path.join(output_dir, filename) + fullpath = os.path.join(path, filename) plt.savefig(fullpath, dpi=300) plt.close() @@ -913,7 +913,7 @@ def plot_income_data( ax11.set_ylabel(r"ability type -$j$") ax11.set_zlabel(r"log ability $log(e_{j,s})$") filename = "ability_3D_log" + filesuffix - fullpath = os.path.join(output_dir, filename) + fullpath = os.path.join(path, filename) plt.savefig(fullpath, dpi=300) plt.close() @@ -961,7 +961,7 @@ def plot_income_data( ax.set_xlabel(r"age-$s$") ax.set_ylabel(r"log ability $log(e_{j,s})$") filename = "ability_2D_log" + filesuffix - fullpath = os.path.join(output_dir, filename) + fullpath = os.path.join(path, filename) plt.savefig(fullpath, dpi=300) plt.close() else: diff --git a/tests/test_parameter_plots.py b/tests/test_parameter_plots.py index 3f9defbb8..88fc10470 100644 --- a/tests/test_parameter_plots.py +++ b/tests/test_parameter_plots.py @@ -44,13 +44,13 @@ def test_plot_imm_rates(): - fig = parameter_plots.plot_imm_rates(base_params, include_title=True) + fig = parameter_plots.plot_imm_rates(base_params.imm_rates, base_params.start_year, [base_params.start_year], include_title=True) assert fig def test_plot_imm_rates_save_fig(tmpdir): - parameter_plots.plot_imm_rates(base_params, path=tmpdir) - img = mpimg.imread(os.path.join(tmpdir, "imm_rates_orig.png")) + parameter_plots.plot_imm_rates(base_params.imm_rates, base_params.start_year, [base_params.start_year], path=tmpdir) + img = mpimg.imread(os.path.join(tmpdir, "imm_rates.png")) assert isinstance(img, np.ndarray) @@ -173,9 +173,9 @@ def test_plot_fert_rates(): ) age_midp = np.array([9, 10, 12, 16, 18.5, 22, 27, 32, 37, 42, 47, 55, 56]) fert_func = si.interp1d(age_midp, fert_data, kind="cubic") - fert_rates = np.random.uniform(size=totpers) + fert_rates = np.random.uniform(size=totpers).reshape((1, totpers)) fig = parameter_plots.plot_fert_rates( - fert_func, age_midp, totpers, min_yr, max_yr, fert_data, fert_rates + fert_rates ) assert fig @@ -206,16 +206,10 @@ def test_plot_fert_rates_save_fig(tmpdir): ) age_midp = np.array([9, 10, 12, 16, 18.5, 22, 27, 32, 37, 42, 47, 55, 56]) fert_func = si.interp1d(age_midp, fert_data, kind="cubic") - fert_rates = np.random.uniform(size=totpers) + fert_rates = np.random.uniform(size=totpers).reshape((1, totpers)) parameter_plots.plot_fert_rates( - fert_func, - age_midp, - totpers, - min_yr, - max_yr, - fert_data, fert_rates, - output_dir=tmpdir, + path=tmpdir, ) img = mpimg.imread(os.path.join(tmpdir, "fert_rates.png")) @@ -224,42 +218,20 @@ def test_plot_fert_rates_save_fig(tmpdir): def test_plot_mort_rates_data(): totpers = base_params.S - 1 - min_yr = 21 - max_yr = 100 - age_year_all = np.arange(min_yr, max_yr) - mort_rates = base_params.rho[-1, 1:].squeeze() - mort_rates_all = base_params.rho[-1, 1:].squeeze() - infmort_rate = base_params.rho[0, 0] + mort_rates = base_params.rho[-1, 1:].reshape((1, totpers)) fig = parameter_plots.plot_mort_rates_data( - totpers, - min_yr, - max_yr, - age_year_all, - mort_rates_all, - infmort_rate, mort_rates, - output_dir=None, + path=None, ) assert fig def test_plot_mort_rates_data_save_fig(tmpdir): totpers = base_params.S - 1 - min_yr = 21 - max_yr = 100 - age_year_all = np.arange(min_yr, max_yr) - mort_rates = base_params.rho[-1, 1:].squeeze() - mort_rates_all = base_params.rho[-1, 1:].squeeze() - infmort_rate = base_params.rho[0, 0] + mort_rates = base_params.rho[-1, 1:].reshape((1, totpers)) parameter_plots.plot_mort_rates_data( - totpers, - min_yr, - max_yr, - age_year_all, - mort_rates_all, - infmort_rate, mort_rates, - output_dir=tmpdir, + path=tmpdir, ) img = mpimg.imread(os.path.join(tmpdir, "mort_rates.png")) @@ -285,7 +257,7 @@ def test_plot_omega_fixed_save_fig(tmpdir): omega_SS_orig = base_params.omega_SS omega_SSfx = base_params.omega_SS parameter_plots.plot_omega_fixed( - age_per_EpS, omega_SS_orig, omega_SSfx, E, S, output_dir=tmpdir + age_per_EpS, omega_SS_orig, omega_SSfx, E, S, path=tmpdir ) img = mpimg.imread(os.path.join(tmpdir, "OrigVsFixSSpop.png")) @@ -311,7 +283,7 @@ def test_plot_imm_fixed_save_fig(tmpdir): imm_rates_orig = base_params.imm_rates[0, :] imm_rates_adj = base_params.imm_rates[-1, :] parameter_plots.plot_imm_fixed( - age_per_EpS, imm_rates_orig, imm_rates_adj, E, S, output_dir=tmpdir + age_per_EpS, imm_rates_orig, imm_rates_adj, E, S, path=tmpdir ) img = mpimg.imread(os.path.join(tmpdir, "OrigVsAdjImm.png")) @@ -322,7 +294,7 @@ def test_plot_population_path(): S = base_params.S age_per_EpS = np.arange(21, S + 21) initial_pop_pct = base_params.omega[0, :] - omega_path_lev = base_params.omega.T + omega_path_lev = base_params.omega omega_SSfx = base_params.omega_SS data_year = base_params.start_year curr_year = base_params.start_year @@ -343,7 +315,7 @@ def test_plot_population_path_save_fig(tmpdir): S = base_params.S age_per_EpS = np.arange(21, S + 21) pop_2013_pct = base_params.omega[0, :] - omega_path_lev = base_params.omega.T + omega_path_lev = base_params.omega omega_SSfx = base_params.omega_SS curr_year = base_params.start_year parameter_plots.plot_population_path( @@ -354,7 +326,7 @@ def test_plot_population_path_save_fig(tmpdir): curr_year, E, S, - output_dir=tmpdir, + path=tmpdir, ) img = mpimg.imread(os.path.join(tmpdir, "PopDistPath.png")) @@ -385,7 +357,7 @@ def test_plot_income_data_save_fig(tmpdir): abil_pcts = np.array([0.25, 0.25, 0.2, 0.1, 0.1, 0.09, 0.01]) emat = p.e parameter_plots.plot_income_data( - ages, abil_midp, abil_pcts, emat, output_dir=tmpdir + ages, abil_midp, abil_pcts, emat, path=tmpdir ) img1 = mpimg.imread(os.path.join(tmpdir, "ability_3D_lev.png")) img2 = mpimg.imread(os.path.join(tmpdir, "ability_3D_log.png"))