diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 505a6a81f7..8f67cf56f1 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -119,7 +119,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = }); TORCH_LIBRARY(tensorrt, m) { - m.def("execute_engine", execute_engine); + m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]"); m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); }); m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; }); m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; }); @@ -166,6 +166,10 @@ TORCH_LIBRARY(tensorrt, m) { }); } +TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) { + m.impl("execute_engine", execute_engine); +} + } // namespace } // namespace runtime } // namespace core diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 534bc3eac5..d2bfeb501a 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -3,7 +3,8 @@ from typing import Any, Callable, Dict, List, Optional import torch -from torch._decomp import _decomp_table_to_post_autograd_aten, register_decomposition +from torch._decomp import register_decomposition +from torch._export.utils import _decomp_table_to_post_autograd_aten from torch._ops import OpOverload from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index a72f26a36e..762345d50f 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -4,4 +4,4 @@ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401 TorchTensorRTModule, ) -from torch_tensorrt.dynamo.runtime.register_fake_class import FakeTRTEngine +from torch_tensorrt.dynamo.runtime.register_fake_class import * diff --git a/py/torch_tensorrt/dynamo/runtime/register_fake_class.py b/py/torch_tensorrt/dynamo/runtime/register_fake_class.py index 2d25ddbd4f..73a38d5dd1 100644 --- a/py/torch_tensorrt/dynamo/runtime/register_fake_class.py +++ b/py/torch_tensorrt/dynamo/runtime/register_fake_class.py @@ -5,7 +5,7 @@ @torch.library.register_fake("tensorrt::execute_engine") -def fake_execute_engine(inputs, trt_engine): +def execute_engine(inputs, trt_engine): breakpoint() return trt_engine(inputs)