Skip to content

Commit 647a526

Browse files
authored
[https://nvbugs/5443039][fix] Fix AutoDeploy pattern matcher for torch 2.8 (#7076)
Signed-off-by: Frida Hou <[email protected]>
1 parent cbcea33 commit 647a526

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ def _patch_unsupported_input_tensor():
4343
"""
4444
original_fn = lowering.unsupported_input_tensor
4545

46-
def patched_fn(t: torch.Tensor, parent=None, node=None):
46+
def patched_fn(t: torch.Tensor, *args, **kwargs):
4747
"""Bypass meta tensor check."""
4848
if t.is_meta:
4949
return False
50-
return original_fn(t, parent, node)
50+
return original_fn(
51+
t, *args, **kwargs
52+
) # a generic pass-through of the arguments to accommodate torch side change
5153

5254
lowering.unsupported_input_tensor = patched_fn
5355
try:

0 commit comments

Comments
 (0)