diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index bbc78c05b35..ae8b8396457 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -86,6 +86,7 @@ Experimental features .. autofunction:: oneflow.experimental.lt .. autofunction:: oneflow.experimental.Tensor.lt .. autofunction:: oneflow.experimental.nn.Identity +.. autofunction:: oneflow.experimental.nn.PixelShuffle .. autofunction:: oneflow.experimental.nn.Linear .. autofunction:: oneflow.experimental.nn.CrossEntropyLoss .. autofunction:: oneflow.experimental.nn.NLLLoss diff --git a/oneflow/python/nn/modules/pixel_shuffle.py b/oneflow/python/nn/modules/pixel_shuffle.py new file mode 100644 index 00000000000..bf01b040480 --- /dev/null +++ b/oneflow/python/nn/modules/pixel_shuffle.py @@ -0,0 +1,115 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.python.framework.tensor import Tensor +from oneflow.python.oneflow_export import oneflow_export, experimental_api +from oneflow.python.nn.module import Module + + +@oneflow_export("nn.PixelShuffle") +@experimental_api +class PixelShuffle(Module): + r"""The interface is consistent with PyTorch. + The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle + + Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. + + This is useful for implementing efficient sub-pixel convolution + with a stride of :math:`1/r`. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et. al (2016) for more details. + + Args: + upscale_factor (int): factor to increase spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \div \text{upscale_factor}^2 + + .. math:: + H_{out} = H_{in} \times \text{upscale_factor} + + .. math:: + W_{out} = W_{in} \times \text{upscale_factor} + + For example: + + .. code-block:: python + + >>> import oneflow.experimental as flow + >>> import numpy as np + >>> flow.enable_eager_execution() + + >>> m = flow.nn.PixelShuffle(upscale_factor=2) + >>> x = flow.Tensor(np.random.randn(3, 4, 5, 5)) + >>> y = m(x) + >>> print(y.size()) + flow.Size([3, 1, 10, 10]) + + >>> m = flow.nn.PixelShuffle(upscale_factor=3) + >>> x = flow.Tensor(np.random.randn(1, 18, 2, 2)) + >>> y = m(x) + >>> print(y.size()) + flow.Size([1, 2, 6, 6]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + def __init__(self, upscale_factor: int) -> None: + super().__init__() + assert upscale_factor > 0, "The scale factor must larger than zero" + self.upscale_factor = upscale_factor + + def forward(self, input: Tensor) -> Tensor: + assert len(input.shape) == 4, "Only Accept 4D Tensor" + + _batch, _channel, _height, _width = input.shape + assert ( + _channel % (self.upscale_factor ** 2) == 0 + ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)" + + _new_c = int(_channel / (self.upscale_factor ** 2)) + + out = input.reshape( + [_batch, _new_c, self.upscale_factor ** 2, _height, _width,] + ) + out = out.reshape( + [_batch, _new_c, self.upscale_factor, self.upscale_factor, _height, _width,] + ) + out = out.permute(0, 1, 4, 2, 5, 3) + out = out.reshape( + [ + _batch, + _new_c, + _height * self.upscale_factor, + _width * self.upscale_factor, + ] + ) + + return out + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/ops/nn_ops.py b/oneflow/python/ops/nn_ops.py index b189ee85c68..fd875efb262 100644 --- a/oneflow/python/ops/nn_ops.py +++ b/oneflow/python/ops/nn_ops.py @@ -4289,6 +4289,7 @@ def _p_norm(x, p=2.0, name="p_norm"): @oneflow_export("nn.PixelShuffle") +@stable_api def pixel_shuffle( input: oneflow._oneflow_internal.BlobDesc, upscale_factor: int, diff --git a/oneflow/python/test/modules/test_pixel_shuffle.py b/oneflow/python/test/modules/test_pixel_shuffle.py new file mode 100644 index 00000000000..f21c10fbc09 --- /dev/null +++ b/oneflow/python/test/modules/test_pixel_shuffle.py @@ -0,0 +1,91 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow.experimental as flow +from test_util import GenArgList + + +def _np_pixel_shuffle(input, factor): + _batch, _channel, _height, _width = input.shape + assert ( + _channel % (factor ** 2) == 0 + ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)" + _new_c = int(_channel / (factor ** 2)) + + out = np.reshape(input, [_batch, _new_c, factor ** 2, _height, _width]) + out = np.reshape(out, [_batch, _new_c, factor, factor, _height, _width]) + out = np.transpose(out, [0, 1, 4, 2, 5, 3]) + out = np.reshape(out, [_batch, _new_c, _height * factor, _width * factor]) + return out + + +def _np_pixel_shuffle_grad(input, factor): + _batch, _new_channel, _height_mul_factor, _width_mul_factor = input.shape + _channel = _new_channel * (factor ** 2) + _height = _height_mul_factor // factor + _width = _width_mul_factor // factor + + out = np.ones(shape=(_batch, _channel, _height, _width)) + return out + + +def _test_pixel_shuffle_impl(test_case, device, shape, upscale_factor): + x = np.random.randn(*shape) + input = flow.Tensor( + x, dtype=flow.float32, requires_grad=True, device=flow.device(device) + ) + + m = flow.nn.PixelShuffle(upscale_factor) + m = m.to(device) + of_out = m(input) + np_out = _np_pixel_shuffle(x, upscale_factor) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + np_grad = _np_pixel_shuffle_grad(np_out, upscale_factor) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestPixelShuffleModule(flow.unittest.TestCase): + def test_pixel_shuffle(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_pixel_shuffle_impl, + ] + arg_dict["device"] = ["cpu", "cuda"] + + arg_dict["shape"] = [(2, 144, 5, 5), (11, 144, 1, 1)] + arg_dict["upscale_factor"] = [2, 3, 4] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + arg_dict["shape"] = [(8, 25, 18, 18), (1, 25, 2, 2)] + arg_dict["upscale_factor"] = [5] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main()