Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue with total_weight in nll_loss_forward_decomposition #829

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented May 22, 2022

Description:

@Chillee catched that total_weight output is wrong for nll_loss_forward_decomposition C++ implementation:

import torch
from torch import tensor
from torch._decomp import decomposition_table
from functorch import vmap
aten = torch.ops.aten

args = (
    tensor([[-4.8270, -7.5824, -0.6047], [-1.5412, -1.9719, -4.1460]], dtype=torch.float64, requires_grad=True),
    tensor([1, 1]),
    None,
    0,
    -100
)
ref_out = aten.nll_loss_forward(*args)

decomp_out = vmap(
    aten.nll_loss_forward.default,
    in_dims=(0, None, None, None, None)
)(args[0].unsqueeze(0), args[1], args[2], args[3], args[4])

torch.testing.assert_close(ref_out[0].unsqueeze(0), decomp_out[0])
torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])

Before this PR:

Traceback (most recent call last):
  File "/tmp/fth/repro_nll_loss_issue.py", line 73, in <module>
    torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])
  File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1317, in assert_close
    assert_equal(
  File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1086, in assert_equal
    raise error_metas[0].to_error()
AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2.0 at index (0,) (up to 1e-07 allowed)
Greatest relative difference: 1.0 at index (0,) (up to 1e-07 allowed)

PR is tested with (as right now there is no way to check total_weight on CI)

for ignore_value in [-100, 0, 2]:
    for reduction in [0, 1, 2]:
        args = (
            tensor([[-4.8270, -7.5824, -0.6047], [-1.5412, -1.9719, -4.1460]], dtype=torch.float64, requires_grad=True),
            tensor([1, 1]),
            None,
            reduction,
            ignore_value
        )
        ref_out = aten.nll_loss_forward(*args)

        decomp_out = vmap(
            aten.nll_loss_forward.default,
            in_dims=(0, None, None, None, None)
        )(args[0].unsqueeze(0), args[1], args[2], args[3], args[4])

        print(ignore_value, reduction, ref_out, decomp_out)
        torch.testing.assert_close(ref_out[0].unsqueeze(0), decomp_out[0])
        torch.testing.assert_close(ref_out[1].unsqueeze(0), decomp_out[1])

Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, LGTM!

Ideally, putting this in Python would be awesome, since we have more comprehensive testing for these kinds of decompositions in PyTorch Core (and also allows us to use this decomposition for things like meta tensors and such).

@vfdev-5
Copy link
Contributor Author

vfdev-5 commented May 27, 2022

@Chillee sure as we discussed that elsewhere, I'll be adding nll_loss_forward to pytorch core for my next task.

EDIT: coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later

@vfdev-5 vfdev-5 marked this pull request as draft May 27, 2022 11:56
@Chillee
Copy link
Contributor

Chillee commented May 31, 2022

coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later

Haha, I've run into this a couple times when porting ops into Python :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants