Skip to content

Commit

Permalink
Remove L2
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 9, 2023
1 parent bdd18e7 commit eae40e7
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions src/main_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,25 @@ def cycle(
x_in = torch.cat([x_view_features, z], dim=1)
z_float, z = lateral_network(x_in)

z2, z2_feedback, h, loss = l2.eval_step(z)

if epoch > 10:
mask_active = (z > 0) | (z2_feedback > 0)
if F.mse_loss(z[mask_active], z2_feedback[mask_active]) < .1:
z = z2_feedback
# z2, z2_feedback, h, loss = l2.eval_step(z)
#
# if epoch > 10:
# mask_active = (z > 0) | (z2_feedback > 0)
# if F.mse_loss(z[mask_active], z2_feedback[mask_active]) < .1:
# z = z2_feedback

features_lat.append(z)
if store_tensors:
features_lat_float.append(z_float)
features_l2.append(z2_feedback)
features_l2_h.append(h)
# features_l2.append(z2_feedback)
# features_l2_h.append(h)

features_lat = torch.stack(features_lat, dim=1)
features_lat_median = torch.median(features_lat, dim=1)[0]
if store_tensors:
features_lat_float = torch.stack(features_lat_float, dim=1)
features_l2 = torch.stack(features_l2, dim=1)
features_l2_h = torch.stack(features_l2_h, dim=1)
# features_l2 = torch.stack(features_l2, dim=1)
# features_l2_h = torch.stack(features_l2_h, dim=1)

if mode == "train": # TODO: Train at the end after all timesteps (use median activation per cell),
# also update L1 after training
Expand All @@ -232,27 +232,27 @@ def cycle(
lateral_network.model.l1.hebbian_update(x_rearranged, features_lat_median)

# Train L2
l2_opt.zero_grad()
z2, z2_feedback, h, loss = l2.train_step(features_lat_median)
fabric.backward(loss)
l2_opt.step()
# l2_opt.zero_grad()
# z2, z2_feedback, h, loss = l2.train_step(features_lat_median)
# fabric.backward(loss)
# l2_opt.step()

if store_tensors:
features_lat_float_median = torch.median(features_lat_float, dim=1)[0]
features_l2_median = torch.median(features_l2, dim=1)[0]
l2h_features_median = torch.median(features_l2_h, dim=1)[0]
# features_l2_median = torch.median(features_l2, dim=1)[0]
# l2h_features_median = torch.median(features_l2_h, dim=1)[0]
features_lat = torch.cat([features_lat, features_lat_median.unsqueeze(1)], dim=1)
features_lat_float = torch.cat([features_lat_float, features_lat_float_median.unsqueeze(1)], dim=1)
features_l2 = torch.cat([features_l2, features_l2_median.unsqueeze(1)], dim=1)
features_l2_h = torch.cat([features_l2_h, l2h_features_median.unsqueeze(1)], dim=1)
# features_l2 = torch.cat([features_l2, features_l2_median.unsqueeze(1)], dim=1)
# features_l2_h = torch.cat([features_l2_h, l2h_features_median.unsqueeze(1)], dim=1)
lateral_features.append(features_lat)
lateral_features_f.append(features_lat_float)
l2_features.append(features_l2)
l2h_features.append(features_l2_h)
# l2_features.append(features_l2)
# l2h_features.append(features_l2_h)

if store_tensors:
return features, torch.stack(input_features, dim=1), torch.stack(lateral_features, dim=1), torch.stack(
lateral_features_f, dim=1), torch.stack(l2_features, dim=1), torch.stack(l2h_features, dim=1)
lateral_features_f, dim=1),None, None


def single_train_epoch(
Expand Down Expand Up @@ -350,16 +350,16 @@ def single_eval_epoch(
plot_input_features=epoch == 0,
show_plot=plot)
weights_fp = lateral_network.plot_model_weights(show_plot=plot)
plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
if epoch == config['run']['n_epochs']:
videos_fp = lateral_network.create_activations_video(plt_img, plt_input_features, plt_activations)
# plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
# if epoch == config['run']['n_epochs']:
# videos_fp = lateral_network.create_activations_video(plt_img, plt_input_features, plt_activations)

if wandb_b:
logs = {str(pfp.name[:-4]): wandb.Image(str(pfp)) for pfp in plots_fp}
logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in weights_fp}
logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in plots_l2_fp}
if epoch == config['run']['n_epochs']:
logs |= {str(vfp.name[:-4]): wandb.Video(str(vfp)) for vfp in videos_fp}
# logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in plots_l2_fp}
# if epoch == config['run']['n_epochs']:
# logs |= {str(vfp.name[:-4]): wandb.Video(str(vfp)) for vfp in videos_fp}
wandb.log(logs | {"epoch": epoch, "trainer/global_step": epoch})


Expand Down Expand Up @@ -396,9 +396,9 @@ def train(
single_train_epoch(config, feature_extractor, lateral_network, l2, train_loader, epoch + 1, fabric, l2_opt)
single_eval_epoch(config, feature_extractor, lateral_network, l2, test_loader, epoch + 1)
lateral_network.on_epoch_end()
l2_logs = l2.on_epoch_end()
if l2_sched is not None:
l2_sched.step(l2_logs["l2/val/loss"])
# l2_logs = l2.on_epoch_end()
# if l2_sched is not None:
# l2_sched.step(l2_logs["l2/val/loss"])
config['run']['current_epoch'] = epoch + 1


Expand Down

0 comments on commit eae40e7

Please sign in to comment.