Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mseloss module #5116

Merged
merged 18 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.Linear
.. autofunction:: oneflow.experimental.nn.CrossEntropyLoss
.. autofunction:: oneflow.experimental.nn.NLLLoss
.. autofunction:: oneflow.experimental.nn.MSELoss
.. autofunction:: oneflow.experimental.masked_fill
.. autofunction:: oneflow.experimental.Tensor.masked_fill
.. autofunction:: oneflow.experimental.sum
Expand Down
117 changes: 117 additions & 0 deletions oneflow/python/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.nn.modules.math_ops import Subtract, Square, Sum, Mean


@oneflow_export("nn.CrossEntropyLoss")
Expand Down Expand Up @@ -296,6 +297,122 @@ def forward(self, input, target):
return res.mean()


@oneflow_export("nn.MSELoss")
@experimental_api
class MSELoss(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html?highlight=mseloss#torch.nn.MSELoss

Creates a criterion that measures the mean squared error (squared L2 norm) between
each element in the input :math:`x` and target :math:`y`.

The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = \left( x_n - y_n \right)^2,

where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:

.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}

:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
of :math:`n` elements each.

The mean operation still operates over all the elements, and divides by :math:`n`.

The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.

Args:
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when :attr:`reduce` is ``False``. Default: ``True``
reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``

Shape:
- Input: :math:`(N, *)` where :math:`*` means, any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input

For example:

.. code-block:: python

>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()

>>> input = flow.Tensor(
... [[-0.02557137, 0.03101675, 1.37493674],
... [0.25599439, -1.08372561, -0.21006816]], dtype=flow.float32)
>>> #1111
>>> target = flow.Tensor(
... [[-1.53105064, -0.68137555, 0.5931354],
... [-0.49158347, 0.93673637, 0.1324141]], dtype=flow.float32)
>>> m = flow.nn.MSELoss(reduction="none")
>>> out = m(input, target)
>>> print(out.numpy())
[[2.266468 0.50750285 0.61121327]
[0.55887264 4.082267 0.1172941 ]]
>>> m = flow.nn.MSELoss(reduction="mean")
>>> out = m(input, target)
>>> print(out.numpy())
[1.3572696]
>>> m = flow.nn.MSELoss(reduction="sum")
>>> out = m(input, target)
>>> print(out.numpy())
[8.143618]

"""

def __init__(self, reduction: str = "mean", size_average=True, reduce=True) -> None:
super().__init__()
if size_average is not None and not size_average:
raise ValueError("Argument size_average is not supported yet")
if reduce is not None and not reduce:
raise ValueError("Argument reduce is not supported yet")
assert reduction in [
"sum",
"none",
"mean",
None,
], "only 'sum', 'mean' and None supported by now"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改一下,类似于reduction parameter only support 'sum'/'mean'/'none'/None value now!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


self.reduction = reduction
self.square_op = Square()
self.subtract_op = Subtract()
self.sum_op = Sum()
self.mean_op = Mean()

def forward(self, input, target):
mean_squared_difference = self.square_op(self.subtract_op(input, target))
if self.reduction == "mean":
return self.mean_op(mean_squared_difference)
elif self.reduction == "sum":
return self.sum_op(mean_squared_difference)
else:
# Do no reduction
return mean_squared_difference


if __name__ == "__main__":
import doctest

Expand Down
1 change: 1 addition & 0 deletions oneflow/python/ops/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3917,6 +3917,7 @@ def bce_with_logits_loss_job(input: tp.Numpy.Placeholder(shape=(2, 3)),


@oneflow_export("nn.MSELoss")
@stable_api
def mse_loss(
input: oneflow._oneflow_internal.BlobDesc,
target: oneflow._oneflow_internal.BlobDesc,
Expand Down
128 changes: 128 additions & 0 deletions oneflow/python/test/modules/test_mseloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np

import oneflow.experimental as flow
from test_util import GenArgList


def _np_mseloss(np_input, np_target):
np_mse = np.square(np_target - np_input)
np_mse_mean = np.mean(np_mse)
np_mse_sum = np.sum(np_mse)

return {
"none": np_mse,
"mean": np_mse_mean,
"sum": np_mse_sum,
}


def _np_mseloss_grad(np_input, np_target):
elem_cnt = np_input.size
np_mse_grad_sum = -2 * (np_target - np_input)
np_mse_grad_mean = np_mse_grad_sum / elem_cnt

return {
"none": np_mse_grad_sum,
"mean": np_mse_grad_mean,
"sum": np_mse_grad_sum,
}


def _test_mseloss_backward(test_case, device, reduction):
x = np.random.randn(3, 5)
y = np.random.randn(3, 5)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
target = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))

loss = flow.nn.MSELoss(reduction=reduction)
loss = loss.to(device)
of_out = loss(input, target)
np_out = _np_mseloss(x, y)[reduction]
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))

of_out = of_out.sum()
of_out.backward()
np_grad = _np_mseloss_grad(x, y)[reduction]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))


def _test_mseloss_high_dim_input_backward(test_case, device, reduction):
x = np.random.randn(3, 2, 4, 16, 5)
y = np.random.randn(3, 2, 4, 16, 5)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
target = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))

loss = flow.nn.MSELoss(reduction=reduction)
loss = loss.to(device)
of_out = loss(input, target)
np_out = _np_mseloss(x, y)[reduction]
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))

of_out = of_out.sum()
of_out.backward()
np_grad = _np_mseloss_grad(x, y)[reduction]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))


def _test_mseloss_one_elem_input_backward(test_case, device, reduction):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然numpy实现了forward和backward,那么这些测试样例都可以合并,通过设置shape来统一测试。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

x = np.array([0]).astype(np.float)
y = np.array([-1]).astype(np.float)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
target = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))

loss = flow.nn.MSELoss(reduction=reduction)
loss = loss.to(device)
of_out = loss(input, target)
np_out = _np_mseloss(x, y)[reduction]
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))

of_out = of_out.sum()
of_out.backward()
np_grad = _np_mseloss_grad(x, y)[reduction]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestMSELossModule(flow.unittest.TestCase):
def test_mseloss(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_mseloss_backward,
_test_mseloss_high_dim_input_backward,
_test_mseloss_one_elem_input_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["reduction"] = ["none", "mean", "sum"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])


if __name__ == "__main__":
unittest.main()