Skip to content

Commit

Permalink
init atanh (#4960)
Browse files Browse the repository at this point in the history
* init atanh

* add atanh

* add /n in the tail'

* update atanh

* format

* add tan

* add test_square.py test_sqrt.py

* update test_atanh grad

* update test_tensor.py

* add doctest

* update

* split tensor

* add 3 dim testcase

* resolve conflict

* auto format by CI

* Update test_atanh.py

use arg_dict[test_fun"]

* Revert "Update test_atanh.py"

This reverts commit d726d7b.

* fix

Co-authored-by: oneflow-ci-bot <[email protected]>
Co-authored-by: oneflow-ci-bot <[email protected]>
Co-authored-by: aajjjtntn <[email protected]>
Co-authored-by: jackalcooper <[email protected]>
  • Loading branch information
5 people authored Jun 4, 2021
1 parent 22fbe76 commit 9583290
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,11 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.Upsample
.. autofunction:: oneflow.experimental.nn.UpsamplingNearest2d
.. autofunction:: oneflow.experimental.nn.UpsamplingBilinear2d
.. autofunction:: oneflow.experimental.atanh
.. autofunction:: oneflow.experimental.Tensor.atanh
.. autofunction:: oneflow.experimental.arctanh
.. autofunction:: oneflow.experimental.Tensor.arctanh
.. autofunction:: oneflow.experimental.tan
.. autofunction:: oneflow.experimental.Tensor.tan
.. autofunction:: oneflow.experimental.log1p
.. autofunction:: oneflow.experimental.Tensor.log1p
95 changes: 95 additions & 0 deletions oneflow/python/nn/modules/atanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import register_tensor_op


class Atanh(Module):
def __init__(self):
super().__init__()
self._op = flow.builtin_op("atanh").Input("x").Output("y").Build()

def forward(self, x):
return self._op(x)[0]


@oneflow_export("atanh")
@experimental_api
def atanh_op(input):
r"""Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`.
.. math::
\text{out}_{i} = \tanh^{-1}(\text{input}_{i})
Args:
input (Tensor): the input tensor.
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> np_arr = np.array([0.5, 0.6, 0.7]).astype(np.float32)
>>> input = flow.Tensor(np_arr)
>>> output = flow.atanh(input)
>>> print(output.numpy())
[0.54930615 0.6931472 0.8673005 ]
"""

return Atanh()(input)


@register_tensor_op("atanh")
@experimental_api
def atanh_op_tensor(x):
r"""
atanh() -> Tensor
See :func:`oneflow.experimental.atanh`
"""

return Atanh()(x)


@oneflow_export("arctanh")
@experimental_api
def arctanh_op(input):
r"""
Alias for :func:`oneflow.experimental.atanh`
"""

return Atanh()(input)


@register_tensor_op("arctanh")
@experimental_api
def arctanh_op_tensor(input):
r"""
Alias for :func:`oneflow.experimental.atanh`
"""

return Atanh()(input)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
73 changes: 73 additions & 0 deletions oneflow/python/nn/modules/tan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import register_tensor_op


class Tan(Module):
def __init__(self):
super().__init__()
self._op = flow.builtin_op("tan").Input("x").Output("y").Build()

def forward(self, x):
return self._op(x)[0]


@oneflow_export("tan")
@experimental_api
def tan_op(input):
r"""Returns the tan value of the elements of :attr:`input`.
.. math::
\text{out}_{i} = \tan(\text{input}_{i})
Args:
input (Tensor): the input tensor.
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> np_arr = np.array([-1/4*np.pi, 0, 1/4*np.pi]).astype(np.float32)
>>> input = flow.Tensor(np_arr)
>>> output = flow.tan(input)
>>> print(output.numpy())
[-1. 0. 1.]
"""

return Tan()(input)


@register_tensor_op("tan")
@experimental_api
def tan_op_tensor(input):
r"""
tan() -> Tensor
See :func:`oneflow.experimental.tan`
"""

return Tan()(input)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
80 changes: 80 additions & 0 deletions oneflow/python/test/modules/test_atanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np

import oneflow.experimental as flow
from test_util import GenArgList


def _test_atanh_impl(test_case, shape, device):
np_input = np.random.random(shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)

of_out = flow.atanh(of_input)
np_out = np.arctanh(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = 1.0 / (1.0 - np.square(np_input))
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)


def _test_arctanh_impl(test_case, shape, device):
np_input = np.random.random(shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)

of_out = flow.arctanh(of_input)
np_out = np.arctanh(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = 1.0 / (1.0 - np.square(np_input))
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestAtanh(flow.unittest.TestCase):
def test_atanh(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_atanh_impl(test_case, *arg)
_test_arctanh_impl(test_case, *arg)


if __name__ == "__main__":
unittest.main()
59 changes: 59 additions & 0 deletions oneflow/python/test/modules/test_tan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np

import oneflow.experimental as flow
from test_util import GenArgList


def _test_tan_impl(test_case, shape, device):
np_input = np.random.random(size=shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)

of_out = flow.tan(of_input)
np_out = np.tan(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = 1 + np.square(np_out)
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTan(flow.unittest.TestCase):
def test_tan(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_tan_impl(test_case, *arg)


if __name__ == "__main__":
unittest.main()
63 changes: 63 additions & 0 deletions oneflow/python/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,69 @@ def test_pow_tensor_function(test_case):
np_out = np.power(input.numpy(), 2.1)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))

@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
)
def test_tensor_atanh(test_case):
np_input = np.random.random((2, 3))
of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)

of_out = of_input.atanh()
np_out = np.arctanh(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = 1.0 / (1.0 - np.square(np_input))
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)

@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
)
def test_tensor_arctanh(test_case):
np_input = np.random.random((2, 3))
of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)

of_out = of_input.arctanh()
np_out = np.arctanh(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = 1.0 / (1.0 - np.square(np_input))
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)

@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
)
def test_tensor_tan(test_case):
np_input = np.random.random((2, 3))
of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)

of_out = of_input.tan()
np_out = np.tan(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = 1 + np.square(np_out)
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9583290

Please sign in to comment.