Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX's splitter looks tad conservative for custom torch.autograd.Functions by pushing them to the fallback path #1539

Open
crcrpar opened this issue Dec 11, 2024 · 3 comments · May be fixed by #1548
Assignees
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Dec 11, 2024

🐛 Bug

ThunderFX does not seem to support some custom torch.autograd.Functions that thunder.jit could handle.

To Reproduce

Simpler torch.compile custom backend calling thunder.jit handles the following snippet (the program is the same as

class MyFunction(torch.autograd.Function):
)
ThunderFX however seems to have the fallback handle the program as in #1538.

import torch

import thunder
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType


class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        return x * 2.0

    # this is wrong on purpose.
    @staticmethod
    def backward(ctx, grad_output) -> torch.Tensor:
        return grad_output


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x) -> torch.Tensor:
        return MyFunction.apply(x)


def thunderfx_no_fallback(gm: torch.fx.GraphModule, sample_args):
    remove_empty_autocast(gm)
    recompile_graph(gm)
    jitted = thunder.jit(gm)
    jitted_fns.append(jitted)
    return jitted


jitted_fns = []


if __name__ == "__main__":

    device = torch.device("cuda")
    dtype = torch.bfloat16

    model = Model().to(device=device, dtype=dtype)
    jitted = torch.compile(model, backend=thunderfx_no_fallback)

    x = torch.randn((2, 2), device=device, dtype=dtype, requires_grad=True)
    out = jitted(x)

    jitted = jitted_fns[0]
    print(thunder.last_traces(jitted)[-1])
    print(thunder.last_backward_traces(jitted)[-1])

What follows is the output of the snippet above.

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(L_x_):
  # L_x_: "cuda:0 bf16[2, 2]"
  [autograd_function_apply] = nvFusion0(L_x_)
    # t11 = prims.convert_element_type(L_x_, dtypes.float32)  # t11: "cuda:0 f32[2, 2]"
    # t12 = prims.mul(t11, 2.0)  # t12: "cuda:0 f32[2, 2]"
    # autograd_function_apply = prims.convert_element_type(t12, dtypes.bfloat16)  # autograd_function_apply: "cuda:0 bf16[2, 2]"
  return {'output': (autograd_function_apply,), 'flat_args': [L_x_], 'flat_output': (autograd_function_apply,)}, ((), ())
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # cotangents: "Collection"
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  [t17] = nvFusion0(t0)
    # t14 = prims.convert_element_type(t0, dtypes.float32)  # t14: "cuda:0 f32[2, 2]"
    # t16 = prims.mul(2.0, t14)  # t16: "cuda:0 f32[2, 2]"
    # t17 = prims.convert_element_type(t16, dtypes.bfloat16)  # t17: "cuda:0 bf16[2, 2]"
  del t0
  return (t17,)

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@crcrpar crcrpar added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Dec 11, 2024
@crcrpar crcrpar changed the title ThunderFX looks tad conservative for custom torch.autograd.Functions by pushing them to the fallback path ThunderFX's splitter looks tad conservative for custom torch.autograd.Functions by pushing them to the fallback path Dec 11, 2024
@kiya00
Copy link
Collaborator

kiya00 commented Dec 11, 2024

is_thunder_supported, split_reason = is_node_supported_by_thunder(node)

The splitter decides whether the operator is supported by thunder by a pre-run, it seems there's an operator not registered in thunder torch.autograd.function.FunctionCtx

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 11, 2024

It doesn't feel like the missing FunctionCtx is the major cause in my gut feeling.
The target (meaning one in

target = node.target # Target is the function to call.
) should be torch.ops.higher_order.autograd_function_apply which has a lookaside for it in
def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_args, **fwd_kwargs):
.

@kshitij12345
Copy link
Collaborator

This is the Dynamo FXGraph that we see

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "bf16[2, 2]"):
        l_x_ = L_x_
        
         # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:227 in forward, code: return MyFunction.apply(x)
        function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
        fwd_body_0 = self.fwd_body_0
        bwd_body_0 = self.bwd_body_0
        autograd_function_apply: "bf16[2, 2]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []);  fwd_body_0 = bwd_body_0 = l_x_ = None
        return (autograd_function_apply,)
        
    class fwd_body_0(torch.nn.Module):
        def forward(self, ctx : torch.autograd.function.Function, x: "bf16[2, 2]"):
             # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:214 in forward, code: return x * 2.0
            mul: "bf16[2, 2]" = x * 2.0;  x = None
            return (mul, [])
            
    class bwd_body_0(torch.nn.Module):
        def forward(self, ctx : torch.autograd.function.Function, grad_output: "bf16[2, 2]"):
            # No stacktrace found for following nodes
            _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
            _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
            return grad_output

I see 2 split reasons: One related to torch.autograd.function.FunctionCtx and other related to torch.ops.higher_order.autograd_function_apply.

[SplitReason(reason_type=<SplitReasonType.MISSING_OP_SUPPORT: 2>, info="node with name: function_ctx and target: <class 'torch.autograd.function.FunctionCtx'> didn't have any mapping in thunder.", exception=None),

SplitReason(reason_type=<SplitReasonType.EXCEPTION_META_THUNDER_OP: 4>, info='Failed while running meta for node with name: autograd_function_apply and target: autograd_function_apply, see exception field', exception=TypeError("'Node' object is not callable"))]

For autograd_function_apply, the problem is that get_proxy_inputs_from_node doesn't correctly handle the case when inputs could be callable (like fwd_body_0 or bwd_body_0 in this case). If the input node has op=get_attr, then we should try to fetch the correct attribute it points.

def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants