Skip to content

Commit

Permalink
adds figure to nyc,llama focus outlier figure
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Jan 23, 2025
1 parent 8dcec55 commit 1cc919f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
15 changes: 12 additions & 3 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,7 @@ rule mcmc_figure:
inf_data['subject_name'] = 'all subjects'
kwargs['palette'] = pal
kwargs['hue_kws'] = {'zorder': zorder}
kwargs["partial_legend"] = "hue"
elif 'focus-subject' in wildcards.plot_type:
col = None
if 'one-ax' in wildcards.plot_type:
Expand Down Expand Up @@ -2223,8 +2224,16 @@ rule mcmc_figure:
raise Exception(f"Don't know how to handle plot type {wildcards.plot_type}!")
if 'focus-outlier' in wildcards.plot_type or 'one-ax' in wildcards.plot_type:
# don't need the legend here, it's not doing much
warnings.warn("Removing legend, because it's not doing much.")
fig.legends[0].remove()
if 'focus-outlier-legend' in wildcards.plot_type:
# in this plot, we end up with three images labeled: the
# outliers plus gnarled. want to replace the text of the gnarled
# legend with "other images"
gnarled_text = [txt for txt in fig.legends[0].texts
if txt.get_text() == "gnarled"]
gnarled_text[0].set_text("other images")
else:
warnings.warn("Removing legend, because it's not doing much.")
fig.legends[0].remove()
if 'V1' in wildcards.model_name:
warnings.warn("Removing ylabel so we don't have redundant labels when composing figure")
fig.axes[0].set_ylabel('')
Expand Down Expand Up @@ -3957,7 +3966,7 @@ def get_compose_figures_input(wildcards):
paths = [path_template.format(wildcards.fig_name.replace('performance_', '').replace('scaling-extended_', ''))]
if 'performance' in wildcards.fig_name:
if 'scaling-extended' in wildcards.fig_name:
paths.append(path_template.format('V1_norm_s6_gaussian/task-split_comp-ref_mcmc_scaling-extended_partially-pooled_performance_focus-outlier'))
paths.append(path_template.format('V1_norm_s6_gaussian/task-split_comp-ref_mcmc_scaling-extended_partially-pooled_performance_focus-outlier-legend'))
else:
paths.append(path_template.format('V1_norm_s6_gaussian/task-split_comp-ref_mcmc_partially-pooled_performance_focus-outlier'))
if 'all_comps_summary' in wildcards.fig_name:
Expand Down
11 changes: 10 additions & 1 deletion foveated_metamers/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,8 @@ def posterior_predictive_check(inf_data, col=None, row=None, hue=None,
style=None, col_wrap=5, comparison='ref',
logscale_xaxis=False, hdi=.95, query_str=None,
tabular_trial_type_legend=False,
increase_size=True, markersize=None, **kwargs):
increase_size=True, markersize=None,
partial_legend=False, **kwargs):
"""Plot posterior predictive check.
In order to make sure that our MCMC gave us a reasonable fit, we plot the
Expand Down Expand Up @@ -945,6 +946,10 @@ def posterior_predictive_check(inf_data, col=None, row=None, hue=None,
pointplot. Else, use lines.linewidth.
markersize : float or None, optional
size of data points.
partial_legend: False, "hue", or "style", optional
By default (partial_legend=False), the legend we create reflects hue and style,
if not None. By setting partial_legend="hue", we only include hue in the legend,
and correspondingly for "style".
kwargs :
passed to sns.FacetGrid
Expand Down Expand Up @@ -1035,6 +1040,10 @@ def posterior_predictive_check(inf_data, col=None, row=None, hue=None,
kwargs.get('height', 5),
col)
# create the legend
if partial_legend == "hue":
style = None
elif partial_legend == "style":
hue = None
plotting._add_legend(df, g.fig, hue, style,
kwargs.get('palette', {}),
final_markers, dashes_dict,
Expand Down

0 comments on commit 1cc919f

Please sign in to comment.