diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 00063c3e21..f9e9464a22 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -243,7 +243,7 @@ def remove_ops( module: torch.fx.GraphModule, ) -> torch.fx.GraphModule: """ - 1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable + 1. Remove clone, _unsafe_view, view node. #TODO Remove this func after functionalization is workable 2. Remove inefficient op getitem(index=slice) P561572458 """ modified = False @@ -258,6 +258,7 @@ def remove_ops( for n in module.graph.nodes: if n.op == "call_function" and n.target in ( torch.ops.aten._unsafe_view.default, + torch.ops.aten.view.default, ): modified = True node = n @@ -437,8 +438,10 @@ def compose_bmm( real_other = input_other_n.all_input_nodes[0] if len(real_other.meta["val"].size()) == 2: new_func = aten_compose_bmm_2d - if len(real_other.meta["val"].size()) == 3: + elif len(real_other.meta["val"].size()) == 3: new_func = aten_compose_bmm_3d + else: + new_func = torch.ops.aten.matmul with module.graph.inserting_after(node): new_args = (real_input, real_other)