Skip to content

Commit

Permalink
Implement 8bit numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Sep 14, 2023
1 parent d3a3287 commit a98b748
Show file tree
Hide file tree
Showing 8 changed files with 753 additions and 6 deletions.
16 changes: 16 additions & 0 deletions configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ straightline:
num_aug_versions: 0
noise: 0.0

eight_bit_numbers:
num_channels: 1
mean: [ 0. ]
std: [ 1. ]
img_width: 32
img_height: 32
train_dataset_params:
samples_per_class: 50
include_noise: False
valid_dataset_params:
samples_per_class: 1
include_noise: False
test_dataset_params:
samples_per_class: 1
include_noise: False

augmentation_v1:
aug1:
- probability: 0.8
Expand Down
76 changes: 76 additions & 0 deletions configs/s1_toy_sample2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
dataset:
name: eight_bit_numbers
augmentation: None
loader: torch
batch_size: 1
num_workers: 0

run:
n_epochs: 15
current_epoch: 0
plots:
enable: True
only_last_epoch: False
store_path: "../tmp/test/"
store_state_path: None
load_state_path: None


metrics:
- mse:
type: MeanSquaredError
meter: avg
- mae:
type: MeanAbsoluteError
meter: avg

optimizers:
l2_opt:
type: Adam
params:
lr: 0.05
betas: [ 0.9, 0.999 ]
eps: 0.00000001
weight_decay: 0
amsgrad: False
scheduler:
type: ReduceLROnPlateau
params:
mode: min
factor: 0.1
patience: 5
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0.000001
eps: 0.00000001

logging:
wandb:
active: False,
save_dir: "../wandb"
project: "lateral_connections_toy_example"
log_model: True
group: null # "hebbian_learning"
job_type: null # "train"
console:
active: True

feature_extractor:
out_channels: 4
add_bg_channel: False
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

lateral_model:
channels: 4
max_timesteps: 6 # 1 = only one forward pass without recurrent connection
min_k: 2
max_k: 3
l1_type: 'lateral_flex'
l1_params:
locality_size: 5
lr: 0.2
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.4 # 'bernoulli'
8 changes: 4 additions & 4 deletions src/changing_line_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def create_image(self, img: Tensor, in_features: Tensor, l1_act: Tensor, l2_act:
l2_act = Image.fromarray(self.to_mask((l2_act * 255).squeeze().cpu().numpy()))

# Resize Images
img = img.resize((self.img_size, self.img_size), Image.Resampling.LANCZOS)
in_features = in_features.resize((self.img_size, self.img_size), Image.Resampling.LANCZOS)
l1_act = l1_act.resize((self.img_size, self.img_size), Image.Resampling.LANCZOS)
l2_act = l2_act.resize((self.img_size, self.img_size), Image.Resampling.LANCZOS)
img = img.resize((self.img_size, self.img_size), Image.Resampling.NEAREST)
in_features = in_features.resize((self.img_size, self.img_size), Image.Resampling.NEAREST)
l1_act = l1_act.resize((self.img_size, self.img_size), Image.Resampling.NEAREST)
l2_act = l2_act.resize((self.img_size, self.img_size), Image.Resampling.NEAREST)

output = self.img_template.copy()

Expand Down
Loading

0 comments on commit a98b748

Please sign in to comment.