Skip to content

Commit

Permalink
Playground with MNIST...
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 2, 2023
1 parent b924cf0 commit 0b65484
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 79 deletions.
30 changes: 15 additions & 15 deletions configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ mnist:
train: False
download: True

mnist-subset:
dir: mnist/
beton_dir: beton/mnist/
mean: [ 0.1307 ]
std: [ 0.3081 ]
num_classes: 10
num_channels: 1
img_width: 28
img_height: 28
train_dataset_params:
train: True
download: True
test_dataset_params:
train: False
download: True
mnist-subset:
dir: mnist/
beton_dir: beton/mnist/
mean: [ 0.1307 ]
std: [ 0.3081 ]
num_classes: 10
num_channels: 1
img_width: 32
img_height: 32
train_dataset_params:
train: True
download: True
test_dataset_params:
train: False
download: True

imagenet:
dir: imagenet/
Expand Down
14 changes: 7 additions & 7 deletions configs/lateral_connection_alternative_cells_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset:
num_workers: 0

run:
n_epochs: 15
n_epochs: 80
current_epoch: 0
plots:
enable: True
Expand Down Expand Up @@ -59,25 +59,25 @@ logging:
feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
bin_threshold: 0.5 # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

n_alternative_cells: 10

lateral_model:
channels: 4
max_timesteps: 1 # 1 = only one forward pass without recurrent connection
max_timesteps: 5 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 3
locality_size: 2
lr: 20.0
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.15 # 'bernoulli'
square_factor: 1.5
support_factor: 5
act_threshold: 0.5 # 'bernoulli'
square_factor: 1
support_factor: 2.0

l2:
k: 1
Expand Down
9 changes: 7 additions & 2 deletions src/data/from_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ def _get_dataset(
test_set = MNIST(root=dataset_path, transform=transform, **dataset_config['test_dataset_params'])
elif dataset_name == "mnist-subset":
transform = transforms.Compose([transforms.ToTensor(), transforms.Pad(2)])
train_set = torch.utils.data.Subset(MNIST(root=dataset_path, transform=transform, **dataset_config['test_dataset_params']), list(range(500)))
dataset = MNIST(root=dataset_path, transform=transform, **dataset_config['test_dataset_params'])
idx = dataset.train_labels == 0
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]

train_set = torch.utils.data.Subset(dataset, list(range(500)))
valid_set = None
test_set = torch.utils.data.Subset(MNIST(root=dataset_path, transform=transform, **dataset_config['test_dataset_params']), [1, 2, 3, 4, 5])
test_set = torch.utils.data.Subset(dataset, [0, 1, 2, 3, 4])
elif dataset_name == "cifar10":
train_set = CIFAR10(root=dataset_path, transform=transform, **dataset_config['train_dataset_params'])
valid_set = None
Expand Down
2 changes: 1 addition & 1 deletion src/data/utils/plot_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def plot_images(
fig.colorbar(im, cax=cax, orientation='vertical')

if masks is not None and mask is not None:
ax.imshow(mask, alpha=0.6, cmap='jet', interpolation=interpolation_, vmin=mask_vmin, vmax=mask_vmax)
ax.imshow(mask, alpha=0.6, cmap='gist_ncar', interpolation=interpolation_, vmin=mask_vmin, vmax=mask_vmax)

if lbl is not None:
lbl = str(lbl.item()) if isinstance(lbl, torch.Tensor) else lbl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,44 @@ def __init__(
super(Conv2dFixedFilters, self).__init__()

if optimized_filter_lines:
self.weight = torch.tensor([[[[+0, -1, +0, -1, +0],
[+0, -1, +2, -1, +0],
[+0, -1, +2, -1, +0],
[+0, -1, +2, -1, +0],
[+0, -1, +0, -1, +0]]],
[[[+0, +0, -1, -1, +0],
[+0, -1, +0, +2, -1],
[-1, +0, +2, +0, -1],
[-1, +2, +0, -1, +0],
[+0, -1, -1, +0, +0]]],
[[[+0, +0, +0, +0, +0],
[-1, -1, -1, -1, -1],
[+0, +2, +2, +2, +0],
[-1, -1, -1, -1, -1],
[+0, +0, +0, +0, +0]]],
[[[+0, -1, -1, +0, +0],
[-1, +2, +0, -1, +0],
[-1, +0, +2, +0, -1],
[+0, -1, +0, +2, -1],
[+0, +0, -1, -1, +0]]], # Filter could be further improved by setting 4x +0 in the middle to -1
], dtype=torch.float32, requires_grad=False).to(fabric.device)
self.weight = self.weight / 6
self.weight = torch.tensor([
[[[0.0, 0.2, 0.0],
[0.0, 0.2, 0.0],
[0.0, 0.2, 0.0]]],
[[[0.0, 0.0, 0.2],
[0.0, 0.2, 0.0],
[0.2, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0],
[0.2, 0.2, 0.2],
[0.0, 0.0, 0.0]]],
[[[0.2, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, 0.0, 0.2]]],
], dtype=torch.float32, requires_grad=False).to(fabric.device)


# self.weight = torch.tensor([[[[+0, -1, +0, -1, +0],
# [+0, -1, +2, -1, +0],
# [+0, -1, +2, -1, +0],
# [+0, -1, +2, -1, +0],
# [+0, -1, +0, -1, +0]]],
# [[[+0, +0, -1, -1, +0],
# [+0, -1, +0, +2, -1],
# [-1, +0, +2, +0, -1],
# [-1, +2, +0, -1, +0],
# [+0, -1, -1, +0, +0]]],
# [[[+0, +0, +0, +0, +0],
# [-1, -1, -1, -1, -1],
# [+0, +2, +2, +2, +0],
# [-1, -1, -1, -1, -1],
# [+0, +0, +0, +0, +0]]],
# [[[+0, -1, -1, +0, +0],
# [-1, +2, +0, -1, +0],
# [-1, +0, +2, +0, -1],
# [+0, -1, +0, +2, -1],
# [+0, +0, -1, -1, +0]]], # Filter could be further improved by setting 4x +0 in the middle to -1
# ], dtype=torch.float32, requires_grad=False).to(fabric.device)
# self.weight = self.weight / 6
else:
self.weight = torch.tensor([[[[-1, +2, -1],
[-1, +2, -1],
Expand Down
71 changes: 41 additions & 30 deletions src/lateral_connections/s1_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self,
self.mask = None
self.square_factor = square_factor
self.support_factor = support_factor
self.step = 0

assert self.hebbian_rule in ['vanilla'], \
f"hebbian_rule must be 'vanilla', but is {self.hebbian_rule}"
Expand Down Expand Up @@ -232,31 +233,41 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]:
max_support = self.support_factor * self.kernel_size[0]
x_lateral_norm = torch.where(x_lateral < max_support, x_lateral, max_support - .5 * (x_lateral - max_support))

# Normalize by dividing through the sum of the weights
x_lateral_norm = x_lateral_norm / (1e-10 + torch.sum(self.W_lateral.data, dim=(1, 2, 3)).view(1, -1, 1, 1))

# TODO: Is the code below necessary?
# if self.ts > 0:
# x_lateral_norm = (x_lateral_norm + self.ts * self.x_lateral_norm_prev) / (self.ts + 1)
# self.x_lateral_norm_prev = x_lateral_norm

if self.n_alternative_cells <= 1: # TODO: is this if/else necessary?
# Bring activation in range [0, 1]
x_lateral_norm_s = x_lateral_norm.shape
x_lateral_norm /= (
1e-10 + x_lateral_norm.view(-1, x_lateral_norm_s[2] * x_lateral_norm_s[3]).max(1)[0].view(
x_lateral_norm_s[:2] + (1, 1)))

else:
# Normalize per alternative channel
x_lateral_norm_s = x_lateral_norm.shape
x_lateral_norm = x_lateral_norm.reshape((x_lateral_norm.shape[0],
x_lateral_norm.shape[1] // self.n_alternative_cells,
self.n_alternative_cells) + x_lateral_norm.shape[2:])
x_lateral_norm_alt_max = x_lateral_norm.view(x_lateral_norm.shape[:2] + (-1,)).max(dim=2)[0]
x_lateral_norm = x_lateral_norm / (
1e-10 + x_lateral_norm_alt_max.reshape(x_lateral_norm_alt_max.shape + (1, 1, 1)))
x_lateral_norm = x_lateral_norm.reshape(x_lateral_norm_s)
# Divide by mean support per channel ?


# self.step += 1
# max_step = 2000
# slow_start = min(self.step, max_step) / 2000
# slow_start = max(slow_start, 0.2)

min_support = x_lateral_norm.reshape(x_lateral_norm.shape[1], -1).max(dim=(1))[0] / 1.5
x_lateral_norm = torch.where(x_lateral_norm < min_support.view(1, -1, 1, 1), 0, 1)

# upper_support = torch.min(torch.ones_like(x_lateral_norm[:, 0, 0, 0]) * max_support, x_lateral_norm.reshape(x_lateral_norm.shape[0], -1).max(dim=(1))[0])
# x_lateral_norm /= upper_support.view(-1, 1, 1, 1)

# # Normalize by dividing through the sum of the weights
# x_lateral_norm = x_lateral_norm / (1e-10 + torch.sum(self.W_lateral.data, dim=(1, 2, 3)).view(1, -1, 1, 1))
#
# if self.n_alternative_cells <= 1: # TODO: is this if/else necessary?
# # Bring activation in range [0, 1]
# x_lateral_norm_s = x_lateral_norm.shape
# x_lateral_norm /= (
# 1e-10 + x_lateral_norm.view(-1, x_lateral_norm_s[2] * x_lateral_norm_s[3]).max(1)[0].view(
# x_lateral_norm_s[:2] + (1, 1)))
#
# else:
# # Normalize per alternative channel
# x_lateral_norm_s = x_lateral_norm.shape
# x_lateral_norm = x_lateral_norm.reshape((x_lateral_norm.shape[0],
# x_lateral_norm.shape[1] // self.n_alternative_cells,
# self.n_alternative_cells) + x_lateral_norm.shape[2:])
# x_lateral_norm_alt_max = x_lateral_norm.view(x_lateral_norm.shape[:2] + (-1,)).max(dim=2)[0]
# x_lateral_norm = x_lateral_norm / (
# 1e-10 + x_lateral_norm_alt_max.reshape(x_lateral_norm_alt_max.shape + (1, 1, 1)))
# x_lateral_norm = x_lateral_norm.reshape(x_lateral_norm_s)

if self.act_threshold == "bernoulli":
x_lateral_bin = torch.bernoulli(torch.clip(x_lateral_norm ** self.square_factor, 0, 1))
Expand Down Expand Up @@ -625,11 +636,11 @@ def _plot_lateral_output(img,
plt_titles.extend([f"Input view {view_idx}", f"Extracted Features {view_idx}"])
background = torch.all((lateral_features[view_idx, -1] == 0), dim=0)
foreground = torch.argmax(lateral_features[view_idx, -1], dim=0)
calc_mask = torch.where(~background, foreground + 1, 0.)
calc_mask = torch.where(~background, foreground + 10, 0.) # +10 to make background very different
plt_masks.extend([None, calc_mask])
plt_images = self._normalize_image_list(plt_images)
plot_images(images=plt_images, titles=plt_titles, masks=plt_masks, max_cols=2, plot_colorbar=True,
vmin=0, vmax=1, mask_vmin=0, mask_vmax=lateral_features.shape[2] + 1, fig_fp=fig_fp,
vmin=0, vmax=1, mask_vmin=0, mask_vmax=lateral_features.shape[2] + 10, fig_fp=fig_fp,
show_plot=show_plot)

fig_fp = self.conf['run']['plots'].get('store_path', None)
Expand All @@ -652,10 +663,10 @@ def _plot_lateral_output(img,
fig_fp=if_fp, show_plot=show_plot)
elif if_fp is not None:
files.remove(if_fp)
_plot_lateral_activation_map(lateral_features_i[batch_idx],
fig_fp=am_fp, show_plot=show_plot)
_plot_lateral_heat_map(lateral_features_f_i[batch_idx],
fig_fp=hm_fp, show_plot=show_plot)
# _plot_lateral_activation_map(lateral_features_i[batch_idx],
# fig_fp=am_fp, show_plot=show_plot)
# _plot_lateral_heat_map(lateral_features_f_i[batch_idx],
# fig_fp=hm_fp, show_plot=show_plot)
_plot_lateral_output(img_i[batch_idx], lateral_features_i[batch_idx],
fig_fp=lo_fp, show_plot=show_plot)

Expand Down
4 changes: 2 additions & 2 deletions src/main_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def single_eval_epoch(
plt_activations_f,
plot_input_features=epoch == 0,
show_plot=plot)
weights_fp = lateral_network.plot_model_weights(show_plot=plot)
plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
# weights_fp = lateral_network.plot_model_weights(show_plot=plot)
# plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
if epoch == config['run']['n_epochs']:
videos_fp = lateral_network.create_activations_video(plt_img, plt_input_features, plt_activations)

Expand Down

0 comments on commit 0b65484

Please sign in to comment.