diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py similarity index 100% rename from py/torch_tensorrt/dynamo/backend/_defaults.py rename to py/torch_tensorrt/dynamo/_defaults.py diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 38e60fce41..4a48a29308 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -10,7 +10,7 @@ from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend -from torch_tensorrt.dynamo.backend._defaults import ( +from torch_tensorrt.dynamo._defaults import ( PRECISION, DEBUG, WORKSPACE_SIZE, diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b97079948e..7afa821498 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -4,7 +4,7 @@ from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.common import CompilationSettings from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 425fb0941e..a4e25c5231 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -2,8 +2,8 @@ import torch import io from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import ( +from torch_tensorrt.dynamo.common import ( + CompilationSettings, InputTensorSpec, TRTInterpreter, ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 4d82bf4be5..2158a940aa 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -3,8 +3,8 @@ import torch -from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name diff --git a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py index 2af251adbc..678a5d9662 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py @@ -4,7 +4,7 @@ from copy import deepcopy from torch_tensorrt.dynamo import compile from utils import lower_graph_testing -from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT +from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT class TestTRTModuleNextCompilation(TestCase): diff --git a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py index d947c955e0..340797aa69 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py @@ -3,7 +3,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase import torch from torch_tensorrt.dynamo import compile -from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT +from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT class TestLowering(TestCase): diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index 23a1cd4795..f51dac3b11 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -2,10 +2,9 @@ import logging from dataclasses import replace, fields -from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.common import CompilationSettings, use_python_runtime_parser from typing import Any, Union, Sequence, Dict from torch_tensorrt import _Input, Device -from ..common_utils import use_python_runtime_parser logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common/__init__.py similarity index 89% rename from py/torch_tensorrt/dynamo/common_utils/__init__.py rename to py/torch_tensorrt/dynamo/common/__init__.py index de0ce0a48a..3cf4f72270 100644 --- a/py/torch_tensorrt/dynamo/common_utils/__init__.py +++ b/py/torch_tensorrt/dynamo/common/__init__.py @@ -1,6 +1,10 @@ import logging from typing import Optional +from ._settings import CompilationSettings +from .input_tensor_spec import InputTensorSpec +from .fx2trt import TRTInterpreter, TRTInterpreterResult + logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/common/_settings.py similarity index 94% rename from py/torch_tensorrt/dynamo/backend/_settings.py rename to py/torch_tensorrt/dynamo/common/_settings.py index d074a6b079..75bcc5428b 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/common/_settings.py @@ -2,7 +2,7 @@ from typing import Optional, Sequence from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.backend._defaults import ( +from torch_tensorrt.dynamo._defaults import ( PRECISION, DEBUG, WORKSPACE_SIZE, diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/common/fx2trt.py similarity index 100% rename from py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py rename to py/torch_tensorrt/dynamo/common/fx2trt.py diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py b/py/torch_tensorrt/dynamo/common/input_tensor_spec.py similarity index 100% rename from py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py rename to py/torch_tensorrt/dynamo/common/input_tensor_spec.py diff --git a/py/torch_tensorrt/dynamo/common_utils/test_utils.py b/py/torch_tensorrt/dynamo/common/test_utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/common_utils/test_utils.py rename to py/torch_tensorrt/dynamo/common/test_utils.py diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py index 85ce01ef20..3c17701d5d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py @@ -6,8 +6,6 @@ NO_IMPLICIT_BATCH_DIM_SUPPORT, tensorrt_converter, ) -from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa -from .input_tensor_spec import InputTensorSpec # noqa from .lower_setting import LowerSetting # noqa from .lower import compile # usort: skip #noqa diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index c0f1ae7870..51724851d4 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -10,11 +10,14 @@ import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer from torch.fx.passes.splitter_base import SplitResult -from .fx2trt import TRTInterpreter, TRTInterpreterResult +from torch_tensorrt.dynamo.common import ( + TRTInterpreter, + TRTInterpreterResult, + use_python_runtime_parser, +) from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder from .passes.pass_utils import PassFunc, validate_inference -from ..common_utils import use_python_runtime_parser from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting @@ -22,6 +25,17 @@ from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + WORKSPACE_SIZE, + MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_PYTHON_RUNTIME, +) logger = logging.getLogger(__name__) @@ -35,24 +49,25 @@ def compile( disable_tf32=False, sparse_weights=False, enabled_precisions=set(), - min_block_size: int = 3, - workspace_size=0, + min_block_size: int = MIN_BLOCK_SIZE, + workspace_size=WORKSPACE_SIZE, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, calibrator=None, truncate_long_and_double=False, require_full_compilation=False, - debug=False, + explicit_batch_dimension=False, + debug=DEBUG, refit=False, timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, is_aten=False, - use_python_runtime=None, - max_aux_streams=None, - version_compatible=False, - optimization_level=None, + use_python_runtime=USE_PYTHON_RUNTIME, + max_aux_streams=MAX_AUX_STREAMS, + version_compatible=VERSION_COMPATIBLE, + optimization_level=OPTIMIZATION_LEVEL, num_avg_timing_iters=1, torch_executed_ops=[], torch_executed_modules=[], diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index 9301a2cd90..1f377efe78 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -4,7 +4,7 @@ from torch import nn from torch.fx.passes.pass_manager import PassManager -from .input_tensor_spec import InputTensorSpec +from torch_tensorrt.dynamo.common import InputTensorSpec from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, fuse_permute_matmul, diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py index 0fd3777254..db7b8b9e4a 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py @@ -10,7 +10,7 @@ from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt import _Input -from ..input_tensor_spec import InputTensorSpec +from torch_tensorrt.dynamo.common import InputTensorSpec from ..lower_setting import LowerSetting from torch_tensorrt.fx.observer import Observer diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py index 0761b964f8..0575d55660 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py @@ -5,7 +5,7 @@ import torch import torch_tensorrt from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting +from torch_tensorrt.dynamo.common import InputTensorSpec class TestTRTModule(TestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py index 334243fef4..a1857a7677 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py @@ -13,7 +13,7 @@ from torch.fx.passes import shape_prop from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter +from torch_tensorrt.dynamo.common import InputTensorSpec, TRTInterpreter from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index f34aad6caf..fc9bf634c2 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -8,7 +8,7 @@ from transformers import BertModel -from torch_tensorrt.dynamo.common_utils.test_utils import ( +from torch_tensorrt.dynamo.common.test_utils import ( COSINE_THRESHOLD, cosine_similarity, )