diff --git a/configs/angle_prediction.yaml b/configs/angle_prediction.yaml new file mode 100644 index 0000000..bd5a45d --- /dev/null +++ b/configs/angle_prediction.yaml @@ -0,0 +1,84 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 100 + current_epoch: 0 + store_state_path: ../checkpoints/s0/ae + # load_state_path: ../checkpoints/s1/vq_vae_full_image/s1_2023-05-04_15-48-50.ckpt + visualize_plots: True + +optimizers: + opt1: + type: Adam + params: + lr: 0.001 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +metrics: + - mae: + type: MeanAbsoluteError + meter: avg + - mape: + type: MeanAbsolutePercentageError + meter: avg + - mse: + type: MeanSquaredError + meter: avg + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "s0" + log_model: True + group: "autoencoder" + job_type: "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +model: + type: angle-predictor + params: + kernel_size: 5 + +n_alternative_cells: 1 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.5 # 'bernoulli' + square_factor: 1.2 + support_factor: 1.3 \ No newline at end of file diff --git a/configs/data.yaml b/configs/data.yaml index 10184b5..f9fe6bc 100644 --- a/configs/data.yaml +++ b/configs/data.yaml @@ -132,7 +132,7 @@ straightline: split: "test" vertical_horizontal_only: False aug_range: 0 - num_images: 4 + num_images: 10 num_aug_versions: 0 noise: 0.0 diff --git a/configs/lateral_connection_alternative_cells.yaml b/configs/lateral_connection_alternative_cells.yaml index b023666..02afdb2 100644 --- a/configs/lateral_connection_alternative_cells.yaml +++ b/configs/lateral_connection_alternative_cells.yaml @@ -62,11 +62,11 @@ feature_extractor: bin_threshold: 0. # set to 0.5 to obtain better features optimized_filter_lines: True # set to True to obtain better features -n_alternative_cells: 20 +n_alternative_cells: 10 lateral_model: channels: 4 - max_timesteps: 6 # 1 = only one forward pass without recurrent connection + max_timesteps: 2 # 1 = only one forward pass without recurrent connection min_k: 2 max_k: 3 l1_type: 'lateral_flex' @@ -76,7 +76,13 @@ lateral_model: hebbian_rule: 'vanilla' neg_corr: True act_threshold: 0.5 # 'bernoulli' - square_factor: 1.2 + square_factor: + - 1.2 + - 1.4 + - 1.6 + - 1.8 + - 2.0 + - 2.2 support_factor: 1.3 l2: diff --git a/configs/lateral_connection_alternative_cells_2.yaml b/configs/lateral_connection_alternative_cells_2.yaml new file mode 100644 index 0000000..8ac2e7f --- /dev/null +++ b/configs/lateral_connection_alternative_cells_2.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.5 # 'bernoulli' + square_factor: + - 2.0 + - 2.1 + - 2.2 + - 2.3 + - 2.4 + - 2.5 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_3.yaml b/configs/lateral_connection_alternative_cells_3.yaml new file mode 100644 index 0000000..5b08d3a --- /dev/null +++ b/configs/lateral_connection_alternative_cells_3.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.5 # 'bernoulli' + square_factor: + - 0.7 + - 0.9 + - 1.1 + - 1.3 + - 1.5 + - 1.7 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_4.yaml b/configs/lateral_connection_alternative_cells_4.yaml new file mode 100644 index 0000000..d6b6df7 --- /dev/null +++ b/configs/lateral_connection_alternative_cells_4.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.3 # 'bernoulli' + square_factor: + - 1.2 + - 1.4 + - 1.6 + - 1.8 + - 2.0 + - 2.2 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_5.yaml b/configs/lateral_connection_alternative_cells_5.yaml new file mode 100644 index 0000000..04671b4 --- /dev/null +++ b/configs/lateral_connection_alternative_cells_5.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.3 # 'bernoulli' + square_factor: + - 2.0 + - 2.1 + - 2.2 + - 2.3 + - 2.4 + - 2.5 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_6.yaml b/configs/lateral_connection_alternative_cells_6.yaml new file mode 100644 index 0000000..f350493 --- /dev/null +++ b/configs/lateral_connection_alternative_cells_6.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.3 # 'bernoulli' + square_factor: + - 0.7 + - 0.9 + - 1.1 + - 1.3 + - 1.5 + - 1.7 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_7.yaml b/configs/lateral_connection_alternative_cells_7.yaml new file mode 100644 index 0000000..b7afe3f --- /dev/null +++ b/configs/lateral_connection_alternative_cells_7.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.7 # 'bernoulli' + square_factor: + - 1.2 + - 1.4 + - 1.6 + - 1.8 + - 2.0 + - 2.2 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_8.yaml b/configs/lateral_connection_alternative_cells_8.yaml new file mode 100644 index 0000000..2d858b5 --- /dev/null +++ b/configs/lateral_connection_alternative_cells_8.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.7 # 'bernoulli' + square_factor: + - 2.0 + - 2.1 + - 2.2 + - 2.3 + - 2.4 + - 2.5 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/configs/lateral_connection_alternative_cells_9.yaml b/configs/lateral_connection_alternative_cells_9.yaml new file mode 100644 index 0000000..f744bd6 --- /dev/null +++ b/configs/lateral_connection_alternative_cells_9.yaml @@ -0,0 +1,90 @@ +dataset: + name: straightline + augmentation: None + loader: torch + batch_size: 1 + num_workers: 0 + +run: + n_epochs: 101 + current_epoch: 0 + plots: + enable: True + only_last_epoch: False + store_path: "../tmp/test/" + store_state_path: None + load_state_path: None + + +metrics: + - mse: + type: MeanSquaredError + meter: avg + - mae: + type: MeanAbsoluteError + meter: avg + +optimizers: + l2_opt: + type: Adam + params: + lr: 0.05 + betas: [ 0.9, 0.999 ] + eps: 0.00000001 + weight_decay: 0 + amsgrad: False + scheduler: + type: ReduceLROnPlateau + params: + mode: min + factor: 0.1 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.000001 + eps: 0.00000001 + +logging: + wandb: + active: False, + save_dir: "../wandb" + project: "lateral_connections_toy_example" + log_model: True + group: null # "hebbian_learning" + job_type: null # "train" + console: + active: True + +feature_extractor: + out_channels: 4 + add_bg_channel: False + bin_threshold: 0. # set to 0.5 to obtain better features + optimized_filter_lines: True # set to True to obtain better features + +n_alternative_cells: 10 + +lateral_model: + channels: 4 + max_timesteps: 2 # 1 = only one forward pass without recurrent connection + min_k: 2 + max_k: 3 + l1_type: 'lateral_flex' + l1_params: + locality_size: 5 + lr: 0.2 + hebbian_rule: 'vanilla' + neg_corr: True + act_threshold: 0.7 # 'bernoulli' + square_factor: + - 0.7 + - 0.9 + - 1.1 + - 1.3 + - 1.5 + - 1.7 + support_factor: 1.3 + +l2: + k: 1 + n_hidden: 16 \ No newline at end of file diff --git a/src/data/custom_datasets/straight_line.py b/src/data/custom_datasets/straight_line.py index 2204afe..9a6d584 100644 --- a/src/data/custom_datasets/straight_line.py +++ b/src/data/custom_datasets/straight_line.py @@ -212,11 +212,11 @@ def _create_image( img = Image.fromarray(img.astype(np.uint8)) # add noise - if noise > 0. or (self.split == 'val' or self.split == 'test') and (idx == 4 or idx == 5 or idx == 6 or idx == 7): - noise = noise if noise > 0. else 0.005 - img = np.array(img) - img = img + np.random.choice(2, img.shape, p=[1 - noise, noise]) * 255 - img = Image.fromarray(img.astype(np.uint8)) + # if noise > 0. or (self.split == 'val' or self.split == 'test') and (idx == 4 or idx == 5 or idx == 6 or idx == 7): + # noise = noise if noise > 0. else 0.005 + # img = np.array(img) + # img = img + np.random.choice(2, img.shape, p=[1 - noise, noise]) * 255 + # img = Image.fromarray(img.astype(np.uint8)) if self.transform: img = self.transform(img) @@ -258,11 +258,11 @@ def get_item( images.append(self._create_image(idx, aug_coords, noise=noise, n_black_pixels=n_black_pixels)) images = torch.stack(images, dim=0) if self.num_aug_versions > 0 else images[0] - return images, {'line_coords': line_coords, 'aug_line_coords': aug_line_coords} + return images, {'line_coords': line_coords, 'aug_line_coords': aug_line_coords, 'angle': math.atan((line_coords[1][1]-line_coords[0][1]) / (1e-10+line_coords[1][0]-line_coords[0][0]))} def __getitem__(self, idx: int): - if (self.split == 'val' or self.split == 'test') and (idx == 8 or idx == 9 or idx == 10 or idx == 11): - return self.get_item(idx, n_black_pixels=1) + # if (self.split == 'val' or self.split == 'test') and (idx == 8 or idx == 9 or idx == 10 or idx == 11): + # return self.get_item(idx, n_black_pixels=1) return self.get_item(idx) diff --git a/src/eval_noise.sh b/src/eval_noise.sh new file mode 100644 index 0000000..539e125 --- /dev/null +++ b/src/eval_noise.sh @@ -0,0 +1,13 @@ +#!/bin/bash +echo "Bash version ${BASH_VERSION}..." + +for config in lateral_connection_alternative_cells lateral_connection_alternative_cells_2 lateral_connection_alternative_cells_3 lateral_connection_alternative_cells_4 lateral_connection_alternative_cells_5 lateral_connection_alternative_cells_6 lateral_connection_alternative_cells_7 lateral_connection_alternative_cells_8 lateral_connection_alternative_cells_9 +do + for noise in $(seq 0.0 .01 0.2) + do + for li in {0..7} + do + python main_evaluation.py $config --load alternative_final.ckp --noise $noise --line_interrupt $li --load_baseline_activations_path ../tmp/alternative_final_baseline.pt; + done + done +done \ No newline at end of file diff --git a/src/lateral_connections/s1_lateral_connections.py b/src/lateral_connections/s1_lateral_connections.py index 517b4e7..4a9a730 100644 --- a/src/lateral_connections/s1_lateral_connections.py +++ b/src/lateral_connections/s1_lateral_connections.py @@ -225,7 +225,7 @@ def hebbian_update(self, x: Tensor, y: Tensor): def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]: with torch.no_grad(): x_rearranged = self.rearrange_input(x) - + x_rearranged = torch.where(x_rearranged > 0.5, 1., 0.) assert torch.all( (x_rearranged == 0.) | (x_rearranged == 1.)), "x_rearranged not binary -> Torch Config Error" x_lateral = F.conv2d(x_rearranged, self.W_lateral, padding="same", ) @@ -346,12 +346,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]: 1e-10 + x_lateral_norm_alt_max.reshape(x_lateral_norm_alt_max.shape + (1, 1, 1))) x_lateral_norm = x_lateral_norm.reshape(x_lateral_norm_s) - square_factor = [1, 1.2, 1.4, 1.6, 1.8, 2.0] + # square_factor = [1.2, 1.4, 1.6, 2.2, 2.4, 2.6] # square_factor = [2.1, 2.2, 2.3, 2.4, 2.5, 2.6] if self.act_threshold == "bernoulli": - x_lateral_bin = torch.bernoulli(torch.clip(x_lateral_norm ** square_factor[self.ts], 0, 1)) + x_lateral_bin = torch.bernoulli(torch.clip(x_lateral_norm ** self.square_factor[self.ts], 0, 1)) else: - x_lateral_bin = (x_lateral_norm ** square_factor[self.ts] >= self.act_threshold).float() + x_lateral_bin = (x_lateral_norm ** self.square_factor[self.ts] >= self.act_threshold).float() stats = { diff --git a/src/main_angle_prediction.py b/src/main_angle_prediction.py new file mode 100644 index 0000000..9af6fae --- /dev/null +++ b/src/main_angle_prediction.py @@ -0,0 +1,453 @@ +import argparse +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import lightning.pytorch as pl +import pytz +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +from lightning.fabric import Fabric +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader +from tqdm import tqdm + +from data import loaders_from_config +from lateral_connections.feature_extractor.straight_line_pl_modules import FixedFilterFeatureExtractor +from lateral_connections.s1_lateral_connections import LateralNetwork +from models import BaseLitModule +from tools import loggers_from_conf, torch_optim_from_conf +from tools.callbacks.save_model import SaveBestModelCallback +from tools.store_load_run import load_run +from utils import get_config, print_start, print_warn + + +class AnglePredictionTorch(nn.Module): + + def __init__(self, kernel_size: int): + super().__init__() + self.kernel_size = kernel_size + self.model = self.setup_model() + + def setup_model(self) -> nn.Sequential: + return nn.Sequential(*[ + nn.Conv2d(4, 32, kernel_size=self.kernel_size, stride=1, padding="same"), + nn.ReLU(True), + nn.Flatten(), + nn.Linear(32 * 32 * 32, 1) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.model(x) + return x + + +class AnglePredictor(BaseLitModule): + """ + Extract features from images using an autoencoder. + """ + + def __init__(self, conf: Dict[str, Optional[Any]], fabric: Fabric): + """ + Constructor. + :param conf: Configuration dictionary. + :param fabric: Fabric instance. + """ + super().__init__(conf, fabric, logging_prefixes=["train", "val"]) + self.model = self.configure_model() + + def preprocess_data_(self, batch: Tensor) -> Tuple[Tensor, Tensor]: + """ + Preprocess the data batch. + :param batch: Data batch, containing input data and labels. + :return: Preprocessed data batch. + """ + x, y = batch + return x, y + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Forward pass through the model. + :param x: + :return: Loss, reconstructed image, perplexity, and encodings. + """ + return self.model(x) + + def step(self, batch: Tensor, batch_idx: int, mode_prefix: str) -> Tensor: + """ + Forward step: Forward pass, and logging. + :param batch: Data batch, containing input data and labels. + :param batch_idx: Index of the batch. + :param mode_prefix: Prefix for the mode (train, val, test). + :return: Loss of the training step. + """ + x, y = self.preprocess_data_(batch) + angle_pred = self.forward(x).squeeze(1) + angle_gt = y['angle'].float() + loss = F.mse_loss(angle_pred, angle_gt) + self.log_step( + processed_values={"loss": loss}, + metric_pairs=[(angle_pred, angle_gt)], + prefix=mode_prefix + ) + + return loss + + def training_step(self, batch: Tensor, batch_idx: int) -> Tensor: + """ + Forward training step: Forward pass, and logging. + :param batch: Data batch, containing input data and labels. + :param batch_idx: Index of the batch. + :return: Loss of the training step. + """ + return self.step(batch, batch_idx, "train") + + def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor: + """ + Forward validation step: Forward pass, and logging. + :param batch: Data batch, containing input data and labels. + :param batch_idx: Index of the batch. + :return: Loss of the validation step. + """ + return self.step(batch, batch_idx, "val") + + def configure_model(self): + """ + Configure the model, i.e. create an VQ-VAE instance. + :return: + """ + model_conf = self.conf["model"] + if model_conf["type"] == "angle-predictor": + params = model_conf["params"] + return AnglePredictionTorch( + **params + ) + else: + raise NotImplementedError(f"Model {model_conf['type']} not implemented") + + def configure_optimizers(self) -> Tuple[Optimizer, Optional[ReduceLROnPlateau]]: + """ + Configure (create instance) the optimizer. + :return: A torch optimizer and scheduler. + """ + return torch_optim_from_conf(self.parameters(), 'opt1', self.conf) + + +def parse_args(parser: Optional[argparse.ArgumentParser] = None): + """ + Parse arguments from command line. + :param parser: Optional ArgumentParser instance. + :return: Parsed arguments. + """ + if parser is None: + parser = argparse.ArgumentParser(description="Feature Extractor Stage 1") + parser.add_argument("config", + type=str, + help="Path to the config file", + ) + parser.add_argument("--batch-size", + type=int, + # default=64, + metavar="N", + dest="dataset:batch_size", + help="input batch size for training (default: 64)" + ) + parser.add_argument("--epochs", + type=int, + # default=20, + metavar="N", + dest="run:n_epochs", + help="number of epochs to train (default: 10)" + ) + parser.add_argument("--lr", + type=float, + # default=0.001, + metavar="LR", + dest="optimizers:opt1:params:lr", + help="learning rate (default: 0.001)" + ) + parser.add_argument('--wandb', + action='store_true', + default=False, + dest='logging:wandb:active', + help='Log to wandb' + ) + parser.add_argument('--plot', + action='store_true', + default=False, + dest='run:visualize_plots', + help='Plot results' + ) + parser.add_argument('--store', + type=str, + dest='run:store_state_path', + help='Path where the model will be stored' + ) + parser.add_argument('--load', + type=str, + dest='run:load_state_path', + help='Path from where the model will be loaded' + ) + + args = parser.parse_args() + return args + + +def get_model(config: Dict[str, Optional[Any]], fabric: Fabric) -> BaseLitModule: + """ + Get the model according to configuration. + :param config: Configuration dict + :param fabric: Fabric instance + :return: The model + """ + return AnglePredictor(config, fabric) + + +def setup_fabric(config: Optional[Dict[str, Optional[Any]]] = None) -> Fabric: + """ + Setup the Fabric instance. + :param config: Configuration dict + :return: Fabric instance + """ + if config is None: + callbacks, loggers = [], [] + else: + callbacks = [] + if "store_state_path" in config["run"] and config["run"]["store_state_path"] != 'None': + callbacks.append(SaveBestModelCallback(metric_key="val/loss", mode="min")) + loggers = loggers_from_conf(config) + fabric = Fabric(accelerator="auto", devices=1, loggers=loggers, callbacks=callbacks) + fabric.launch() + fabric.seed_everything(1) + return fabric + + +def setup_components(config: Dict[str, Optional[Any]], fabric: Fabric) -> ( + BaseLitModule, Optimizer, Optional[ReduceLROnPlateau]): + """ + Setup components for training. + :param config: Configuration dict + :param fabric: Fabric instance + :return: Returns the model and the optimizer + """ + model = get_model(config, fabric) + optimizer, scheduler = model.configure_optimizers() + model, optimizer = fabric.setup(model, optimizer) + return model, optimizer, scheduler + + +def setup_dataloader(config: Dict[str, Optional[Any]], fabric: Fabric) -> (DataLoader, DataLoader): + """ + Setup the dataloaders for training and testing. + :param config: Configuration dict + :param fabric: Fabric instance + :return: Returns the training and testing dataloader + """ + train_loader, _, test_loader = loaders_from_config(config) + if isinstance(train_loader, DataLoader): + train_loader = fabric.setup_dataloaders(train_loader) + test_loader = fabric.setup_dataloaders(test_loader) + else: + print_warn("Train and test loader not setup with fabric.", "Fabric Warning:") + + return train_loader, test_loader + + +def configure() -> Dict[str, Optional[Any]]: + """ + Load the config based on the given console args. + :return: + """ + args = parse_args() + config = get_config(args.config, args) + if config['run']['store_state_path'] != 'None' and Path(config['run']['store_state_path']).is_dir(): + f_name = f"s1_{datetime.now(pytz.timezone('Europe/Zurich')).strftime('%Y-%m-%d_%H-%M-%S')}.ckpt" + config['run']['store_state_path'] = config['run']['store_state_path'] + f"/{f_name}" + if not torch.cuda.is_available(): + print_warn("CUDA is not available.", title="Slow training expected.") + return config + + +def setup_feature_extractor(config: Dict[str, Optional[Any]], fabric: Fabric) -> pl.LightningModule: + """ + Setup the feature extractor model that is used to extract features from images before they are fed into the model + leveraging lateral connections. + :param config: Configuration dict + :param fabric: Fabric instance + :return: Feature extractor model. + """ + feature_extractor = FixedFilterFeatureExtractor(config, fabric) + feature_extractor = fabric.setup(feature_extractor) + return feature_extractor + + +def setup_modules(config: Dict[str, Optional[Any]]) -> Tuple[ + Fabric, pl.LightningModule, BaseLitModule, Optimizer, Optional[ReduceLROnPlateau], + DataLoader, DataLoader]: + """ + Setup the modules for training. + :param config: Configuration dict + :return: Returns the fabric, model, optimizer, scheduler, training dataloader and testing dataloader + """ + fabric = setup_fabric(config) + feature_extractor = setup_feature_extractor(config, fabric) + model, optimizer, scheduler = setup_components(config, fabric) + train_dataloader, test_dataloader = setup_dataloader(config, fabric) + # if 'load_state_path' in config['run'] and config['run']['load_state_path'] != 'None': + # config, components = load_run(config, fabric) + # model.load_state_dict(components['model']) + # optimizer.load_state_dict(components['optimizer']) + # scheduler.load_state_dict(components['scheduler']) + return fabric, feature_extractor, model, optimizer, scheduler, train_dataloader, test_dataloader + + +def single_train_epoch( + config: Dict[str, Optional[Any]], + fabric: Fabric, +lateral_network: LateralNetwork, + feature_extractor: pl.LightningModule, + model: BaseLitModule, + optimizer: Optimizer, + train_dataloader: DataLoader, + epoch: int, +): + """ + Train a single epoch. + :param config: Configuration dict + :param fabric: Fabric instance + :param model: Model to train + :param optimizer: Optimizer to use + :param train_dataloader: Training dataloader + :param epoch: Current epoch + :return: Returns the training logs + """ + model.train() + for i, batch in tqdm(enumerate(train_dataloader), + total=len(train_dataloader), + colour="GREEN", + desc=f"Train Epoch {epoch + 1}/{config['run']['n_epochs']}"): + with torch.no_grad(): + batch[0] = feature_extractor.binarize_features(feature_extractor(batch[0]).squeeze(1)) + + if False: + lateral_network.new_sample() + z = torch.zeros((batch[0].shape[0], lateral_network.model.out_channels, batch[0].shape[2], + batch[0].shape[3]), device=batch[0].device) + + for t in range(config["lateral_model"]["max_timesteps"]): + lateral_network.model.update_ts(t) + x_in = torch.cat([batch[0], z], dim=1) + z_float, z = lateral_network(x_in) + + batch[0] = z + + optimizer.zero_grad() + loss = model.training_step(batch, i) + fabric.backward(loss) + optimizer.step() + + +def single_eval_epoch( + config: Dict[str, Optional[Any]], + lateral_network: LateralNetwork, + feature_extractor: pl.LightningModule, + model: BaseLitModule, + test_dataloader: DataLoader, + epoch: int, +): + """ + Evaluate a single epoch. + :param config: Configuration dict + :param model: The model to evaluate + :param test_dataloader: Testing dataloader + :param epoch: Current epoch + :return: Returns the validation logs + """ + model.eval() + with torch.no_grad(): + for i, batch in tqdm(enumerate(test_dataloader), + total=len(test_dataloader), + colour="GREEN", + desc=f"Validate Epoch {epoch + 1}/{config['run']['n_epochs']}"): + batch[0] = feature_extractor.binarize_features(feature_extractor(batch[0]).squeeze(1)) + + if True: + features_s = batch[0].shape + num_elements = batch[0].numel() + num_flips = int(0.03 * num_elements) + random_mask = torch.randperm(num_elements)[:num_flips] + random_mask = torch.zeros(num_elements, dtype=torch.bool).scatter(0, random_mask, 1) + batch[0] = batch[0].view(-1) + batch[0][random_mask] = 1.0 - batch[0][random_mask] + batch[0] = batch[0].view(features_s) + + if False: + lateral_network.new_sample() + z = torch.zeros((batch[0].shape[0], lateral_network.model.out_channels, batch[0].shape[2], + batch[0].shape[3]), device=batch[0].device) + + for t in range(config["lateral_model"]["max_timesteps"]): + lateral_network.model.update_ts(t) + x_in = torch.cat([batch[0], z], dim=1) + z_float, z = lateral_network(x_in) + + batch[0] = z + model.validation_step(batch, i) + + + +def train( + config: Dict[str, Optional[Any]], + fabric: Fabric, + lateral_network: LateralNetwork, + feature_extractor: pl.LightningModule, + model: BaseLitModule, + optimizer: Optimizer, + scheduler: Optional[ReduceLROnPlateau], + train_dataloader: DataLoader, + test_dataloader: DataLoader, +): + """ + Train the feature extractor for multiple epochs. + :param config: Configuration dict + :param fabric: Fabric instance + :param model: Model to train + :param optimizer: Optimizer to use + :param scheduler: LR scheduler to use + :param train_dataloader: Training dataloader + :param test_dataloader: Testing dataloader + :return: + """ + start_epoch = config['run']['current_epoch'] + for epoch in range(start_epoch, config['run']['n_epochs']): + config['run']['current_epoch'] = epoch + single_train_epoch(config, fabric, lateral_network, feature_extractor, model, optimizer, train_dataloader, epoch) + single_eval_epoch(config, lateral_network, feature_extractor, model, test_dataloader, epoch) + logs = model.on_epoch_end() + if scheduler is not None: + scheduler.step(logs["val/loss"]) + fabric.call("on_epoch_end", config=config, logs=logs, fabric=fabric, + components={"model": model, "optimizer": optimizer, "scheduler": scheduler.state_dict()}) + fabric.call("on_train_end") + + +def main(): + """ + Run the model and store the model with the lowest loss. + """ + print_start("Starting python script 'main_autoencoder.py'...", + title="Training S0: Autoencoder Feature Extractor") + config = configure() + fabric, feature_extractor, model, optimizer, scheduler, train_dataloader, test_dataloader = setup_modules(config) + lateral_network = fabric.setup(LateralNetwork(config, fabric)) + config2, state = load_run(config, fabric) + lateral_network.load_state_dict(state['lateral_network']) + train(config, fabric, lateral_network, feature_extractor, model, optimizer, scheduler, train_dataloader, test_dataloader) + + +if __name__ == '__main__': + main() diff --git a/src/main_evaluation.py b/src/main_evaluation.py index 7146c61..df6b8dd 100644 --- a/src/main_evaluation.py +++ b/src/main_evaluation.py @@ -1,4 +1,5 @@ import argparse +import json from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -11,6 +12,7 @@ from lightning import Fabric from torch import Tensor from tqdm import tqdm +import torch.nn.functional as F from data.custom_datasets.straight_line import StraightLine from lateral_connections.s1_lateral_connections import LateralNetwork @@ -32,7 +34,7 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> argparse.Arg parser.add_argument("--n_samples", type=int, metavar="N", - default=300, + default=180, help="Number of samples to evaluate." ) parser.add_argument('--simplified', @@ -40,10 +42,11 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> argparse.Arg default=False, help='Use simple dataset only containing lines with angels of 0°, 45°, -45°, and 90°.' ) - parser.add_argument('--add_noise', - action='store_true', - default=False, - help='Add noise to evaluation samples.' + parser.add_argument("--noise", + type=float, + metavar="N", + default=0.0, + help="Ratio of noise to add." ) parser.add_argument("--line_interrupt", type=int, @@ -57,7 +60,16 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> argparse.Arg default=10, help="Number of samples to evaluate." ) - + parser.add_argument("--store_baseline_activations_path", + type=str, + default=None, + help="Store baseline activations to compare models to." + ) + parser.add_argument("--load_baseline_activations_path", + type=str, + default=None, + help="Load baseline activations to compare models to." + ) return parser @@ -319,6 +331,40 @@ def analyze_noise(noise: Tensor, random_mask: Tensor, lateral_features: List[Ten return removed_noise_ratio.item() +def analyze_recon_error(lateral_features: List[Tensor], baseline_lateral_features: List[Tensor]) -> Tuple[float, float, float]: + """ + Analyzes the reconstruction error of the lateral features. + + :param lateral_features: + :param baseline_lateral_features: + :return: The reconstruction error + """ + lateral_features = lateral_features[-1].view(-1) + baseline_lateral_features = baseline_lateral_features[-1].view(-1) + accuracy = 1. - F.l1_loss(lateral_features, baseline_lateral_features) + recall = 1. - F.l1_loss(lateral_features[baseline_lateral_features > 0.], baseline_lateral_features[baseline_lateral_features > 0.]) + precision = 1. - F.l1_loss(lateral_features[lateral_features > 0.], baseline_lateral_features[lateral_features > 0.]) + return accuracy.item(), recall.item(), precision.item() + + +def analyze_interrupt_line(img: Tensor, baseline_img: Tensor, lateral_features: List[Tensor], baseline_lateral_features: List[Tensor]) -> float: + """ + Analyzes the reconstruction error of the lateral features. + :param img: The input image + :param baseline_img: The baseline input image + :param lateral_features: The lateral features + :param baseline_lateral_features: The baseline lateral features + :return: Reconstruction accuracy + """ + baseline_img = baseline_img[-1].squeeze() + mask = torch.where(img != baseline_img, True, False).unsqueeze(0).repeat(lateral_features[-1].shape[1], 1, 1) + baseline_lateral_features = baseline_lateral_features[-1].squeeze(0)[mask] + lateral_features = lateral_features[-1].squeeze(0)[mask] + accuracy = 1. - F.l1_loss(lateral_features, baseline_lateral_features) + return accuracy.item() + + + def predict_sample( config: Dict[str, Optional[Any]], fabric: Fabric, @@ -326,7 +372,7 @@ def predict_sample( lateral_network: LateralNetwork, batch: Tensor, batch_idx: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, float]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor, float, float]: """ Predicts the features for a given sample :param config: Configuration @@ -344,10 +390,10 @@ def predict_sample( features = feature_extractor(batch.unsqueeze(0)) features = feature_extractor.binarize_features(features).squeeze(1) - if config['add_noise']: + if config['noise'] > 0.: features_s = features.shape num_elements = features.numel() - num_flips = int(0.005 * num_elements) + num_flips = int(config['noise'] * num_elements) random_mask = torch.randperm(num_elements)[:num_flips] random_mask = torch.zeros(num_elements, dtype=torch.bool).scatter(0, random_mask, 1) features = features.view(-1) @@ -371,9 +417,16 @@ def predict_sample( lateral_features = merge_alt_channels(config, lateral_features) lateral_features_float = merge_alt_channels(config, lateral_features_float) - removed_noise = analyze_noise(noise, random_mask, lateral_features) if config['add_noise'] else 0 + removed_noise = analyze_noise(noise, random_mask, lateral_features) if config['noise'] > 0. else 0 + if 'load_baseline_activations_path' in config and config['load_baseline_activations_path'] is not None: + t = torch.load(config['load_baseline_activations_path']) + recon_error = analyze_recon_error(lateral_features, t[0][batch_idx]) + interrupt_line_recon = analyze_interrupt_line(batch[0], t[1][batch_idx], lateral_features, t[0][batch_idx]) + else: + recon_error = (-1, -1, -1) + interrupt_line_recon = -1 return (torch.stack(input), torch.stack(input_features), torch.stack(lateral_features), - torch.stack(lateral_features_float), removed_noise) + torch.stack(lateral_features_float), removed_noise, interrupt_line_recon, recon_error) def process_data( @@ -392,23 +445,52 @@ def process_data( :param feature_extractor: Feature extractor :param lateral_network: Lateral network (L1) """ + imgs_, l1_acts = [], [] ci = CustomImage() avg_noise_meter = AverageMeter() - fp = f"../tmp/v2/{config['run']['load_state_path'].split('.')[0]}_{'noise' if config['add_noise'] else 'no-noise'}_li-{config['line_interrupt']}.mp4" + avg_line_recon_accuracy_meter = AverageMeter() + avg_recon_accuracy_meter, avg_recon_recall_meter, avg_recon_precision_meter = AverageMeter(), AverageMeter(), AverageMeter() + fp = f"../tmp/v2/{config['run']['load_state_path'].split('.')[0]}_{'noise:'+str(config['noise']) if config['noise'] > 0 else 'no-noise'}_li-{config['line_interrupt']}.mp4" if Path(fp).exists(): Path(fp).unlink() out = cv2.VideoWriter(fp, cv2.VideoWriter_fourcc(*'mp4v'), config['fps'], (ci.width, ci.height)) for i, img in tqdm(enumerate(generator), total=config["n_samples"]): - inp, inp_features, l1_act, l1_act_prob, removed_noise = predict_sample(config, fabric, feature_extractor, lateral_network, img, i) + inp, inp_features, l1_act, l1_act_prob, removed_noise, interrupt_line_recon, recon_error = predict_sample(config, fabric, feature_extractor, lateral_network, img, i) l1_act_prob = torch.where((l1_act > 0.) | (inp_features > 0.), l1_act_prob, torch.zeros_like(l1_act_prob)) + l1_acts.append(l1_act) + imgs_.append(inp) avg_noise_meter(removed_noise) + avg_line_recon_accuracy_meter(interrupt_line_recon) + avg_recon_accuracy_meter(recon_error[0]) + avg_recon_recall_meter(recon_error[1]) + avg_recon_precision_meter(recon_error[2]) + for timestep in range(l1_act.shape[0]): result = ci.create_image(inp[timestep], inp_features[timestep, 0], l1_act[timestep, 0], l1_act_prob[timestep, 0]) out.write(cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) out.release() + if 'store_baseline_activations_path' in config and config['store_baseline_activations_path'] is not None: + torch.save([torch.stack(l1_acts), torch.stack(imgs_)], config['store_baseline_activations_path']) print("Video stored at", fp) print(f"Average Noise Reduction: {avg_noise_meter.mean}") + print(f"Average Interrupt Line Reconstruction Accuracy: {avg_line_recon_accuracy_meter.mean}") + print(f"Average Reconstruction Accuracy: {avg_recon_accuracy_meter.mean}") + print(f"Average Reconstruction Recall: {avg_recon_recall_meter.mean}") + print(f"Average Reconstruction Precision: {avg_recon_precision_meter.mean}") + return avg_noise_meter.mean, avg_line_recon_accuracy_meter.mean, avg_recon_accuracy_meter.mean, avg_recon_recall_meter.mean, avg_recon_precision_meter.mean + + +def store_noise_results(noise_reduction: float, avg_line_recon_accuracy_meter:float, recon_accuracy: float, recon_recall: float, recon_precision: float, config: Dict[str, Any]): + """ + Stores the noise reduction results in a csv file + :param noise_reduction: The noise reduction + :param recon_error: The reconstruction error + :param config: Configuration + """ + fp = f"../tmp/noise_reduction.json" + with open(fp, "a") as f: + json.dump({'config': config, 'noise_reduction': noise_reduction, 'avg_line_recon_accuracy_meter':avg_line_recon_accuracy_meter, 'recon_accuracy': recon_accuracy, 'recon_recall': recon_recall, 'recon_precision': recon_precision}, f) def main(): @@ -418,10 +500,9 @@ def main(): print_start("Starting python script 'main_evaluation.py'...", title="Evaluating Model and Print activations") config, fabric, feature_extractor, lateral_network = load_models() - args = parse_args() - config = config | vars(args) generator = get_data_generator(config) - process_data(generator, config, fabric, feature_extractor, lateral_network) + noise_reduction, avg_line_recon_accuracy_meter, recon_accuracy, recon_recall, recon_precision = process_data(generator, config, fabric, feature_extractor, lateral_network) + store_noise_results(noise_reduction, avg_line_recon_accuracy_meter, recon_accuracy, recon_recall, recon_precision, config) if __name__ == "__main__": main() diff --git a/src/tools/store_load_run.py b/src/tools/store_load_run.py index ac8a32f..85913e0 100644 --- a/src/tools/store_load_run.py +++ b/src/tools/store_load_run.py @@ -77,5 +77,5 @@ def load_run( state = fabric.load(config['run']['load_state_path']) print_info(f"Loaded run from {config['run']['load_state_path']}", "Run State Loaded") config_old = state.pop("config") - config = merge_configs(config, config_old) + # config = merge_configs(config, config_old) return config, state