Skip to content

Commit

Permalink
Optionally add noise to data
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 8, 2023
1 parent 5dea761 commit dcfd2c1
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion configs/lateral_connection_alternative_cells_arc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset:
num_workers: 0

run:
n_epochs: 15
n_epochs: 150
current_epoch: 0
plots:
enable: True
Expand Down
21 changes: 19 additions & 2 deletions src/data/custom_datasets/arc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import random
from pathlib import Path
from typing import List, Literal, Tuple

import torch.nn.functional as F
import numpy as np
import torch
from matplotlib.colors import LinearSegmentedColormap
Expand All @@ -10,7 +11,8 @@

class ArcDataset(Dataset):

def __init__(self):
def __init__(self, add_noise: bool = False):
self.add_noise = add_noise
self.path = Path('../data/arc_subset.json')
self.tasks = self.load_tasks()
self.tasks = self.load_tasks()
Expand Down Expand Up @@ -42,6 +44,21 @@ def get_item(self,
metadata['pad'] = (pad_left, pad_right, pad_top, pad_bottom)
data = torch.nn.functional.pad(data, (pad_left, pad_right, pad_top, pad_bottom), "constant", 0)

if self.add_noise:
# data = data.argmax(dim=2).float()

noise_probability = 0.01
noise = torch.rand_like(data)
noise_indices = random.sample(range(noise.numel()), int((1 - noise_probability) * noise.numel()))
noise_s = noise.shape
noise = noise.view(-1)
noise[noise_indices] = 0
noise = noise.view(noise_s)

data = (data + (noise * 10).round()) % 10
# data = F.one_hot(data.long(), num_classes=10)
# data = data.permute(0, 1, 4, 2, 3).float()

if one_hot:
data = torch.nn.functional.one_hot(data.to(torch.int64), num_classes=10).permute(2, 0, 1).float()

Expand Down
4 changes: 2 additions & 2 deletions src/data/from_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def _get_dataset(
valid_set = EightBitDataset(transform=transform, **dataset_config['valid_dataset_params'])
test_set = EightBitDataset(transform=transform, **dataset_config['test_dataset_params'])
elif dataset_name == "arc":
train_set = ArcDataset()
train_set = ArcDataset(add_noise=False)
valid_set = None
test_set = ArcDataset()
test_set = ArcDataset(add_noise=True)
else:
raise ValueError("Unknown dataset name: {}".format(dataset_name))

Expand Down
3 changes: 1 addition & 2 deletions src/main_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import wandb
from lightning import Fabric
from torch import Tensor
Expand Down Expand Up @@ -353,7 +352,7 @@ def single_eval_epoch(
plt_activations_f,
plot_input_features=epoch == 0,
show_plot=plot)
# weights_fp = lateral_network.plot_model_weights(show_plot=plot)
weights_fp = lateral_network.plot_model_weights(show_plot=plot)
# plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
if epoch == config['run']['n_epochs']:
videos_fp = lateral_network.create_activations_video(plt_img, plt_input_features, plt_activations)
Expand Down

0 comments on commit dcfd2c1

Please sign in to comment.