Skip to content

Commit

Permalink
fix/feat: Add Dynamo-only converter registry (#1944)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Jul 21, 2023
1 parent 95730fe commit 6920876
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 23 deletions.
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._settings import *
from .compile import compile
from .aten_tracer import trace
from .converter_registry import (
DYNAMO_CONVERTERS,
dynamo_tensorrt_converter,
)
from .compile import compile
30 changes: 30 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch


def dynamic_unsupported(node: torch.fx.Node) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

# Check node value itself
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
return False

# Check node arguments individually
if any(
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
for arg in node.args
if isinstance(arg, torch.fx.Node)
):
return False

# Check node keyword arguments individually
if any(
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
return False

return True
17 changes: 9 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/trt_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import tensorrt as trt
import torch
import torch.fx
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata

from torch_tensorrt.fx import CONVERTERS
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
from torch_tensorrt import Input
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.fx.utils import (
Expand Down Expand Up @@ -69,6 +68,7 @@ def __init__(
self.input_specs = input_specs
self.input_specs_iter = 0
self._cur_node_name: Optional[str] = None
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._itensor_to_tensor_meta: Dict[
Expand All @@ -82,14 +82,14 @@ def validate_conversion(self):
missing_converter = set()

for node in self.module.graph.nodes:
if node.op == "call_function" and not CONVERTERS.get(node.target):
if node.op == "call_function" and not CONVERTERS.get(node):
missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
elif node.op == "call_method" and not CONVERTERS.get(node.target):
elif node.op == "call_method" and not CONVERTERS.get(node):
missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
elif node.op == "call_module":
submod = self.fetch_attr(node.target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
if not CONVERTERS.get(submod_type):
if not CONVERTERS.get(node):
missing_converter.add(f"{node.op} {torch.typename(submod_type)}")

return missing_converter
Expand Down Expand Up @@ -226,6 +226,7 @@ def run(

def run_node(self, n):
self._cur_node_name = str(n)
self._cur_node = n
# add "_itensor_to_tensor_meta"
kwargs = dict(n.kwargs)
kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta
Expand Down Expand Up @@ -276,7 +277,7 @@ def call_module(self, target, args, kwargs):
assert isinstance(target, str)
submod = self.fetch_attr(target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
converter = CONVERTERS.get(submod_type)
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
Expand All @@ -287,7 +288,7 @@ def call_module(self, target, args, kwargs):
return converter(self.network, submod, args, kwargs, self._cur_node_name)

def call_function(self, target, args, kwargs):
converter = CONVERTERS.get(target)
converter = CONVERTERS.get(self._cur_node)
if not converter:
raise RuntimeError(
f"Conversion of function {torch.typename(target)} not currently supported!"
Expand All @@ -298,7 +299,7 @@ def call_function(self, target, args, kwargs):

def call_method(self, target, args, kwargs):
assert isinstance(target, str)
converter = CONVERTERS.get(target)
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
Expand Down
Loading

0 comments on commit 6920876

Please sign in to comment.