-
Notifications
You must be signed in to change notification settings - Fork 695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mseloss module #5116
add mseloss module #5116
Conversation
YongtaoShi
commented
Jun 7, 2021
•
edited
Loading
edited
@doombeaker @Flowingsun007 @BBuf |
oneflow/python/nn/modules/loss.py
Outdated
"none", | ||
"mean", | ||
None, | ||
], "only 'sum', 'mean' and None supported by now" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里修改一下,类似于reduction parameter only support 'sum'/'mean'/'none'/None value now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) | ||
|
||
|
||
def _test_mseloss_one_elem_input_backward(test_case, device, reduction): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然numpy实现了forward和backward,那么这些测试样例都可以合并,通过设置shape来统一测试。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
} | ||
|
||
|
||
def _test_mseloss_backward(test_case, device, shape, reduction): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
名字改成_test_mseloss_impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
…nc/oneflow into shiyongtao/dev_mseloss