Skip to content

Commit

Permalink
Cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Sep 30, 2023
1 parent 17fd577 commit b924cf0
Show file tree
Hide file tree
Showing 18 changed files with 106 additions and 102 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,29 @@ pip install ffcv
pip install -r requirements.txt
```

## Run experiments



```bash
python main_lateral_connections.py <config> --wandb --plot --store <store_path>
python main_visualization.py <config> --load <store_path>
```

For config, use one of the following:
- `lateral_connection_baseline.yaml` (4 straight lines without alternative cells)
- `lateral_connection_alternative_cells.yaml` (straight lines)
- `lateral_connection_alternative_cells_8bit.yaml` (straight line digits)
- `lateral_connection_alternative_cells_mnist.yaml` (mnist digits)


## Create plots published in thesis

```bash
python print_thesis.py
```


## Create documentation

Locally:
Expand Down
52 changes: 26 additions & 26 deletions configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,32 @@ straightline-simple:
num_aug_versions: 0
noise: 0.0

straightline:
num_channels: 1
mean: [ 0. ]
std: [ 1. ]
img_width: 32
img_height: 32
train_dataset_params:
split: "train"
vertical_horizontal_only: False
aug_range: 0
num_images: 300
num_aug_versions: 0
valid_dataset_params:
split: "val"
vertical_horizontal_only: False
aug_range: 0
num_images: 4
num_aug_versions: 0
noise: 0.0
test_dataset_params:
split: "test"
vertical_horizontal_only: False
aug_range: 0
num_images: 4
num_aug_versions: 0
noise: 0.0
straightline:
num_channels: 1
mean: [ 0. ]
std: [ 1. ]
img_width: 32
img_height: 32
train_dataset_params:
split: "train"
vertical_horizontal_only: False
aug_range: 0
num_images: 300
num_aug_versions: 0
valid_dataset_params:
split: "val"
vertical_horizontal_only: False
aug_range: 0
num_images: 4
num_aug_versions: 0
noise: 0.0
test_dataset_params:
split: "test"
vertical_horizontal_only: False
aug_range: 0
num_images: 4
num_aug_versions: 0
noise: 0.0

eight_bit_numbers:
num_channels: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ feature_extractor:
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

alternative_cells: 10
n_alternative_cells: 10

lateral_model:
channels: 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ feature_extractor:
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

alternative_cells: 10
n_alternative_cells: 10

lateral_model:
channels: 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ feature_extractor:
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

alternative_cells: 10
n_alternative_cells: 10

lateral_model:
channels: 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ dataset:
name: straightline-simple
augmentation: None
loader: torch
batch_size: 128
batch_size: 32
num_workers: 0

run:
n_epochs: 15
n_epochs: 30
current_epoch: 0
plots:
enable: True
Expand Down Expand Up @@ -62,7 +62,7 @@ feature_extractor:
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

alternative_cells: 1
n_alternative_cells: 1

lateral_model:
channels: 4
Expand All @@ -74,11 +74,11 @@ lateral_model:
locality_size: 5
lr: 0.2
hebbian_rule: 'vanilla'
neg_corr: True
act_threshold: 0.45 # 'bernoulli'
square_factor: 1.2
neg_corr: False
act_threshold: 0.5 # 'bernoulli'
square_factor: 3
support_factor: 1.3

l2:
k: 5
n_hidden: 16
l2:
k: 5
n_hidden: 16
File renamed without changes.
2 changes: 2 additions & 0 deletions src/lateral_connections/feature_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from lateral_connections.feature_extractor.vq_vae_pl_modules import VQVAEFeatureExtractorPatchMode, \
VQVAEFeatureExtractorImageMode
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def __init__(self, conf: Dict[str, Optional[Any]], fabric: Fabric):
self.avg_value_meter = {}

lm_conf = self.conf["lateral_model"]
self.out_channels = self.conf["lateral_model"]["channels"] * conf['alternative_cells']
self.out_channels = self.conf["lateral_model"]["channels"] * conf['n_alternative_cells']
self.in_channels = self.conf["feature_extractor"]["out_channels"] + self.out_channels
if self.conf["feature_extractor"]["add_bg_channel"]:
self.in_channels += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def step(self, x: Tensor, log_prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tens
# loss = self.model.free_energy_hidden(pt) - self.model.free_energy_hidden(pt2)
v, v_gibb, h = self.forward(x)
loss = self.model.free_energy(v) - self.model.free_energy(v_gibb)
v = v.reshape(-1, self.conf['lateral_model']['channels'] * self.conf['alternative_cells'], self.conf['data']['img_width'], self.conf['data']['img_height'])
v_gibb = v_gibb.reshape(-1, self.conf['lateral_model']['channels'] * self.conf['alternative_cells'], self.conf['data']['img_width'], self.conf['data']['img_height'])
v = v.reshape(-1, self.conf['lateral_model']['channels'] * self.conf['n_alternative_cells'], self.conf['dataset']['img_width'], self.conf['dataset']['img_height'])
v_gibb = v_gibb.reshape(-1, self.conf['lateral_model']['channels'] * self.conf['n_alternative_cells'], self.conf['dataset']['img_width'], self.conf['dataset']['img_height'])
self.log_step(processed_values={"loss": loss}, metric_pairs=[(v, v_gibb)], prefix=log_prefix)
return v, v_gibb, h, loss

Expand All @@ -109,7 +109,7 @@ def configure_model(self, conf: Dict[str, Optional[Any]]) -> nn.Module:
:param conf: Configuration dictionary.
:return: A torch model.
"""
n_visible = conf['lateral_model']['channels'] * conf['data']['img_width'] * conf['data']['img_height'] * conf['alternative_cells']
n_visible = conf['lateral_model']['channels'] * conf['dataset']['img_width'] * conf['dataset']['img_height'] * conf['n_alternative_cells']
return RBM(n_visible=n_visible, n_hidden=conf['l2']['n_hidden'], k=conf['l2']['k'])

def configure_optimizers(self) -> Tuple[Optimizer, Optional[ReduceLROnPlateau]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from data import loaders_from_config
from models import BaseLitModule
from stage_1.feature_extractor import VQVAEFeatureExtractorImageMode, VQVAEFeatureExtractorPatchMode
from lateral_connections.feature_extractor import VQVAEFeatureExtractorImageMode, VQVAEFeatureExtractorPatchMode
from tools import loggers_from_conf
from utils import get_config, print_start, print_warn
from tools.callbacks.save_model import SaveBestModelCallback
Expand Down
26 changes: 4 additions & 22 deletions src/s1_toy_example.py → src/main_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from tqdm import tqdm

from data import loaders_from_config
from stage_1.feature_extractor.straight_line_pl_modules import FixedFilterFeatureExtractor
from stage_1.lateral.l2_rbm import L2RBM
from stage_1.lateral.lateral_connections_toy import LateralNetwork
from lateral_connections.feature_extractor.straight_line_pl_modules import FixedFilterFeatureExtractor
from lateral_connections.s2_rbm import L2RBM
from lateral_connections.s1_lateral_connections import LateralNetwork
from tools import loggers_from_conf
from tools.store_load_run import load_run, save_run
from utils import get_config, print_start, print_warn
Expand Down Expand Up @@ -73,24 +73,6 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None):
dest='run:plots:store_path',
help='Store the plotted results in the given path'
)
parser.add_argument("--train_noise",
type=float,
# default=0.,
dest="dataset:train_dataset_params:noise",
help="The noise added to the training data (default: 0.)"
)
parser.add_argument("--valid_noise",
type=float,
# default=0.005,
dest="dataset:valid_dataset_params:noise",
help="The noise added to the validation data (default: 0.005)"
)
parser.add_argument("--test_noise",
type=float,
# default=0.005,
dest="dataset:test_dataset_params:noise",
help="The noise added to the test data (default: 0.005)"
)
parser.add_argument('--store',
type=str,
dest='run:store_state_path',
Expand Down Expand Up @@ -444,7 +426,7 @@ def main():
"""
Run the model: Create modules, extract features from images and run the model leveraging lateral connections.
"""
print_start("Starting python script 's1_toy_example.py'...",
print_start("Starting python script 'main_lateral_connections.py'...",
title="Training S1: Lateral Connections Toy Example")
config = configure()
fabric = setup_fabric(config)
Expand Down
48 changes: 23 additions & 25 deletions src/changing_line_demo.py → src/main_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

from data.custom_datasets.eight_bit_numbers import EightBitDataset
from data.custom_datasets.straight_line import StraightLine
from s1_toy_example import configure, cycle, setup_fabric, setup_feature_extractor, setup_l2, setup_lateral_network
from stage_1.lateral.l2_rbm import L2RBM
from stage_1.lateral.lateral_connections_toy import LateralNetwork
from main_lateral_connections import configure, cycle, setup_fabric, setup_feature_extractor, setup_l2, setup_lateral_network
from lateral_connections.s2_rbm import L2RBM
from lateral_connections.s1_lateral_connections import LateralNetwork
from tools.store_load_run import load_run
from utils import print_start

count = 0
dataset = "straight_line" # eight-bit, mnist

config = {
"n_cycles": 200,
visu_config = {
"n_cycles": 88,
"cycle_length": 1,
"noise":
{
Expand All @@ -48,9 +47,6 @@
}
}

if dataset == "mnist":
config["n_cycles"] = 88

class CustomImage:
"""
Custom Image class to draw the current state of the network.
Expand Down Expand Up @@ -277,14 +273,14 @@ def get_dataset(config: Dict[str, Any], strategy: Dict[str, Any]) -> StraightLin
:param strategy: The strategy
:return: SrtaightLine dataset
"""
if dataset == "straight_line":
if config['dataset']['name'].startswith("straightline"):
return StraightLine(split="test",
num_images=len(strategy["line"]),
num_aug_versions=0,
)
elif dataset == "eight-bit":
elif config['dataset']['name'].startswith("eight_bit"):
return EightBitDataset(samples_per_class=config['n_cycles'])
elif dataset == "mnist":
elif config['dataset']['name'].startswith("mnist"):
transform = transforms.Compose([transforms.ToTensor(), transforms.Pad(2)])
return torch.utils.data.Subset(
MNIST(root='../data/mnist', transform=transform),
Expand All @@ -293,23 +289,25 @@ def get_dataset(config: Dict[str, Any], strategy: Dict[str, Any]) -> StraightLin
123, 124, 125, 127, 131, 132, 133, 136, 139, 140, 141, 142, 143, 144, 145, 148, 150, 151, 152, 153, 154,
155, 156, 158, 160, 162, 170, 174, 176, 178, 183, 186, 188, 191, 195, 196])
else:
raise ValueError(f"Invalid dataset {dataset}.")
raise ValueError(f"Invalid dataset.")

def get_data_gen(strategy: Dict[str, Any], dataset: StraightLine):
def get_data_gen(strategy: Dict[str, Any], dataset: StraightLine, config: Dict[str, Any]) -> Iterator[Tuple[Tensor, List[Dict[str, Any]]]]:
"""
Data generator for the given strategy and dataset
:param strategy: The strategy
:param dataset: The dataset
:return: Image generator
"""
for i in range(config["n_cycles"] + 1):
for i in range(visu_config["n_cycles"] + 1):
images, metas = [], []
for _ in range(config["cycle_length"]):
if dataset == "straight_line":
for _ in range(visu_config["cycle_length"]):
if config['dataset']['name'].startswith("straightline"):
img, meta = dataset.get_item(i, line_coords=strategy["line"][i], noise=strategy["noise"][i],
n_black_pixels=strategy["black"][i])
elif dataset == "eight-bit" or dataset == "mnist":
elif config['dataset']['name'].startswith("eight_bit") or config['dataset']['name'].startswith("mnist"):
img, meta = dataset[i]
else:
raise ValueError(f"Invalid dataset.")
images.append(img.unsqueeze(0))
metas.append(meta)
yield torch.vstack(images).unsqueeze(0), metas
Expand Down Expand Up @@ -341,14 +339,14 @@ def load_models() -> Tuple[Dict[str, Any], Fabric, pl.LightningModule, pl.Lightn
return config, fabric, feature_extractor, lateral_network, l2


def load_data_generator() -> Iterator[Tuple[Tensor, List[Dict[str, Any]]]]:
def load_data_generator(config: Dict[str, Optional[Any]]) -> Iterator[Tuple[Tensor, List[Dict[str, Any]]]]:
"""
Loads the data generator
:return: Data generator
"""
strategy = get_strategy(config)
strategy = get_strategy(visu_config)
dataset = get_dataset(config, strategy)
generator = get_data_gen(strategy, dataset)
generator = get_data_gen(strategy, dataset, config)
return generator


Expand Down Expand Up @@ -408,11 +406,11 @@ def process_data(
:param l2: L2 network
:param video_fp: Video file path
"""
fps = 1. if dataset == "mnist" else 10.
frames_last = 0 if dataset == "mnist" else int(fps // 2) # show the median activation for a longer time...
fps = 1. if config['dataset']['name'].startswith("mnist") else 10.
frames_last = 0 if config['dataset']['name'].startswith("mnist") else int(fps // 2) # show the median activation for a longer time...
ci = CustomImage()
out = cv2.VideoWriter(video_fp, cv2.VideoWriter_fourcc(*'mp4v'), fps, (ci.width, ci.height))
for i, img in tqdm(enumerate(generator), total=config["n_cycles"] + 1):
for i, img in tqdm(enumerate(generator), total=visu_config["n_cycles"] + 1):
inp_features, l1_act, l2_act, l2_h_act = predict_sample(config, fabric, feature_extractor, lateral_network, l2,
img, i)
for view in range(img[0].shape[1]):
Expand All @@ -432,7 +430,7 @@ def main():
print_start("Starting python script 'changing_line_demo.py'...",
title="Demo Lines: Creating a Video of a Changing Line")
config, fabric, feature_extractor, lateral_network, l2 = load_models()
generator = load_data_generator()
generator = load_data_generator(config)
process_data(generator, config, fabric, feature_extractor, lateral_network, l2,
video_fp=f"../tmp/demo/{Path(config['run']['load_state_path']).name}.mp4")

Expand Down
Loading

0 comments on commit b924cf0

Please sign in to comment.