Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[UX][TVMScript] Overload __neg__ for relax expr (#434)
Browse files Browse the repository at this point in the history
This PR overloads `__neg__` given that `relax.negative` is now supported and adds corresponding tests.
  • Loading branch information
SiriusNEO authored Feb 13, 2023
1 parent c9fb54b commit 5302925
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp":
return _op_ffi_api.astype(self, dtype) # type: ignore

def __neg__(self) -> "ExprWithOp":
raise ValueError("relax.negative is not supported yet.")
return _op_ffi_api.negative(self) # type: ignore

def __lt__(self, other: Expr) -> "ExprWithOp":
return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore
Expand Down
1 change: 1 addition & 0 deletions tests/python/relax/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _check_call(expr, op_name: str):
_check_call(x <= y, "less_equal")

# Arithmetic operators
_check_call(-x, "negative")
_check_call(x + y, "add")
_check_call(x - y, "subtract")
_check_call(x * y, "multiply")
Expand Down
26 changes: 14 additions & 12 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,11 +883,12 @@ def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined
def test_arith_operators():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")):
a0 = x + y
a1 = x - y
a2 = x * y
a3 = x / y
a4 = x // y
a0 = -x
a1 = x + y
a2 = x - y
a3 = x * y
a4 = x / y
a5 = x // y

c0 = x > y
c1 = x < y
Expand All @@ -898,19 +899,20 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")):
t0 = tuple_expr[0]
t1 = tuple_expr[1]
t2 = tuple_expr[0][0] # <= Will normalize to two bindings
return a0, a1, a2, a3, a4, c0, c1, c2, c3, t0, t1, t2
return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2

m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = relax.Var("x", relax.TensorStructInfo([m, n], "float32"))
y = relax.Var("y", relax.TensorStructInfo([m, n], "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x, y)):
a0 = bb.emit(relax.op.add(x, y))
a1 = bb.emit(relax.op.subtract(x, y))
a2 = bb.emit(relax.op.multiply(x, y))
a3 = bb.emit(relax.op.divide(x, y))
a4 = bb.emit(relax.op.floor_divide(x, y))
a0 = bb.emit(relax.op.negative(x))
a1 = bb.emit(relax.op.add(x, y))
a2 = bb.emit(relax.op.subtract(x, y))
a3 = bb.emit(relax.op.multiply(x, y))
a4 = bb.emit(relax.op.divide(x, y))
a5 = bb.emit(relax.op.floor_divide(x, y))

c0 = bb.emit(relax.op.greater(x, y))
c1 = bb.emit(relax.op.less(x, y))
Expand All @@ -922,7 +924,7 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")):
t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1))
tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0))
t2 = bb.emit(relax.TupleGetItem(tmp, 0))
bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, c0, c1, c2, c3, t0, t1, t2)))
bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2)))

_check(foo, bb.get()["foo"])

Expand Down

0 comments on commit 5302925

Please sign in to comment.