-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ThunderFX's splitter looks tad conservative for custom torch.autograd.Function
s by pushing them to the fallback path
#1539
Comments
torch.autograd.Function
s by pushing them to the fallback pathtorch.autograd.Function
s by pushing them to the fallback path
lightning-thunder/thunder/dynamo/splitter.py Line 121 in 9de5434
The splitter decides whether the operator is supported by thunder by a pre-run, it seems there's an operator not registered in thunder torch.autograd.function.FunctionCtx
|
It doesn't feel like the missing lightning-thunder/thunder/dynamo/utils.py Line 349 in 9de5434
torch.ops.higher_order.autograd_function_apply which has a lookaside for it in lightning-thunder/thunder/core/jit_ext.py Line 777 in 9de5434
|
This is the Dynamo FXGraph that we see class GraphModule(torch.nn.Module):
def forward(self, L_x_: "bf16[2, 2]"):
l_x_ = L_x_
# File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:227 in forward, code: return MyFunction.apply(x)
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "bf16[2, 2]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
return (autograd_function_apply,)
class fwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, x: "bf16[2, 2]"):
# File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:214 in forward, code: return x * 2.0
mul: "bf16[2, 2]" = x * 2.0; x = None
return (mul, [])
class bwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, grad_output: "bf16[2, 2]"):
# No stacktrace found for following nodes
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return grad_output I see 2 split reasons: One related to [SplitReason(reason_type=<SplitReasonType.MISSING_OP_SUPPORT: 2>, info="node with name: function_ctx and target: <class 'torch.autograd.function.FunctionCtx'> didn't have any mapping in thunder.", exception=None),
SplitReason(reason_type=<SplitReasonType.EXCEPTION_META_THUNDER_OP: 4>, info='Failed while running meta for node with name: autograd_function_apply and target: autograd_function_apply, see exception field', exception=TypeError("'Node' object is not callable"))] For lightning-thunder/thunder/dynamo/utils.py Line 147 in 9de5434
|
🐛 Bug
ThunderFX does not seem to support some custom
torch.autograd.Function
s thatthunder.jit
could handle.To Reproduce
Simpler torch.compile custom backend calling
thunder.jit
handles the following snippet (the program is the same aslightning-thunder/thunder/tests/test_jit_general.py
Line 1144 in 9de5434
ThunderFX however seems to have the fallback handle the program as in #1538.
What follows is the output of the snippet above.
Expected behavior
Environment
conda
,pip
, source):Additional context
The text was updated successfully, but these errors were encountered: