Skip to content

Commit

Permalink
Fix normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 23, 2024
1 parent 84895a2 commit be6626f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 18 deletions.
23 changes: 10 additions & 13 deletions examples/block_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,20 @@
transformStandard = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
transformPoly = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))]
)

normalization = {
"max_abs": MaxAbsNormalization,
"max_center": MaxCenterNormalization,
}

grid_x, grid_y = torch.meshgrid(
(torch.arange(28) - 14) / 14, (torch.arange(28) - 14) / 14, indexing="ij"
(torch.arange(28) - 13.5) / 13.5, (torch.arange(28) - 13.5) / 13.5, indexing="ij"
)
grid = torch.stack([grid_x, grid_y])
print('grid', grid)


def collate_fn(batch):

input = []
classification = []
for element in batch:
Expand Down Expand Up @@ -70,12 +67,12 @@ def __init__(self, cfg: DictConfig):
self._layer_type = cfg.layer_type
self._train_fraction = cfg.train_fraction

self._transform = transformPoly
self._transform = transformStandard

layer1 = high_order_fc_layers(
layer_type=cfg.layer_type,
n=[3,n,n],
segments = cfg.segments,
n=[3, n, n],
segments=cfg.segments,
in_features=1,
out_features=10,
intialization="constant_random",
Expand All @@ -86,14 +83,14 @@ def __init__(self, cfg: DictConfig):
self.model = nn.Sequential(*[layer1, normalize])

def forward(self, x):
#print("x.shape", x.shape)
# print("x.shape", x.shape)
batch_size, inputs = x.shape[:2]
xin = x.view(-1, 1, 3)
#print("xin.shape", xin.shape)
# print("xin.shape", xin.shape)
res = self.model(xin)
res = res.reshape(batch_size, inputs, -1)
output = torch.sum(res,dim=1)
#print("res.shape", output.shape)
output = torch.sum(res, dim=1)
# print("res.shape", output.shape)
# xout = res.view(batch_size, )
return output

Expand Down
6 changes: 1 addition & 5 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
transformStandard = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
transformPoly = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))]
)


normalization = {
"max_abs": MaxAbsNormalizationND,
Expand All @@ -46,7 +42,7 @@ def __init__(self, cfg: DictConfig):
self._train_fraction = cfg.train_fraction
segments = cfg.segments

self._transform = transformPoly
self._transform = transformStandard

in_channels = cfg.channels[0]
out_channels = cfg.channels[1]
Expand Down
1 change: 1 addition & 0 deletions high_order_layers_torch/PolynomialLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def _constant_random_initialization(self, weight_magnitude):
print("self.w.data", self.w)

def which_segment(self, x: torch.Tensor) -> torch.Tensor:
print('segments x', torch.min(x[:,:,0]), torch.max(x[:,:,1]))
return (
(
(x + self._half)
Expand Down

0 comments on commit be6626f

Please sign in to comment.