diff --git a/e2cnn/__about__.py b/e2cnn/__about__.py index b56cbd24..b28aaba3 100644 --- a/e2cnn/__about__.py +++ b/e2cnn/__about__.py @@ -12,7 +12,7 @@ __title__ = "e2cnn" __summary__ = "E(2)-Equivariant CNNs Library for PyTorch" __url__ = 'https://github.com/QUVA-Lab/e2cnn' -__version__ = "0.1" +__version__ = "0.1.1" __author__ = "Gabriele Cesa, Maurice Weiler" __email__ = "cesa.gabriele@gmail.com" __license__ = "BSD 3-Clause Clear" diff --git a/e2cnn/nn/modules/r2_conv/r2convolution.py b/e2cnn/nn/modules/r2_conv/r2convolution.py index c88a1065..b4042644 100644 --- a/e2cnn/nn/modules/r2_conv/r2convolution.py +++ b/e2cnn/nn/modules/r2_conv/r2convolution.py @@ -1,5 +1,5 @@ -from torch.nn.functional import conv2d +from torch.nn.functional import conv2d, pad from e2cnn.nn import init from e2cnn.nn import FieldType @@ -31,6 +31,7 @@ def __init__(self, padding: int = 0, stride: int = 1, dilation: int = 1, + padding_mode: str = 'zeros', groups: int = 1, bias: bool = True, basisexpansion: str = 'blocks', @@ -110,9 +111,10 @@ def __init__(self, in_type (FieldType): the type of the input field, specifying its transformation law out_type (FieldType): the type of the output field, specifying its transformation law kernel_size (int): the size of the (square) filter - padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0`` - stride(int, optional): the stride of the kernel. Default: ``1`` - dilation(int, optional): the spacing between kernel elements. Default: ``1`` + padding (int, optional): implicit zero paddings on both sides of the input. Default: ``0`` + padding_mode(str, optional): ``zeros``, ``reflect``, ``replicate`` or ``circular``. Default: ``zeros`` + stride (int, optional): the stride of the kernel. Default: ``1`` + dilation (int, optional): the spacing between kernel elements. Default: ``1`` groups (int, optional): number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in ``groups`` groups, all equal to each other. @@ -160,8 +162,21 @@ def __init__(self, self.stride = stride self.dilation = dilation self.padding = padding + self.padding_mode = padding_mode self.groups = groups + if isinstance(padding, tuple) and len(padding) == 2: + _padding = padding + elif isinstance(padding, int): + _padding = (padding, padding) + else: + raise ValueError('padding needs to be either an integer or a tuple containing two integers but {} found'.format(padding)) + + padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} + if padding_mode not in padding_modes: + raise ValueError("padding_mode must be one of [{}], but got padding_mode='{}'".format(padding_modes, padding_mode)) + self._reversed_padding_repeated_twice = tuple(x for x in reversed(_padding) for _ in range(2)) + if groups > 1: # Check the input and output classes can be split in `groups` groups, all equal to each other # first, check that the number of fields is divisible by `groups` @@ -310,13 +325,23 @@ def forward(self, input: GeometricTensor): filter, bias = self.expand_parameters() # use it for convolution and return the result - output = conv2d(input.tensor, filter, - padding=self.padding, - stride=self.stride, - dilation=self.dilation, - groups=self.groups, - bias=bias) + if self.padding_mode != 'zeros': + output = conv2d(input.tensor, filter, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=bias) + else: + output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode), + filter, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=bias) + return GeometricTensor(output, self.out_type) def train(self, mode=True): diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 43d8b3b9..1cc4f019 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -150,6 +150,36 @@ def test_flip(self): cl.eval() cl.check_equivariance() + def test_padding_mode_reflect(self): + g = Flip2dOnR2(axis=np.pi / 2) + + r1 = FieldType(g, [g.trivial_repr]) + r2 = FieldType(g, [g.regular_repr]) + + s = 3 + cl = R2Conv(r1, r2, s, bias=True, padding=1, padding_mode='reflect', initialize=False) + + for _ in range(32): + init.generalized_he_init(cl.weights.data, cl.basisexpansion) + cl.eval() + cl.check_equivariance() + + def test_padding_mode_circular(self): + g = FlipRot2dOnR2(4, axis=np.pi / 2) + + r1 = FieldType(g, [g.trivial_repr]) + r2 = FieldType(g, [g.regular_repr]) + + for mode in ['circular', 'reflect', 'replicate']: + for s in [3, 5, 7]: + padding = s // 2 + cl = R2Conv(r1, r2, s, bias=True, padding=padding, padding_mode=mode, initialize=False) + + for _ in range(10): + init.generalized_he_init(cl.weights.data, cl.basisexpansion) + cl.eval() + cl.check_equivariance() + if __name__ == '__main__': unittest.main()