Skip to content

Commit

Permalink
Paper code
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 15, 2023
1 parent 8333b1f commit 2d596bc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/data/from_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_dataset(

train_set = torch.utils.data.Subset(dataset, list(range(500)))
valid_set = None
test_set = torch.utils.data.Subset(dataset, [0, 1, 2, 3, 4])
test_set = torch.utils.data.Subset(dataset, list(range(20)))
elif dataset_name == "cifar10":
train_set = CIFAR10(root=dataset_path, transform=transform, **dataset_config['train_dataset_params'])
valid_set = None
Expand Down
2 changes: 1 addition & 1 deletion src/data/utils/plot_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def plot_images(
fig.colorbar(im, cax=cax, orientation='vertical')

if masks is not None and mask is not None:
cmap = plt.cm.get_cmap('tab20')
cmap = plt.cm.get_cmap('viridis')
cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplist[0] = (0, 0, 0, 1.0)
cmap = colors.LinearSegmentedColormap.from_list(f'tab20_modified', cmaplist, cmap.N)
Expand Down
25 changes: 16 additions & 9 deletions src/lateral_connections/s1_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,12 +690,18 @@ def _plot_lateral_output(img,
plt_images.extend([img[view_idx], img[view_idx]])
plt_titles.extend([f"Input view {view_idx}", f"Extracted Features {view_idx}"])
background = torch.all((lateral_features[view_idx, -1] == 0), dim=0)
foreground = torch.argmax(lateral_features[view_idx, -1], dim=0)
calc_mask = torch.where(~background, foreground + 10, 0.) # +10 to make background very different

channel_activations = (torch.sum(lateral_features[view_idx, -1].reshape(4, 10, 32, 32), dim=1) > 0).float()
result = torch.zeros_like(channel_activations)[0]
for i in range(4):
result += channel_activations[i] * 2**(i+1)

# foreground = torch.argmax(lateral_features[view_idx, -1], dim=0)
calc_mask = torch.where(~background, result + 10, 0.) # +10 to make background very different
plt_masks.extend([None, calc_mask])
plt_images = self._normalize_image_list(plt_images)
plot_images(images=plt_images, titles=plt_titles, masks=plt_masks, max_cols=2, plot_colorbar=False,
vmin=0, vmax=1, mask_vmin=0, mask_vmax=lateral_features.shape[2] + 10, fig_fp=fig_fp,
vmin=0, vmax=1, mask_vmin=0, mask_vmax=32 + 10, fig_fp=fig_fp,
show_plot=show_plot)

fig_fp = self.conf['run']['plots'].get('store_path', None)
Expand All @@ -715,19 +721,20 @@ def _plot_lateral_output(img,
else:
if_fp, am_fp, hm_fp, lo_fp = None, None, None, None
if plot_input_features:
_plot_input_features(img_i[batch_idx], features_i[batch_idx], input_features_i[batch_idx],
fig_fp=if_fp, show_plot=show_plot)
pass
# _plot_input_features(img_i[batch_idx], features_i[batch_idx], input_features_i[batch_idx],
# fig_fp=if_fp, show_plot=show_plot)
elif if_fp is not None:
files.remove(if_fp)
# _plot_lateral_activation_map(lateral_features_i[batch_idx],
# fig_fp=am_fp, show_plot=show_plot)
# _plot_lateral_heat_map(lateral_features_f_i[batch_idx],
# fig_fp=hm_fp, show_plot=show_plot)
# _plot_lateral_output(img_i[batch_idx], input_features_i[batch_idx], lateral_features_i[batch_idx],
# fig_fp=lo_fp, show_plot=show_plot)
_plot_lateral_output(img_i[batch_idx], lateral_features_i[batch_idx],
fig_fp=f"../tmp/mnist_stuff/{i}.png", show_plot=show_plot)

plot_alternative_cells(img_i[batch_idx], input_features_i[batch_idx], lateral_features_i[batch_idx],
self.conf["n_alternative_cells"], fig_fp=lo_fp, show_plot=show_plot)
# plot_alternative_cells(img_i[batch_idx], input_features_i[batch_idx], lateral_features_i[batch_idx],
# self.conf["n_alternative_cells"], fig_fp=lo_fp, show_plot=show_plot)

return files

Expand Down

0 comments on commit 2d596bc

Please sign in to comment.