Skip to content

Commit

Permalink
chore: updates
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Oct 28, 2024
1 parent fca16a5 commit df13856
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
6 changes: 5 additions & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
});

TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]");
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
Expand Down Expand Up @@ -166,6 +166,10 @@ TORCH_LIBRARY(tensorrt, m) {
});
}

TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) {
m.impl("execute_engine", execute_engine);
}

} // namespace
} // namespace runtime
} // namespace core
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._decomp import _decomp_table_to_post_autograd_aten, register_decomposition
from torch._decomp import register_decomposition
from torch._export.utils import _decomp_table_to_post_autograd_aten
from torch._ops import OpOverload
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
TorchTensorRTModule,
)
from torch_tensorrt.dynamo.runtime.register_fake_class import FakeTRTEngine
from torch_tensorrt.dynamo.runtime.register_fake_class import *
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@torch.library.register_fake("tensorrt::execute_engine")
def fake_execute_engine(inputs, trt_engine):
def execute_engine(inputs, trt_engine):
breakpoint()
return trt_engine(inputs)

Expand Down

0 comments on commit df13856

Please sign in to comment.