Skip to content

Commit

Permalink
fixes zorder
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Nov 8, 2024
1 parent e10c20c commit 415491f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
20 changes: 19 additions & 1 deletion foveated_metamers/create_metamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def find_figsizes(model_name, model, image_shape):
images by in the plots we'll create
"""
if model_name.startswith('RGC'):
if model_name.startswith('RGC') or model_name.startswith('Moments'):
animate_figsize = ((3+(image_shape[-1] / image_shape[-2])) * 5 + 2, 5.5)
# these values were selected at 72 dpi, so will need to be adjusted if
# ours is different
Expand Down Expand Up @@ -298,6 +298,24 @@ def setup_model(model_name, scaling, image, min_ecc, max_ecc, cache_dir, normali
cache_dir=cache_dir,
std_dev=std_dev,
normalize_dict=normalize_dict)
elif model_name.startswith('Moments'):
if 'norm' not in model_name:
if normalize_dict:
raise Exception("Cannot normalize Moments model (must be Moments-#_norm)!")
normalize_dict = {}
if not normalize_dict and 'norm' in model_name:
raise Exception("If model_name is Moments-#_norm, normalize_dict must be set!")
moments = int(re.findall('Moments-([0-9]+)_', model_name)[0])
moments = list(range(2, moments+1))
model = pop.PooledMoments(scaling, image.shape[-2:],
min_eccentricity=min_ecc,
max_eccentricity=max_ecc,
window_type=window_type,
transition_region_width=t_width,
cache_dir=cache_dir,
std_dev=std_dev,
normalize_dict=normalize_dict,
moments=moments)
elif model_name.startswith('V1'):
if 'norm' not in model_name:
if normalize_dict:
Expand Down
3 changes: 2 additions & 1 deletion foveated_metamers/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,8 @@ def vertical_pointplot(data, x, y, norm_y=False, **kwargs):
color = {'metamer_vs_reference': c, 'metamer_vs_metamer': 'w'}
for n, g in data.groupby('trial_type'):
ax.scatter(g[x].values, g[y].values, s=ms, marker=marker[n],
color=color[n], edgecolors=c, linewidths=lw, **kwargs)
color=color[n], edgecolors=c, linewidths=lw, zorder=2,
**kwargs)


def image_heatmap_schematic():
Expand Down

0 comments on commit 415491f

Please sign in to comment.