diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 1d132c855ed9..b580e1679b90 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -63,13 +63,15 @@ def _convert_data_type(input_type): """converts the PyTorch scalar type input_type to a TVM dtype.""" import torch # type: ignore - input_type = input_type.lower() + input_type = input_type.lower() if isinstance(input_type, str) else input_type if input_type in ["float", "float32", "torch.float32", torch.float32]: return "float32" elif input_type in ["float16", "torch.float16", torch.float16]: return "float16" elif input_type in ["int64", "torch.int64", torch.int64]: return "int64" + elif input_type in ["int32", "torch.int32", torch.int32]: + return "int32" else: raise NotImplementedError("input_type {} is not handled yet".format(input_type)) @@ -134,12 +136,21 @@ def _cos(self, node: fx.node.Node) -> relax.Var: def _sin(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) + def _sigmoid(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]])) + def _sqrt(self, node: fx.node.Node) -> relax.Expr: arg = self.env[node.args[0]] if isinstance(arg, (int, float)): arg = relax.const(arg, "float32") return self.block_builder.emit(relax.op.sqrt(arg)) + def _round(self, node: fx.node.Node) -> relax.Expr: + if "decimals" in node.kwargs and node.kwargs["decimals"] != 0: + raise ValueError("specifying decimals for round is not supported yet") + arg = self.env[node.args[0]] + return self.block_builder.emit(relax.op.round(arg)) + def _add(self, node: fx.node.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): @@ -200,11 +211,93 @@ def _lt(self, node: fx.node.Node) -> relax.Expr: ########## Creation ########## - def _tril(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - return self.block_builder.emit(relax.op.create.tril(x, k)) + def _arange(self, node: fx.node.Node) -> relax.Var: + import torch + import numpy as np + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"])) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + + return relax.const(np.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.node.Node) -> relax.Var: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"])) + return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + + def _inplace_fill(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _tensor(self, node: fx.node.Node) -> relax.Var: + dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float64") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert def _new_ones(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) @@ -238,8 +331,9 @@ def _half(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) def _type(self, node: fx.node.Node) -> relax.Var: - args = self.retrieve_args(node) - return self.block_builder.emit(relax.op.astype(args[0], args[1])) + x = self.env[node.args[0]] + dtype = self._convert_data_type(node.args[1]) + return self.block_builder.emit(relax.op.astype(x, dtype)) ########## Linear Algebra ########## @@ -313,12 +407,35 @@ def _split(self, node: fx.node.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) + def _chunk(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 2: + dim = node.args[2] + else: + dim = 0 + return self.block_builder.emit(relax.op.split(x, chunks, dim)) + def _transpose(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) full_idx = list(range(len(self.shape_of(args[0])))) full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + def _squeeze(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] + else: + dim = None + return self.block_builder.emit(relax.op.squeeze(x, dim)) + ########## Search ########## def _argmax_argmin(self, op: Callable) -> Callable: @@ -521,7 +638,16 @@ def _embedding(self, node: fx.node.Node) -> relax.Var: module = self.named_modules[node.target] weight = self.params[module.weight] x = self.block_builder.emit(relax.op.astype(x, "int32")) - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) def _interpolate(self, node: fx.node.Node) -> relax.Var: # torch.nn.functional.interpolate( @@ -620,6 +746,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: while i < len(shape): begin.append(0) end.append(shape[i]) + stride.append(1) axes.append(i) i = i + 1 sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) @@ -627,6 +754,9 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: for i in expand_dim: sliced_shape.insert(i, 1) return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + elif isinstance(x, relax.Constant): + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) else: assert False @@ -660,24 +790,37 @@ def create_convert_map(self): "mul": self._mul, "sub": self._sub, "pow": self._pow, + "sigmoid": self._sigmoid, "sqrt": self._sqrt, + "round": self._round, "lt": self._lt, "truediv": self._truediv, + "fill_": self._inplace_fill, "new_ones": self._new_ones, - "tril": self._tril, + "arange": self._arange, + "empty": self._empty, + "tensor": self._tensor, + "tril": self._tril_triu(relax.op.tril), + "triu": self._tril_triu(relax.op.triu), + "tril_": self._inplace_tril_triu(relax.op.tril), + "triu_": self._inplace_tril_triu(relax.op.triu), "sum": self._sum, "float": self._float, "half": self._half, "type": self._type, + "astype": self._type, "matmul": self._matmul, "addmm": self._addmm, + "bmm": self._matmul, "cat": self._cat, "expand": self._expand, "flatten": self._flatten, "permute": self._permute, "reshape": self._reshape, "split": self._split, + "chunk": self._chunk, "transpose": self._transpose, + "squeeze": self._squeeze, "unsqueeze": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), @@ -685,6 +828,7 @@ def create_convert_map(self): "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), "softmax": self._softmax, + "dropout": lambda node: self.env[node.args[0]], "clamp": self._clamp, "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), "gelu": lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])), @@ -693,6 +837,7 @@ def create_convert_map(self): "getattr": self._getattr, "getitem": self._getitem, "contiguous": lambda node: self.env[node.args[0]], + "to": lambda node: self.env[node.args[0]], "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), } diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 84fc97be27dc..9ab0b3304c0d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -222,6 +222,45 @@ def main( ) +@tvm.testing.requires_gpu +def test_bmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class BMM(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((4, 128, 256), dtype="float32"), + input_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + BMM(), + [((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected, + ) + + @tvm.testing.requires_gpu def test_relu(): import torch @@ -576,7 +615,7 @@ def test_dropout(): input_info = [([1, 3, 10, 10], "float32")] - class Dropout(Module): + class Dropout1(Module): def __init__(self): super().__init__() self.dropout = torch.nn.Dropout(0.5) @@ -584,6 +623,10 @@ def __init__(self): def forward(self, input): return self.dropout(input) + class Dropout2(Module): + def forward(self, input): + return torch.dropout(input, 0.5, train=True) + @tvm.script.ir_module class expected1: @R.function @@ -596,7 +639,8 @@ def main( R.output(gv) return gv - verify_model(Dropout(), input_info, {}, expected1) + verify_model(Dropout1(), input_info, {}, expected1) + verify_model(Dropout2(), input_info, {}, expected1) @tvm.testing.requires_gpu @@ -1078,6 +1122,52 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 1 verify_model(Size(), input_info, {}, expected1) +@tvm.testing.requires_gpu +def test_squeeze(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([3, 1, 4, 1], "float32")] + + class Squeeze1(Module): + def forward(self, input): + return input.squeeze(1) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tensor((3, 4, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) + gv: R.Tensor((3, 4, 1), dtype="float32") = lv + R.output(gv) + return gv + + class Squeeze2(Module): + def forward(self, input): + return input.squeeze() + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tensor((3, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + gv: R.Tensor((3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Squeeze1(), input_info, {}, Expected1) + verify_model(Squeeze2(), input_info, {}, Expected2) + + @tvm.testing.requires_gpu def test_unsqueeze(): import torch @@ -1260,6 +1350,46 @@ def main( verify_model(Sqrt(), input_info, {}, expected3) + # sigmoid + class Sigmoid(Module): + def forward(self, input): + return torch.sigmoid(input) + + @tvm.script.ir_module + class expected4: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sigmoid(), input_info, {}, expected4) + + # round + class Round(Module): + def forward(self, input): + return torch.round(input) + + @tvm.script.ir_module + class expected5: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Round(), input_info, {}, expected5) + @tvm.testing.requires_gpu def test_gelu(): @@ -1467,6 +1597,159 @@ def main( verify_model(Split(), input_info, {}, expected1) +@tvm.testing.requires_gpu +def test_chunk(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Chunk(Module): + def forward(self, input): + return torch.chunk(input, 3, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Chunk(), input_info, {}, Expected) + + +@tvm.testing.requires_gpu +def test_inplace_fill(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class InplaceFill(Module): + def forward(self, input): + input.fill_(1.5) + return input + + @tvm.script.ir_module + class Expected: + @R.function + def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected) + + +@tvm.testing.requires_gpu +def test_arange(): + import numpy as np + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Arange(Module): + def forward(self, input): + return torch.arange(0, 20, dtype=torch.int32) + + graph_model = fx.symbolic_trace(Arange()) + mod = from_fx(graph_model, [([10, 10], "float32")]) + assert len(mod["main"].body.blocks) == 1 + assert len(mod["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) + tvm.testing.assert_allclose( + mod["main"].body.blocks[0].bindings[0].value.data.numpy(), np.arange(0, 20, dtype="int32") + ) + + +@tvm.testing.requires_gpu +def test_empty(): + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Empty(Module): + def forward(self, input): + return torch.empty((10, 10), dtype=torch.float32) + + graph_model = fx.symbolic_trace(Empty()) + mod = from_fx(graph_model, [([10, 10], "float32")]) + assert len(mod["main"].body.blocks) == 1 + assert len(mod["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) + assert mod["main"].body.blocks[0].bindings[0].value.data.shape == (10, 10) + assert mod["main"].body.blocks[0].bindings[0].value.data.dtype == "float32" + + +@tvm.testing.requires_gpu +def test_tensor(): + import torch + from torch import fx + from torch.nn import Module + from tvm.relax.frontend.torch import from_fx + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Empty1(Module): + def forward(self, input): + return torch.tensor(3, dtype=torch.float32) + + class Empty2(Module): + def forward(self, input): + return torch.tensor(3) + + graph_model1 = fx.symbolic_trace(Empty1()) + mod1 = from_fx(graph_model1, [([10, 10], "float32")]) + assert len(mod1["main"].body.blocks) == 1 + assert len(mod1["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, relax.Constant) + assert mod1["main"].body.blocks[0].bindings[0].value.data.shape == () + assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == "float32" + + graph_model2 = fx.symbolic_trace(Empty2()) + mod2 = from_fx(graph_model2, [([10, 10], "float32")]) + assert len(mod2["main"].body.blocks) == 1 + assert len(mod2["main"].body.blocks[0].bindings) == 1 + assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, relax.Constant) + assert mod2["main"].body.blocks[0].bindings[0].value.data.shape == () + assert mod2["main"].body.blocks[0].bindings[0].value.data.dtype == "int64" + + @tvm.testing.requires_gpu def test_tril(): import torch @@ -1481,6 +1764,11 @@ class Tril(Module): def forward(self, input): return torch.tril(input, 1) + class InplaceTril(Module): + def forward(self, input): + input.tril_(1) + return input + @tvm.script.ir_module class expected1: @R.function @@ -1495,6 +1783,43 @@ def main( return gv verify_model(Tril(), input_info, {}, expected1) + verify_model(InplaceTril(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_triu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([10, 10], "float32")] + + class Triu(Module): + def forward(self, input): + return torch.triu(input, 1) + + class InplaceTriu(Module): + def forward(self, input): + input.triu_(1) + return input + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Triu(), input_info, {}, expected1) + verify_model(InplaceTriu(), input_info, {}, expected1) @tvm.testing.requires_gpu @@ -1589,7 +1914,7 @@ def main( @tvm.testing.requires_gpu -def test_to(): +def test_datatype(): import torch from torch.nn import Module @@ -1638,6 +1963,19 @@ def main( verify_model(ToHalf(), input_info, {}, expected2) + # type + class Type(Module): + def forward(self, x): + return x.type(torch.float32) + + # astype + class AsType(Module): + def forward(self, x): + return x.astype(torch.float32) + + verify_model(Type(), input_info, {}, expected1) + verify_model(AsType(), input_info, {}, expected1) + @tvm.testing.requires_gpu def test_permute():