Skip to content

Commit

Permalink
[Unity][Frontend] FX translator supporting more ops (#14196)
Browse files Browse the repository at this point in the history
This PR improves the torch FX translator in the following perspectives:
* support unary op `sigmoid` and `round`,
* support in-place `fill`, `triu` and `tril`,
* support `tensor`, `arange`, `empty`,
* support `bmm` (batch matrix multiplication),
* support `astype`,
* support `chunk` and `squeeze`.

This PR also fixes `Embedding`. Previously the translation assumes that
the input to Embedding will only be 1-dimensional, and will throw
exception when the input has more than one dimension (i.e., batched).
This PR brings the support.
  • Loading branch information
MasterJH5574 authored Mar 5, 2023
1 parent 0c64959 commit 70ea70f
Show file tree
Hide file tree
Showing 2 changed files with 496 additions and 13 deletions.
165 changes: 155 additions & 10 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def _convert_data_type(input_type):
"""converts the PyTorch scalar type input_type to a TVM dtype."""
import torch # type: ignore

input_type = input_type.lower()
input_type = input_type.lower() if isinstance(input_type, str) else input_type
if input_type in ["float", "float32", "torch.float32", torch.float32]:
return "float32"
elif input_type in ["float16", "torch.float16", torch.float16]:
return "float16"
elif input_type in ["int64", "torch.int64", torch.int64]:
return "int64"
elif input_type in ["int32", "torch.int32", torch.int32]:
return "int32"
else:
raise NotImplementedError("input_type {} is not handled yet".format(input_type))

Expand Down Expand Up @@ -134,12 +136,21 @@ def _cos(self, node: fx.node.Node) -> relax.Var:
def _sin(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))

def _sigmoid(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))

def _sqrt(self, node: fx.node.Node) -> relax.Expr:
arg = self.env[node.args[0]]
if isinstance(arg, (int, float)):
arg = relax.const(arg, "float32")
return self.block_builder.emit(relax.op.sqrt(arg))

def _round(self, node: fx.node.Node) -> relax.Expr:
if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
raise ValueError("specifying decimals for round is not supported yet")
arg = self.env[node.args[0]]
return self.block_builder.emit(relax.op.round(arg))

def _add(self, node: fx.node.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
Expand Down Expand Up @@ -200,11 +211,93 @@ def _lt(self, node: fx.node.Node) -> relax.Expr:

########## Creation ##########

def _tril(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
k = node.args[1] if len(node.args) > 1 else 0
assert isinstance(k, int)
return self.block_builder.emit(relax.op.create.tril(x, k))
def _arange(self, node: fx.node.Node) -> relax.Var:
import torch
import numpy as np

start_end_step = [None, None, None]
if "start" in node.kwargs:
start_end_step[0] = node.kwargs["start"]
if "end" in node.kwargs:
start_end_step[1] = node.kwargs["end"]
if "step" in node.kwargs:
start_end_step[2] = node.kwargs["step"]

if len(node.args) == 1:
assert start_end_step[1] is None
start_end_step[1] = node.args[0]
elif len(node.args) == 2:
assert start_end_step[0] is None
assert start_end_step[1] is None
start_end_step[0] = node.args[0]
start_end_step[1] = node.args[1]
elif len(node.args) == 3:
assert start_end_step[0] is None
assert start_end_step[1] is None
assert start_end_step[2] is None
start_end_step[0] = node.args[0]
start_end_step[1] = node.args[1]
start_end_step[2] = node.args[2]

if start_end_step[0] is None:
start_end_step[0] = 0
if start_end_step[2] is None:
start_end_step[2] = 1

if "dtype" in node.kwargs:
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
elif any([isinstance(x, float) for x in start_end_step]):
dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype())
else:
dtype = "int64"

return relax.const(np.arange(*start_end_step, dtype=dtype))

def _empty(self, node: fx.node.Node) -> relax.Var:
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
return self.block_builder.emit(relax.op.zeros(node.args, dtype))

def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))
self.env[node.args[0]] = filled
return filled

def _tensor(self, node: fx.node.Node) -> relax.Var:
dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
if isinstance(node.args[0], float):
return relax.const(node.args[0], dtype if dtype is not None else "float64")
elif isinstance(node.args[0], int):
return relax.const(node.args[0], dtype if dtype is not None else "int64")
raise ValueError("torch.tensor with value not a float or int is not accepted")

def _tril_triu(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
k = node.args[1] if len(node.args) > 1 else 0
assert isinstance(k, int)
return self.block_builder.emit(op(x, k))

return convert

def _inplace_tril_triu(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
k = node.args[1] if len(node.args) > 1 else 0
assert isinstance(k, int)

mutated = self.block_builder.emit(op(x, k))
self.env[node.args[0]] = mutated
return mutated

return convert

def _new_ones(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
Expand Down Expand Up @@ -238,8 +331,9 @@ def _half(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16"))

def _type(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
return self.block_builder.emit(relax.op.astype(args[0], args[1]))
x = self.env[node.args[0]]
dtype = self._convert_data_type(node.args[1])
return self.block_builder.emit(relax.op.astype(x, dtype))

########## Linear Algebra ##########

Expand Down Expand Up @@ -313,12 +407,35 @@ def _split(self, node: fx.node.Node) -> relax.Var:
n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size
return self.block_builder.emit(relax.op.split(x, n_section, dim))

def _chunk(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
chunks = node.args[1]

if "dim" in node.kwargs:
dim = node.kwargs["dim"]
elif len(node.args) > 2:
dim = node.args[2]
else:
dim = 0
return self.block_builder.emit(relax.op.split(x, chunks, dim))

def _transpose(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
full_idx = list(range(len(self.shape_of(args[0]))))
full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]]
return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx))

def _squeeze(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]

if "dim" in node.kwargs:
dim = node.kwargs["dim"]
elif len(node.args) > 1:
dim = node.args[1]
else:
dim = None
return self.block_builder.emit(relax.op.squeeze(x, dim))

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
Expand Down Expand Up @@ -521,7 +638,16 @@ def _embedding(self, node: fx.node.Node) -> relax.Var:
module = self.named_modules[node.target]
weight = self.params[module.weight]
x = self.block_builder.emit(relax.op.astype(x, "int32"))
return self.block_builder.emit(relax.op.take(weight, x, axis=0))

ndim = x.struct_info.ndim
if ndim == 1:
return self.block_builder.emit(relax.op.take(weight, x, axis=0))
else:
x_shape = x.struct_info.shape.values
emb_size = weight.struct_info.shape.values[-1]
x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0))
return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size]))

def _interpolate(self, node: fx.node.Node) -> relax.Var:
# torch.nn.functional.interpolate(
Expand Down Expand Up @@ -620,13 +746,17 @@ def _getitem(self, node: fx.node.Node) -> relax.Var:
while i < len(shape):
begin.append(0)
end.append(shape[i])
stride.append(1)
axes.append(i)
i = i + 1
sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))
sliced_shape = list(self.shape_of(sliced))
for i in expand_dim:
sliced_shape.insert(i, 1)
return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape))
elif isinstance(x, relax.Constant):
dtype = x.struct_info.dtype
return relax.const(x.data.numpy()[node.args[1]], dtype)
else:
assert False

Expand Down Expand Up @@ -660,31 +790,45 @@ def create_convert_map(self):
"mul": self._mul,
"sub": self._sub,
"pow": self._pow,
"sigmoid": self._sigmoid,
"sqrt": self._sqrt,
"round": self._round,
"lt": self._lt,
"truediv": self._truediv,
"fill_": self._inplace_fill,
"new_ones": self._new_ones,
"tril": self._tril,
"arange": self._arange,
"empty": self._empty,
"tensor": self._tensor,
"tril": self._tril_triu(relax.op.tril),
"triu": self._tril_triu(relax.op.triu),
"tril_": self._inplace_tril_triu(relax.op.tril),
"triu_": self._inplace_tril_triu(relax.op.triu),
"sum": self._sum,
"float": self._float,
"half": self._half,
"type": self._type,
"astype": self._type,
"matmul": self._matmul,
"addmm": self._addmm,
"bmm": self._matmul,
"cat": self._cat,
"expand": self._expand,
"flatten": self._flatten,
"permute": self._permute,
"reshape": self._reshape,
"split": self._split,
"chunk": self._chunk,
"transpose": self._transpose,
"squeeze": self._squeeze,
"unsqueeze": lambda node: self.block_builder.emit(
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
),
"view": self._reshape,
"argmax": self._argmax_argmin(relax.op.argmax),
"argmin": self._argmax_argmin(relax.op.argmin),
"softmax": self._softmax,
"dropout": lambda node: self.env[node.args[0]],
"clamp": self._clamp,
"relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
"gelu": lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
Expand All @@ -693,6 +837,7 @@ def create_convert_map(self):
"getattr": self._getattr,
"getitem": self._getitem,
"contiguous": lambda node: self.env[node.args[0]],
"to": lambda node: self.env[node.args[0]],
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
}

Expand Down
Loading

0 comments on commit 70ea70f

Please sign in to comment.