Skip to content

Commit

Permalink
Add noise to feature channels during eval
Browse files Browse the repository at this point in the history
  • Loading branch information
sagerpascal committed Nov 9, 2023
1 parent eae40e7 commit bf9d539
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion configs/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ straightline:
split: "test"
vertical_horizontal_only: False
aug_range: 0
num_images: 8
num_images: 4
num_aug_versions: 0
noise: 0.0

Expand Down
2 changes: 1 addition & 1 deletion configs/lateral_connection_alternative_cells.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ feature_extractor:
bin_threshold: 0. # set to 0.5 to obtain better features
optimized_filter_lines: True # set to True to obtain better features

n_alternative_cells: 10
n_alternative_cells: 20

lateral_model:
channels: 4
Expand Down
22 changes: 11 additions & 11 deletions src/lateral_connections/s1_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self,
self.n_alternative_cells = n_alternative_cells
self.in_channels = in_channels
self.out_channels = out_channels
self.in_feature_channels = self.in_channels - self.out_channels
self.locality_size = locality_size
self.neib_size = 2 * self.locality_size + 1
self.kernel_size = (self.neib_size, self.neib_size)
Expand Down Expand Up @@ -166,15 +167,15 @@ def _init_lateral_weights(self) -> Tensor:
if self.n_alternative_cells <= 1: # TODO: is this if/else still necessary?
for co in range(self.out_channels):
for ci in range(self.in_channels):
if ci == co or ci + 4 == co:
if ci == co or ci + self.in_feature_channels == co:
cii = ci * self.kernel_size[0] * self.kernel_size[1] + self.locality_size * self.kernel_size[
1] + self.locality_size
W_lateral[co, cii, 0, 0] = 1

else:
for co in range(self.out_channels):
for ci in range(self.in_channels):
if (ci < 4 and co // self.n_alternative_cells == ci) or ci == co + 4:
if (ci < self.in_feature_channels and co // self.n_alternative_cells == ci) or ci == co + self.in_feature_channels:
cii = ci * self.kernel_size[0] * self.kernel_size[1] + self.locality_size * self.kernel_size[
1] + self.locality_size
W_lateral[co, cii, 0, 0] = 1
Expand Down Expand Up @@ -273,15 +274,14 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]:
# Shape w2: 5324, 40, 1
# Shape x2: 5324, 1, 1024
d2 = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
w2 = (self.W_lateral.reshape(40, d2, 1).permute(1, 0, 2) > 0).float()
w2 = (self.W_lateral.reshape(self.out_channels, d2, 1).permute(1, 0, 2) > 0).float()
x2 = x_rearranged.reshape(d2, 1, 1024).permute(0, 1, 2)
pos_corr3 = torch.matmul(w2, x2) # 40,5324,1 * 40,1,1024
neg_corr3 = torch.matmul(w2, 1 - x2) + torch.matmul(1 - w2, x2)
pos_corr3 = pos_corr3.permute(2, 1, 0).reshape(1024, 4, 10, d2)
neg_corr3 = neg_corr3.permute(2, 1, 0).reshape(1024, 4, 10, d2)

pos_corr = pos_corr3
neg_corr = neg_corr3
pos_corr = torch.matmul(w2, x2) # 40,5324,1 * 40,1,1024
neg_corr = torch.matmul(w2, 1 - x2) + torch.matmul(1 - w2, x2)
pos_corr = pos_corr.permute(2, 1, 0).reshape(1024, self.in_feature_channels,
self.n_alternative_cells, d2)
neg_corr = neg_corr.permute(2, 1, 0).reshape(1024, self.in_feature_channels,
self.n_alternative_cells, d2)

# correlation shape: (batch_size*H*W, out_channels, alt_channels, in_channels*kernel_w*kernel_h)
# Goal for every position in the channel 0, one of the alternative output channels should be active
Expand Down Expand Up @@ -342,7 +342,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]]:
best_channel = best_channel.reshape((x_lateral_bin.shape[0], ) + x_lateral_bin.shape[2:] + (-1, ))

x_lateral_bin_reshaped = x_lateral_bin.reshape((x_lateral_bin.shape[0], x_lateral_bin.shape[1] // self.n_alternative_cells, self.n_alternative_cells) + x_lateral_bin.shape[2:]).permute(0, 3, 4, 1, 2)
self.mask = (torch.arange(0, 10).view(1, 1, 1, 1, -1).cuda() == best_channel.unsqueeze(-1))
self.mask = (torch.arange(0, self.n_alternative_cells).view(1, 1, 1, 1, -1).cuda() == best_channel.unsqueeze(-1))
assert torch.all(torch.sum(self.mask, dim=4) == 1)

else:
Expand Down
13 changes: 12 additions & 1 deletion src/main_lateral_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,24 @@ def cycle(
with torch.no_grad():
features = feature_extractor(batch)

features = feature_extractor.binarize_features(features)

if mode == "eval":
features_s = features.shape
num_elements = features.numel()
num_flips = int(0.005 * num_elements)
random_mask = torch.randperm(num_elements)[:num_flips]
random_mask = torch.zeros(num_elements, dtype=torch.bool).scatter(0, random_mask, 1)
features = features.view(-1)
features[random_mask] = 1.0 - features[random_mask]
features = features.view(features_s)

lateral_network.new_sample()
z = None

input_features, lateral_features, lateral_features_f, l2_features, l2h_features = [], [], [], [], []
for view_idx in range(features.shape[1]):
x_view_features = features[:, view_idx, ...]
x_view_features = feature_extractor.binarize_features(x_view_features)

# Add noise to the input features -> should be removed by net fragments
# x_view_features = np.array(x_view_features.detach().cpu())
Expand Down

0 comments on commit bf9d539

Please sign in to comment.