Skip to content

Commit

Permalink
Paper code
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 13, 2023
1 parent bf9d539 commit 61c2cf9
Show file tree
Hide file tree
Showing 11 changed files with 646 additions and 123 deletions.
2 changes: 1 addition & 1 deletion configs/lateral_connection_alternative_cells.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset:
num_workers: 0

run:
n_epochs: 15
n_epochs: 101
current_epoch: 0
plots:
enable: True
Expand Down
2 changes: 1 addition & 1 deletion configs/lateral_connection_baseline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset:
num_workers: 0

run:
n_epochs: 30
n_epochs: 20
current_epoch: 0
plots:
enable: True
Expand Down
7 changes: 6 additions & 1 deletion src/data/utils/plot_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torchvision.transforms as T
from PIL.Image import Image
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable

def undo_norm(img: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
Expand Down Expand Up @@ -130,7 +131,11 @@ def plot_images(
fig.colorbar(im, cax=cax, orientation='vertical')

if masks is not None and mask is not None:
ax.imshow(mask, alpha=0.6, cmap='jet', interpolation=interpolation_, vmin=mask_vmin, vmax=mask_vmax)
cmap = plt.cm.get_cmap('jet')
cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplist[0] = (0, 0, 0, 1.0)
cmap = colors.LinearSegmentedColormap.from_list(f'jet_modified', cmaplist, cmap.N)
ax.imshow(mask, alpha=0.7, cmap=cmap, interpolation=interpolation_, vmin=mask_vmin, vmax=mask_vmax)

if lbl is not None:
lbl = str(lbl.item()) if isinstance(lbl, torch.Tensor) else lbl
Expand Down
3 changes: 3 additions & 0 deletions src/eval_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python main_evaluation.py lateral_connection_alternative_cells --load alternative2.ckp --line_interrupt 5 &&
python main_evaluation.py lateral_connection_alternative_cells --load alternative2.ckp &&
python main_evaluation.py lateral_connection_alternative_cells --load alternative2.ckp --add_noise
2 changes: 1 addition & 1 deletion src/lateral_connections/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from stage_1.feature_extractor import *
from lateral_connections.feature_extractor import *
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
from lightning import Fabric
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import Tensor
from torchvision import utils

Expand Down Expand Up @@ -207,20 +208,28 @@ def _hist_plot(ax, weight, title):
ax.set_xlabel(f'Bins form {min:.4f} to {max:.4f}')
ax.set_title(title)

def _plot_weights(ax, weight, title):
def _plot_weights(fig, ax, weight, title):
weight_img_list = [weight[i, j].unsqueeze(0) for j in range(weight.shape[1]) for i in
range(weight.shape[0])]
# Order is [(0, 0), (1, 0), ..., (3, 0), (0, 1), ..., (3, 7)]
# The columns show the output channels, the rows the input channels
grid = utils.make_grid(weight_img_list, nrow=weight.shape[0], normalize=True, scale_each=True, pad_value=1)
ax.imshow(grid.permute(1, 2, 0), interpolation='none')
#grid = grid / 2 - 1/6 # Normalize to [-1/6, 1/3]
im = ax.imshow(grid[:, 2:-2, 2:-2].permute(1, 2, 0), interpolation='none', cmap="gray")#, vmin=-1/6, vmax=1/3)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])

files = []
for layer, weight in [('feature extractor', self.model.weight)]:
fig, axs = plt.subplots(1, 2, figsize=(8, 5))
fig, axs = plt.subplots(1, 2, figsize=(16, 10))
_hist_plot(axs[0], weight.detach().cpu(), f"Weight distribution ({layer})")
_plot_weights(axs[1], weight[:20, :20, ...].detach().cpu(), f"Weight matrix ({layer})")
_plot_weights(fig, axs[1], weight[:20, :20, ...].detach().cpu(), f"Weight matrix ({layer})")
plt.tight_layout()

fig_fp = self.conf['run']['plots'].get('store_path', None)
Expand Down
229 changes: 123 additions & 106 deletions src/lateral_connections/s1_lateral_connections.py

Large diffs are not rendered by default.

Loading

0 comments on commit 61c2cf9

Please sign in to comment.