diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 93f97829bc..5989b9017f 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -125,7 +125,7 @@ def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": return _op_ffi_api.astype(self, dtype) # type: ignore def __neg__(self) -> "ExprWithOp": - raise ValueError("relax.negative is not supported yet.") + return _op_ffi_api.negative(self) # type: ignore def __lt__(self, other: Expr) -> "ExprWithOp": return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore diff --git a/tests/python/relax/test_op.py b/tests/python/relax/test_op.py index 3a5697d3e3..65772baadf 100644 --- a/tests/python/relax/test_op.py +++ b/tests/python/relax/test_op.py @@ -62,6 +62,7 @@ def _check_call(expr, op_name: str): _check_call(x <= y, "less_equal") # Arithmetic operators + _check_call(-x, "negative") _check_call(x + y, "add") _check_call(x - y, "subtract") _check_call(x * y, "multiply") diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index bb28674ba0..68cf0ea3b9 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -883,11 +883,12 @@ def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined def test_arith_operators(): @R.function def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): - a0 = x + y - a1 = x - y - a2 = x * y - a3 = x / y - a4 = x // y + a0 = -x + a1 = x + y + a2 = x - y + a3 = x * y + a4 = x / y + a5 = x // y c0 = x > y c1 = x < y @@ -898,7 +899,7 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): t0 = tuple_expr[0] t1 = tuple_expr[1] t2 = tuple_expr[0][0] # <= Will normalize to two bindings - return a0, a1, a2, a3, a4, c0, c1, c2, c3, t0, t1, t2 + return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2 m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -906,11 +907,12 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - a0 = bb.emit(relax.op.add(x, y)) - a1 = bb.emit(relax.op.subtract(x, y)) - a2 = bb.emit(relax.op.multiply(x, y)) - a3 = bb.emit(relax.op.divide(x, y)) - a4 = bb.emit(relax.op.floor_divide(x, y)) + a0 = bb.emit(relax.op.negative(x)) + a1 = bb.emit(relax.op.add(x, y)) + a2 = bb.emit(relax.op.subtract(x, y)) + a3 = bb.emit(relax.op.multiply(x, y)) + a4 = bb.emit(relax.op.divide(x, y)) + a5 = bb.emit(relax.op.floor_divide(x, y)) c0 = bb.emit(relax.op.greater(x, y)) c1 = bb.emit(relax.op.less(x, y)) @@ -922,7 +924,7 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1)) tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0)) t2 = bb.emit(relax.TupleGetItem(tmp, 0)) - bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, c0, c1, c2, c3, t0, t1, t2))) + bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2))) _check(foo, bb.get()["foo"])