diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index 9b37a7b58b2..3767ad22fa5 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -153,5 +153,11 @@ Experimental features .. autofunction:: oneflow.experimental.nn.Upsample .. autofunction:: oneflow.experimental.nn.UpsamplingNearest2d .. autofunction:: oneflow.experimental.nn.UpsamplingBilinear2d +.. autofunction:: oneflow.experimental.atanh +.. autofunction:: oneflow.experimental.Tensor.atanh +.. autofunction:: oneflow.experimental.arctanh +.. autofunction:: oneflow.experimental.Tensor.arctanh +.. autofunction:: oneflow.experimental.tan +.. autofunction:: oneflow.experimental.Tensor.tan .. autofunction:: oneflow.experimental.log1p .. autofunction:: oneflow.experimental.Tensor.log1p diff --git a/oneflow/python/nn/modules/atanh.py b/oneflow/python/nn/modules/atanh.py new file mode 100644 index 00000000000..57ee341f551 --- /dev/null +++ b/oneflow/python/nn/modules/atanh.py @@ -0,0 +1,95 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +from oneflow.python.oneflow_export import oneflow_export, experimental_api +from oneflow.python.nn.module import Module +from oneflow.python.framework.tensor import register_tensor_op + + +class Atanh(Module): + def __init__(self): + super().__init__() + self._op = flow.builtin_op("atanh").Input("x").Output("y").Build() + + def forward(self, x): + return self._op(x)[0] + + +@oneflow_export("atanh") +@experimental_api +def atanh_op(input): + r"""Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) + Args: + input (Tensor): the input tensor. + + For example: + + .. code-block:: python + + >>> import oneflow.experimental as flow + >>> import numpy as np + >>> flow.enable_eager_execution() + >>> np_arr = np.array([0.5, 0.6, 0.7]).astype(np.float32) + >>> input = flow.Tensor(np_arr) + >>> output = flow.atanh(input) + >>> print(output.numpy()) + [0.54930615 0.6931472 0.8673005 ] + """ + + return Atanh()(input) + + +@register_tensor_op("atanh") +@experimental_api +def atanh_op_tensor(x): + r""" + atanh() -> Tensor + See :func:`oneflow.experimental.atanh` + + """ + + return Atanh()(x) + + +@oneflow_export("arctanh") +@experimental_api +def arctanh_op(input): + r""" + + Alias for :func:`oneflow.experimental.atanh` + """ + + return Atanh()(input) + + +@register_tensor_op("arctanh") +@experimental_api +def arctanh_op_tensor(input): + r""" + + Alias for :func:`oneflow.experimental.atanh` + """ + + return Atanh()(input) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/nn/modules/tan.py b/oneflow/python/nn/modules/tan.py new file mode 100644 index 00000000000..118e6788734 --- /dev/null +++ b/oneflow/python/nn/modules/tan.py @@ -0,0 +1,73 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +from oneflow.python.oneflow_export import oneflow_export, experimental_api +from oneflow.python.nn.module import Module +from oneflow.python.framework.tensor import register_tensor_op + + +class Tan(Module): + def __init__(self): + super().__init__() + self._op = flow.builtin_op("tan").Input("x").Output("y").Build() + + def forward(self, x): + return self._op(x)[0] + + +@oneflow_export("tan") +@experimental_api +def tan_op(input): + r"""Returns the tan value of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan(\text{input}_{i}) + Args: + input (Tensor): the input tensor. + + For example: + + .. code-block:: python + + >>> import oneflow.experimental as flow + >>> import numpy as np + >>> flow.enable_eager_execution() + >>> np_arr = np.array([-1/4*np.pi, 0, 1/4*np.pi]).astype(np.float32) + >>> input = flow.Tensor(np_arr) + >>> output = flow.tan(input) + >>> print(output.numpy()) + [-1. 0. 1.] + """ + + return Tan()(input) + + +@register_tensor_op("tan") +@experimental_api +def tan_op_tensor(input): + r""" + tan() -> Tensor + See :func:`oneflow.experimental.tan` + + """ + + return Tan()(input) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/test/modules/test_atanh.py b/oneflow/python/test/modules/test_atanh.py new file mode 100644 index 00000000000..cadedcbbf3f --- /dev/null +++ b/oneflow/python/test/modules/test_atanh.py @@ -0,0 +1,80 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow.experimental as flow +from test_util import GenArgList + + +def _test_atanh_impl(test_case, shape, device): + np_input = np.random.random(shape) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + + of_out = flow.atanh(of_input) + np_out = np.arctanh(np_input) + test_case.assertTrue( + np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True) + ) + + of_out = of_out.sum() + of_out.backward() + np_out_grad = 1.0 / (1.0 - np.square(np_input)) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True) + ) + + +def _test_arctanh_impl(test_case, shape, device): + np_input = np.random.random(shape) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + + of_out = flow.arctanh(of_input) + np_out = np.arctanh(np_input) + test_case.assertTrue( + np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True) + ) + + of_out = of_out.sum() + of_out.backward() + np_out_grad = 1.0 / (1.0 - np.square(np_input)) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True) + ) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestAtanh(flow.unittest.TestCase): + def test_atanh(test_case): + arg_dict = OrderedDict() + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + _test_atanh_impl(test_case, *arg) + _test_arctanh_impl(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_tan.py b/oneflow/python/test/modules/test_tan.py new file mode 100644 index 00000000000..25bd6b8ff4c --- /dev/null +++ b/oneflow/python/test/modules/test_tan.py @@ -0,0 +1,59 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow.experimental as flow +from test_util import GenArgList + + +def _test_tan_impl(test_case, shape, device): + np_input = np.random.random(size=shape) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + + of_out = flow.tan(of_input) + np_out = np.tan(np_input) + test_case.assertTrue( + np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True) + ) + + of_out = of_out.sum() + of_out.backward() + np_out_grad = 1 + np.square(np_out) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True) + ) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestTan(flow.unittest.TestCase): + def test_tan(test_case): + arg_dict = OrderedDict() + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + _test_tan_impl(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/tensor/test_tensor.py b/oneflow/python/test/tensor/test_tensor.py index b6079b21d51..453ade57d3d 100644 --- a/oneflow/python/test/tensor/test_tensor.py +++ b/oneflow/python/test/tensor/test_tensor.py @@ -761,6 +761,69 @@ def test_pow_tensor_function(test_case): np_out = np.power(input.numpy(), 2.1) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + @unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + "numpy doesn't work in lazy mode", + ) + def test_tensor_atanh(test_case): + np_input = np.random.random((2, 3)) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + + of_out = of_input.atanh() + np_out = np.arctanh(np_input) + test_case.assertTrue( + np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True) + ) + + of_out = of_out.sum() + of_out.backward() + np_out_grad = 1.0 / (1.0 - np.square(np_input)) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True) + ) + + @unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + "numpy doesn't work in lazy mode", + ) + def test_tensor_arctanh(test_case): + np_input = np.random.random((2, 3)) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + + of_out = of_input.arctanh() + np_out = np.arctanh(np_input) + test_case.assertTrue( + np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True) + ) + + of_out = of_out.sum() + of_out.backward() + np_out_grad = 1.0 / (1.0 - np.square(np_input)) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True) + ) + + @unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + "numpy doesn't work in lazy mode", + ) + def test_tensor_tan(test_case): + np_input = np.random.random((2, 3)) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + + of_out = of_input.tan() + np_out = np.tan(np_input) + test_case.assertTrue( + np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True) + ) + + of_out = of_out.sum() + of_out.backward() + np_out_grad = 1 + np.square(np_out) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True) + ) + if __name__ == "__main__": unittest.main()