Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 9, 2023
1 parent f003c7f commit bdd18e7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
32 changes: 16 additions & 16 deletions configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ mnist:
train: False
download: True

mnist-subset:
dir: mnist/
beton_dir: beton/mnist/
mean: [ 0.1307 ]
std: [ 0.3081 ]
num_classes: 10
num_channels: 1
img_width: 28
img_height: 28
train_dataset_params:
train: True
download: True
test_dataset_params:
train: False
download: True
mnist-subset:
dir: mnist/
beton_dir: beton/mnist/
mean: [ 0.1307 ]
std: [ 0.3081 ]
num_classes: 10
num_channels: 1
img_width: 28
img_height: 28
train_dataset_params:
train: True
download: True
test_dataset_params:
train: False
download: True

imagenet:
dir: imagenet/
Expand Down Expand Up @@ -132,7 +132,7 @@ straightline:
split: "test"
vertical_horizontal_only: False
aug_range: 0
num_images: 4
num_images: 8
num_aug_versions: 0
noise: 0.0

Expand Down
18 changes: 9 additions & 9 deletions src/data/custom_datasets/straight_line.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from typing import Callable, List, Literal, Optional, Tuple

import math
import numpy as np
import torch
from PIL import Image, ImageDraw
Expand Down Expand Up @@ -276,17 +277,16 @@ def _plot_some_samples():
transforms.ToTensor(),
])

dataset = StraightLine(split="test", img_h=32, img_w=32, num_images=12, num_aug_versions=9, num_channels=1,
transform=transform, vertical_horizontal_only=True, noise=0.00,
aug_strategy='trajectory', aug_range=15, n_black_pixels=5)
dataset = StraightLine(split="train", num_images=8, num_aug_versions=0, num_channels=1,
transform=transform, vertical_horizontal_only=False, noise=0.00)

fig, axs = plt.subplots(12, 10, figsize=(10, 12))
for i in range(12):
fig, axs = plt.subplots(1, 8, figsize=(16, 2))
for i in range(8):
img, meta = dataset.get_item(i, n_black_pixels=0)
for idx in range(img.shape[0]):
j = i * 10 + idx
axs[j // 10, j % 10].imshow(img[idx].squeeze(), vmin=0, vmax=1, cmap='gray')
axs[j // 10, j % 10].axis('off')
alpha = -round(math.atan((meta['line_coords'][1][1]-meta['line_coords'][0][1]) / (meta['line_coords'][1][0]-meta['line_coords'][0][0])) * 180 / math.pi, 2)
axs[i].set_title(f"{alpha:.2f}°")
axs[i].imshow(torch.where(img==1, 0, 1).squeeze(), vmin=0, vmax=1.2, cmap='gray')
axs[i].axis('off')
plt.tight_layout()
plt.show()

Expand Down
18 changes: 9 additions & 9 deletions src/lateral_connections/s1_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,15 +647,15 @@ def _plot_lateral_output(img,
files.extend([if_fp, am_fp, hm_fp, lo_fp])
else:
if_fp, am_fp, hm_fp, lo_fp = None, None, None, None
if plot_input_features:
_plot_input_features(img_i[batch_idx], features_i[batch_idx], input_features_i[batch_idx],
fig_fp=if_fp, show_plot=show_plot)
elif if_fp is not None:
files.remove(if_fp)
_plot_lateral_activation_map(lateral_features_i[batch_idx],
fig_fp=am_fp, show_plot=show_plot)
_plot_lateral_heat_map(lateral_features_f_i[batch_idx],
fig_fp=hm_fp, show_plot=show_plot)
# if plot_input_features:
# _plot_input_features(img_i[batch_idx], features_i[batch_idx], input_features_i[batch_idx],
# fig_fp=if_fp, show_plot=show_plot)
# elif if_fp is not None:
# files.remove(if_fp)
# _plot_lateral_activation_map(lateral_features_i[batch_idx],
# fig_fp=am_fp, show_plot=show_plot)
# _plot_lateral_heat_map(lateral_features_f_i[batch_idx],
# fig_fp=hm_fp, show_plot=show_plot)
_plot_lateral_output(img_i[batch_idx], lateral_features_i[batch_idx],
fig_fp=lo_fp, show_plot=show_plot)

Expand Down

0 comments on commit bdd18e7

Please sign in to comment.