Skip to content

Commit

Permalink
Cleanup branch
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Sep 29, 2023
1 parent 0ae023a commit 93823e7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/changing_line_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
}
}

start_strategy = {"noise": [0.0], "black": [20], "line": [((2, 16), (30, 16))]}

class CustomImage:
"""
Expand All @@ -51,7 +52,7 @@ def __init__(self):
"""
Initialize the class.
"""
self.img_size = 512
self.img_size = 128
self.img_template = self.create_template_image()

def to_mask(self, mask: np.array) -> np.array:
Expand Down Expand Up @@ -214,7 +215,7 @@ def get_strategy(config: Dict[str, Any]) -> Dict[str, Any]:
:param config: Configuration of the demo, describing the strategy
:return: The strategy
"""
strategy = {"noise": [0.0], "black": [20], "line": [((2, 16), (30, 16))]}
strategy = start_strategy
for cycle in range(config["n_cycles"]):
if random.random() <= config["noise"]["probability"]:
noise = random.uniform(config["noise"]["min"], config["noise"]["max"])
Expand Down
27 changes: 14 additions & 13 deletions src/s1_toy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def cycle(
x_view_features = features[:, view_idx, ...]
x_view_features = feature_extractor.binarize_features(x_view_features)

# Add noise to the input features -> should be removed by net fragments
# x_view_features = np.array(x_view_features.detach().cpu())
# x_view_features = x_view_features + np.random.choice(2, x_view_features.shape, p=[1 - 0.005, 0.005])
# x_view_features = torch.from_numpy(x_view_features).cuda().float()
Expand Down Expand Up @@ -353,8 +354,8 @@ def single_eval_epoch(
assert not wandb_b or wandb_b and store_plots, "Wandb logging requires storing the plots."

if plot or wandb_b or store_plots:
# if epoch == 0:
# feature_extractor.plot_model_weights(show_plot=plot)
if epoch == 0:
feature_extractor.plot_model_weights(show_plot=plot)
#
plots_fp = lateral_network.plot_samples(plt_img,
plt_features,
Expand All @@ -364,17 +365,17 @@ def single_eval_epoch(
plot_input_features=epoch == 0,
show_plot=plot)
weights_fp = lateral_network.plot_model_weights(show_plot=plot)
# plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
# if epoch == config['run']['n_epochs']:
# videos_fp = lateral_network.create_activations_video(plt_img, plt_input_features, plt_activations)
#
# if wandb_b:
# logs = {str(pfp.name[:-4]): wandb.Image(str(pfp)) for pfp in plots_fp}
# logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in weights_fp}
# logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in plots_l2_fp}
# if epoch == config['run']['n_epochs']:
# logs |= {str(vfp.name[:-4]): wandb.Video(str(vfp)) for vfp in videos_fp}
# wandb.log(logs | {"epoch": epoch, "trainer/global_step": epoch})
plots_l2_fp = l2.plot_samples(plt_img, plt_activations_l2, show_plot=plot)
if epoch == config['run']['n_epochs']:
videos_fp = lateral_network.create_activations_video(plt_img, plt_input_features, plt_activations)

if wandb_b:
logs = {str(pfp.name[:-4]): wandb.Image(str(pfp)) for pfp in plots_fp}
logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in weights_fp}
logs |= {str(wfp.name[:-4]): wandb.Image(str(wfp)) for wfp in plots_l2_fp}
if epoch == config['run']['n_epochs']:
logs |= {str(vfp.name[:-4]): wandb.Video(str(vfp)) for vfp in videos_fp}
wandb.log(logs | {"epoch": epoch, "trainer/global_step": epoch})


def train(
Expand Down

0 comments on commit 93823e7

Please sign in to comment.