Skip to content

Commit

Permalink
Improve plot
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 3, 2023
1 parent 0b65484 commit 8333b1f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 8 deletions.
7 changes: 6 additions & 1 deletion src/data/utils/plot_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torchvision.transforms as T
from PIL.Image import Image
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as colors

def undo_norm(img: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
"""
Expand Down Expand Up @@ -130,7 +131,11 @@ def plot_images(
fig.colorbar(im, cax=cax, orientation='vertical')

if masks is not None and mask is not None:
ax.imshow(mask, alpha=0.6, cmap='gist_ncar', interpolation=interpolation_, vmin=mask_vmin, vmax=mask_vmax)
cmap = plt.cm.get_cmap('tab20')
cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplist[0] = (0, 0, 0, 1.0)
cmap = colors.LinearSegmentedColormap.from_list(f'tab20_modified', cmaplist, cmap.N)
ax.imshow(mask, alpha=0.7, cmap=cmap, interpolation=interpolation_, vmin=mask_vmin, vmax=mask_vmax)

if lbl is not None:
lbl = str(lbl.item()) if isinstance(lbl, torch.Tensor) else lbl
Expand Down
73 changes: 66 additions & 7 deletions src/lateral_connections/s1_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self,
self.n_alternative_cells = n_alternative_cells
self.in_channels = in_channels
self.out_channels = out_channels
self.in_feature_channels = self.in_channels - self.out_channels
self.locality_size = locality_size
self.neib_size = 2 * self.locality_size + 1
self.kernel_size = (self.neib_size, self.neib_size)
Expand Down Expand Up @@ -284,12 +285,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]:
# Shape w2: 5324, 40, 1
# Shape x2: 5324, 1, 1024
d2 = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
w2 = (self.W_lateral.reshape(40, d2, 1).permute(1, 0, 2) > 0).float()
w2 = (self.W_lateral.reshape(self.out_channels, d2, 1).permute(1, 0, 2) > 0).float()
x2 = x_rearranged.reshape(d2, 1, 1024).permute(0, 1, 2)
pos_corr3 = torch.matmul(w2, x2) # 40,5324,1 * 40,1,1024
neg_corr3 = torch.matmul(w2, 1 - x2) + torch.matmul(1 - w2, x2)
pos_corr3 = pos_corr3.permute(2, 1, 0).reshape(1024, 4, 10, d2)
neg_corr3 = neg_corr3.permute(2, 1, 0).reshape(1024, 4, 10, d2)
pos_corr3 = pos_corr3.permute(2, 1, 0).reshape(1024, self.in_feature_channels, self.n_alternative_cells, d2)
neg_corr3 = neg_corr3.permute(2, 1, 0).reshape(1024, self.in_feature_channels, self.n_alternative_cells, d2)

pos_corr = pos_corr3
neg_corr = neg_corr3
Expand Down Expand Up @@ -353,7 +354,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]:
best_channel = best_channel.reshape((x_lateral_bin.shape[0], ) + x_lateral_bin.shape[2:] + (-1, ))

x_lateral_bin_reshaped = x_lateral_bin.reshape((x_lateral_bin.shape[0], x_lateral_bin.shape[1] // self.n_alternative_cells, self.n_alternative_cells) + x_lateral_bin.shape[2:]).permute(0, 3, 4, 1, 2)
self.mask = (torch.arange(0, 10).view(1, 1, 1, 1, -1).cuda() == best_channel.unsqueeze(-1))
self.mask = (torch.arange(0, self.n_alternative_cells).view(1, 1, 1, 1, -1).cuda() == best_channel.unsqueeze(-1))
assert torch.all(torch.sum(self.mask, dim=4) == 1)

else:
Expand Down Expand Up @@ -625,6 +626,60 @@ def _plot_lateral_heat_map(lateral_features_f, fig_fp: Optional[str] = None, sho
plot_images(images=plt_images, titles=plt_titles, max_cols=lateral_features_f.shape[2], plot_colorbar=True,
fig_fp=fig_fp, cmap='hot', interpolation='nearest', vmin=v_min, vmax=v_max, show_plot=show_plot)


def plot_alternative_cells(
img,
input_features,
lateral_features,
n_alternative_cells,
fig_fp: Optional[str] = None,
show_plot: Optional[bool] = False
):
max_views = 1
n_channels = input_features.shape[1]

plt_images, plt_titles, plt_masks = [], [], []
for view_idx in range(min(max_views, lateral_features.shape[0])):
img_norm = self._normalize_image_list([img[view_idx]])

plt_titles.append(f"Input")
plt_images.extend(img_norm * (n_channels + 1))

# input features
plt_titles.extend(["Input Features"] * n_channels)
masks = [input_features[view_idx, c] for c in range(n_channels)]
plt_masks.extend([None] + masks)

for time_idx in range(lateral_features.shape[1]):
ti = "avg" if time_idx == lateral_features.shape[1] - 1 else time_idx
plt_titles.append(f"Input T={ti}")
plt_images.extend(img_norm * (n_channels + 1))
plt_masks.append(None)

# lateral features
plt_titles.extend(["Lat. Features"] * n_channels)
lf = lateral_features[view_idx, time_idx].reshape(n_channels, n_alternative_cells, *lateral_features.shape[-2:])
for c in range(n_channels):
background = torch.all((lf[c] == 0), dim=0)
foreground = torch.argmax(lf[c], dim=0)
assert torch.sum(lf[c], dim=0).max() <= 1, "Only one cell should be active"
calc_mask = torch.where(~background, foreground + 1, 0.)
plt_masks.append(calc_mask)

plot_images(images=plt_images, titles=plt_titles, masks=plt_masks, max_cols=n_channels+1, plot_colorbar=False,
vmin=0, vmax=1, mask_vmin=0, mask_vmax=n_alternative_cells + 1, fig_fp=fig_fp,
show_plot=show_plot)









# TODO: Dieser Plot ist fürn Arsch weil jeweils nur 1 Channel pro Stelle geplottet wird.
# Plotte jeweils eine Farbe pro alt channel?
def _plot_lateral_output(img,
lateral_features,
fig_fp: Optional[str] = None,
Expand All @@ -639,7 +694,7 @@ def _plot_lateral_output(img,
calc_mask = torch.where(~background, foreground + 10, 0.) # +10 to make background very different
plt_masks.extend([None, calc_mask])
plt_images = self._normalize_image_list(plt_images)
plot_images(images=plt_images, titles=plt_titles, masks=plt_masks, max_cols=2, plot_colorbar=True,
plot_images(images=plt_images, titles=plt_titles, masks=plt_masks, max_cols=2, plot_colorbar=False,
vmin=0, vmax=1, mask_vmin=0, mask_vmax=lateral_features.shape[2] + 10, fig_fp=fig_fp,
show_plot=show_plot)

Expand All @@ -655,6 +710,7 @@ def _plot_lateral_output(img,
am_fp = fig_fp / f'{base_name}_lateral_act_maps.png'
hm_fp = fig_fp / f'{base_name}_lateral_heat_maps.png'
lo_fp = fig_fp / f'{base_name}_lateral_output.png'
lo_fp = fig_fp / f'{base_name}_alt_cells.png'
files.extend([if_fp, am_fp, hm_fp, lo_fp])
else:
if_fp, am_fp, hm_fp, lo_fp = None, None, None, None
Expand All @@ -667,8 +723,11 @@ def _plot_lateral_output(img,
# fig_fp=am_fp, show_plot=show_plot)
# _plot_lateral_heat_map(lateral_features_f_i[batch_idx],
# fig_fp=hm_fp, show_plot=show_plot)
_plot_lateral_output(img_i[batch_idx], lateral_features_i[batch_idx],
fig_fp=lo_fp, show_plot=show_plot)
# _plot_lateral_output(img_i[batch_idx], input_features_i[batch_idx], lateral_features_i[batch_idx],
# fig_fp=lo_fp, show_plot=show_plot)

plot_alternative_cells(img_i[batch_idx], input_features_i[batch_idx], lateral_features_i[batch_idx],
self.conf["n_alternative_cells"], fig_fp=lo_fp, show_plot=show_plot)

return files

Expand Down

0 comments on commit 8333b1f

Please sign in to comment.