diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py index 6c9c188aaad0..f3c0ed23ebb1 100644 --- a/python/tvm/relax/frontend/__init__.py +++ b/python/tvm/relax/frontend/__init__.py @@ -17,3 +17,5 @@ """ Frontends for constructing Relax programs, with the model importers """ +from . import torch +from .common import ImporterOutput diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py new file mode 100644 index 000000000000..cdb88cd12c08 --- /dev/null +++ b/python/tvm/relax/frontend/common.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Commons for Relax frontend.""" +from typing import Dict, List, Optional + +import tvm + + +class ImporterOutput: + """The data structure representing the result of frontend imports. + + Attributes + ---------- + mod : tvm.IRModule + The IRModule imported from frontend. + + params : Optional[Dict[str, List[tvm.nd.NDArray]]] + The weights of the imported model, when the weights of the model are + requested to be kept as parameters of functions in the IRModule. (e.g., + when the `keep_params_as_input` flag of `frontend.torch.from_fx` is set to + True.) + - `params` is defined to be None when not requested. + - The keys of `params` are the names of the Relax functions in the IRModule. + - Each weight tensor is in the form of TVM NDArray on device CPU. + - The order of the returned weights is in accordance with the order of + the kept Relax function input variables. + """ + + mod: tvm.IRModule + params: Optional[Dict[str, List[tvm.nd.NDArray]]] + + def __init__(self, mod: tvm.IRModule, params: Optional[Dict[str, List[tvm.nd.NDArray]]]): + self.mod = mod + self.params = params diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index 589c6be3b5b5..3f30044bb8b9 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -24,7 +24,9 @@ import tvm from tvm.relax import build as relax_build -from tvm.relax.frontend.torch.fx_translator import from_fx + +from .fx_translator import from_fx +from ..common import ImporterOutput def device_from_inputs(example_inputs): @@ -72,7 +74,7 @@ def to_tvm_tensor(torch_tensor): device = device_from_inputs(example_inputs) input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] - mod = from_fx(graph_module, input_info) + mod = from_fx(graph_module, input_info).mod if device.type == "cuda": dev = tvm.cuda(device.index) @@ -114,7 +116,7 @@ def exec_tvm(*i_args): return _relax_backend -def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule: +def dynamo_capture_subgraphs(model, *params, **kwargs) -> ImporterOutput: """Capture subgraphs of the PyTorch model using torch.compile into an IRModule. Parameters @@ -125,28 +127,38 @@ def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule: params : List[torch.Tensor] The parameters of the PyTorch model. + keep_params_as_input : bool + Whether to keep model parameters as input variables of the captured Relax functions. + Returns ------- - mod : tvm.ir.IRModule - The IRModule that contains captured subgraphs. + output : ImporterOutput + The output of translation, including the translated IRModule, and + the weights of the input model when `keep_params_as_input` is true. """ import torch # type: ignore[import] from torch import fx # type: ignore[import] from torch import _dynamo as dynamo # type: ignore[import] + keep_params_as_input = "keep_params_as_input" in kwargs and kwargs["keep_params_as_input"] + mod = tvm.IRModule() + params_ndarray = dict() if keep_params_as_input else None def _capture(graph_module: fx.GraphModule, example_inputs): assert isinstance(graph_module, torch.fx.GraphModule) input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] - subgraph = from_fx(graph_module, input_info) - mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"] + trace_output = from_fx(graph_module, input_info, keep_params_as_input) + func_name = f"subgraph_{len(mod.get_global_vars())}" + mod[func_name] = trace_output.mod["main"] + if keep_params_as_input: + params_ndarray[func_name] = trace_output.params["main"] return graph_module.forward dynamo.reset() compiled_model = torch.compile(model, backend=_capture) compiled_model(*params) - return mod + return ImporterOutput(mod, params_ndarray) @functools.lru_cache(None) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index b580e1679b90..a73bc9d0db8c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -24,6 +24,8 @@ import tvm from tvm import relax +from ..common import ImporterOutput + class TorchFXImporter: """An importer from PyTorch FX to Relax.""" @@ -843,7 +845,7 @@ def create_convert_map(self): def from_fx( self, model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool - ) -> tvm.IRModule: + ) -> ImporterOutput: """Convert a PyTorch FX GraphModule to a Relax program.""" from torch import fx @@ -860,18 +862,23 @@ def from_fx( ) # Initialize the block builder with a function and a dataflow block. + func_name = "main" self.block_builder = relax.BlockBuilder() if keep_params_as_input: + params_ = [] func_attrs = {"num_input": len(inputs)} for name, param in model.named_parameters(): shape = param.data.shape dtype = self._convert_data_type(str(param.data.dtype)) inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) self.params[param] = inputs[-1] + params_.append(tvm.nd.array(param.data.cpu().numpy())) + params = {func_name: params_} else: + params = None func_attrs = None - with self.block_builder.function(name="main", params=inputs.copy(), attrs=func_attrs): + with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): output = None with self.block_builder.dataflow(): # Translate model parameters. @@ -916,12 +923,12 @@ def from_fx( assert output is not None self.block_builder.emit_func_output(output) - return self.block_builder.get() + return ImporterOutput(self.block_builder.get(), params) def from_fx( model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool = False -) -> tvm.IRModule: +) -> ImporterOutput: """Convert a PyTorch FX GraphModule to a Relax program Parameters @@ -937,8 +944,9 @@ def from_fx( Returns ------- - module : tvm.IRModule - The converted Relax program. + output : ImporterOutput + The output of translation, including the translated IRModule, and + the weights of the input model when `keep_params_as_input` is true. Examples -------- @@ -981,7 +989,7 @@ def forward(self, input): raise RuntimeError("Failed to export the PyTorch model to FX.") # Use the importer to import the PyTorch model to Relax. - mod: tvm.IRModule = from_fx(graph_module, input_info) + mod: tvm.IRModule = from_fx(graph_module, input_info).mod # Print out the imported model. print(mod.script()) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index b47e3e22bd71..14d1e48fb5ec 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -147,7 +147,7 @@ def subgraph_0( return gv model = Input1() - mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)) + mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)).mod binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} binding = {k: tvm.nd.array(v) for k, v in binding.items()} expected = relax.transform.BindParams("subgraph_0", binding)(Expected1) @@ -190,7 +190,7 @@ def subgraph_1( R.output(gv1) return gv1 - mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)) + mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)).mod tvm.ir.assert_structural_equal(mod, Expected2) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 9ab0b3304c0d..e28483dc2fab 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -27,7 +27,7 @@ def verify_model(torch_model, input_info, binding, expected, keep_params_as_inpu from tvm.relax.frontend.torch import from_fx graph_model = fx.symbolic_trace(torch_model) - mod = from_fx(graph_model, input_info, keep_params_as_input=keep_params_as_input) + mod = from_fx(graph_model, input_info, keep_params_as_input=keep_params_as_input).mod binding = {k: tvm.nd.array(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) tvm.ir.assert_structural_equal(mod, expected) @@ -2096,7 +2096,9 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype= @tvm.testing.requires_gpu def test_keep_params(): import torch + from torch import fx from torch.nn import Module + from tvm.relax.frontend.torch import from_fx class Conv2D1(Module): def __init__(self): @@ -2135,8 +2137,19 @@ def main( return gv model = Conv2D1() - input_info = [([1, 3, 10, 10], "float32")] - verify_model(model, input_info, {}, expected1, keep_params_as_input=True) + graph_model = fx.symbolic_trace(model) + trace_output = from_fx(graph_model, [([1, 3, 10, 10], "float32")], keep_params_as_input=True) + tvm.ir.assert_structural_equal(trace_output.mod, expected1) + func = trace_output.mod["main"] + params = trace_output.params["main"] + + assert len(params) == len(func.params) - 1 + for param_var, param_ndarray in zip(func.params[1:], params): + assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape + assert param_var.struct_info.dtype == param_ndarray.dtype + + tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().numpy()) + tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().numpy()) @tvm.testing.requires_gpu