Skip to content

Commit

Permalink
Change the style of the plots
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed May 7, 2024
1 parent d84d7a2 commit 3b251fd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
7 changes: 4 additions & 3 deletions workflows/projects/Mixtures/distinct_profiles.smk
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ rule plot_samples:
output:
"figure_distinct_profiles.pdf"
run:
fig, axs = plt.subplots(1, 4, figsize=(7, 2))
fig, axs = plt.subplots(1, 4, figsize=(7, 1.5), dpi=500)

color1 = "navy"
color2 = "salmon"
color1 = "mediumblue"
color2 = "forestgreen"

# Plot normal distribution
ax = axs[0]
Expand Down Expand Up @@ -163,6 +163,7 @@ rule plot_samples:
ax.hist(pmi_u, bins=bins, density=True, color=color2, alpha=0.5, label="Mixture")
ax.set_title("PMI profiles")
ax.set_xlabel("PMI")
ax.set_xlim(-1, 2)
ax.set_ylabel("")
ax.set_yticks([])
ax.spines[['right', 'top', 'left']].set_visible(False)
Expand Down
8 changes: 6 additions & 2 deletions workflows/projects/Mixtures/figure_bmm_vs_other.smk
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ rule generate_figure:
data_v2 = pd.read_csv(input.v2)
data_bmm = pd.read_csv(input.bmm)

fig, axs = subplots_from_axsize(1, len(TASKS), (2.3, 0.8), left=0.8, right=0.05, top=0.3, bottom=0.3, dpi=350, wspace=0.0)
fig, axs = subplots_from_axsize(1, len(TASKS), (2.3, 0.8), left=0.8, right=0.05, top=0.3, bottom=0.3, dpi=350, wspace=0.05)

y_scaler = YScaler(estimator_ids=["BMM"] + [config.id for config in POINT_ESTIMATORS], eps=0.12)

Expand All @@ -119,6 +119,7 @@ rule generate_figure:
ax.set_xlim(*task_config.xlim)
ax.set_xticks(task_config.xticks)
ax.set_yticks([])
ax.spines[["top", "left", "right"]].set_visible(False)

mi_true = data_v2.groupby("task_id")["mi_true"].mean()[task_id]
ax.axvline(mi_true, linestyle="--", color="black", linewidth=2)
Expand All @@ -141,5 +142,8 @@ rule generate_figure:
y = y_scaler.get_y(estimator_id=estimator_id, n_points=len(estimates))
ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE)

axs[0].set_yticks(y_scaler.get_tick_locations(), ["BMM"] + [config.name for config in POINT_ESTIMATORS])
ax = axs[0]
ax.set_yticks(y_scaler.get_tick_locations(), ["BMM"] + [config.name for config in POINT_ESTIMATORS])
ax.spines["left"].set_visible(True)

fig.savefig(str(output))

0 comments on commit 3b251fd

Please sign in to comment.