From 415491f4d933c417a9f4a18547876c4770da63fa Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Fri, 8 Nov 2024 13:02:36 -0500 Subject: [PATCH] fixes zorder --- foveated_metamers/create_metamers.py | 20 +++++++++++++++++++- foveated_metamers/plotting.py | 3 ++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/foveated_metamers/create_metamers.py b/foveated_metamers/create_metamers.py index 812dddf..8ff4f3e 100644 --- a/foveated_metamers/create_metamers.py +++ b/foveated_metamers/create_metamers.py @@ -124,7 +124,7 @@ def find_figsizes(model_name, model, image_shape): images by in the plots we'll create """ - if model_name.startswith('RGC'): + if model_name.startswith('RGC') or model_name.startswith('Moments'): animate_figsize = ((3+(image_shape[-1] / image_shape[-2])) * 5 + 2, 5.5) # these values were selected at 72 dpi, so will need to be adjusted if # ours is different @@ -298,6 +298,24 @@ def setup_model(model_name, scaling, image, min_ecc, max_ecc, cache_dir, normali cache_dir=cache_dir, std_dev=std_dev, normalize_dict=normalize_dict) + elif model_name.startswith('Moments'): + if 'norm' not in model_name: + if normalize_dict: + raise Exception("Cannot normalize Moments model (must be Moments-#_norm)!") + normalize_dict = {} + if not normalize_dict and 'norm' in model_name: + raise Exception("If model_name is Moments-#_norm, normalize_dict must be set!") + moments = int(re.findall('Moments-([0-9]+)_', model_name)[0]) + moments = list(range(2, moments+1)) + model = pop.PooledMoments(scaling, image.shape[-2:], + min_eccentricity=min_ecc, + max_eccentricity=max_ecc, + window_type=window_type, + transition_region_width=t_width, + cache_dir=cache_dir, + std_dev=std_dev, + normalize_dict=normalize_dict, + moments=moments) elif model_name.startswith('V1'): if 'norm' not in model_name: if normalize_dict: diff --git a/foveated_metamers/plotting.py b/foveated_metamers/plotting.py index 8fcba04..99bd42a 100644 --- a/foveated_metamers/plotting.py +++ b/foveated_metamers/plotting.py @@ -2175,7 +2175,8 @@ def vertical_pointplot(data, x, y, norm_y=False, **kwargs): color = {'metamer_vs_reference': c, 'metamer_vs_metamer': 'w'} for n, g in data.groupby('trial_type'): ax.scatter(g[x].values, g[y].values, s=ms, marker=marker[n], - color=color[n], edgecolors=c, linewidths=lw, **kwargs) + color=color[n], edgecolors=c, linewidths=lw, zorder=2, + **kwargs) def image_heatmap_schematic():