Skip to content

Commit

Permalink
Merge MNIST manually
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Sep 29, 2023
1 parent f30657f commit 87d4840
Show file tree
Hide file tree
Showing 7 changed files with 527 additions and 11 deletions.
18 changes: 17 additions & 1 deletion configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ mnist:
train: False
download: True

mnist-subset:
dir: mnist/
beton_dir: beton/mnist/
mean: [ 0.1307 ]
std: [ 0.3081 ]
num_classes: 10
num_channels: 1
img_width: 28
img_height: 28
train_dataset_params:
train: True
download: True
test_dataset_params:
train: False
download: True

imagenet:
dir: imagenet/
beton_dir: beton/imagenet/
Expand Down Expand Up @@ -76,7 +92,7 @@ straightline:
split: "train"
vertical_horizontal_only: True
aug_range: 0
num_images: 600
num_images: 300
num_aug_versions: 0
valid_dataset_params:
split: "val"
Expand Down
3 changes: 2 additions & 1 deletion configs/s1_toy_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ lateral_model:
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.5 # 'bernoulli'
square_factor: 1.2
square_factor: 1.2
support_factor: 1.3
3 changes: 2 additions & 1 deletion configs/s1_toy_sample2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ lateral_model:
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.4 # 'bernoulli'
square_factor: 1.0
square_factor: 1.0
support_factor: 1.3
78 changes: 78 additions & 0 deletions configs/s1_toy_sample3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
dataset:
name: mnist-subset
augmentation: None
loader: torch
batch_size: 1
num_workers: 0

run:
n_epochs: 15
current_epoch: 0
plots:
enable: True
only_last_epoch: False
store_path: "../tmp/test/"
store_state_path: None
load_state_path: None


metrics:
- mse:
type: MeanSquaredError
meter: avg
- mae:
type: MeanAbsoluteError
meter: avg

optimizers:
l2_opt:
type: Adam
params:
lr: 0.05
betas: [ 0.9, 0.999 ]
eps: 0.00000001
weight_decay: 0
amsgrad: False
scheduler:
type: ReduceLROnPlateau
params:
mode: min
factor: 0.1
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0.000001
eps: 0.00000001

logging:
wandb:
active: False,
save_dir: "../wandb"
project: "lateral_connections_toy_example"
log_model: True
group: null # "hebbian_learning"
job_type: null # "train"
console:
active: True

feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

lateral_model:
channels: 4
max_timesteps: 1 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 3
lr: 20.0
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.15 # 'bernoulli'
square_factor: 1.5
support_factor: 5
Loading

0 comments on commit 87d4840

Please sign in to comment.