Skip to content

Commit

Permalink
add test for output shape of R2Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabri95 committed Nov 12, 2020
1 parent 25ed6f0 commit 734553a
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/nn/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from e2cnn.gspaces import *

import numpy as np
import math

import torch

Expand Down Expand Up @@ -179,6 +180,28 @@ def test_padding_mode_circular(self):
init.generalized_he_init(cl.weights.data, cl.basisexpansion)
cl.eval()
cl.check_equivariance()

def test_output_shape(self):
g = FlipRot2dOnR2(4, axis=np.pi / 2)

r1 = FieldType(g, [g.trivial_repr])
r2 = FieldType(g, [g.regular_repr])

S = 17

x = torch.randn(1, r1.size, S, S)
x = GeometricTensor(x, r1)

with torch.no_grad():
for k in [3, 5, 7, 9, 4, 8]:
for p in [0, 1, 2, 4]:
for s in [1, 2, 3]:
for mode in ['zeros', 'circular', 'reflect', 'replicate']:
cl = R2Conv(r1, r2, k, padding=p, stride=s, padding_mode=mode, initialize=False).eval()
y = cl(x)
_S = math.floor((S + 2*p - k) / s + 1)
self.assertEqual(y.shape, (1, r2.size, _S, _S))
self.assertEqual(y.shape, cl.evaluate_output_shape(x.shape))


if __name__ == '__main__':
Expand Down

0 comments on commit 734553a

Please sign in to comment.