diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 1cc4f019..99b6d6fe 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -6,6 +6,7 @@ from e2cnn.gspaces import * import numpy as np +import math import torch @@ -179,6 +180,28 @@ def test_padding_mode_circular(self): init.generalized_he_init(cl.weights.data, cl.basisexpansion) cl.eval() cl.check_equivariance() + + def test_output_shape(self): + g = FlipRot2dOnR2(4, axis=np.pi / 2) + + r1 = FieldType(g, [g.trivial_repr]) + r2 = FieldType(g, [g.regular_repr]) + + S = 17 + + x = torch.randn(1, r1.size, S, S) + x = GeometricTensor(x, r1) + + with torch.no_grad(): + for k in [3, 5, 7, 9, 4, 8]: + for p in [0, 1, 2, 4]: + for s in [1, 2, 3]: + for mode in ['zeros', 'circular', 'reflect', 'replicate']: + cl = R2Conv(r1, r2, k, padding=p, stride=s, padding_mode=mode, initialize=False).eval() + y = cl(x) + _S = math.floor((S + 2*p - k) / s + 1) + self.assertEqual(y.shape, (1, r2.size, _S, _S)) + self.assertEqual(y.shape, cl.evaluate_output_shape(x.shape)) if __name__ == '__main__':