Skip to content

Commit

Permalink
More detailed evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 22, 2023
1 parent 61c2cf9 commit 68fa7b8
Show file tree
Hide file tree
Showing 17 changed files with 1,390 additions and 33 deletions.
84 changes: 84 additions & 0 deletions configs/angle_prediction.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
dataset:
name: straightline
augmentation: None
loader: torch
batch_size: 1
num_workers: 0

run:
n_epochs: 100
current_epoch: 0
store_state_path: ../checkpoints/s0/ae
# load_state_path: ../checkpoints/s1/vq_vae_full_image/s1_2023-05-04_15-48-50.ckpt
visualize_plots: True

optimizers:
opt1:
type: Adam
params:
lr: 0.001
betas: [ 0.9, 0.999 ]
eps: 0.00000001
weight_decay: 0
amsgrad: False
scheduler:
type: ReduceLROnPlateau
params:
mode: min
factor: 0.1
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0.000001
eps: 0.00000001

metrics:
- mae:
type: MeanAbsoluteError
meter: avg
- mape:
type: MeanAbsolutePercentageError
meter: avg
- mse:
type: MeanSquaredError
meter: avg

logging:
wandb:
active: False,
save_dir: "../wandb"
project: "s0"
log_model: True
group: "autoencoder"
job_type: "train"
console:
active: True

feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

model:
type: angle-predictor
params:
kernel_size: 5

n_alternative_cells: 1

lateral_model:
channels: 4
max_timesteps: 2 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 5
lr: 0.2
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.5 # 'bernoulli'
square_factor: 1.2
support_factor: 1.3
2 changes: 1 addition & 1 deletion configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ straightline:
split: "test"
vertical_horizontal_only: False
aug_range: 0
num_images: 4
num_images: 10
num_aug_versions: 0
noise: 0.0

Expand Down
12 changes: 9 additions & 3 deletions configs/lateral_connection_alternative_cells.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ feature_extractor:
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

n_alternative_cells: 20
n_alternative_cells: 10

lateral_model:
channels: 4
max_timesteps: 6 # 1 = only one forward pass without recurrent connection
max_timesteps: 2 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
Expand All @@ -76,7 +76,13 @@ lateral_model:
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.5 # 'bernoulli'
square_factor: 1.2
square_factor:
- 1.2
- 1.4
- 1.6
- 1.8
- 2.0
- 2.2
support_factor: 1.3

l2:
Expand Down
90 changes: 90 additions & 0 deletions configs/lateral_connection_alternative_cells_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
dataset:
name: straightline
augmentation: None
loader: torch
batch_size: 1
num_workers: 0

run:
n_epochs: 101
current_epoch: 0
plots:
enable: True
only_last_epoch: False
store_path: "../tmp/test/"
store_state_path: None
load_state_path: None


metrics:
- mse:
type: MeanSquaredError
meter: avg
- mae:
type: MeanAbsoluteError
meter: avg

optimizers:
l2_opt:
type: Adam
params:
lr: 0.05
betas: [ 0.9, 0.999 ]
eps: 0.00000001
weight_decay: 0
amsgrad: False
scheduler:
type: ReduceLROnPlateau
params:
mode: min
factor: 0.1
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0.000001
eps: 0.00000001

logging:
wandb:
active: False,
save_dir: "../wandb"
project: "lateral_connections_toy_example"
log_model: True
group: null # "hebbian_learning"
job_type: null # "train"
console:
active: True

feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

n_alternative_cells: 10

lateral_model:
channels: 4
max_timesteps: 2 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 5
lr: 0.2
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.5 # 'bernoulli'
square_factor:
- 2.0
- 2.1
- 2.2
- 2.3
- 2.4
- 2.5
support_factor: 1.3

l2:
k: 1
n_hidden: 16
90 changes: 90 additions & 0 deletions configs/lateral_connection_alternative_cells_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
dataset:
name: straightline
augmentation: None
loader: torch
batch_size: 1
num_workers: 0

run:
n_epochs: 101
current_epoch: 0
plots:
enable: True
only_last_epoch: False
store_path: "../tmp/test/"
store_state_path: None
load_state_path: None


metrics:
- mse:
type: MeanSquaredError
meter: avg
- mae:
type: MeanAbsoluteError
meter: avg

optimizers:
l2_opt:
type: Adam
params:
lr: 0.05
betas: [ 0.9, 0.999 ]
eps: 0.00000001
weight_decay: 0
amsgrad: False
scheduler:
type: ReduceLROnPlateau
params:
mode: min
factor: 0.1
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0.000001
eps: 0.00000001

logging:
wandb:
active: False,
save_dir: "../wandb"
project: "lateral_connections_toy_example"
log_model: True
group: null # "hebbian_learning"
job_type: null # "train"
console:
active: True

feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

n_alternative_cells: 10

lateral_model:
channels: 4
max_timesteps: 2 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 5
lr: 0.2
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.5 # 'bernoulli'
square_factor:
- 0.7
- 0.9
- 1.1
- 1.3
- 1.5
- 1.7
support_factor: 1.3

l2:
k: 1
n_hidden: 16
90 changes: 90 additions & 0 deletions configs/lateral_connection_alternative_cells_4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
dataset:
name: straightline
augmentation: None
loader: torch
batch_size: 1
num_workers: 0

run:
n_epochs: 101
current_epoch: 0
plots:
enable: True
only_last_epoch: False
store_path: "../tmp/test/"
store_state_path: None
load_state_path: None


metrics:
- mse:
type: MeanSquaredError
meter: avg
- mae:
type: MeanAbsoluteError
meter: avg

optimizers:
l2_opt:
type: Adam
params:
lr: 0.05
betas: [ 0.9, 0.999 ]
eps: 0.00000001
weight_decay: 0
amsgrad: False
scheduler:
type: ReduceLROnPlateau
params:
mode: min
factor: 0.1
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0.000001
eps: 0.00000001

logging:
wandb:
active: False,
save_dir: "../wandb"
project: "lateral_connections_toy_example"
log_model: True
group: null # "hebbian_learning"
job_type: null # "train"
console:
active: True

feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

n_alternative_cells: 10

lateral_model:
channels: 4
max_timesteps: 2 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 5
lr: 0.2
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.3 # 'bernoulli'
square_factor:
- 1.2
- 1.4
- 1.6
- 1.8
- 2.0
- 2.2
support_factor: 1.3

l2:
k: 1
n_hidden: 16
Loading

0 comments on commit 68fa7b8

Please sign in to comment.