From be6626fb1c200b6a39c3cfed48048404fbbde7a3 Mon Sep 17 00:00:00 2001 From: jloveric Date: Sun, 23 Jun 2024 16:08:56 -0700 Subject: [PATCH] Fix normalization --- examples/block_mnist.py | 23 +++++++++------------ examples/mnist.py | 6 +----- high_order_layers_torch/PolynomialLayers.py | 1 + 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/examples/block_mnist.py b/examples/block_mnist.py index c991d93..c73b21c 100644 --- a/examples/block_mnist.py +++ b/examples/block_mnist.py @@ -23,9 +23,6 @@ transformStandard = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) -transformPoly = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))] -) normalization = { "max_abs": MaxAbsNormalization, @@ -33,13 +30,13 @@ } grid_x, grid_y = torch.meshgrid( - (torch.arange(28) - 14) / 14, (torch.arange(28) - 14) / 14, indexing="ij" + (torch.arange(28) - 13.5) / 13.5, (torch.arange(28) - 13.5) / 13.5, indexing="ij" ) grid = torch.stack([grid_x, grid_y]) -print('grid', grid) + def collate_fn(batch): - + input = [] classification = [] for element in batch: @@ -70,12 +67,12 @@ def __init__(self, cfg: DictConfig): self._layer_type = cfg.layer_type self._train_fraction = cfg.train_fraction - self._transform = transformPoly + self._transform = transformStandard layer1 = high_order_fc_layers( layer_type=cfg.layer_type, - n=[3,n,n], - segments = cfg.segments, + n=[3, n, n], + segments=cfg.segments, in_features=1, out_features=10, intialization="constant_random", @@ -86,14 +83,14 @@ def __init__(self, cfg: DictConfig): self.model = nn.Sequential(*[layer1, normalize]) def forward(self, x): - #print("x.shape", x.shape) + # print("x.shape", x.shape) batch_size, inputs = x.shape[:2] xin = x.view(-1, 1, 3) - #print("xin.shape", xin.shape) + # print("xin.shape", xin.shape) res = self.model(xin) res = res.reshape(batch_size, inputs, -1) - output = torch.sum(res,dim=1) - #print("res.shape", output.shape) + output = torch.sum(res, dim=1) + # print("res.shape", output.shape) # xout = res.view(batch_size, ) return output diff --git a/examples/mnist.py b/examples/mnist.py index 588c850..3e97ab4 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -19,10 +19,6 @@ transformStandard = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) -transformPoly = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))] -) - normalization = { "max_abs": MaxAbsNormalizationND, @@ -46,7 +42,7 @@ def __init__(self, cfg: DictConfig): self._train_fraction = cfg.train_fraction segments = cfg.segments - self._transform = transformPoly + self._transform = transformStandard in_channels = cfg.channels[0] out_channels = cfg.channels[1] diff --git a/high_order_layers_torch/PolynomialLayers.py b/high_order_layers_torch/PolynomialLayers.py index 9637a91..b2461fc 100644 --- a/high_order_layers_torch/PolynomialLayers.py +++ b/high_order_layers_torch/PolynomialLayers.py @@ -436,6 +436,7 @@ def _constant_random_initialization(self, weight_magnitude): print("self.w.data", self.w) def which_segment(self, x: torch.Tensor) -> torch.Tensor: + print('segments x', torch.min(x[:,:,0]), torch.max(x[:,:,1])) return ( ( (x + self._half)