From 01be7f5e73fca7981ac3c49713c0ae0d05f41699 Mon Sep 17 00:00:00 2001 From: "Yi-Hsiang (Sean) Lai" Date: Wed, 29 Apr 2020 13:52:51 -0400 Subject: [PATCH] [API][Backend] Fix hcl.print with UInt supported (#184) --- python/heterocl/api.py | 5 +++-- tests/test_api_print.py | 2 +- tests/test_api_print_cases/print_expr.py | 25 +++++++++++++++++++++--- tvm/src/codegen/llvm/codegen_llvm.cc | 7 ++++--- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/heterocl/api.py b/python/heterocl/api.py index 9215291a6..9bffd44fa 100644 --- a/python/heterocl/api.py +++ b/python/heterocl/api.py @@ -384,8 +384,9 @@ def print(vals, format=""): def get_format(val): if isinstance(val, (TensorSlice, Scalar, _expr.Expr)): - if util.get_type(val.dtype)[0] == "int": - return "%d" + if (util.get_type(val.dtype)[0] == "int" + or util.get_type(val.dtype)[0] == "uint"): + return "%lld" else: return "%f" elif isinstance(val, int): diff --git a/tests/test_api_print.py b/tests/test_api_print.py index c76e8be4f..bed5525f0 100644 --- a/tests/test_api_print.py +++ b/tests/test_api_print.py @@ -22,7 +22,7 @@ def test_print_expr(): outputs = get_stdout("print_expr").split("\n") - N = 4 + N = 5 for i in range(0, N): assert outputs[i] == outputs[i+N] diff --git a/tests/test_api_print_cases/print_expr.py b/tests/test_api_print_cases/print_expr.py index 1e6d71db4..6a3f5da29 100644 --- a/tests/test_api_print_cases/print_expr.py +++ b/tests/test_api_print_cases/print_expr.py @@ -20,7 +20,26 @@ def kernel(A): print(hcl_A.asnumpy()[5]) -# case2: float +# case1: uint + +hcl.init(hcl.UInt(4)) + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print(A[5]) + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.randint(20, 30, size=(10,)) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +print(hcl_A.asnumpy()[5]) + +# case3: float hcl.init(hcl.Float()) @@ -39,7 +58,7 @@ def kernel(A): print("%.4f" % hcl_A.asnumpy()[5]) -# case3: fixed points +# case4: fixed points hcl.init(hcl.UFixed(6, 4)) @@ -58,7 +77,7 @@ def kernel(A): print("%.4f" % hcl_A.asnumpy()[5]) -# case4: two ints +# case5: two ints hcl.init() diff --git a/tvm/src/codegen/llvm/codegen_llvm.cc b/tvm/src/codegen/llvm/codegen_llvm.cc index b81185351..25afb3101 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.cc +++ b/tvm/src/codegen/llvm/codegen_llvm.cc @@ -1414,7 +1414,7 @@ void CodeGenLLVM::VisitStmt_(const Print* op) { values.push_back(MakeValue(v)); types.push_back(v.type()); if (v.type().is_int() || v.type().is_uint()) { - llvm_types.push_back(LLVMType(v.type())); + llvm_types.push_back(t_int64_); } else { llvm_types.push_back(llvm::Type::getDoubleTy(*ctx_)); } @@ -1423,14 +1423,15 @@ void CodeGenLLVM::VisitStmt_(const Print* op) { #if TVM_LLVM_VERSION <= 60 llvm::Function* printf_call = llvm::cast(module_->getOrInsertFunction("printf", call_ftype)); #else - llvm::Function *printf_call = llvm::cast(module_->getOrInsertFunction("printf", call_ftype).getCallee()); + llvm::Function* printf_call = llvm::cast(module_->getOrInsertFunction("printf", call_ftype).getCallee()); #endif std::vector printf_args; std::string format = op->format; printf_args.push_back(builder_->CreateGlobalStringPtr(format)); for (size_t i = 0; i < op->values.size(); i++) { if (types[i].is_int() || types[i].is_uint()) { - printf_args.push_back(values[i]); + llvm::Value* ivalue = CreateCast(types[i], Int(64), values[i]); + printf_args.push_back(ivalue); } else { // fixed or float llvm::Value* fvalue = CreateCast(types[i], Float(64), values[i]); printf_args.push_back(fvalue);