diff --git a/.circleci/install_tvm.sh b/.circleci/install_tvm.sh index 4c603d4bf..e3d00f44b 100644 --- a/.circleci/install_tvm.sh +++ b/.circleci/install_tvm.sh @@ -1,6 +1,7 @@ #!/bin/bash git clone --recursive https://github.com/apache/incubator-tvm tvm-origin cd tvm-origin +git checkout 07ac7712ea2bfecb3c8d21d9358a24c7df585925 mkdir build cp ../.circleci/config.cmake build cd build diff --git a/.gitignore b/.gitignore index 65f3dfcf8..3a5a25a92 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ out *.params *.gz +# Generated files +project diff --git a/hlib/python/hlib/__init__.py b/hlib/python/hlib/__init__.py index 35ba4b106..dd7bf5f0e 100644 --- a/hlib/python/hlib/__init__.py +++ b/hlib/python/hlib/__init__.py @@ -1,3 +1,4 @@ from . import op +from . import ip from . import frontend from . import utils diff --git a/hlib/python/hlib/ip/__init__.py b/hlib/python/hlib/ip/__init__.py new file mode 100644 index 000000000..16411fa56 --- /dev/null +++ b/hlib/python/hlib/ip/__init__.py @@ -0,0 +1 @@ +from .fft import single_fft_hls diff --git a/hlib/python/hlib/ip/fft.py b/hlib/python/hlib/ip/fft.py new file mode 100644 index 000000000..5b5d34bff --- /dev/null +++ b/hlib/python/hlib/ip/fft.py @@ -0,0 +1,98 @@ +import heterocl as hcl +import numpy as np +from heterocl import util +from heterocl.tvm import make as _make +from heterocl.tvm import stmt as _stmt +from heterocl.tvm import ir_pass as _pass +from heterocl.tvm._api_internal import _ExternOp +from heterocl.schedule import Schedule, Stage +from heterocl.mutator import Mutator +from collections import OrderedDict +import os +from hlib.op.extern import * + +dtype = hcl.Int() + +@register_extern_ip(vendor="xilinx") +def single_fft_hls(X_real, X_imag, F_real=None, F_imag=None, name=None): + + if name is None: name = "hls::fft" + L = X_real.shape[0] + assert X_real.shape == X_imag.shape + assert np.log2(L) % 1 == 0, "length must be power of 2: " + str(L) + + # functional behavior + with hcl.Stage("ExternModule") as Module: + num_stages = int(np.log2(L)) + bit_width = int(np.log2(L)) + IndexTable = np.zeros((L), dtype='int') + for i in range(L): + b = '{:0{width}b}'.format(i, width=bit_width) + IndexTable[i] = int(b[::-1], 2) + + return_tensors = False + Table = hcl.copy(IndexTable, "table", dtype=hcl.Int()) + if (F_real is None) and (F_imag is None): + return_tensors = True + F_real = hcl.compute((L,), lambda i: X_real[Table[i]], name='F_real') + F_imag = hcl.compute((L,), lambda i: X_imag[Table[i]], name='F_imag') + else: # use passed-in tensors + hcl.update(F_real, lambda i: X_real[Table[i]], name='F_real_update') + hcl.update(F_imag, lambda i: X_imag[Table[i]], name='F_imag_update') + + with hcl.Stage("Out"): + one = hcl.scalar(1, dtype="int32") + with hcl.for_(0, num_stages) as stage: + DFTpts = one[0] << (stage + 1) + numBF = DFTpts / 2 + e = -2 * np.pi / DFTpts + a = hcl.scalar(0) + with hcl.for_(0, numBF) as j: + c = hcl.scalar(hcl.cos(a[0])) + s = hcl.scalar(hcl.sin(a[0])) + a[0] = a[0] + e + with hcl.for_(j, L + DFTpts - 1, DFTpts) as i: + i_lower = i + numBF + temp_r = hcl.scalar(F_real[i_lower] * c - F_imag[i_lower] * s) + temp_i = hcl.scalar(F_imag[i_lower] * c + F_real[i_lower] * s) + F_real[i_lower] = F_real[i] - temp_r[0] + F_imag[i_lower] = F_imag[i] - temp_i[0] + F_real[i] = F_real[i] + temp_r[0] + F_imag[i] = F_imag[i] + temp_i[0] + + dicts = {} + dicts["name"] = name + tensors = [X_real, X_imag, F_real, F_imag] + dicts["args"] = [(_.name, _.dtype) for _ in tensors] + + # declare headers and typedef + dicts["header"] = """ +#include \"hls_fft.h\" +#include +struct config : hls::ip_fft::params_t { + static const unsigned ordering_opt = hls::ip_fft::natural_order; + static const unsigned config_width = 16; // FFT_CONFIG_WIDTH +}; +typedef std::complex> fxpComplex; +""" + # extern ip function + dicts["func"] = """ + hls::ip_fft::config_t fft_config; + hls::ip_fft::config_t fft_status; + fft_config.setDir(0); + fft_config.setSch(0x2AB); + complex> xn[{}]; + complex> xk[{}]; + for (int i = 0; i < {}; i++) + xn[i] = fxpComplex({}[i], {}[i]); + hls::fft(xn, xk, &fft_config, &fft_status); + for (int i = 0; i < {}; i++) {{ + {}[i] = xk.real(); + {}[i] = xk.imag(); + }} +""".format(L, L, L, X_real.name, X_imag.name, + L, F_real.name, F_imag.name) + + create_extern_module(Module, dicts, ip_type="hls") + if return_tensors: return F_real, F_imag + diff --git a/hlib/python/hlib/op/__init__.py b/hlib/python/hlib/op/__init__.py index 68e1abfee..409279fcc 100644 --- a/hlib/python/hlib/op/__init__.py +++ b/hlib/python/hlib/op/__init__.py @@ -1,3 +1,4 @@ from . import math from . import nn from . import op +from . import extern diff --git a/hlib/python/hlib/op/extern.py b/hlib/python/hlib/op/extern.py new file mode 100644 index 000000000..74c1173ad --- /dev/null +++ b/hlib/python/hlib/op/extern.py @@ -0,0 +1,117 @@ +import heterocl as hcl +import numpy as np +from heterocl import util +from heterocl.tvm import make as _make +from heterocl.tvm import stmt as _stmt +from heterocl.tvm import ir_pass as _pass +from heterocl.tvm._api_internal import _ExternOp +from heterocl.schedule import Schedule, Stage +from heterocl.mutator import Mutator +from collections import OrderedDict +import os + +dtype = hcl.Int() + +class ModuleMarker(Mutator): + """ create extern module used at inner-loop level""" + def __init__(self, axis, info, args): + self.axis = axis + self.info = info + self.args = args + self.count = 0 + self.range_ = {} + self.index_map = {} + + def record_index(self, var, index): + for key, value in self.range_.items(): + if value == 1: + sub = {key : 0} + index = _pass.Simplify(_pass.Substitute(index, sub)) + self.index_map[var] = index + + def mutate_Load(self, node): + buffer_var = self.mutate(node.buffer_var) + index = self.mutate(node.index) + index = util.CastRemover().mutate(index) + self.record_index(buffer_var, index) + + predicate = self.mutate(node.predicate) + return _make.Load(node.dtype, buffer_var, index, predicate) + + def mutate_Store(self, node): + buffer_var = self.mutate(node.buffer_var) + index = self.mutate(node.index) + index = util.CastRemover().mutate(index) + self.record_index(buffer_var, index) + + value = self.mutate(node.value) + predicate = self.mutate(node.predicate) + return _make.Store(buffer_var, value, index, predicate) + + def mutate_For(self, node): + self.count += 1 + loop_var = self.mutate(node.loop_var) + _min = self.mutate(node.min) + extent = self.mutate(node.extent) + self.range_[loop_var] = extent.value - _min.value + body = self.mutate(node.body) + + if (self.count == self.axis): + self.count = 0 + if isinstance(body, _stmt.AttrStmt): + body = body.body + # insert index map + index_map = { k.name : v for k, v in self.index_map.items() } + for i in range(len(self.args)): + self.info["index" + str(i)] = str(index_map[self.args[i]]) + + body = _make.ExternModule( + "rtl", + _make.StringImm("test"), body, + list(self.info.keys()), list(self.info.values())) + + return _make.For(loop_var, _min, extent, node.for_type, node.device_api, body) + + +def register_extern_ip(**attrs): + def with_attrs(f): + for k,v in attrs.items(): + setattr(f, k, v) + return f + return with_attrs + + +# create hls ip invoked within the top function +def create_hls_ip(op, name, args, ip_type="hls", path=None): + # must be called within a superstage + assert Stage._current + curr = Schedule.last_stages[-1] + input_ops = [i._op for i in curr.input_stages] + output_bufs = [curr._buf] + + +# include external ip files +def create_extern_module(op, dicts, ip_type="hls", path=None): + curr = Schedule.last_stages[-1] + input_ops = [i._op for i in curr.input_stages] + input_bufs = [i._buf for i in curr.input_stages] + output_bufs = [curr._buf] + + # input and output arguments + assert "args" in dicts.keys() + annotate_dict = dicts + for name, dtype in dicts["args"]: + annotate_dict["input::" + name] = dtype + del annotate_dict["args"] + + op = op._op.op + assert ip_type in ["rtl", "hls", "host"] + body = _make.ExternModule( + "top", _make.StringImm(ip_type), op.body, + list(annotate_dict.keys()), list(annotate_dict.values())) + + new_op = _ExternOp( + op.name, op.tag, op.axis, + input_ops, input_bufs, output_bufs, body) + curr._op = new_op.output(0) + diff --git a/hlib/tests/test_extern_ip.py b/hlib/tests/test_extern_ip.py new file mode 100644 index 000000000..b4a0831cd --- /dev/null +++ b/hlib/tests/test_extern_ip.py @@ -0,0 +1,92 @@ +import heterocl as hcl +import numpy as np +import numpy.testing as tst +import hlib +import os +from itertools import permutations + +dtype = hcl.Float(64) + +def test_fft_hls(): + + def _test_llvm(length): + hcl.init(hcl.Float()) + X_real = hcl.placeholder((length,), name="X_real") + X_imag = hcl.placeholder((length,), name="X_imag") + + def math_func(A, B): + return hlib.ip.single_fft_hls(A, B) + + s = hcl.create_schedule([X_real, X_imag], math_func) + f = hcl.build(s) + + x_real_np = np.random.random((length)) + x_imag_np = np.random.random((length)) + x_np = x_real_np + 1j * x_imag_np + + out_np = np.fft.fft(x_np) + out_real_np = out_np.real + out_imag_np = out_np.imag + + x_real_hcl = hcl.asarray(x_real_np) + x_imag_hcl = hcl.asarray(x_imag_np) + + out_real_hcl = hcl.asarray(np.zeros((length))) + out_imag_hcl = hcl.asarray(np.zeros((length))) + + f(x_real_hcl, x_imag_hcl, out_real_hcl, out_imag_hcl) + + np.testing.assert_allclose(out_real_np, out_real_hcl.asnumpy(), rtol=1e-02, atol=1e-3) + np.testing.assert_allclose(out_imag_np, out_imag_hcl.asnumpy(), rtol=1e-02, atol=1e-3) + + _test_llvm(32) + _test_llvm(512) + _test_llvm(1024) + + def _test_sim(length): + hcl.init(hcl.Float()) + X_real = hcl.placeholder((length,), name="X_real") + X_imag = hcl.placeholder((length,), name="X_imag") + + def math_func(A, B): + real, imag = hlib.ip.single_fft_hls(A, B) + return hcl.compute((length,), lambda x: + hcl.sqrt(real[x] * real[x] + imag[x] * imag[x]), name="abs") + + s = hcl.create_schedule([X_real, X_imag], math_func) + target = hcl.platform.aws_f1 + target.config(compile="vitis", backend="vhls") + s.to([X_real, X_imag], target.xcel) + s.to(math_func.abs, target.host) + ir = str(hcl.lower(s)) + pattern = "test({}.channel, {}.channel, abs.channel)" + combination = [ pattern.format(*_) + for _ in list(permutations(["X_real", "X_imag"])) ] + assert any([_ in ir for _ in combination]) + # f = hcl.build(s, target) + + # x_real_np = np.random.random((length)) + # x_imag_np = np.random.random((length)) + # x_np = x_real_np + 1j * x_imag_np + # + # out_np = np.fft.fft(x_np) + # out_real_np = out_np.real + # out_imag_np = out_np.imag + # + # x_real_hcl = hcl.asarray(x_real_np) + # x_imag_hcl = hcl.asarray(x_imag_np) + # + # out_real_hcl = hcl.asarray(np.zeros((length))) + # out_imag_hcl = hcl.asarray(np.zeros((length))) + + # f(x_real_hcl, x_imag_hcl, out_real_hcl, out_imag_hcl) + + # np.testing.assert_allclose(out_real_np, out_real_hcl.asnumpy(), rtol=1e-02, atol=1e-3) + # np.testing.assert_allclose(out_imag_np, out_imag_hcl.asnumpy(), rtol=1e-02, atol=1e-3) + + _test_sim(32) + _test_sim(512) + _test_sim(1024) + +if __name__ == '__main__': + test_fft_hls() diff --git a/python/heterocl/mutator.py b/python/heterocl/mutator.py index 7d49f1e76..4bbb2f176 100644 --- a/python/heterocl/mutator.py +++ b/python/heterocl/mutator.py @@ -88,6 +88,8 @@ def mutate(self, node): return self.mutate_AssertStmt(node) elif isinstance(node, _stmt.ProducerConsumer): return self.mutate_ProducerConsumer(node) + elif isinstance(node, _stmt.ExternModule): + return self.mutate_ExternModule(node) elif isinstance(node, _stmt.For): return self.mutate_For(node) elif isinstance(node, _stmt.Store): @@ -270,9 +272,14 @@ def mutate_AssertStmt(self, node): return _make.AssertStmt(condition, message, body) def mutate_ProducerConsumer(self, node): - body = self.mutate(body) + body = self.mutate(node.body) return _make.ProducerConsumer(node.func, node.is_producer, body) + def mutate_ExternModule(self, node): + body = self.mutate(node.body) + return _make.ExternModule(node.attr_key, node.value, body, + node.annotate_keys, node.annotate_values) + def mutate_For(self, node): loop_var = self.mutate(node.loop_var) _min = self.mutate(node.min) diff --git a/python/heterocl/tvm/build_module.py b/python/heterocl/tvm/build_module.py index de2f6da9a..97ebabf68 100755 --- a/python/heterocl/tvm/build_module.py +++ b/python/heterocl/tvm/build_module.py @@ -45,7 +45,7 @@ def tvm_callback_exec_evaluate(platform, mode): qor = dict() if platform == "vivado": - out = run_process("cd __tmp__; make vivado 2>&1") + out = run_process("cd project; make vivado 2>&1") print(out) elif platform == "vivado_hls": @@ -57,11 +57,11 @@ def tvm_callback_exec_evaluate(platform, mode): "g++ version too old {}.{}.{}".format(ver[0], ver[1], ver[2]) # for host only mode - if not os.path.isfile("__tmp__/kernel.cpp"): - replace_text("__tmp__/Makefile", "kernel.cpp", "") - replace_text("__tmp__/host.cpp", "#include \"kernel.h\"", "") + if not os.path.isfile("project/kernel.cpp"): + replace_text("project/Makefile", "kernel.cpp", "") + replace_text("project/host.cpp", "#include \"kernel.h\"", "") - cmd = "cd __tmp__; make " + cmd = "cd project; make " if mode == "sw_sim": cmd += "csim" else: assert False @@ -73,7 +73,7 @@ def tvm_callback_exec_evaluate(platform, mode): elif platform == "sdsoc": assert os.system("which sds++ >> /dev/null") == 0, \ "cannot find sds++ on system path" - out = run_process("cd __tmp__; make sdsoc") + out = run_process("cd project; make sdsoc") print(out) elif platform == "sdaccel": @@ -81,20 +81,20 @@ def tvm_callback_exec_evaluate(platform, mode): "cannot find xocc on system path" if mode == "sw_sim": - cmd = "cd __tmp__; " +\ + cmd = "cd project; " +\ "export XCL_EMULATION_MODE=sw_emu; " +\ "./top_function_0_host.exe -f top_function_0.sw_emu.xclbin" out = run_process(cmd) elif mode == "hw_sim": - cmd = "cd __tmp__; " +\ + cmd = "cd project; " +\ "export XCL_EMULATION_MODE=hw_emu; " +\ "./top_function_0_host.exe -f top_function_0.hw_emu.xclbin" out = run_process(cmd) - os.system("cat __tmp__/profile_summary.csv") + os.system("cat project/profile_summary.csv") elif mode == "hw": - cmd = "cd __tmp__; " +\ + cmd = "cd project; " +\ "export XCL_EMULATION_MODE=hw; " +\ "./top_function_0_host.exe -f top_function_0.hw.xclbin" out = run_process(cmd) @@ -104,13 +104,13 @@ def tvm_callback_exec_evaluate(platform, mode): "cannot find v++ on system path" device = os.environ["XDEVICE"].split("/")[-1] device = device.replace(".xpfm", "") - cmd = "cd __tmp__; " + \ + cmd = "cd project; " + \ "XCL_EMULATION_MODE=sw_emu ./host build_dir" + \ ".sw_emu." + device + "/kernel.xclbin" out = run_process(cmd) elif platform == "aocl": - cmd = "cd __tmp__; " + \ + cmd = "cd project; " + \ "env CL_CONTEXT_EMULATOR_DEVICE_INTELFPGA=1 ./host " + \ " kernel.aocx" out = run_process(cmd) @@ -121,7 +121,7 @@ def tvm_callback_exec_evaluate(platform, mode): return str(qor) @register_func -def copy_and_compile(platform, mode, backend): +def copy_and_compile(platform, mode, backend, cfg): """ create necessary files and compile into binary """ path = api.__file__ path = os.path.join(path[0:path.find("python")], "tvm/src/template/") @@ -152,25 +152,25 @@ def copy_and_compile(platform, mode, backend): # copy tcl and testbench elif platform == "vivado_hls" or platform == "vivado": - os.system("cp " + path + "vivado/* __tmp__/") - os.system("cp " + path + "harness.mk __tmp__/") + os.system("cp " + path + "vivado/* project/") + os.system("cp " + path + "harness.mk project/") return "success" # copy sdsoc makefile elif platform == "sdsoc": - os.system("cp " + path + "sdsoc/* __tmp__/") - os.system("cp " + path + "harness.mk __tmp__/") + os.system("cp " + path + "sdsoc/* project/") + os.system("cp " + path + "harness.mk project/") return "success" elif platform == "sdaccel": - os.system("cp " + path + "sdaccel/* __tmp__/") - os.system("cp " + path + "harness.mk __tmp__/") - replace_text("__tmp__/Makefile", "App", "top_function_0") - replace_text("__tmp__/utils.h", + os.system("cp " + path + "sdaccel/* project/") + os.system("cp " + path + "harness.mk project/") + replace_text("project/Makefile", "App", "top_function_0") + replace_text("project/utils.h", "xilinx_aws-vu9p-f1-04261818_dynamic_5_0", "xilinx_vcu1525_dynamic_5_1") if backend == "vhls": - replace_text("__tmp__/Makefile", "kernel.cl", "kernel.cpp") + replace_text("project/Makefile", "kernel.cl", "kernel.cpp") # compile the program assert os.system("which xocc >> /dev/null") == 0, \ @@ -183,16 +183,16 @@ def copy_and_compile(platform, mode, backend): # re-compile host only (reuse context ?) if False and os.path.isfile("top_function_0.sw_emu.xclbin"): - run_process("cd __tmp__; make clean; make host") - run_process("cp top_function_0.sw_emu.xclbin __tmp__/") + run_process("cd project; make clean; make host") + run_process("cp top_function_0.sw_emu.xclbin project/") else: # config & compile env["XCL_EMULATION_MODE"] = "sw_emu" - cmd = "cd __tmp__; make clean;" + cmd = "cd project; make clean;" cmd += "emconfigutil --platform=$AWS_PLATFORM;" cmd += "make ocl OCL_TARGET=sw_emu \ OCL_PLATFORM=$AWS_PLATFORM \ - APPLICATION_DIR=" + os.getcwd() + "/__tmp__/" + APPLICATION_DIR=" + os.getcwd() + "/project/" out = run_process(cmd, env=env) # enable profiler @@ -202,11 +202,11 @@ def copy_and_compile(platform, mode, backend): "aws platform info missing" env["XCL_EMULATION_MODE"] = "hw_emu" - cmd = "cd __tmp__; make clean;" + cmd = "cd project; make clean;" cmd += "emconfigutil --platform=$AWS_PLATFORM;" cmd += "make ocl OCL_TARGET=hw_emu \ OCL_PLATFORM=$AWS_PLATFORM \ - APPLICATION_DIR=" + os.getcwd() + "/__tmp__/" + APPLICATION_DIR=" + os.getcwd() + "/project/" out = run_process(cmd, env=env) elif mode == "hw": @@ -215,11 +215,11 @@ def copy_and_compile(platform, mode, backend): "aws platform info missing" env["XCL_EMULATION_MODE"] = "hw" - cmd = "cd __tmp__; make clean;" + cmd = "cd project; make clean;" cmd += "emconfigutil --platform=$AWS_PLATFORM;" cmd += "make ocl OCL_TARGET=hw \ OCL_PLATFORM=$AWS_PLATFORM \ - APPLICATION_DIR=" + os.getcwd() + "/__tmp__/" + APPLICATION_DIR=" + os.getcwd() + "/project/" out = run_process(cmd, env=env) return "success" @@ -228,8 +228,9 @@ def copy_and_compile(platform, mode, backend): env = os.environ.copy() assert "XDEVICE" in os.environ, \ "vitis platform info missing" - os.system("cp " + path + "vitis/* __tmp__/") - cmd = "cd __tmp__; make clean;" + os.system("cp " + path + "vitis/* project/") + + cmd = "cd project; make clean;" cmd += "make all TARGET=sw_emu DEVICE=$XDEVICE" out = run_process(cmd) return "success" @@ -239,8 +240,8 @@ def copy_and_compile(platform, mode, backend): assert "INTELFPGAOCLSDKROOT" in os.environ, \ "cannot find aocl sdk for fpga on path" - os.system("cp " + path + "aocl/* __tmp__/") - cmd = "cd __tmp__; make clean; make;" + os.system("cp " + path + "aocl/* project/") + cmd = "cd project; make clean; make;" # compile kernel for xcel device cmd += " aoc" if mode == "sw_sim": diff --git a/python/heterocl/tvm/stmt.py b/python/heterocl/tvm/stmt.py index d5c2d0a18..99820cbf0 100644 --- a/python/heterocl/tvm/stmt.py +++ b/python/heterocl/tvm/stmt.py @@ -55,6 +55,10 @@ class Allocate(Stmt): class AttrStmt(Stmt): pass +@register_node +class ExternModule(Stmt): + pass + @register_node class Free(Stmt): pass diff --git a/tvm/HalideIR/src/ir/Expr.h b/tvm/HalideIR/src/ir/Expr.h index c505bd921..48f07d891 100644 --- a/tvm/HalideIR/src/ir/Expr.h +++ b/tvm/HalideIR/src/ir/Expr.h @@ -95,7 +95,9 @@ enum class IRNodeType : int { StreamExpr, StreamStmt, /** for stencil analysis **/ - Stencil + Stencil, + /** for external module **/ + ExternModule }; /** The abstract base classes for a node in the Halide IR. */ diff --git a/tvm/HalideIR/src/ir/IR.cpp b/tvm/HalideIR/src/ir/IR.cpp index decba20fc..6bf0021af 100644 --- a/tvm/HalideIR/src/ir/IR.cpp +++ b/tvm/HalideIR/src/ir/IR.cpp @@ -898,6 +898,19 @@ Stmt Stencil::make(Array inputs, Array outputs, Stmt body, return Stmt(node); } +Stmt ExternModule::make(std::string attr_key, Expr value, Stmt body, + Array annotate_keys, Array annotate_values) { + internal_assert(body.defined()) << "undefined body\n"; + + std::shared_ptr node = std::make_shared(); + node->attr_key = std::move(attr_key); + node->value = std::move(value); + node->body = std::move(body); + node->annotate_keys = std::move(annotate_keys); + node->annotate_values = std::move(annotate_values); + return Stmt(node); +} + namespace { // Helper function to determine if a sequence of indices is a @@ -969,6 +982,7 @@ template<> void ExprNode::accept(IRVisitor *v, const Expr &e) const { v->v template<> void ExprNode::accept(IRVisitor *v, const Expr &e) const { v->visit((const Let *)this, e); } template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const LetStmt *)this, s); } template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const AttrStmt *)this, s); } +template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const ExternModule *)this, s); } template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const AssertStmt *)this, s); } template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const ProducerConsumer *)this, s); } template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const For *)this, s); } diff --git a/tvm/HalideIR/src/ir/IR.h b/tvm/HalideIR/src/ir/IR.h index d3d08d8ee..a5cc7147e 100644 --- a/tvm/HalideIR/src/ir/IR.h +++ b/tvm/HalideIR/src/ir/IR.h @@ -1290,6 +1290,30 @@ struct Stencil : public StmtNode { static constexpr const char* _type_key = "Stencil"; }; +/** The function block to save external module **/ +struct ExternModule : public StmtNode { + std::string attr_key; + Expr value; + Stmt body; + Array annotate_keys; + Array annotate_values; + + EXPORT static Stmt make(std::string attr_key, + Expr value, Stmt body, + Array annotate_keys, + Array annotate_values); + + void VisitAttrs(IR::AttrVisitor* v) final { + v -> Visit("attr_key", &attr_key); + v -> Visit("value", &value); + v -> Visit("body", &body); + v -> Visit("annotate_keys", &annotate_keys); + v -> Visit("annotate_values", &annotate_values); + } + static const IRNodeType _type_info = IRNodeType::ExternModule; + static constexpr const char* _type_key = "ExternModule"; +}; + } // inline functions diff --git a/tvm/HalideIR/src/ir/IRMutator.cpp b/tvm/HalideIR/src/ir/IRMutator.cpp index 2a8109b05..82904a62b 100644 --- a/tvm/HalideIR/src/ir/IRMutator.cpp +++ b/tvm/HalideIR/src/ir/IRMutator.cpp @@ -202,6 +202,18 @@ void IRMutator::visit(const AttrStmt *op, const Stmt &s) { } } +void IRMutator::visit(const ExternModule *op, const Stmt &s) { + Expr value = mutate(op->value); + Stmt body = mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + stmt = s; + } else { + stmt = ExternModule::make(op->attr_key, value, body, + op->annotate_keys, op->annotate_values); + } +} + void IRMutator::visit(const AssertStmt *op, const Stmt &s) { Expr condition = mutate(op->condition); Expr message = mutate(op->message); diff --git a/tvm/HalideIR/src/ir/IRMutator.h b/tvm/HalideIR/src/ir/IRMutator.h index 4088ae5ea..aba20df9f 100644 --- a/tvm/HalideIR/src/ir/IRMutator.h +++ b/tvm/HalideIR/src/ir/IRMutator.h @@ -72,6 +72,7 @@ class IRMutator : public IRVisitor { EXPORT virtual void visit(const Let *, const Expr &); EXPORT virtual void visit(const LetStmt *, const Stmt &); EXPORT virtual void visit(const AttrStmt *, const Stmt &); + EXPORT virtual void visit(const ExternModule *, const Stmt &); EXPORT virtual void visit(const AssertStmt *, const Stmt &); EXPORT virtual void visit(const ProducerConsumer *, const Stmt &); EXPORT virtual void visit(const For *, const Stmt &); diff --git a/tvm/HalideIR/src/ir/IRPrinter.cpp b/tvm/HalideIR/src/ir/IRPrinter.cpp index bba0b56b2..f8cd8259f 100644 --- a/tvm/HalideIR/src/ir/IRPrinter.cpp +++ b/tvm/HalideIR/src/ir/IRPrinter.cpp @@ -409,6 +409,22 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->print(op->body); }); +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ExternModule *op, IRPrinter *p) { + p->do_indent(); + p->stream << "// extern module ("; + p->stream << op->attr_key; + p->stream << ") "; + for (size_t i = 0; i < op->annotate_keys.size(); i++) { + p->stream << " "; + p->print(op->annotate_keys[i]); + p->stream << "="; + p->print(op->annotate_values[i]); + } + p->stream << '\n'; + p->print(op->body); +}); + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const AssertStmt *op, IRPrinter* p) { p->do_indent(); diff --git a/tvm/HalideIR/src/ir/IRVisitor.cpp b/tvm/HalideIR/src/ir/IRVisitor.cpp index e5e803005..39f94113e 100644 --- a/tvm/HalideIR/src/ir/IRVisitor.cpp +++ b/tvm/HalideIR/src/ir/IRVisitor.cpp @@ -150,6 +150,11 @@ void IRVisitor::visit(const AttrStmt *op, const Stmt &) { op->body.accept(this); } +void IRVisitor::visit(const ExternModule *op, const Stmt &) { + op->value.accept(this); + op->body.accept(this); +} + void IRVisitor::visit(const AssertStmt *op, const Stmt &) { op->condition.accept(this); op->message.accept(this); diff --git a/tvm/HalideIR/src/ir/IRVisitor.h b/tvm/HalideIR/src/ir/IRVisitor.h index a4faa4aba..384598bb6 100644 --- a/tvm/HalideIR/src/ir/IRVisitor.h +++ b/tvm/HalideIR/src/ir/IRVisitor.h @@ -53,6 +53,7 @@ class IRVisitor { EXPORT virtual void visit(const Shuffle *, const Expr &); EXPORT virtual void visit(const LetStmt *, const Stmt &); EXPORT virtual void visit(const AttrStmt *, const Stmt &); + EXPORT virtual void visit(const ExternModule *, const Stmt &); EXPORT virtual void visit(const AssertStmt *, const Stmt &); EXPORT virtual void visit(const ProducerConsumer *, const Stmt &); EXPORT virtual void visit(const For *, const Stmt &); diff --git a/tvm/include/tvm/ir.h b/tvm/include/tvm/ir.h index 8a26e551c..dd9a2ceac 100644 --- a/tvm/include/tvm/ir.h +++ b/tvm/include/tvm/ir.h @@ -513,6 +513,7 @@ using Halide::Internal::While; using Halide::Internal::Reuse; using Halide::Internal::Partition; using Halide::Internal::Stencil; +using Halide::Internal::ExternModule; // ir functions using Halide::Internal::is_const_power_of_two_integer; diff --git a/tvm/include/tvm/ir_functor_ext.h b/tvm/include/tvm/ir_functor_ext.h index 39ce6d2b8..149382158 100644 --- a/tvm/include/tvm/ir_functor_ext.h +++ b/tvm/include/tvm/ir_functor_ext.h @@ -247,6 +247,7 @@ class StmtFunctor { virtual R VisitStmt_(const KernelDef* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const KernelStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StreamStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ExternModule* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Return* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Break* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const While* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -279,6 +280,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(KernelDef); IR_STMT_FUNCTOR_DISPATCH(KernelStmt); IR_STMT_FUNCTOR_DISPATCH(StreamStmt); + IR_STMT_FUNCTOR_DISPATCH(ExternModule); IR_STMT_FUNCTOR_DISPATCH(Return); IR_STMT_FUNCTOR_DISPATCH(Break); IR_STMT_FUNCTOR_DISPATCH(While); diff --git a/tvm/include/tvm/ir_mutator.h b/tvm/include/tvm/ir_mutator.h index 200534644..3ad2ace3d 100644 --- a/tvm/include/tvm/ir_mutator.h +++ b/tvm/include/tvm/ir_mutator.h @@ -57,6 +57,7 @@ class TVM_DLL IRMutator { // The underscore allows Mutate not to be shadowed by inheritance virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); + virtual Stmt Mutate_(const ExternModule* op, const Stmt& s); virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Stmt Mutate_(const For* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s); diff --git a/tvm/include/tvm/ir_visitor.h b/tvm/include/tvm/ir_visitor.h index 21ef77c32..87a424f16 100644 --- a/tvm/include/tvm/ir_visitor.h +++ b/tvm/include/tvm/ir_visitor.h @@ -83,6 +83,7 @@ class TVM_DLL IRVisitor { virtual void Visit_(const Variable* op); virtual void Visit_(const LetStmt* op); virtual void Visit_(const AttrStmt* op); + virtual void Visit_(const ExternModule* op); virtual void Visit_(const IfThenElse* op); virtual void Visit_(const For* op); virtual void Visit_(const Allocate* op); diff --git a/tvm/src/api/api_ir.cc b/tvm/src/api/api_ir.cc index 4d9ab21cf..1795365c9 100644 --- a/tvm/src/api/api_ir.cc +++ b/tvm/src/api/api_ir.cc @@ -251,6 +251,7 @@ REGISTER_MAKE1(Return); REGISTER_MAKE2(While); REGISTER_MAKE2(Reuse); REGISTER_MAKE6(Stencil); +REGISTER_MAKE5(ExternModule); } // namespace ir } // namespace TVM diff --git a/tvm/src/api/api_pass.cc b/tvm/src/api/api_pass.cc index 1728b0c23..2f152a892 100644 --- a/tvm/src/api/api_pass.cc +++ b/tvm/src/api/api_pass.cc @@ -30,6 +30,17 @@ TVM_REGISTER_API("ir_pass.Simplify") } }); +TVM_REGISTER_API("ir_pass.Substitute") +.set_body([](TVMArgs args, TVMRetValue *ret) { + if (args[0].IsNodeType()) { + const Map& map = args[1]; + *ret = Substitute(args[0].operator Stmt(), map); + } else { + const Map& map = args[1]; + *ret = Substitute(args[0].operator Expr(), map); + } + }); + TVM_REGISTER_API("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsNodeType()) { diff --git a/tvm/src/codegen/build_common.cc b/tvm/src/codegen/build_common.cc index abbdec3f1..43aa70de0 100644 --- a/tvm/src/codegen/build_common.cc +++ b/tvm/src/codegen/build_common.cc @@ -35,12 +35,13 @@ class SimModuleNode final : public ModuleNode { SimModuleNode(LoweredFunc func, std::string host_code, std::vector arg_names, - std::string dev_code, std::string platform, + std::string dev_code, std::string cfg_code, std::string platform, std::unordered_map options) : func_(func), host_(host_code), - arg_names_(arg_names), dev_(dev_code), + cfg_(cfg_code), + arg_names_(arg_names), platform_(platform), options_(options) {} @@ -80,7 +81,7 @@ class SimModuleNode final : public ModuleNode { GenSharedMem(args, shmids, arg_sizes); LOG(CLEAN) << "Generating harness files ..."; - system("rm -rf __tmp__; mkdir __tmp__"); + system("rm -rf project; mkdir project"); GenHostCode(args, shmids, arg_types, func_, platform_, host_, arg_names_); @@ -92,7 +93,7 @@ class SimModuleNode final : public ModuleNode { CHECK(options_.count("mode")) << "mode mot set"; auto mode = options_["mode"]; auto backend = options_["backend"]; - (*f)(platform_, mode, backend).operator std::string(); + (*f)(platform_, mode, backend, cfg_).operator std::string(); } } @@ -139,9 +140,8 @@ class SimModuleNode final : public ModuleNode { private: LoweredFunc func_; - std::string host_; + std::string host_, dev_, cfg_; std::vector arg_names_; - std::string dev_; std::string platform_; std::unordered_map options_; std::vector shmids; @@ -149,16 +149,14 @@ class SimModuleNode final : public ModuleNode { }; Module CreateSimModule( - LoweredFunc func, - std::string host_code, - std::string dev_code, - std::vector arg_names, - std::string platform, - std::unordered_map options) { + LoweredFunc func, std::string host_code, + std::string dev_code, std::string cfg_code, std::vector arg_names, + std::string platform, std::unordered_map options) { + std::shared_ptr n = - std::make_shared(func, host_code, - arg_names, dev_code, - platform, options); + std::make_shared( + func, host_code, arg_names, dev_code, + cfg_code, platform, options); return Module(n); } } // namespace runtime @@ -202,11 +200,9 @@ runtime::Module BuildSimModule(Array funcs, auto val = values[k].as()->value; options[key] = val; } - return runtime::CreateSimModule(funcs[0], - cg_host.GetHost(), - cg_dev.GetDevice(), - cg_host.arg_names, - platform, options); + return runtime::CreateSimModule( + funcs[0], cg_host.GetHost(), cg_dev.GetDevice(), + cg_dev.GetConfig(), cg_host.arg_names, platform, options); } TVM_REGISTER_API("codegen.build_sim") diff --git a/tvm/src/codegen/build_util.cc b/tvm/src/codegen/build_util.cc index e775d6743..e76762f01 100644 --- a/tvm/src/codegen/build_util.cc +++ b/tvm/src/codegen/build_util.cc @@ -336,8 +336,9 @@ void GenKernelCode(std::string& test_file, std::string kernel_ext = "cpp"; if (platform == "sdaccel" && backend == "sdaccel") kernel_ext = "cl"; if (platform == "aocl") kernel_ext = "cl"; - stream.open("__tmp__/kernel." + kernel_ext); + stream.open("project/kernel." + kernel_ext); + // create typedef and header if (platform == "vivado" || platform == "vivado_hls" || platform == "sdsoc") { @@ -360,7 +361,7 @@ void GenKernelCode(std::string& test_file, // generate header file std::ofstream header; - header.open("__tmp__/kernel.h"); + header.open("project/kernel.h"); header << "#ifndef __KERNEL_H__\n" << "#define __KERNEL_H__\n\n"; header << "#include \n"; @@ -386,6 +387,10 @@ void GenKernelCode(std::string& test_file, } else if (platform == "aocl") { stream << "#include \"ihc_apint.h\"\n"; + + } else if (platform == "vitis") { + stream << "#include \n"; + stream << "#include \n"; } stream << test_file; @@ -476,7 +481,7 @@ void GenHostCode(TVMArgs& args, std::vector arg_names) { int indent = 0; std::ofstream stream; - stream.open("__tmp__/host.cpp"); + stream.open("project/host.cpp"); GenHostHeaders(stream, platform); auto code = SplitHostCode(host_code); CHECK((signed)arg_names.size() == args.size()); diff --git a/tvm/src/codegen/codegen_c.cc b/tvm/src/codegen/codegen_c.cc index 1228a48a0..d048bd427 100644 --- a/tvm/src/codegen/codegen_c.cc +++ b/tvm/src/codegen/codegen_c.cc @@ -129,12 +129,17 @@ void CodeGenC::AddFunction(LoweredFunc f, this->stream << "}\n\n"; } +std::string CodeGenC::GetConfig() { + return this->cfg_stream.str(); +} + std::string CodeGenC::GetHost() { return this->stream.str(); } std::string CodeGenC::GetDevice() { - return module_stream.str(); + return decl_stream.str() + + module_stream.str(); } std::string CodeGenC::Finish() { @@ -957,6 +962,10 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { this->PrintStmt(op->body); } +void CodeGenC::VisitStmt_(const ExternModule* op) { + LOG(FATAL) << "does not support ExternModule in C"; +} + void CodeGenC::VisitStmt_(const AssertStmt* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); diff --git a/tvm/src/codegen/codegen_c.h b/tvm/src/codegen/codegen_c.h index 1f0f4d164..6b95ca4ec 100644 --- a/tvm/src/codegen/codegen_c.h +++ b/tvm/src/codegen/codegen_c.h @@ -68,6 +68,7 @@ class CodeGenC : * \return The device code. */ std::string GetDevice(); + std::string GetConfig(); /*! * \brief Print the Stmt n to CodeGenC->stream * \param n The statement to be printed. @@ -139,6 +140,7 @@ class CodeGenC : void VisitStmt_(const IfThenElse* op) override; void VisitStmt_(const Allocate* op) override; void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const ExternModule* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Block* op) override; @@ -223,15 +225,15 @@ class CodeGenC : const std::string& target, const std::string& src, Type t) final; /*! \brief restrict keyword */ std::string restrict_keyword_{""}; - /*! \brief the func arg decl stream */ - std::ostringstream arg_stream; + /*! \brief the Makefile target object list */ + std::ostringstream cfg_stream; /*! \brief the storage scope of allocation */ std::unordered_map alloc_storage_scope_; /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; std::unordered_map buf_length_map_; - // save for kernel gen + // save for kernel generation std::unordered_map alloc_storage_scope_save; std::unordered_map handle_data_type_save; std::unordered_map var_idmap_save; diff --git a/tvm/src/codegen/hlsc/codegen_vhls.cc b/tvm/src/codegen/hlsc/codegen_vhls.cc index b7233fcd6..bdb0b581f 100644 --- a/tvm/src/codegen/hlsc/codegen_vhls.cc +++ b/tvm/src/codegen/hlsc/codegen_vhls.cc @@ -179,27 +179,27 @@ void CodeGenVivadoHLS::VisitStmt_(const For* op) { std::ostringstream os; if (ptr_mode) { - if (const For* for_op = op->body.as()) { - while (for_op->body.as()) - for_op = for_op->body.as(); - if (auto s = for_op->body.as()) { - if (s->buffer_var.get()->name_hint.find("channel") + Stmt stmt = op->body; + while (const For* for_op = stmt.as()) + stmt = for_op->body; + + if (auto s = stmt.as()) { + if (s->buffer_var.get()->name_hint.find("channel") + != std::string::npos) return; + } else if (auto st = stmt.as()) { + if (auto e = st->value.as()) { + if (e->buffer_var.get()->name_hint.find("channel") != std::string::npos) return; - } else if (auto st = for_op->body.as()) { - if (auto e = st->value.as()) { - if (e->buffer_var.get()->name_hint.find("channel") - != std::string::npos) return; - - } else { - auto value = st->value; - if (auto c = value.as()) value = c->value; - if (auto v = value.as()) { - if (v->value == 0) return; - } else if (auto v = value.as()) { - if (v->value == 0) return; - } else if (auto v = value.as()) { - if (v->value == 0) return; - } + + } else { + auto value = st->value; + if (auto c = value.as()) value = c->value; + if (auto v = value.as()) { + if (v->value == 0) return; + } else if (auto v = value.as()) { + if (v->value == 0) return; + } else if (auto v = value.as()) { + if (v->value == 0) return; } } } @@ -280,6 +280,75 @@ void CodeGenVivadoHLS::VisitExpr_(const StreamExpr* op, std::ostream& os) { os << vid << ".read()"; } +// generate the module as blackbox +void CodeGenVivadoHLS::VisitStmt_(const ExternModule* op) { + std::string ip_name, config, spec, decl; + std::vector args_in, args_out, indices; + + PrintIndent(); + for (size_t i = 0; i < op->annotate_keys.size(); i++) { + auto key = op->annotate_keys[i].as()->value; + if (key == "name") { + ip_name = op->annotate_values[i].as()->value; + } else if (key == "json") { + config = op->annotate_values[i].as()->value; + } else if (key == "decl") { + decl = op->annotate_values[i].as()->value; + } else if (key == "spec") { + spec = op->annotate_values[i].as()->value; + } else if (key.find("input") != std::string::npos) { + auto arg = op->annotate_values[i].as()->value; + args_in.push_back(arg); + } else if (key.find("output") != std::string::npos) { + auto arg = op->annotate_values[i].as()->value; + args_out.push_back(arg); + } else if (key.find("index") != std::string::npos) { + auto idx = op->annotate_values[i].as()->value; + indices.push_back(idx); + } + } + + // generate external ip core + if (indices.size() > 0) { + CHECK(indices.size() == args_in.size() + args_out.size()); + // initialize temp values + for (auto arg : args_out) { + stream << "ap_int<32> " << arg << "_temp;\n"; + PrintIndent(); + } + + stream << ip_name << "("; + auto index = 0; + for (auto arg : args_in) { + if (index > 0) stream << ", "; + stream << arg << "[" << indices[index] << "]"; + index++; + } + for (auto arg : args_out) { + if (index > 0) stream << ", "; + stream << arg << "_temp"; index++; + } + stream << ");\n"; + + // assign temp value back + index = args_in.size(); + for (auto arg : args_out) { + PrintIndent(); + stream << arg << "[" << indices[index++] + << "] = " << arg << "_temp;\n"; + } + + } else { + stream << ip_name << "("; + } + + // generate TCL and Makefile + if (op->attr_key == "rtl") { + cfg_stream << "add_files -blackbox " << config; + decl_stream << decl << "\n"; + } +} + void CodeGenVivadoHLS::VisitStmt_(const StreamStmt* op) { std::string vid = GetVarID(op->buffer_var.get()); switch (op->stream_type) { diff --git a/tvm/src/codegen/hlsc/codegen_vhls.h b/tvm/src/codegen/hlsc/codegen_vhls.h index 0f8c2b3de..8ab63f41f 100644 --- a/tvm/src/codegen/hlsc/codegen_vhls.h +++ b/tvm/src/codegen/hlsc/codegen_vhls.h @@ -33,6 +33,7 @@ class CodeGenVivadoHLS final : public CodeGenHLSC { void VisitStmt_(const StreamStmt* op) override; void VisitStmt_(const KernelDef* op) override; void VisitStmt_(const KernelStmt* op) override; + void VisitStmt_(const ExternModule* op) override; private: std::ofstream soda_header_; diff --git a/tvm/src/codegen/llvm/codegen_llvm.cc b/tvm/src/codegen/llvm/codegen_llvm.cc index 6c8d257e7..6677721f5 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.cc +++ b/tvm/src/codegen/llvm/codegen_llvm.cc @@ -1310,6 +1310,10 @@ void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) { this->VisitStmt(op->body); } +void CodeGenLLVM::VisitStmt_(const ExternModule* op) { + this->VisitStmt(op->body); +} + void CodeGenLLVM::VisitStmt_(const KernelDef* op) { this->SaveFuncState(); const UIntImm* is_void = op->ret_void.as(); diff --git a/tvm/src/codegen/llvm/codegen_llvm.h b/tvm/src/codegen/llvm/codegen_llvm.h index 3a525ff69..b3dcdb612 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.h +++ b/tvm/src/codegen/llvm/codegen_llvm.h @@ -129,6 +129,7 @@ class CodeGenLLVM : void VisitStmt_(const Block* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; + void VisitStmt_(const ExternModule* op) override; void VisitStmt_(const KernelDef* op) override; void VisitStmt_(const KernelStmt* op) override; void VisitStmt_(const Return* op) override; diff --git a/tvm/src/codegen/merlinc/codeanalys_merlinc.cc b/tvm/src/codegen/merlinc/codeanalys_merlinc.cc index d6fa1c6ba..b1afb7800 100644 --- a/tvm/src/codegen/merlinc/codeanalys_merlinc.cc +++ b/tvm/src/codegen/merlinc/codeanalys_merlinc.cc @@ -796,6 +796,10 @@ void CodeAnalysMerlinC::VisitStmt_(const AttrStmt* op) { this->PrintStmt(op->body); } +void CodeAnalysMerlinC::VisitStmt_(const ExternModule* op) { + this->PrintStmt(op->body); +} + void CodeAnalysMerlinC::VisitStmt_(const AssertStmt* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); diff --git a/tvm/src/codegen/merlinc/codeanalys_merlinc.h b/tvm/src/codegen/merlinc/codeanalys_merlinc.h index 421f0d96f..ae705a99e 100644 --- a/tvm/src/codegen/merlinc/codeanalys_merlinc.h +++ b/tvm/src/codegen/merlinc/codeanalys_merlinc.h @@ -120,6 +120,7 @@ class CodeAnalysMerlinC : void VisitStmt_(const IfThenElse* op) override; void VisitStmt_(const Allocate* op) override; void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const ExternModule* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Block* op) override; diff --git a/tvm/src/codegen/opencl/codegen_xocl.cc b/tvm/src/codegen/opencl/codegen_xocl.cc index f361e9f4a..8666f7c62 100644 --- a/tvm/src/codegen/opencl/codegen_xocl.cc +++ b/tvm/src/codegen/opencl/codegen_xocl.cc @@ -190,6 +190,10 @@ void CodeGenXOCL::VisitStmt_(const Partition* op) { } } +void CodeGenXOCL::VisitStmt_(const ExternModule* op) { + this->PrintStmt(op->body); +} + void CodeGenXOCL::VisitStmt_(const StreamStmt* op) { std::string vid = GetVarID(op->buffer_var.get()); PrintIndent(); diff --git a/tvm/src/codegen/opencl/codegen_xocl.h b/tvm/src/codegen/opencl/codegen_xocl.h index 468a64c38..742a26847 100755 --- a/tvm/src/codegen/opencl/codegen_xocl.h +++ b/tvm/src/codegen/opencl/codegen_xocl.h @@ -19,6 +19,7 @@ class CodeGenXOCL : public CodeGenOpenCL { void VisitStmt_(const For* op) override; //NOLINT(*) void VisitStmt_(const Partition* op) override; //NOLINT(*) void VisitStmt_(const StreamStmt* op) override; //NOLINT(*) + void VisitStmt_(const ExternModule* op) override; //NOLINT(*) void VisitExpr_(const StreamExpr* op, std::ostream& os) override; //NOLINT(*) diff --git a/tvm/src/codegen/opencl/codegen_xocl_host.cc b/tvm/src/codegen/opencl/codegen_xocl_host.cc index c0895e523..3505d032d 100644 --- a/tvm/src/codegen/opencl/codegen_xocl_host.cc +++ b/tvm/src/codegen/opencl/codegen_xocl_host.cc @@ -68,16 +68,32 @@ void CodeGenXOCLHost::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(* } void CodeGenXOCLHost::VisitStmt_(const For* op) { - // ignore the data tranmission for stmts - if (const For* for_op = op->body.as()) { - while (for_op->body.as()) - for_op = for_op->body.as(); - if (for_op->body.as()) { - return; - } else if (auto st = for_op->body.as()) { - if (st->value.as()) return; + + Stmt stmt = op->body; + while (const For* for_op = stmt.as()) + stmt = for_op->body; + + if (auto s = stmt.as()) { + if (s->buffer_var.get()->name_hint.find("channel") + != std::string::npos) return; + } else if (auto st = stmt.as()) { + if (auto e = st->value.as()) { + if (e->buffer_var.get()->name_hint.find("channel") + != std::string::npos) return; + + } else { + auto value = st->value; + if (auto c = value.as()) value = c->value; + if (auto v = value.as()) { + if (v->value == 0) return; + } else if (auto v = value.as()) { + if (v->value == 0) return; + } else if (auto v = value.as()) { + if (v->value == 0) return; + } } } + CodeGenC::VisitStmt_(op); } @@ -181,6 +197,10 @@ void CodeGenXOCLHost::VisitStmt_(const Allocate* op) { vid.replace(vid.find("_channel"), 8, ""); if (alloc_set.find(vid) != alloc_set.end()) { not_alloc = true; + } else { + for (auto& name : arg_names) { + if (name == vid) not_alloc = true; + } } } @@ -303,8 +323,18 @@ void CodeGenXOCLHost::VisitStmt_(const KernelStmt* op) { } stream << ");\n"; } +} - +void CodeGenXOCLHost::VisitStmt_(const ExternModule* op) { + std::string name; + for (size_t i = 0; i < op->annotate_keys.size(); i++) { + auto key = op->annotate_keys[i].as()->value; + auto value = op->annotate_values[i].as()->value; + if (key == "name") { + name = value; + } + } + this->PrintStmt(op->body); } } // namespace codegen diff --git a/tvm/src/codegen/opencl/codegen_xocl_host.h b/tvm/src/codegen/opencl/codegen_xocl_host.h index e61b1cbbd..9b1d2d3da 100644 --- a/tvm/src/codegen/opencl/codegen_xocl_host.h +++ b/tvm/src/codegen/opencl/codegen_xocl_host.h @@ -28,6 +28,7 @@ class CodeGenXOCLHost : public CodeGenC { void VisitStmt_(const Allocate* op) override; void VisitStmt_(const KernelStmt* op) override; void VisitStmt_(const Store* op) override; + void VisitStmt_(const ExternModule* op) override; void GenForStmt(const For* op, std::string pragma, bool before); diff --git a/tvm/src/lang/ir.cc b/tvm/src/lang/ir.cc index c88f8ea94..5994a5dd8 100644 --- a/tvm/src/lang/ir.cc +++ b/tvm/src/lang/ir.cc @@ -157,6 +157,7 @@ TVM_REGISTER_NODE_TYPE(While); TVM_REGISTER_NODE_TYPE(Reuse); TVM_REGISTER_NODE_TYPE(Partition); TVM_REGISTER_NODE_TYPE(Stencil); +TVM_REGISTER_NODE_TYPE(ExternModule); } // namespace ir } // namespace TVM diff --git a/tvm/src/pass/ir_mutator.cc b/tvm/src/pass/ir_mutator.cc index c7b47066e..38d4de635 100644 --- a/tvm/src/pass/ir_mutator.cc +++ b/tvm/src/pass/ir_mutator.cc @@ -115,6 +115,18 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { } } +Stmt IRMutator::Mutate_(const ExternModule* op, const Stmt& s) { + Expr value = this->Mutate(op->value); + Stmt body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return ExternModule::make(op->attr_key, value, body, + op->annotate_keys, op->annotate_values); + } +} + Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { Expr value = this->Mutate(op->value); Stmt body = this->Mutate(op->body); @@ -401,6 +413,7 @@ Stmt IRMutator::Mutate_(const Stencil *op, const Stmt &s) { TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt) +.DISPATCH_TO_MUTATE_STMT(ExternModule) .DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(Allocate) diff --git a/tvm/src/pass/ir_visitor.cc b/tvm/src/pass/ir_visitor.cc index 6346c6262..8c549f56c 100644 --- a/tvm/src/pass/ir_visitor.cc +++ b/tvm/src/pass/ir_visitor.cc @@ -60,6 +60,11 @@ void IRVisitor::Visit_(const AttrStmt* op) { this->Visit(op->body); } +void IRVisitor::Visit_(const ExternModule* op) { + this->Visit(op->value); + this->Visit(op->body); +} + void IRVisitor::Visit_(const For *op) { IRVisitor* v = this; v->Visit(op->min); @@ -352,7 +357,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(While) .DISPATCH_TO_VISIT(Reuse) .DISPATCH_TO_VISIT(Partition) -.DISPATCH_TO_VISIT(Stencil); +.DISPATCH_TO_VISIT(Stencil) +.DISPATCH_TO_VISIT(ExternModule); } // namespace ir } // namespace TVM diff --git a/tvm/src/schedule/schedule_reorder.cc b/tvm/src/schedule/schedule_reorder.cc index 96f6b76e0..b741bb545 100644 --- a/tvm/src/schedule/schedule_reorder.cc +++ b/tvm/src/schedule/schedule_reorder.cc @@ -69,6 +69,9 @@ std::vector ExtractSubGraph( const ReadGraph& g, const Schedule& sch, std::unordered_map& dev, + // module map recording super stage attachment + std::unordered_map> atts_map, std::vector& boundary, Array>& inputs, Array>& outputs, @@ -117,6 +120,7 @@ std::vector ExtractSubGraph( visited.insert(t->op.get()); } + std::unordered_set shared; while (!stack.empty()) { Operation op = stack.back(); stack.pop_back(); @@ -136,6 +140,8 @@ std::vector ExtractSubGraph( if (visited.count(t->op.get()) == 0) { visited.insert(t->op.get()); if (!reach_bound) stack.push_back(t->op); + } else { // visited ancestor + shared.insert(t->op.get()); } } } @@ -165,15 +171,61 @@ std::vector ExtractSubGraph( } } + CHECK(subgraph.size() > 0); + std::vector new_subgraph; + size_t op_count = 0; + for (auto& op : subgraph) { + auto name = op->name; + if (name.find(".new") != std::string::npos) { + new_subgraph.insert(new_subgraph.begin(), op); + op_count += 1; + } else { // ordinary ops + if (shared.find(op.get()) != shared.end()) { + new_subgraph.insert(new_subgraph.begin() + op_count, op); + continue; + } + new_subgraph.push_back(op); + } + } + + std::unordered_map inserted; + // for(auto op : new_subgraph) LOG(INFO) << op; + // for(auto kv : atts_map) { + // inserted[kv.first] = false; + // LOG(INFO) << kv.first << ":------------"; + // for (auto& k : kv.second) LOG(INFO) << k; + // } Stmt body = Evaluate::make(0); - for (Operation op : subgraph) { + for (Operation op : new_subgraph) { CHECK(op.as()) << op; if (auto extern_op = op.as()) { - CHECK(extern_op->output_placeholders.size()); - Buffer out_buf = extern_op->output_placeholders[0]; - Stmt attr = AttrStmt::make(VarExpr(out_buf.node_), - "attach_scope", StringImm::make("test"), no_op); - body = Block::make(body, attr); + + // check if subgraph op in extern module inputs + bool in_extern_mod = false; + for (auto& kv : atts_map) { + if (kv.second.count(op->name) && + op->name.find(".new") == std::string::npos) { + in_extern_mod = true; + if (!inserted[kv.first]) { + inserted[kv.first] = true; + auto mod_op = kv.first.as(); + Buffer mod_buf = mod_op->output_placeholders[0]; + // LOG(INFO) << "insert " << kv.first << ":" << mod_buf; + Stmt attr = AttrStmt::make(VarExpr(mod_buf.node_), + "attach_scope", StringImm::make("test"), no_op); + body = Block::make(body, attr); + } + } + } + + // insert standalone subgraph op + if (!in_extern_mod) { + CHECK(extern_op->output_placeholders.size()); + Buffer out_buf = extern_op->output_placeholders[0]; + Stmt attr = AttrStmt::make(VarExpr(out_buf.node_), + "attach_scope", StringImm::make("test"), no_op); + body = Block::make(body, attr); + } } } @@ -195,11 +247,30 @@ std::vector ExtractSubGraph( aggregate->body = AttrStmt::make( VarExpr(), attr::device_scope, scope, body); merged_ops.push_back(Operation(aggregate)); - return subgraph; + return new_subgraph; } static int bound_index = 0; + +// extract the bounded op arrays from subgraph root +void PostDFSBoundary(const Operation& op, + const ReadGraph& g, + std::unordered_set* visited, + Array* bounded_ops) { + if (visited->count(op)) return; + visited->insert(op); + + CHECK(op.as()); + for (const auto& t : g.at(op)) { + if (op->name.find(".new") == std::string::npos) + PostDFSBoundary(t->op, g, visited, bounded_ops); + } + // record ops before .new ops + bounded_ops->push_back(op); +} + // schedule the ops with subgraphs +// store ops that are not in subgraph void PostDFSSplit(const Operation& op, const ReadGraph& g, std::unordered_set* visited, @@ -210,13 +281,15 @@ void PostDFSSplit(const Operation& op, visited->insert(op); CHECK(dev.count(op.get())) << "not found " << op; - // push op into separate list + // visit from root to source and record break point + // push op into array if it is outside the subgraph bool reach_bound = false; for (auto& node : subgraphs) { if (op.same_as(node)) { // insert subgraph ops index - if (bound_index == 0) + if (bound_index == 0) { bound_index = visited->size(); + } reach_bound = true; } } @@ -236,8 +309,25 @@ Array PostDFSSplit( std::vector boundary; std::unordered_set visited; - for (Operation op : roots) + std::unordered_map> extern_mods; + // check the external module + for (Operation op : roots) { dev[op.get()] = DeviceType::devHost; + if (auto extern_op = op.as()) { + if (extern_op->body.as()) { + CHECK(g.count(op)) << "not found " << op; + for (const auto& t : g.at(op)) { + extern_mods[op].insert(t->op->name); + // TODO: find a better abstraction + if (g.count(t->op) && t->op->name.find(".new") == std::string::npos) { + for (auto& pt : g.at(t->op)) + extern_mods[op].insert(pt->op->name); + } + } + } + } + } for (Stage stage : sch->stages) { if (dev.count(stage->op.get())) @@ -254,23 +344,73 @@ Array PostDFSSplit( // are required to form an enclosed subgraph Array> inputs, outputs; std::vector merged_ops; - auto subgraph = ExtractSubGraph(roots, g, sch, dev, + + // not create aggregate for extern module + // note: the subgraph does not exactly descibe the compute flow + // e.g. if there are some otehr super stages modifying the tensor + // before we use the tensor, the read graph does not capture that + auto subgraph = ExtractSubGraph(roots, g, sch, dev, extern_mods, boundary, inputs, outputs, merged_ops); + // for (auto& op : subgraph) LOG(INFO) << op; Array post_order; + Array bounded_ops; for (Operation op : roots) { - PostDFSSplit(op, g, &visited, &post_order, dev, subgraph); + if (extern_mods.count(op)) { + // create op array of extern module (from .new to super stage root) + // return inner ops inside extern module (must be bounded by .new ops) + bool dev_scope = false; + for (auto& input : extern_mods.at(op)) { + if (input.find(".new") != std::string::npos) + dev_scope = true; + } + + if (dev_scope) { + std::unordered_set visited_ops; + PostDFSBoundary(op, g, &visited_ops, &bounded_ops); + } else { // in host scope, for sim + PostDFSSplit(op, g, &visited, &post_order, dev, subgraph); + } + } else { + // without extern module (subgraph & post_order) + // return post_order with op out of subgraph + PostDFSSplit(op, g, &visited, &post_order, dev, subgraph); + } } // op array index to insert subgraph if (bound_index > 0) { Array results; for (size_t k = 0; k < post_order.size(); k++) { - if (k == post_order.size() - bound_index + 1) { - for (auto& sub_op : subgraph) - results.push_back(sub_op); - for (auto& sub_op : merged_ops) - results.push_back(sub_op); + // scope switching right after index-th last op + if (k == post_order.size() - (bound_index - 1)) { + + if (extern_mods.size() == 0) { + for (auto& sub_op : subgraph) + results.push_back(sub_op); + for (auto& sub_op : merged_ops) + results.push_back(sub_op); + + // replace the modfied ops with extern module + // i.e. ops in the keys of corresponding module + } else { + CHECK(bounded_ops.size() > 0); + for (Operation op : bounded_ops) { + results.push_back(op); + } + for (auto& sub_op : subgraph) { + bool found_in_module = false; + for (auto& kv : extern_mods) { + if (kv.second.count(sub_op->name)) { + found_in_module = true; + } + } + if (!found_in_module) + results.push_back(sub_op); + } + for (auto& sub_op : merged_ops) + results.push_back(sub_op); + } } Operation op = post_order[k]; results.push_back(op); @@ -279,6 +419,7 @@ Array PostDFSSplit( << "schedule op array error " << results; return results; } + return post_order; } diff --git a/tvm/src/template/vitis/Makefile b/tvm/src/template/vitis/Makefile index 4b767e0af..74475d16f 100755 --- a/tvm/src/template/vitis/Makefile +++ b/tvm/src/template/vitis/Makefile @@ -67,10 +67,10 @@ endif # Kernel compiler global settings CLFLAGS += -t $(TARGET) --platform $(DEVICE) --save-temps -ifneq ($(TARGET), hw) - CLFLAGS += -g -endif - +# ifneq ($(TARGET), hw) +# CLFLAGS += -g +# endif +CLFLAGS += --config config.ini EXECUTABLE = host @@ -95,7 +95,7 @@ build: $(BINARY_CONTAINERS) # Building kernel $(TEMP_DIR)/kernel.xo: kernel.cpp mkdir -p $(TEMP_DIR) - $(VPP) $(CLFLAGS) --temp_dir $(TEMP_DIR) -c -k test -I'$(