Skip to content

Commit

Permalink
[Unity][Frontend] FX translator returning weights with `keep_params_a…
Browse files Browse the repository at this point in the history
…s_input` (#14197)

PR #14067 introduces the flag `keep_params_as_input` to the FX
translator, in the purpose to handle to model weights outside of the
translated Relax function.

This PR takes a further step, by returning the model weights as
NDArrays when the flag `keep_params_as_input` is true. With this PR, the
translator now can return back the weights upon requested. Otherwise,
after the import we will lose the model weights in the given PyTorch
model.
  • Loading branch information
MasterJH5574 committed Mar 5, 2023
1 parent 70ea70f commit 58db106
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 24 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
"""
Frontends for constructing Relax programs, with the model importers
"""
from . import torch
from .common import ImporterOutput
48 changes: 48 additions & 0 deletions python/tvm/relax/frontend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Commons for Relax frontend."""
from typing import Dict, List, Optional

import tvm


class ImporterOutput:
"""The data structure representing the result of frontend imports.
Attributes
----------
mod : tvm.IRModule
The IRModule imported from frontend.
params : Optional[Dict[str, List[tvm.nd.NDArray]]]
The weights of the imported model, when the weights of the model are
requested to be kept as parameters of functions in the IRModule. (e.g.,
when the `keep_params_as_input` flag of `frontend.torch.from_fx` is set to
True.)
- `params` is defined to be None when not requested.
- The keys of `params` are the names of the Relax functions in the IRModule.
- Each weight tensor is in the form of TVM NDArray on device CPU.
- The order of the returned weights is in accordance with the order of
the kept Relax function input variables.
"""

mod: tvm.IRModule
params: Optional[Dict[str, List[tvm.nd.NDArray]]]

def __init__(self, mod: tvm.IRModule, params: Optional[Dict[str, List[tvm.nd.NDArray]]]):
self.mod = mod
self.params = params
28 changes: 20 additions & 8 deletions python/tvm/relax/frontend/torch/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

import tvm
from tvm.relax import build as relax_build
from tvm.relax.frontend.torch.fx_translator import from_fx

from .fx_translator import from_fx
from ..common import ImporterOutput


def device_from_inputs(example_inputs):
Expand Down Expand Up @@ -72,7 +74,7 @@ def to_tvm_tensor(torch_tensor):

device = device_from_inputs(example_inputs)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
mod = from_fx(graph_module, input_info)
mod = from_fx(graph_module, input_info).mod

if device.type == "cuda":
dev = tvm.cuda(device.index)
Expand Down Expand Up @@ -114,7 +116,7 @@ def exec_tvm(*i_args):
return _relax_backend


def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule:
def dynamo_capture_subgraphs(model, *params, **kwargs) -> ImporterOutput:
"""Capture subgraphs of the PyTorch model using torch.compile into an IRModule.
Parameters
Expand All @@ -125,28 +127,38 @@ def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule:
params : List[torch.Tensor]
The parameters of the PyTorch model.
keep_params_as_input : bool
Whether to keep model parameters as input variables of the captured Relax functions.
Returns
-------
mod : tvm.ir.IRModule
The IRModule that contains captured subgraphs.
output : ImporterOutput
The output of translation, including the translated IRModule, and
the weights of the input model when `keep_params_as_input` is true.
"""
import torch # type: ignore[import]
from torch import fx # type: ignore[import]
from torch import _dynamo as dynamo # type: ignore[import]

keep_params_as_input = "keep_params_as_input" in kwargs and kwargs["keep_params_as_input"]

mod = tvm.IRModule()
params_ndarray = dict() if keep_params_as_input else None

def _capture(graph_module: fx.GraphModule, example_inputs):
assert isinstance(graph_module, torch.fx.GraphModule)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
subgraph = from_fx(graph_module, input_info)
mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"]
trace_output = from_fx(graph_module, input_info, keep_params_as_input)
func_name = f"subgraph_{len(mod.get_global_vars())}"
mod[func_name] = trace_output.mod["main"]
if keep_params_as_input:
params_ndarray[func_name] = trace_output.params["main"]
return graph_module.forward

dynamo.reset()
compiled_model = torch.compile(model, backend=_capture)
compiled_model(*params)
return mod
return ImporterOutput(mod, params_ndarray)


@functools.lru_cache(None)
Expand Down
22 changes: 15 additions & 7 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import tvm
from tvm import relax

from ..common import ImporterOutput


class TorchFXImporter:
"""An importer from PyTorch FX to Relax."""
Expand Down Expand Up @@ -843,7 +845,7 @@ def create_convert_map(self):

def from_fx(
self, model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool
) -> tvm.IRModule:
) -> ImporterOutput:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx

Expand All @@ -860,18 +862,23 @@ def from_fx(
)

# Initialize the block builder with a function and a dataflow block.
func_name = "main"
self.block_builder = relax.BlockBuilder()
if keep_params_as_input:
params_ = []
func_attrs = {"num_input": len(inputs)}
for name, param in model.named_parameters():
shape = param.data.shape
dtype = self._convert_data_type(str(param.data.dtype))
inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype)))
self.params[param] = inputs[-1]
params_.append(tvm.nd.array(param.data.cpu().numpy()))
params = {func_name: params_}
else:
params = None
func_attrs = None

with self.block_builder.function(name="main", params=inputs.copy(), attrs=func_attrs):
with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs):
output = None
with self.block_builder.dataflow():
# Translate model parameters.
Expand Down Expand Up @@ -916,12 +923,12 @@ def from_fx(
assert output is not None
self.block_builder.emit_func_output(output)

return self.block_builder.get()
return ImporterOutput(self.block_builder.get(), params)


def from_fx(
model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool = False
) -> tvm.IRModule:
) -> ImporterOutput:
"""Convert a PyTorch FX GraphModule to a Relax program
Parameters
Expand All @@ -937,8 +944,9 @@ def from_fx(
Returns
-------
module : tvm.IRModule
The converted Relax program.
output : ImporterOutput
The output of translation, including the translated IRModule, and
the weights of the input model when `keep_params_as_input` is true.
Examples
--------
Expand Down Expand Up @@ -981,7 +989,7 @@ def forward(self, input):
raise RuntimeError("Failed to export the PyTorch model to FX.")
# Use the importer to import the PyTorch model to Relax.
mod: tvm.IRModule = from_fx(graph_module, input_info)
mod: tvm.IRModule = from_fx(graph_module, input_info).mod
# Print out the imported model.
print(mod.script())
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_frontend_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def subgraph_0(
return gv

model = Input1()
mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)).mod
binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()}
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
Expand Down Expand Up @@ -190,7 +190,7 @@ def subgraph_1(
R.output(gv1)
return gv1

mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)).mod
tvm.ir.assert_structural_equal(mod, Expected2)


Expand Down
27 changes: 20 additions & 7 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def verify_model(torch_model, input_info, binding, expected, keep_params_as_inpu
from tvm.relax.frontend.torch import from_fx

graph_model = fx.symbolic_trace(torch_model)
mod = from_fx(graph_model, input_info, keep_params_as_input=keep_params_as_input)
mod = from_fx(graph_model, input_info, keep_params_as_input=keep_params_as_input).mod
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
Expand Down Expand Up @@ -1683,7 +1683,7 @@ def forward(self, input):
return torch.arange(0, 20, dtype=torch.int32)

graph_model = fx.symbolic_trace(Arange())
mod = from_fx(graph_model, [([10, 10], "float32")])
mod = from_fx(graph_model, [([10, 10], "float32")]).mod
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant)
Expand All @@ -1707,7 +1707,7 @@ def forward(self, input):
return torch.empty((10, 10), dtype=torch.float32)

graph_model = fx.symbolic_trace(Empty())
mod = from_fx(graph_model, [([10, 10], "float32")])
mod = from_fx(graph_model, [([10, 10], "float32")]).mod
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant)
Expand All @@ -1734,15 +1734,15 @@ def forward(self, input):
return torch.tensor(3)

graph_model1 = fx.symbolic_trace(Empty1())
mod1 = from_fx(graph_model1, [([10, 10], "float32")])
mod1 = from_fx(graph_model1, [([10, 10], "float32")]).mod
assert len(mod1["main"].body.blocks) == 1
assert len(mod1["main"].body.blocks[0].bindings) == 1
assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, relax.Constant)
assert mod1["main"].body.blocks[0].bindings[0].value.data.shape == ()
assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == "float32"

graph_model2 = fx.symbolic_trace(Empty2())
mod2 = from_fx(graph_model2, [([10, 10], "float32")])
mod2 = from_fx(graph_model2, [([10, 10], "float32")]).mod
assert len(mod2["main"].body.blocks) == 1
assert len(mod2["main"].body.blocks[0].bindings) == 1
assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, relax.Constant)
Expand Down Expand Up @@ -2096,7 +2096,9 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype=
@tvm.testing.requires_gpu
def test_keep_params():
import torch
from torch import fx
from torch.nn import Module
from tvm.relax.frontend.torch import from_fx

class Conv2D1(Module):
def __init__(self):
Expand Down Expand Up @@ -2135,8 +2137,19 @@ def main(
return gv

model = Conv2D1()
input_info = [([1, 3, 10, 10], "float32")]
verify_model(model, input_info, {}, expected1, keep_params_as_input=True)
graph_model = fx.symbolic_trace(model)
trace_output = from_fx(graph_model, [([1, 3, 10, 10], "float32")], keep_params_as_input=True)
tvm.ir.assert_structural_equal(trace_output.mod, expected1)
func = trace_output.mod["main"]
params = trace_output.params["main"]

assert len(params) == len(func.params) - 1
for param_var, param_ndarray in zip(func.params[1:], params):
assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape
assert param_var.struct_info.dtype == param_ndarray.dtype

tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().numpy())
tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().numpy())


@tvm.testing.requires_gpu
Expand Down

0 comments on commit 58db106

Please sign in to comment.