From e322a080e755f6bc1178965c61c30ef7afe524e5 Mon Sep 17 00:00:00 2001 From: Shaojie Xiang Date: Sun, 3 May 2020 18:49:50 -0400 Subject: [PATCH] [API] Data Movement Support Enhancement (#171) * [API] Enable building a function directly from IR (#133) * add a pass for building a function directly from IR * remove redundant print statement * [API] Enable select API to accept Python expressions * [API] Fixed incorrect casting for select in CastRemover * [API][Backend] Streaming and OpenCL Backends (#138) * add sdaccel, aocl for heterocl * fpga * Create codeanalys_openclc.cc * Update target.py * run * can run successfully * Create codegen_opencl.cc * now * all done * Update codegen_sdaccel.cc * Update codegen_sdaccel.cc * modified: python/heterocl/tvm/target.py * new file: samples/ppac/gemm/csrcPrint.py new file: samples/ppac/gemm/data.py new file: samples/ppac/gemm/gemm_ppac.py new file: samples/ppac/gemm/headcode.txt new file: samples/ppac/gemm/ppac_common.py new file: tvm/src/codegen/build_ppac.cc new file: tvm/src/codegen/codegen_rv64_ppac.cc new file: tvm/src/codegen/codegen_rv64_ppac.h * all * remove tvm check code from kernel * opencl-backend * all * fix ppac module build * support ppac MVPb pragma * fix ignoring ppac pragma in cpu backend * opencl-backend * aocl-backend * move ppac codegen to ppac folder; fix argument name with merlinc analyser * discard the new for-loop type; include ppac in hlib * discard some previous changes * Use int64_t as return type of GeMM on ppac * [add] codegenc kernedef + stream init * [add] var_shape_map * [update] kerneldef struct shape * [update] use noderef and restore * [fix] return op * [add] hcl device & kernelstmt printer * [fix] def workaround * [update] stream example * [add] stream expr & stmt ir * [fix] kernel arg location for stream * opt1 * opencl-general * new-version * no bug * a * test+unroll+pipeline * pragma * new * type has fixed * new_test * test_reorder_split_fuse * target * order * simplified by rui * analysis * bug fixed * [delete] all of the code about opencl * [ADD] new opencl back-end including xilinx & intel * fixed __local * fixed data_type for xilinx opencl * add makefile for SDAccel_runtime * add the runtime for sdaccel * create the sdaccel host * fixed the indent problem partly * test the zhang-05 server * add indent to the host.cpp * automatically generate makefile * delete common folder from opencl * add shmat to sdaccel runtime * fixed bug for sdaccel runtime seg fault * fixed the bug of host.cpp multiple * fixed host.cpp multiple bug * fixed endif for makefile * modify sdaccel_sw_emu -> sdaccel_csim * fix the __local and __global for intel opencl back-end * Fix the arbitrary integer precision for aocl * [add] ir visitor & functor for codegen * [add] aocl stream codegen * [add] aocl stream support * [fix] aocl type conversion * [fix] aocl channel syntax * [add] sch.stream_to * [fix] add stream annotation * [add] host device codegen * [add] stream ir mutator * [Add] Interface prag,a for SDx sim * [add] host xcel codegen * [update] build interface * [update] new build interface * [fix] temp update * [update] stream example * [add] rocc-ppac sim * [rm] submodule * [update] rocc ppac hlib * [add] unified sim & kernel updater * re-organize build common util * [update] stream in codegen c * [update] codegen construct for streaming * [update] code post-processing * [fix] test cases * [fix] python compatibility * [update] future * [fix] metaclass * [fix] test import issue * Revert "[API][Backend] Streaming and OpenCL Backends (#138)" (#139) This reverts commit 2c7534468af660b857878aa38bb49b02d19d4af3. * [API] Remove support for Python 2 (#143) * remove support for Python 2 * switch from Python 3.7 to 3.6 * [Backend] Fix LLVM CodeGen for intrinsics (#147) * Fix llvm codegen for instrinsic log, pow, and sqrt * fix test case * [Backend] Fix LLVM power intrinsic with large integer (#151) * fix llvm power with large integer * fix test case * [API] Fix Wrong Index Calculation in API reuse_at (#156) * test cases * rename * systolic * self loopback * fix test * fork and join * fix host only * sobel * [API] Adding printing function to HeteroCL (#178) * initial attempt for hcl.print * enable better printing * finish a version fo hcl.print * add tests * [Backend] Fixed Incorrect Behavior When Casting Constants to Very Long Int (#179) * fixed incorrect behavior in Halide * add test * map reduce example * reshape check * [backend] Add LLVM 9.0 support (#182) * [API][Backend] Fix hcl.print with UInt supported (#184) * memory * clean up stream type * codegen update * hbm support * host codegen update * fix auto-merge issue * fix extern ip * fix hls ip * update * if * join api * comment Co-authored-by: Yi-Hsiang (Sean) Lai Co-authored-by: HZ Chen <982270930@qq.com> --- Makefile.config | 4 +- hlib/python/hlib/ip/fft.py | 12 +- hlib/python/hlib/op/extern.py | 21 +- pkgs/Makefile.pkg.config | 2 +- python/heterocl/api.py | 101 ++++- python/heterocl/compute_api.py | 1 + python/heterocl/devices.py | 119 +++++- python/heterocl/schedule.py | 134 ++++++- python/heterocl/tvm/build_module.py | 11 +- python/heterocl/tvm/expr.py | 5 +- python/heterocl/tvm/schedule.py | 49 ++- python/heterocl/tvm/stmt.py | 4 + samples/mapreduce/mapreduce.py | 56 +++ samples/sobel/sobel.py | 91 ----- samples/sobel/sobel_main.py | 56 +++ samples/sobel/sobel_stream.py | 61 +++ samples/systolic_array/systolic_array_main.py | 65 +++ .../systolic_array/systolic_array_stream.py | 60 +++ .../systolic_array/systolic_array_vitis.py | 65 +++ tests/test_api_print.py | 41 ++ tests/test_api_print_cases/print_expr.py | 97 +++++ tests/test_api_print_cases/print_number.py | 18 + tests/test_api_print_cases/print_tensor_1D.py | 25 ++ tests/test_api_print_cases/print_tensor_2D.py | 30 ++ tests/test_dsl_basic.py | 107 +++-- tests/test_dtype.py | 17 + tests/test_schedule_stream.py | 278 +++++++++++-- tvm/HalideIR/src/ir/Expr.h | 16 +- tvm/HalideIR/src/ir/IR.cpp | 23 +- tvm/HalideIR/src/ir/IR.h | 24 +- tvm/HalideIR/src/ir/IRMutator.cpp | 19 + tvm/HalideIR/src/ir/IRMutator.h | 1 + tvm/HalideIR/src/ir/IRPrinter.cpp | 11 + tvm/HalideIR/src/ir/IRVisitor.cpp | 13 + tvm/HalideIR/src/ir/IRVisitor.h | 2 + tvm/include/tvm/ir.h | 2 + tvm/include/tvm/ir_functor_ext.h | 2 + tvm/include/tvm/ir_mutator.h | 1 + tvm/include/tvm/ir_visitor.h | 1 + tvm/include/tvm/schedule.h | 9 +- tvm/src/api/api_ir.cc | 1 + tvm/src/api/api_lang.cc | 16 +- tvm/src/codegen/build_common.cc | 13 +- tvm/src/codegen/build_util.cc | 18 +- tvm/src/codegen/build_util.h | 3 +- tvm/src/codegen/codegen_c.cc | 52 ++- tvm/src/codegen/codegen_c.h | 5 +- tvm/src/codegen/hlsc/codegen_vhls.cc | 66 ++- tvm/src/codegen/hlsc/codegen_vhls.h | 1 + tvm/src/codegen/llvm/codegen_llvm.cc | 36 ++ tvm/src/codegen/llvm/codegen_llvm.h | 2 + tvm/src/codegen/llvm/llvm_module.cc | 32 ++ tvm/src/codegen/merlinc/codeanalys_merlinc.cc | 1 - tvm/src/codegen/opencl/codegen_aocl.cc | 30 +- tvm/src/codegen/opencl/codegen_aocl_host.cc | 4 +- tvm/src/codegen/opencl/codegen_xocl.cc | 11 - tvm/src/codegen/opencl/codegen_xocl_host.cc | 115 ++++-- tvm/src/lang/ir.cc | 1 + tvm/src/pass/ir_mutator.cc | 13 +- tvm/src/pass/ir_visitor.cc | 9 +- tvm/src/pass/storage_flatten.cc | 4 +- tvm/src/pass/stream_inference.cc | 376 ++++++++++++++++-- tvm/src/schedule/schedule_dataflow_rewrite.cc | 352 +++++++++++----- tvm/src/schedule/schedule_lang.cc | 1 - tvm/src/schedule/schedule_ops.cc | 20 +- tvm/src/schedule/schedule_reorder.cc | 332 ++++++++++++---- 66 files changed, 2554 insertions(+), 614 deletions(-) create mode 100644 samples/mapreduce/mapreduce.py delete mode 100644 samples/sobel/sobel.py create mode 100644 samples/sobel/sobel_main.py create mode 100644 samples/sobel/sobel_stream.py create mode 100644 samples/systolic_array/systolic_array_main.py create mode 100644 samples/systolic_array/systolic_array_stream.py create mode 100644 samples/systolic_array/systolic_array_vitis.py create mode 100644 tests/test_api_print.py create mode 100644 tests/test_api_print_cases/print_expr.py create mode 100644 tests/test_api_print_cases/print_number.py create mode 100644 tests/test_api_print_cases/print_tensor_1D.py create mode 100644 tests/test_api_print_cases/print_tensor_2D.py diff --git a/Makefile.config b/Makefile.config index 60d1cfd3e..7a99b6224 100644 --- a/Makefile.config +++ b/Makefile.config @@ -3,7 +3,9 @@ LLVM_CONFIG = $(shell which llvm-config-4.0 2>/dev/null || \ which llvm-config-5.0 2>/dev/null || \ which llvm-config-6.0 2>/dev/null || \ - which llvm-config-7.0 2>/dev/null) + which llvm-config-7.0 2>/dev/null || \ + which llvm-config-8.0 2>/dev/null || \ + which llvm-config-9.0 2>/dev/null) # set your own path to cmake CMAKE_CONFIG = $(shell which cmake 2> /dev/null) diff --git a/hlib/python/hlib/ip/fft.py b/hlib/python/hlib/ip/fft.py index 5b5d34bff..4383641e8 100644 --- a/hlib/python/hlib/ip/fft.py +++ b/hlib/python/hlib/ip/fft.py @@ -41,20 +41,20 @@ def single_fft_hls(X_real, X_imag, F_real=None, F_imag=None, name=None): hcl.update(F_imag, lambda i: X_imag[Table[i]], name='F_imag_update') with hcl.Stage("Out"): - one = hcl.scalar(1, dtype="int32") + one = hcl.scalar(1, dtype="int32", name="one") with hcl.for_(0, num_stages) as stage: DFTpts = one[0] << (stage + 1) numBF = DFTpts / 2 e = -2 * np.pi / DFTpts - a = hcl.scalar(0) + a = hcl.scalar(0, "a") with hcl.for_(0, numBF) as j: - c = hcl.scalar(hcl.cos(a[0])) - s = hcl.scalar(hcl.sin(a[0])) + c = hcl.scalar(hcl.cos(a[0]), name="cos") + s = hcl.scalar(hcl.sin(a[0]), name="sin") a[0] = a[0] + e with hcl.for_(j, L + DFTpts - 1, DFTpts) as i: i_lower = i + numBF - temp_r = hcl.scalar(F_real[i_lower] * c - F_imag[i_lower] * s) - temp_i = hcl.scalar(F_imag[i_lower] * c + F_real[i_lower] * s) + temp_r = hcl.scalar(F_real[i_lower] * c - F_imag[i_lower] * s, "temp_r") + temp_i = hcl.scalar(F_imag[i_lower] * c + F_real[i_lower] * s, "temp_i") F_real[i_lower] = F_real[i] - temp_r[0] F_imag[i_lower] = F_imag[i] - temp_i[0] F_real[i] = F_real[i] + temp_r[0] diff --git a/hlib/python/hlib/op/extern.py b/hlib/python/hlib/op/extern.py index 74c1173ad..303c1e01b 100644 --- a/hlib/python/hlib/op/extern.py +++ b/hlib/python/hlib/op/extern.py @@ -82,20 +82,17 @@ def with_attrs(f): # create hls ip invoked within the top function -def create_hls_ip(op, name, args, ip_type="hls", path=None): +def create_hls_ip(stage, name, args, ip_type="hls", path=None): # must be called within a superstage - assert Stage._current - curr = Schedule.last_stages[-1] - input_ops = [i._op for i in curr.input_stages] - output_bufs = [curr._buf] + input_ops = [i._op for i in stage.input_stages] + output_bufs = [stage._buf] # include external ip files -def create_extern_module(op, dicts, ip_type="hls", path=None): - curr = Schedule.last_stages[-1] - input_ops = [i._op for i in curr.input_stages] - input_bufs = [i._buf for i in curr.input_stages] - output_bufs = [curr._buf] +def create_extern_module(stage, dicts, ip_type="hls", path=None): + input_ops = [i._op for i in stage.input_stages] + input_bufs = [i._buf for i in stage.input_stages] + output_bufs = [stage._buf] # input and output arguments assert "args" in dicts.keys() @@ -104,7 +101,7 @@ def create_extern_module(op, dicts, ip_type="hls", path=None): annotate_dict["input::" + name] = dtype del annotate_dict["args"] - op = op._op.op + op = stage._op.op assert ip_type in ["rtl", "hls", "host"] body = _make.ExternModule( "top", _make.StringImm(ip_type), op.body, @@ -113,5 +110,5 @@ def create_extern_module(op, dicts, ip_type="hls", path=None): new_op = _ExternOp( op.name, op.tag, op.axis, input_ops, input_bufs, output_bufs, body) - curr._op = new_op.output(0) + stage._op = new_op.output(0) diff --git a/pkgs/Makefile.pkg.config b/pkgs/Makefile.pkg.config index 0704af435..be0c4bd38 100644 --- a/pkgs/Makefile.pkg.config +++ b/pkgs/Makefile.pkg.config @@ -21,7 +21,7 @@ else CMAKE_OK = yes endif -ifeq ("", "$(findstring $(LLVM_VERSION), 40 50 60 70)") +ifeq ("", "$(findstring $(LLVM_VERSION), 40 50 60 70 80 90)") DIRS += llvm endif diff --git a/python/heterocl/api.py b/python/heterocl/api.py index b0c9351b5..af177b31e 100644 --- a/python/heterocl/api.py +++ b/python/heterocl/api.py @@ -1,13 +1,14 @@ """This module contains all HeteroCL APIs""" #pylint: disable=no-member +import numbers from ordered_set import OrderedSet from .tvm.build_module import build as _build, lower as _lower -from .tvm.api import convert +from .tvm.api import convert, _IterVar from .tvm import _api_internal as tvm_api from .tvm import schedule as _schedule -from .tvm import make as _make from .tvm import call_intrin -from .tensor import Scalar, Tensor +from .tvm import expr as _expr, stmt as _stmt, make as _make +from .tensor import Scalar, Tensor, TensorSlice from .schedule import Stage, Schedule from .scheme import Scheme from . import util @@ -142,8 +143,9 @@ def algo(A): """ if not isinstance(inputs, list): inputs = [inputs] - func(*inputs) - for op in Schedule.stage_ops: + with Stage("_top") as top: + func(*inputs) + for op in top.substages: func.__setattr__(op.name, op) return Scheme(inputs, func) @@ -201,7 +203,8 @@ def algo(A): Schedule.stage_ops = [] Schedule.last_stages = OrderedSet([]) # execute the algorithm - ret = func(*inputs) + with Stage("_top") as top: + ret = func(*inputs) # append the output tensors to the input list if ret is not None: if isinstance(ret, tuple): @@ -209,7 +212,8 @@ def algo(A): else: inputs.append(ret) # let each stage be an attribute of the function - for op in Schedule.stage_ops: + for op in top.substages: + #op = stage._op func.__setattr__(op.name, op) t = Schedule.last_stages ops = [t_._op.op for t_ in t] @@ -360,3 +364,86 @@ def select(cond, true, false): Expr """ return _make.Select(convert(cond), convert(true), convert(false)) + +def print(vals, format=""): + """Print a HeteroCL object. + + Parameters + ---------- + vals : Expr or list of Expr + The values to be printed + + format : string, optional + The printing format similar to printf + + Returns + ------- + None + """ + if not isinstance(vals, (tuple, list)): + vals = [vals] + + def get_format(val): + if isinstance(val, (TensorSlice, Scalar, _expr.Expr)): + if (util.get_type(val.dtype)[0] == "int" + or util.get_type(val.dtype)[0] == "uint"): + return "%lld" + else: + return "%f" + elif isinstance(val, int): + return "%d" + elif isinstance(val, float): + return "%f" + + def print_tensor(val, ivs, i, ndim): + if i == 0: #inner-most + iv = ivs[ndim-1] + stmt = _make.Print([], "[") + value = val[tuple(ivs)] + body = _make.Print([value], get_format(value)) + ite = _make.IfThenElse(iv < iv.dom.extent-1, + _make.Print([], ", "), + _make.Evaluate(0)) + body = _make.Block(body, ite) + loop = _make.For(iv.var, iv.dom.min, iv.dom.extent, 0, 0, body) + stmt = _make.Block(stmt, loop) + stmt = _make.Block(stmt, _make.Print([], "]")) + return stmt + else: + iv = ivs[ndim-1-i] + stmt = _make.Print([], "[") + body = print_tensor(val, ivs, i-1, ndim) + ite = _make.IfThenElse(iv < iv.dom.extent-1, + _make.Print([], ",\n"), + _make.Evaluate(0)) + body = _make.Block(body, ite) + loop = _make.For(iv.var, iv.dom.min, iv.dom.extent, 0, 0, body) + stmt = _make.Block(stmt, loop) + stmt = _make.Block(stmt, _make.Print([], "]")) + return stmt + + def print_val(val): + stage = Stage.get_current() + if isinstance(val, (Scalar, _expr.Expr, numbers.Number)): + stage.emit(_make.Print([val], get_format(val) + "\n")) + elif isinstance(val, TensorSlice) \ + and len(val.indices) == len(val.tensor.shape): + stage.emit(_make.Print([val], get_format(val) + "\n")) + else: # we are dealing with tensors + nshape = len(val.tensor.shape) + ndim = nshape + if isinstance(val, TensorSlice): + ndim = nshape - len(val.indices) + args = ["print_"+str(n) for n in range(0, ndim)] + ivs = [_IterVar((0, val.tensor.shape[nshape-n-1]), args[n], 0) \ + for n in range(0, ndim)] + import builtins + stage.emit(print_tensor(val, ivs, ndim-1, ndim)) + stage.emit(_make.Print([], "\n")) + + if format == "": + for val in vals: + print_val(val) + else: + stage = Stage.get_current() + stage.emit(_make.Print(vals, format)) diff --git a/python/heterocl/compute_api.py b/python/heterocl/compute_api.py index 05539e26f..8d178cab9 100644 --- a/python/heterocl/compute_api.py +++ b/python/heterocl/compute_api.py @@ -670,6 +670,7 @@ def assign_val(*indices): return compute(tuple(new_shape), assign_val, name, dtype) + def reduce_axis(lower, upper, name=None): """Create a reduction axis for reduction operations. diff --git a/python/heterocl/devices.py b/python/heterocl/devices.py index 3f1333b4d..16772fd79 100644 --- a/python/heterocl/devices.py +++ b/python/heterocl/devices.py @@ -57,6 +57,65 @@ def __repr__(self): "llvm" : tool("llvm", *option_table["llvm"]) } +class Memory(object): + """The base class for memory modules""" + def __init__(self, types, cap, channels): + self.types = types + self.capacity = cap + self.channels = channels + self.port = 0 + + def __getitem__(self, key): + if not isinstance(key, int): + raise DeviceError("port must be integer") + if key > self.channels: + raise DeviceError("port must be within \ + the channel range %d", self.channels) + self.port = key + return self + + def __str__(self): + return str(self.types) + ":" + \ + str(self.port) + +class DRAM(Memory): + def __init__(self, cap=16, channels=4): + super(DRAM, self).__init__("DRAM", cap, channels) + +class HBM(Memory): + def __init__(self, cap=32, channels=32): + super(HBM, self).__init__("HBM", cap, channels) + +class PLRAM(Memory): + def __init__(self, cap=32, channels=32): + super(PLRAM, self).__init__("PLRAM", cap, channels) + +class DevMediaPair(object): + def __init__(self, dev, media): + self.xcel = dev + self.memory = media + + @property + def dev(self): + return self.xcel + + @property + def media(self): + return self.memory + + def __getitem__(self, key): + if not isinstance(key, int): + raise DeviceError("port must be integer") + if key > self.media.channels: + raise DeviceError("port must be within \ + the channel range %d", self.media.channels) + self.media.port = key + return self + + def __str__(self): + return str(self.xcel) + ":" + \ + str(self.media) + class Device(object): """The base class for all device types @@ -69,18 +128,23 @@ class Device(object): model: str Model of device to place date """ - def __init__(self, types, vendor, - model, **kwargs): + def __init__(self, types, vendor, model, **kwargs): self.vendor = vendor self.types = types self.model = model - self.impls = {"lang": ""} + self.impls = { "lang": "" } for key, value in kwargs.items(): self.impls[key] = value + # connect to ddr by default + self.storage = { "ddr" : DRAM() } def __getattr__(self, key): """ device hierarchy """ - return self.impls[key] + if key in self.impls.keys(): + return self.impls[key] + else: # return attached memory + media = self.storage[key] + return DevMediaPair(self, media) def set_lang(self, lang): assert lang in \ @@ -89,6 +153,7 @@ def set_lang(self, lang): self.impls["lang"] = lang return self + class CPU(Device): """cpu device with different models""" def __init__(self, vendor, model, **kwargs): @@ -97,6 +162,7 @@ def __init__(self, vendor, model, **kwargs): assert "cpu_" + model in model_table[vendor], \ model + " not supported yet" super(CPU, self).__init__("CPU", vendor, model, **kwargs) + def __repr__(self): return "cpu-" + self.vendor + "-" + str(self.model) + \ ":" + self.impls["lang"] @@ -109,6 +175,11 @@ def __init__(self, vendor, model, **kwargs): assert "fpga_" + model in model_table[vendor], \ model + " not supported yet" super(FPGA, self).__init__("FPGA", vendor, model, **kwargs) + # attach supported memory modules + if vendor == "xilinx" and "xcvu19p" in model: + self.storage["hbm"] = HBM() + self.storage["plram"] = PLRAM() + def __repr__(self): return "fpga-" + self.vendor + "-" + str(self.model) + \ ":" + self.impls["lang"] @@ -121,6 +192,7 @@ def __init__(self, vendor, model, **kwargs): assert "gpu_" + model in model_table[vendor], \ model + " not supported yet" super(GPU, self).__init__("GPU", vendor, model, **kwargs) + def __repr__(self): return "gpu-" + self.vendor + "-" + str(self.model) + \ ":" + self.impls["lang"] @@ -177,7 +249,7 @@ def __getattr__(cls, key): host = devs[0].set_lang("c") xcel = None else: # unsupported device - raise DeviceError("not supported") + raise DeviceError(key + " not supported") tool = tool_table[key] return cls(key, devs, host, xcel, tool) @@ -213,6 +285,10 @@ def config(self, compile=None, mode=None, backend=None): "not support backend lang " + backend self.xcel.lang = backend + # check correctness of device attribute + if self.host.lang == "": + self.host.lang = "xocl" + def __getattr__(self, key): """ return tool options """ return self.tool.__getattr__(key) @@ -233,6 +309,39 @@ def __repr__(self): str(self.host) + " : " + \ str(self.xcel) + ")" + @classmethod + def custom(cls, config): + assert isinstance(config, dict) + assert "host" in config.keys() + if "xcel" not in config.keys(): + print("\33[1;34m[HeteroCL Warning]\33[0m" + "empty xcel slots") + host = config["host"] + xcel = None if not "xcel" in config.keys() else config["xcel"] + devs = [ host ] + xcel + # TODO: support multiple xcel devs + if isinstance(xcel, list): + xcel = xcel[0] + tool = None + return cls("custom", devs, host, xcel, tool) + + +class dev(object): + def __init__(self, types, vendor, model): + self.types = types + + @classmethod + def cpu(cls, vendor, model): + return CPU(vendor, model) + + @classmethod + def fpga(cls, vendor, model): + return FPGA(vendor, model) + + @classmethod + def gpu(cls, vendor, model): + return GPU(vendor, model) + + def device_to_str(dtype): """Convert a device type to string format. diff --git a/python/heterocl/schedule.py b/python/heterocl/schedule.py index 0fbfc33ca..22b016305 100644 --- a/python/heterocl/schedule.py +++ b/python/heterocl/schedule.py @@ -12,7 +12,7 @@ from .tvm._api_internal import _ExternOp from .debug import DSLError, APIError from . import util -from .devices import Device +from .devices import Device, DevMediaPair class Schedule(object): """Create a compute schedule. @@ -98,6 +98,34 @@ def gen_graph(stage, y): return graph, op_map + + def subgraph(self, inputs, outputs): + assert len(inputs) > 0, "empty inputs" + assert len(outputs) > 0, "empty outputs" + graph, op_map = self.dataflow_graph() + + # check availability + inputs = [ _.name.replace(".new", "") for _ in inputs ] + outputs = [ _.name.replace(".new", "") for _ in outputs ] + + # from root to parents + stack = outputs + subgraph = set() + while len(stack) > 0: + op = stack.pop() + if op in subgraph: continue + subgraph.add(op) + if op not in graph.nodes: + op = "_top." + op + assert op in graph.nodes, \ + "cannot find node " + op + " in " + str(graph.nodes) + for _ in graph.predecessors(op): + if not op in inputs: + stack.append(_) + + return op_map + + def reuse_at(self, target, parent, axis, name=None): """Create a reuse buffer reusing the output of current stage @@ -136,9 +164,44 @@ def reuse_at(self, target, parent, axis, name=None): name = target.name + ".reuse" return self.sch.reuse_at(target, parent, axis, name) + + def join(self, srcs, dest=None): + """ join multiple tensors to single dest """ + assert len(srcs) > 0, "joined tensors should be " + \ + "collectde from more than one srcs" + + # create channels and collector stage + if dest is not None: + if isinstance(dest, tuple): + dest, target = dest + dest = self[dest] + elif isinstance(dest, Stage): + target = dest._op + elif isinstance(dest, tuple): + src, target = dest + else: # target tensor + target = dest.tensor + else: target = dest + + for src in srcs: + if isinstance(src, tuple): + src, tensor = src + assert tensor == target, + \ + "inconsistent tensor joining" + self.sch.join(target, dest, self[src]) + + + def fork(self, tensor, dests, axis=0): + """ fork tensor to multiple dests """ + assert len(dests) > 0, "forked tensor should be " + \ + "broadcast to more than one dest" + # dest as tvm stages + for dest in dests: + self.to(tensor, self[dest]) + + def to(self, tensors, dst, src=None, axis=0, - stream_type=_expr.StreamExpr.Channel, - depth=1, name=None): + stream_type=_expr.StreamExpr.FIFO, depth=1, name=None): """Stream a list of Tensors to dst devices Parameters @@ -166,25 +229,42 @@ def to(self, tensors, dst, src=None, axis=0, tensors = [tensors] for tensor in tensors: try: - target = tensor.tensor - except (AttributeError, ValueError): - try: + if isinstance(tensor, Stage): target = tensor._op - except AttributeError: - target = tensor + # unpack tuple of src stage and tensor + elif isinstance(tensor, tuple): + src, target = tensor + # from hcl stage to tvm stage + src = self.__getitem__(src) + else: # target tensor + target = tensor.tensor + except (AttributeError, ValueError): + target = tensor + + # convert hcl stage + try: dst = self[dst] + except: pass - # record the placement op.output if src is None: - if isinstance(dst, Device): - self.placement[target] = dst - else: # auto-complete + # move to device + if isinstance(dst, Device) or \ + isinstance(dst, DevMediaPair): + if axis == 0: + self.placement[target] = dst + else: + assert isinstance(tensor, Stage) + target = self[tensor] + + else: # inter-stage src = self[tensor] # target can be stage or tensor ret = self.sch.to(target, dst, src, axis, stream_type, depth) rets.append(ret) - return rets + + if len(rets) == 1: return rets[0] + else: return rets def partition(self, target, partition_type=_stmt.Partition.Complete, dim=0, factor=0): """Partition a Tensor into smaller Tensors or even registers @@ -331,6 +411,7 @@ def __init__(self, name=None, dtype=None, shape=()): self.ret_dtype = None self.for_level = 0 self.for_ID = 0 + self.substages = [] # Attributes for cross-stage relation self.input_stages = set([]) self.lhs_tensors = set([]) @@ -384,8 +465,10 @@ def __exit__(self, ptype, value, trace): superstage.var_dict[self.name] = self # update prefix self.name_with_prefix = superstage.name_with_prefix + "." + self.name - - else: # otherwise update the list of stages globally + # update superstage's substages + superstage.substages.append(self) + # Otherwise update the list of stages globally + else: Schedule.stage_ops.append(self) Schedule.last_stages.add(self) Schedule.last_stages -= self.input_stages @@ -395,7 +478,26 @@ def __repr__(self): def __getattr__(self, name): try: - return self.var_dict[name] + if name in self.var_dict: + return self.var_dict[name] + else: + # return stage and target tensor op + for tensor in self.lhs_tensors: + if tensor.name == name: + return (self, tensor._tensor) + # check tensors in input stages + for stage in self.input_stages: + if stage.name == name: + return (self, stage._op) + # check tensors in input_stage.lhs + for stage in self.input_stages: + lhs = stage.lhs_tensors + for tensor in lhs: + if tensor.name == name: + return (self, tensor._tensor) + raise ValueError("Member " + name + \ + " not found in " + str(self.lhs_tensors) + " or " + \ + str(self.input_stages)) except KeyError: raise ValueError("Uknown member " + name + " of " + self.name) diff --git a/python/heterocl/tvm/build_module.py b/python/heterocl/tvm/build_module.py index 97ebabf68..561c2128a 100755 --- a/python/heterocl/tvm/build_module.py +++ b/python/heterocl/tvm/build_module.py @@ -40,7 +40,7 @@ def run_process(cmd, pattern=None, env=None): return out.decode("utf-8") @register_func -def tvm_callback_exec_evaluate(platform, mode): +def tvm_callback_exec_evaluate(platform, mode, host_only): # perform simulation and extract qor qor = dict() @@ -107,6 +107,7 @@ def tvm_callback_exec_evaluate(platform, mode): cmd = "cd project; " + \ "XCL_EMULATION_MODE=sw_emu ./host build_dir" + \ ".sw_emu." + device + "/kernel.xclbin" + if host_only: cmd = "cd project; ./host" out = run_process(cmd) elif platform == "aocl": @@ -121,7 +122,7 @@ def tvm_callback_exec_evaluate(platform, mode): return str(qor) @register_func -def copy_and_compile(platform, mode, backend, cfg): +def copy_and_compile(platform, mode, backend, host_only, cfg): """ create necessary files and compile into binary """ path = api.__file__ path = os.path.join(path[0:path.find("python")], "tvm/src/template/") @@ -229,9 +230,11 @@ def copy_and_compile(platform, mode, backend, cfg): assert "XDEVICE" in os.environ, \ "vitis platform info missing" os.system("cp " + path + "vitis/* project/") - cmd = "cd project; make clean;" - cmd += "make all TARGET=sw_emu DEVICE=$XDEVICE" + + if not host_only: + cmd += "make all TARGET=sw_emu DEVICE=$XDEVICE" + else: cmd += "make host" out = run_process(cmd) return "success" diff --git a/python/heterocl/tvm/expr.py b/python/heterocl/tvm/expr.py index d1ea4ae75..af23b8c30 100644 --- a/python/heterocl/tvm/expr.py +++ b/python/heterocl/tvm/expr.py @@ -385,6 +385,5 @@ class KernelExpr(Expr): @register_node class StreamExpr(Expr): - Channel = 0 - Pipe = 1 - FIFO = 2 + FIFO = 0 + DoubleBuffer = 1 diff --git a/python/heterocl/tvm/schedule.py b/python/heterocl/tvm/schedule.py index 7fe630e5d..15b333d92 100644 --- a/python/heterocl/tvm/schedule.py +++ b/python/heterocl/tvm/schedule.py @@ -3,7 +3,7 @@ from ._ffi.base import string_types from ._ffi.node import NodeBase, register_node from ._ffi.function import _init_api -from ..devices import Device +from ..devices import Device, DevMediaPair from . import _api_internal from . import tensor as _tensor from . import expr as _expr @@ -334,36 +334,42 @@ def reuse_at(self, target, parent, axis, name): def partition(self, target, partition_type, dim, factor): return _api_internal._SchedulePartition(self, target, dim, factor, partition_type) - def to(self, tensor, dst, src, index=0, - types=_expr.StreamExpr.Channel, depth=1): + def join(self, target, dst, src, type=_expr.StreamExpr.FIFO, depth=1): + """ join multiple writes to target tensor """ + return _api_internal._ScheduleJoin(self, target, src, dst, type, depth) + + def to(self, tensor, dst, src, axis=0, + type=_expr.StreamExpr.FIFO, depth=1): """ Stream data to devices or on-chip module Parameters ---------- tensor : list of Tensors Tensor to be streamed. - dst : hcl device or dst stage - The device or module for streaming - type : channel type - The streaming type (e.g. fifo or pipe) Returns ------- Tensor """ # create producer and consumer for stream - if isinstance(dst, Device): + if isinstance(dst, Device) or isinstance(dst, DevMediaPair): + pair = False if isinstance(dst, Device) else True + media = dst.media if pair else dst.ddr.media dst = 1 if 'fpga' in str(dst) else 0 if isinstance(tensor, _Stage): # move data within stage return _api_internal._ScheduleInStageMove( - self, tensor, dst, types, depth, index) + self, tensor, dst, type, depth, axis) else: # move placeholder or extern op assert isinstance(tensor, _tensor._Tensor), \ "input " + str(tensor) + " not a tensor" - return _api_internal._ScheduleMove( - self, tensor, dst, - types, depth, index) + if media.types == "DRAM": dev = 0 + else: # move to hetero-storage-dev + dev = 1 if media.types == "HBM" else 2 + + dev_port = [dev, media.port] + return _api_internal._ScheduleMove(self, tensor, src, dst, + type, depth, dev_port) else: # inter-stage streaming assert isinstance(dst, _Stage), "dst not a stage " @@ -379,7 +385,7 @@ def to(self, tensor, dst, src, index=0, index = index + 1 if len(match) > 1: - names = [str(n).replace(dst.op.name + ".", "") for n in dst.op.body.args] + names = [str(n).replace("_top." + dst.op.name + ".", "") for n in dst.op.body.args] assert str(tensor.op.name) in names, \ "unknwon arg, please specify id " + \ str(names) + ":" + str(tensor.op.name) @@ -395,27 +401,28 @@ def to(self, tensor, dst, src, index=0, index = index + 1 if len(match) > 2: # use name for matching - names = [str(n).replace(src.op.name + ".", "") + names = [str(n).replace("_top." + src.op.name + ".", "") for n in src.op.body.args] assert str(tensor.op.name) in names, \ "unknwon arg, please specify id" + \ str(names) + ":" + str(tensor.op.name) match = [match[0], names.index(str(tensor.op.name))] + # stream between two kernel defs _api_internal._ScheduleStream( self, tensor, dst, src, - match, types, depth, "link") + match, type, depth, "link") - else: # multi-cast from local buffer to kernel + else: # from local buffer to kernel _api_internal._ScheduleMoveToStage( self, tensor, dst, match[0], - types, depth, "stream") + type, depth, "stream") else: # inter-stage streaming channel - index_list = [] - _api_internal._ScheduleStream( - self, tensor, dst, src, - index_list, types, depth, "link") + index_list = [] + _api_internal._ScheduleStream( + self, tensor, dst, src, + index_list, type, depth, "link") @register_node("Stage") class _Stage(NodeBase): diff --git a/python/heterocl/tvm/stmt.py b/python/heterocl/tvm/stmt.py index 99820cbf0..d340c90e1 100644 --- a/python/heterocl/tvm/stmt.py +++ b/python/heterocl/tvm/stmt.py @@ -120,3 +120,7 @@ class Stencil(Stmt): @register_node class StreamStmt(Stmt): pass + +@register_node +class Print(Stmt): + pass diff --git a/samples/mapreduce/mapreduce.py b/samples/mapreduce/mapreduce.py new file mode 100644 index 000000000..079885784 --- /dev/null +++ b/samples/mapreduce/mapreduce.py @@ -0,0 +1,56 @@ +import heterocl as hcl +st = hcl.Struct({"key": hcl.Int(32), "val": hcl.Int(32)}) +size = 1024 +class_number = 6 +compute_units = 4 + +hcl.init(hcl.UInt(32)) +inputs = hcl.placeholder((size,), dtype=st, name="input") + +def kernel(inputs): + + def split(inputs, number): + cus = [] + size = inputs.shape[0] + for i in range(number): + base = i * (size/number) + name = "batch_" + str(i) + ret = hcl.compute((int(size/number),), + lambda x: inputs[base+x], dtype=st, name=name) + cus.append(ret) + return cus + + # ret is the input slice { (key, value)...} + # res is the intermediate result + def count(res, ret, x): + res[ret[x].key] += ret[x].val + + def reducer(ress, output, x): + for res in ress: + output[x] += res[x] + + rets = split(inputs, compute_units) + + ress = [] + for ret in rets: + name = "map_batch_" + str(rets.index(ret)) + res = hcl.compute((class_number,), lambda *args: 0, name=name) + # mapping (accumulate quality scores in each batch) + hcl.mutate((int(size/compute_units),), + lambda x: count(res, ret, x), name="mutate_" + name) + ress.append(res) + + # shuffle and reduce the ress into output + output = hcl.compute((class_number, ), lambda x: 0, name="output") + hcl.mutate((class_number,), lambda x: reducer(ress, output, x), "reducer") + return output + +target = hcl.platform.aws_f1 +s = hcl.create_schedule([inputs], kernel) + +# new_inputs = s.to(inputs, target.xcel) +# s.to(kernel.reducer.output, target.host) + +# consumers = [getattr(kernel, "batch_" + str(_)) for _ in range(compute_units)] +# s.multicast(new_inputs, consumers) +print(hcl.lower(s)) diff --git a/samples/sobel/sobel.py b/samples/sobel/sobel.py deleted file mode 100644 index a4299d8ae..000000000 --- a/samples/sobel/sobel.py +++ /dev/null @@ -1,91 +0,0 @@ -import heterocl as hcl -import hlib -import numpy as np -from PIL import Image -from urllib.request import urlopen - -batch_size = 1 -hcl.init(hcl.UInt(32)) -dtype = hcl.UInt(32) -image_size = () -kernel_size = 3 - -# setup target using vivado -tool = hcl.tool.vivado("csim") -target = hcl.platform.zc706 - -def sobel(): - image = hcl.placeholder((batch_size, 1, 256, 256), "input_image") - k1 = hcl.placeholder((1, 1, 3, 3), "kernel_1") - k2 = hcl.placeholder((1, 1, 3, 3), "kernel_2") - - def kernel(input_image, kernel_1, kernel_2): - - def absolute(image, *args): - with hcl.if_(image[args] > 0): - hcl.return_(image[args]) - with hcl.else_(): - hcl.return_(-1 * image[args]) - - def dev(gx, gy, org): - assert gx.shape == gy.shape, "mismatch" - rx = hcl.reduce_axis(0, 255, "rx") - ry = hcl.reduce_axis(0, 255, "ry") - mat_sum = hcl.compute(gx.shape, lambda nn, ff, xx, yy: - gx[nn, ff, xx, yy] + gy[nn, ff, xx, yy], name="add") - return hcl.compute(mat_sum.shape, lambda nn, ff, xx, yy: - mat_sum[nn, ff, xx, yy] * 255.0 / hcl.max(mat_sum[nn, ff, rx, ry], axis=[rx, ry]), - name = "derv") - - # make the conv op a kernel on fpga. - # return tensor required (cannot do def_()) - output_shape = (1,1,254,254) - - # make compute wrapped in hcl def - module1 = hcl.def_([input_image.shape, kernel_1.shape, output_shape], name="conv1")(hlib.nn.conv2d_nchw_imp) - module2 = hcl.def_([input_image.shape, kernel_1.shape, output_shape], name="conv2")(hlib.nn.conv2d_nchw_imp) - conv1 = hcl.compute(output_shape, lambda *args: 0) - conv2 = hcl.compute(output_shape, lambda *args: 0) - module1(input_image, kernel_1, conv1) - module2(input_image, kernel_2, conv2) - - abs1 = hcl.compute(conv1.shape, - lambda *args: absolute(conv1, *args)) - abs2 = hcl.compute(conv2.shape, - lambda *args: absolute(conv2, *args)) - - # derivative module for normalization - return dev(abs1, abs2, input_image) - - s = hcl.create_schedule([image, k1, k2], kernel) - - # data moved to local - i0, k10 = s.to([image, k1], target.fpga) - s.to([i0, k10], s[kernel.conv1]) - s.to(kernel.derv, target.cpu) - - # create stream channel between modules - print(type(target.fpga), hcl.lower(s)) - return hcl.build(s, target) - -# Load sample data -img = Image.open(urlopen('http://i.stack.imgur.com/8zINU.gif')) -kernel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) -kernel_y = np.flip(kernel_x.T.T, axis=0) -img = np.array(img) - -img = img[np.newaxis, ...] -img = img[np.newaxis, ...] -kernel_x = kernel_x[np.newaxis, ...] -kernel_x = kernel_x[np.newaxis, ...] -kernel_y = kernel_y[np.newaxis, ...] -kernel_y = kernel_y[np.newaxis, ...] - -hcl_input = hcl.asarray(img, dtype) -kernel_x = hcl.asarray(kernel_x, dtype) -kernel_y = hcl.asarray(kernel_y, dtype) -hcl_output = hcl.asarray(np.zeros((1,1,254,254)), dtype) - -f = sobel() -f(hcl_input, kernel_x, kernel_y, hcl_output) - diff --git a/samples/sobel/sobel_main.py b/samples/sobel/sobel_main.py new file mode 100644 index 000000000..552b37352 --- /dev/null +++ b/samples/sobel/sobel_main.py @@ -0,0 +1,56 @@ +from PIL import Image +import heterocl as hcl +import numpy as np +import math +import imageio +from urllib.request import urlopen + +hcl.init(init_dtype=hcl.Float()) +img = Image.open(urlopen('http://i.stack.imgur.com/8zINU.gif')) +width, height = img.size + +A = hcl.placeholder((height,width), "A", dtype=hcl.Float()) +Gx = hcl.placeholder((3,3), "Gx",dtype=hcl.Float()) +Gy = hcl.placeholder((3,3), "Gy",dtype=hcl.Float()) + +def sobel(A, Gx, Gy): + + r = hcl.reduce_axis(0,3) + c = hcl.reduce_axis(0,3) + B = hcl.compute((height-2,width-2), + lambda x,y: hcl.sum(A[x+r,y+c]*Gx[r,c], axis=[r,c]), + name="B", dtype=hcl.Float()) + t = hcl.reduce_axis(0,3) + g = hcl.reduce_axis(0,3) + + C = hcl.compute((height-2,width-2), + lambda x,y: hcl.sum(A[x+t,y+g]*Gy[t,g], axis=[t,g]), + name="C", dtype=hcl.Float()) + return hcl.compute((height-2,width-2), + lambda x, y :hcl.sqrt(B[x,y]*B[x,y] + C[x,y]*C[x,y])/4328*255, + dtype=hcl.Float()) + +s = hcl.create_schedule([A,Gx,Gy],sobel) +f = hcl.build(s) + +npA = np.array(img) +npGx = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]) +npGy = np.array([[1,2,1],[0,0,0],[-1,-2,-1]]) +hcl_A = hcl.asarray(npA) +hcl_Gx = hcl.asarray(npGx) +hcl_Gy = hcl.asarray(npGy) + +npF = np.zeros((height-2,width-2)) +hcl_F = hcl.asarray(npF) + +f(hcl_A, hcl_Gx,hcl_Gy, hcl_F) +npF = hcl_F.asnumpy() + +newimg = np.zeros((height-2,width-2,3)) +for x in range(0, height-2): + for y in range(0, width-2): + for z in range(0,3): + newimg[x,y,z] = npF[x,y] + +newimg = newimg.astype(np.uint8) +# imageio.imsave("pic_sobel.jpg",newimg) diff --git a/samples/sobel/sobel_stream.py b/samples/sobel/sobel_stream.py new file mode 100644 index 000000000..90fe2f892 --- /dev/null +++ b/samples/sobel/sobel_stream.py @@ -0,0 +1,61 @@ +from PIL import Image +import heterocl as hcl +import numpy as np +import math +import imageio +from urllib.request import urlopen + +hcl.init(init_dtype=hcl.Float()) +img = Image.open(urlopen('http://i.stack.imgur.com/8zINU.gif')) +width, height = img.size + +A = hcl.placeholder((height,width), "A", dtype=hcl.Float()) +Gx = hcl.placeholder((3,3), "Gx",dtype=hcl.Float()) +Gy = hcl.placeholder((3,3), "Gy",dtype=hcl.Float()) + +def sobel(A, Gx, Gy): + + r = hcl.reduce_axis(0,3) + c = hcl.reduce_axis(0,3) + B = hcl.compute((height-2,width-2), + lambda x,y: hcl.sum(A[x+r,y+c]*Gx[r,c], axis=[r,c], name="sum1"), + name="B", dtype=hcl.Float()) + t = hcl.reduce_axis(0,3) + g = hcl.reduce_axis(0,3) + + C = hcl.compute((height-2,width-2), + lambda x,y: hcl.sum(A[x+t,y+g]*Gy[t,g], axis=[t,g], name="sum2"), + name="C", dtype=hcl.Float()) + return hcl.compute((height-2,width-2), + lambda x, y :hcl.sqrt(B[x,y]*B[x,y] + C[x,y]*C[x,y])/4328*255, + name="output", dtype=hcl.Float()) + +target = hcl.platform.aws_f1 +target.config(compile="vitis", backend="vhls") + +s = hcl.create_schedule([A,Gx,Gy], sobel) +s.to([Gx, Gy, A], target.xcel) +s.to(sobel.output, target.host) +f = hcl.build(s, target) + +npA = np.array(img) +npGx = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]) +npGy = np.array([[1,2,1],[0,0,0],[-1,-2,-1]]) +hcl_A = hcl.asarray(npA) +hcl_Gx = hcl.asarray(npGx) +hcl_Gy = hcl.asarray(npGy) + +npF = np.zeros((height-2,width-2)) +hcl_F = hcl.asarray(npF) + +f(hcl_A, hcl_Gx,hcl_Gy, hcl_F) +npF = hcl_F.asnumpy() + +newimg = np.zeros((height-2,width-2,3)) +for x in range(0, height-2): + for y in range(0, width-2): + for z in range(0,3): + newimg[x,y,z] = npF[x,y] + +newimg = newimg.astype(np.uint8) +# imageio.imsave("pic_sobel.jpg", newimg) diff --git a/samples/systolic_array/systolic_array_main.py b/samples/systolic_array/systolic_array_main.py new file mode 100644 index 000000000..25d221224 --- /dev/null +++ b/samples/systolic_array/systolic_array_main.py @@ -0,0 +1,65 @@ +import heterocl as hcl +import numpy as np +import time + +m = k = n = 16 +x_max = y_max = 16 + +def gemm(m=16, n=16, k=16, dtype=hcl.Int(), target=None): + matrix_1 = hcl.placeholder((m, k), dtype=dtype) + matrix_2 = hcl.placeholder((k, n), dtype=dtype) + + def kernel(matrix_1, matrix_2): + r = hcl.reduce_axis(0, k, 'k') + return hcl.compute((m, n), + lambda x, y: hcl.sum(matrix_1[x, r] * matrix_2[r, y], + axis=r, dtype=dtype), + dtype=dtype, + name="out_matrix") + + s = hcl.create_schedule([matrix_1, matrix_2], kernel) + out_matrix = kernel.out_matrix + block_size = 8 + y0, y1 = s[out_matrix].split(out_matrix.axis[0], factor=block_size) + x0, x1 = s[out_matrix].split(out_matrix.axis[1], factor=block_size) + s[out_matrix].reorder(y0, x0, y1, x1) + + f = hcl.build(s, target=target) + return f + +def systolic(m=16, k=16, n=16, dtype=hcl.Int(), target=None): + hcl.init(dtype) + + dim_x, dim_y = 16, 16 + m_A = hcl.placeholder((m, k), dtype=dtype, name="m_A") + m_B = hcl.placeholder((k, n), dtype=dtype, name="m_B") + m_output = hcl.placeholder((m, n), dtype=dtype, name="m_output") + + # k (time) and y/x (spatial) dim + def kernel(k, y, x): + last = hcl.scalar( + hcl.select(k==0, 0, m_output[y, x]), "last") + m_output[y, x] = last.v + m_A[y, k] * m_B[k, x] + + hcl.mutate((m, dim_y, dim_x), + lambda k, y, x: kernel(k, y, x)) + s = hcl.create_schedule([m_A, m_B, m_output]) + f = hcl.build(s, target=target) + return f + +dtype = hcl.Int() +fg = gemm(m, n, k, dtype=hcl.Int(), target="llvm") +fs = systolic(m, n, k, dtype=hcl.Int(), target="llvm") + +np_1 = np.random.randint(10, size=(m, k)) +np_2 = np.random.randint(10, size=(k, n)) +np_3 = np.matmul(np_1, np_2) + +hcl_m1 = hcl.asarray(np_1, dtype=dtype) +hcl_m2 = hcl.asarray(np_2, dtype=dtype) +hcl_m3 = hcl.asarray(np.zeros((m, n)), dtype=dtype) +hcl_m4 = hcl.asarray(np.zeros((m, n)), dtype=dtype) + +fg(hcl_m1, hcl_m2, hcl_m3) +fs(hcl_m1, hcl_m2, hcl_m4) +assert np.array_equal(hcl_m3.asnumpy(), hcl_m4.asnumpy()) diff --git a/samples/systolic_array/systolic_array_stream.py b/samples/systolic_array/systolic_array_stream.py new file mode 100644 index 000000000..bb609f240 --- /dev/null +++ b/samples/systolic_array/systolic_array_stream.py @@ -0,0 +1,60 @@ +import heterocl as hcl +import numpy as np +import time +from systolic_array_main import gemm + +m = k = n = 16 +x_max = y_max = 16 +dtype = hcl.Int() + +def systolic(m=16, k=16, n=16, dtype=hcl.Int(), target=None): + hcl.init(dtype) + + dim_x, dim_y = 16, 16 + A = hcl.placeholder((m, k), dtype=dtype, name="A") + B = hcl.placeholder((k, n), dtype=dtype, name="B") + output = hcl.placeholder((m, n), dtype=dtype, name="output") + + def kernel(A, B, O): + + localA = hcl.compute((m, k-1), lambda *args: 0, "localA") + localB = hcl.compute((k-1, n), lambda *args: 0, "localB") + + def update(k, y, x): + last = hcl.scalar( + hcl.select(k==0, 0, O[y, x]), "last") + + localA[y, x] = hcl.select(x>0, localA[y, x-1], A[y, k]) + localB[y, x] = hcl.select(y>0, localB[y-1, x], B[k, x]) + O[y, x] = last.v + localA[y, x] * localB[y, x] + + hcl.mutate((m, dim_y, dim_x), + lambda k, y, x: update(k, y, x), name="update") + + s = hcl.create_schedule([A, B, output], kernel) + + k = kernel.update + s[k].pipeline(k.axis[0]) + + # self loopback streaming + s.to(k.localA, kernel.update) + s.to(k.localB, kernel.update) + + f = hcl.build(s, target=target) + return f + +np_1 = np.random.randint(10, size=(m, k)) +np_2 = np.random.randint(10, size=(k, n)) +np_3 = np.matmul(np_1, np_2) + +hcl_m1 = hcl.asarray(np_1, dtype=dtype) +hcl_m2 = hcl.asarray(np_2, dtype=dtype) +hcl_m3 = hcl.asarray(np.zeros((m, n)), dtype=dtype) +hcl_m4 = hcl.asarray(np.zeros((m, n)), dtype=dtype) + +fg = gemm(m, n, k, dtype=hcl.Int(), target="llvm") +fs = systolic(m, n, k, dtype=hcl.Int(), target="llvm") + +fg(hcl_m1, hcl_m2, hcl_m3) +fs(hcl_m1, hcl_m2, hcl_m4) +assert np.array_equal(hcl_m3.asnumpy(), hcl_m4.asnumpy()) diff --git a/samples/systolic_array/systolic_array_vitis.py b/samples/systolic_array/systolic_array_vitis.py new file mode 100644 index 000000000..b39bbe927 --- /dev/null +++ b/samples/systolic_array/systolic_array_vitis.py @@ -0,0 +1,65 @@ +import heterocl as hcl +import numpy as np +import time + +m = k = n = 16 +x_max = y_max = 16 +dtype = hcl.Int() +host_only = False + +def systolic(m=16, k=16, n=16, dtype=hcl.Int(), target=None): + hcl.init(dtype) + + dim_x, dim_y = 16, 16 + A = hcl.placeholder((m, k), dtype=dtype, name="A") + B = hcl.placeholder((k, n), dtype=dtype, name="B") + + def kernel(A, B): + + localA = hcl.compute((m, k-1), lambda *args: 0, "localA") + localB = hcl.compute((k-1, n), lambda *args: 0, "localB") + output = hcl.compute((m, n), lambda *args: 0, "output") + + def update(k, y, x): + + localA[y, x] = hcl.select(x>0, localA[y, x-1], A[y, k]) + localB[y, x] = hcl.select(y>0, localB[y-1, x], B[k, x]) + output[y, x] = hcl.select(k==0, 0, output[y, x]) + localA[y, x] * localB[y, x] + + hcl.mutate((m, dim_y, dim_x), + lambda k, y, x: update(k, y, x), name="update") + return output + + s = hcl.create_schedule([A, B], kernel) + + k = kernel.update + s[k].pipeline(k.axis[0]) + + # self loopback streaming + s.to(k.localA, kernel.update) + s.to(k.localB, kernel.update) + + # move to xcel scope + if not host_only: + s.to([A, B], target.xcel) + s.to(k.output, target.host) + + print(hcl.lower(s)) + f = hcl.build(s, target=target) + return f + +np_1 = np.random.randint(10, size=(m, k)) +np_2 = np.random.randint(10, size=(k, n)) +np_3 = np.matmul(np_1, np_2) + +hcl_m1 = hcl.asarray(np_1, dtype=dtype) +hcl_m2 = hcl.asarray(np_2, dtype=dtype) +hcl_m3 = hcl.asarray(np.zeros((m, n)), dtype=dtype) + +target = hcl.platform.aws_f1 +target.config(compile="vitis", backend="vhls") +fs = systolic(m, n, k, dtype=hcl.Int(), target=target) +fs(hcl_m1, hcl_m2, hcl_m3) + +print(hcl_m3.asnumpy()) +print(np.matmul(np_1, np_2)) diff --git a/tests/test_api_print.py b/tests/test_api_print.py new file mode 100644 index 000000000..bed5525f0 --- /dev/null +++ b/tests/test_api_print.py @@ -0,0 +1,41 @@ +import heterocl as hcl +import numpy as np +import pathlib +import subprocess + +def get_stdout(filename): + path = pathlib.Path(__file__).parent.absolute() + path = str(path) + "/test_api_print_cases/" + filename + ".py" + p = subprocess.run(['python', path], stdout=subprocess.PIPE) + output = p.stdout.decode('utf-8') + return str(output) + +def test_print_number(): + + output = get_stdout("print_number") + + golden = "5\n2.500000\n" + + assert str(output) == golden + +def test_print_expr(): + + outputs = get_stdout("print_expr").split("\n") + + N = 5 + for i in range(0, N): + assert outputs[i] == outputs[i+N] + +def test_print_tensor_1D(): + + outputs = get_stdout("print_tensor_1D").split("\n") + + assert outputs[0] == outputs[1] + +def test_print_tensor_2D(): + + outputs = get_stdout("print_tensor_2D").split("\n") + + N = 10 + for i in range(0, N): + assert outputs[i] == outputs[i+N] diff --git a/tests/test_api_print_cases/print_expr.py b/tests/test_api_print_cases/print_expr.py new file mode 100644 index 000000000..6a3f5da29 --- /dev/null +++ b/tests/test_api_print_cases/print_expr.py @@ -0,0 +1,97 @@ +import heterocl as hcl +import numpy as np + +# case1: int + +hcl.init() + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print(A[5]) + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.randint(0, 10, size=(10,)) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +print(hcl_A.asnumpy()[5]) + +# case1: uint + +hcl.init(hcl.UInt(4)) + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print(A[5]) + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.randint(20, 30, size=(10,)) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +print(hcl_A.asnumpy()[5]) + +# case3: float + +hcl.init(hcl.Float()) + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print(A[5], "%.4f\n") + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.rand(10) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +print("%.4f" % hcl_A.asnumpy()[5]) + +# case4: fixed points + +hcl.init(hcl.UFixed(6, 4)) + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print(A[5], "%.4f\n") + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.rand(10) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +print("%.4f" % hcl_A.asnumpy()[5]) + +# case5: two ints + +hcl.init() + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print((A[5], A[6]), "%d %d\n") + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.randint(0, 10, size=(10,)) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +print(str(np_A[5]) + " " + str(np_A[6])) diff --git a/tests/test_api_print_cases/print_number.py b/tests/test_api_print_cases/print_number.py new file mode 100644 index 000000000..ae0d6f007 --- /dev/null +++ b/tests/test_api_print_cases/print_number.py @@ -0,0 +1,18 @@ +import heterocl as hcl +import numpy as np + +A = hcl.placeholder((10,)) + +hcl.init() + +def kernel(A): + hcl.print(5) + hcl.print(2.5) + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.rand(10) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) diff --git a/tests/test_api_print_cases/print_tensor_1D.py b/tests/test_api_print_cases/print_tensor_1D.py new file mode 100644 index 000000000..69b198b2d --- /dev/null +++ b/tests/test_api_print_cases/print_tensor_1D.py @@ -0,0 +1,25 @@ +import heterocl as hcl +import numpy as np + +hcl.init() + +A = hcl.placeholder((10,)) + +def kernel(A): + hcl.print(A) + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.randint(0, 10, size=(10,)) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +s = "[" +for i in range(0, 10): + s += str(np_A[i]) + if i < 9: + s += ", " +s += "]" +print(s) diff --git a/tests/test_api_print_cases/print_tensor_2D.py b/tests/test_api_print_cases/print_tensor_2D.py new file mode 100644 index 000000000..99599bd4a --- /dev/null +++ b/tests/test_api_print_cases/print_tensor_2D.py @@ -0,0 +1,30 @@ +import heterocl as hcl +import numpy as np + +hcl.init() + +A = hcl.placeholder((10, 10)) + +def kernel(A): + hcl.print(A) + +s = hcl.create_schedule([A], kernel) +f = hcl.build(s) + +np_A = np.random.randint(0, 10, size=(10, 10)) +hcl_A = hcl.asarray(np_A) + +f(hcl_A) + +s = "[" +for i in range(0, 10): + s += "[" + for j in range(0, 10): + s += str(np_A[i][j]) + if j < 9: + s += ", " + s += "]" + if i < 9: + s += ",\n" +s += "]" +print(s) diff --git a/tests/test_dsl_basic.py b/tests/test_dsl_basic.py index 4607f49d4..f40334b06 100644 --- a/tests/test_dsl_basic.py +++ b/tests/test_dsl_basic.py @@ -55,9 +55,8 @@ def test_or(): def test_if(): def kernel(A): - with hcl.Stage(): - with hcl.if_(A[0] > 5): - A[0] = 5 + with hcl.if_(A[0] > 5): + A[0] = 5 A = hcl.placeholder((1,)) s = hcl.create_schedule(A, kernel) @@ -76,11 +75,10 @@ def kernel(A): def test_else(): def kernel(A): - with hcl.Stage(): - with hcl.if_(A[0] > 5): - A[0] = 5 - with hcl.else_(): - A[0] = -1 + with hcl.if_(A[0] > 5): + A[0] = 5 + with hcl.else_(): + A[0] = -1 A = hcl.placeholder((1,)) s = hcl.create_schedule(A, kernel) @@ -99,11 +97,10 @@ def kernel(A): def test_elif(): def kernel(A): - with hcl.Stage(): - with hcl.if_(A[0] > 5): - A[0] = 5 - with hcl.elif_(A[0] > 3): - A[0] = 3 + with hcl.if_(A[0] > 5): + A[0] = 5 + with hcl.elif_(A[0] > 3): + A[0] = 3 A = hcl.placeholder((1,)) s = hcl.create_schedule(A, kernel) @@ -122,13 +119,12 @@ def kernel(A): def test_cond_all(): def kernel(A): - with hcl.Stage(): - with hcl.if_(A[0] > 5): - A[0] = 5 - with hcl.elif_(A[0] > 3): - A[0] = 3 - with hcl.else_(): - A[0] = 0 + with hcl.if_(A[0] > 5): + A[0] = 5 + with hcl.elif_(A[0] > 3): + A[0] = 3 + with hcl.else_(): + A[0] = 0 A = hcl.placeholder((1,)) s = hcl.create_schedule(A, kernel) @@ -146,11 +142,10 @@ def kernel(A): def test_elif(): def kernel(A): - with hcl.Stage(): - with hcl.if_(A[0] > 5): - A[0] = 5 - with hcl.elif_(A[0] > 3): - A[0] = 3 + with hcl.if_(A[0] > 5): + A[0] = 5 + with hcl.elif_(A[0] > 3): + A[0] = 3 A = hcl.placeholder((1,)) s = hcl.create_schedule(A, kernel) @@ -168,9 +163,8 @@ def kernel(A): def test_for_basic(): def kernel(A): - with hcl.Stage(): - with hcl.for_(0, 10) as i: - A[i] = i + with hcl.for_(0, 10) as i: + A[i] = i A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -189,9 +183,8 @@ def kernel(A): def test_for_irregular_bound(): def kernel(A): - with hcl.Stage(): - with hcl.for_(4, 8) as i: - A[i] = i + with hcl.for_(4, 8) as i: + A[i] = i A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -212,9 +205,8 @@ def kernel(A): def test_for_step_non_one(): def kernel(A): - with hcl.Stage(): - with hcl.for_(0, 10, 2) as i: - A[i] = i + with hcl.for_(0, 10, 2) as i: + A[i] = i A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -235,9 +227,8 @@ def kernel(A): def test_for_step_negative(): def kernel(A): - with hcl.Stage(): - with hcl.for_(9, -1, -1) as i: - A[i] = i + with hcl.for_(9, -1, -1) as i: + A[i] = i A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -256,11 +247,10 @@ def kernel(A): def test_while_basic(): def kernel(A): - with hcl.Stage(): - a = hcl.scalar(0) - with hcl.while_(a[0] < 10): - A[a[0]] = a[0] - a[0] += 1 + a = hcl.scalar(0) + with hcl.while_(a[0] < 10): + A[a[0]] = a[0] + a[0] += 1 A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -279,11 +269,10 @@ def kernel(A): def test_break_in_for(): def kernel(A): - with hcl.Stage(): - with hcl.for_(0, 10) as i: - with hcl.if_(i > 5): - hcl.break_() - A[i] = i + with hcl.for_(0, 10) as i: + with hcl.if_(i > 5): + hcl.break_() + A[i] = i A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -304,13 +293,12 @@ def kernel(A): def test_break_in_while(): def kernel(A): - with hcl.Stage(): - i = hcl.scalar(0) - with hcl.while_(True): - with hcl.if_(i[0] > 5): - hcl.break_() - A[i[0]] = i[0] - i[0] += 1 + i = hcl.scalar(0) + with hcl.while_(True): + with hcl.if_(i[0] > 5): + hcl.break_() + A[i[0]] = i[0] + i[0] += 1 A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) @@ -331,12 +319,11 @@ def kernel(A): def test_break_multi_level(): def kernel(A): - with hcl.Stage(): - with hcl.for_(0, 10) as i: - with hcl.for_(0, 10) as j: - with hcl.if_(j >= i): - hcl.break_() - A[i] += j + with hcl.for_(0, 10) as i: + with hcl.for_(0, 10) as j: + with hcl.if_(j >= i): + hcl.break_() + A[i] += j A = hcl.placeholder((10,)) s = hcl.create_schedule(A, kernel) diff --git a/tests/test_dtype.py b/tests/test_dtype.py index ad7d60f9b..a68e92f7f 100644 --- a/tests/test_dtype.py +++ b/tests/test_dtype.py @@ -325,3 +325,20 @@ def kernel(A, B, C, O): # hcl_B = hcl.asarray(np_B, dtype=hcl.Fixed(13, 11)) # f(hcl_A, hcl_B) +def test_dtype_const_long_int(): + + hcl.init(hcl.Int()) + r = np.random.randint(0, 10, size=(1,)) + + def kernel(): + A = hcl.compute((1,), lambda x: r[0], dtype=hcl.Int(128)) + B = hcl.compute((1,), lambda x: A[x]) + return B + + s = hcl.create_schedule([], kernel) + f = hcl.build(s) + np_B = np.zeros((1,)) + hcl_B = hcl.asarray(np_B) + f(hcl_B) + + assert np.array_equal(r, hcl_B.asnumpy()) diff --git a/tests/test_schedule_stream.py b/tests/test_schedule_stream.py index bd0146f2e..3d2b1081a 100644 --- a/tests/test_schedule_stream.py +++ b/tests/test_schedule_stream.py @@ -2,23 +2,63 @@ from itertools import permutations def test_placeholders(): - hcl.init() - A = hcl.placeholder((10, 32), "A") - B = hcl.placeholder((10, 32), "B") - C = hcl.placeholder((10, 32), "C") - D = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "D") - E = hcl.compute(C.shape, lambda i, j: C[i][j] * D[i][j], "E") - F = hcl.compute(C.shape, lambda i, j: E[i][j] + 1, "F") - target = hcl.platform.aws_f1 - s = hcl.create_schedule([A, B, C, D, E, F]) - s.to([A, B, C], target.xcel) - s.to(E, target.host) - code = str(hcl.lower(s)) - pattern = "test({}.channel, {}.channel, {}.channel, E.channel)" - combination = [ pattern.format(*_) for _ in list(permutations(["A", "B", "C"])) ] - cond = any([_ in code for _ in combination]) - assert cond + def move_inputs(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + B = hcl.placeholder((10, 32), "B") + C = hcl.placeholder((10, 32), "C") + D = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "D") + E = hcl.compute(C.shape, lambda i, j: C[i][j] * D[i][j], "E") + F = hcl.compute(C.shape, lambda i, j: E[i][j] + 1, "F") + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A, B, C, D, E, F]) + s.to([A, B, C], target.xcel) + s.to(E, target.host) + code = str(hcl.lower(s)) + pattern = "test({}.channel, {}.channel, {}.channel, E.channel)" + combination = [ pattern.format(*_) for _ in list(permutations(["A", "B", "C"])) ] + assert any([_ in code for _ in combination]) + + def move_outputs(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + + def kernel(A): + B = hcl.compute(A.shape, lambda i, j: A[i, j] * 2, "B") + hcl.update(B, lambda i, j: B[i, j] + 1, "update1") + hcl.update(B, lambda i, j: B[i, j] * 2, "update2") + return B + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A], kernel) + s.to(A, target.xcel) + s.to(kernel.update1.B, target.host) + + code = str(hcl.lower(s)) + assert "test(A.channel, B.update.channel)" in code + + def self_move_back(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + + def kernel(A): + hcl.update(A, lambda i, j: A[i, j] + 1, "update1") + hcl.update(A, lambda i, j: A[i, j] * 2, "update2") + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A], kernel) + s.to(A, target.xcel) + s.to(kernel.update1.A, target.host) + + code = str(hcl.lower(s)) + assert "test(A.channel, A.update.channel)" in code + + move_inputs() + move_outputs() + self_move_back() + def test_extern_ops(): hcl.init() @@ -35,31 +75,69 @@ def kernel(A): s.to(kernel.C, target.host) code = str(hcl.lower(s)) assert "test(B.channel, C.channel)" in code - -def test_imperative_loops(): - hcl.init() - A = hcl.placeholder((10, 32), "A") - B = hcl.placeholder((10, 32), "B") - def kernel(A, B): - C = hcl.compute(A.shape, lambda *args : 0, "C") - with hcl.Stage("stage"): - with hcl.for_(0, 10, name="i") as i: - with hcl.for_(0, 32, name="j") as j: - B[i, j] = A[i, j] + B[i, j] - C[i, j] = 2 * B[i, j] - return C - - target = hcl.platform.aws_f1 - s = hcl.create_schedule([A, B], kernel) - stage = kernel.stage - s.to(s[stage], target.xcel, axis=0) - code = str(hcl.lower(s)) - pattern = "test({}, {}, {}, {})" - combination = [ pattern.format(*_) - for _ in list(permutations(["A", "B", "C", "i"])) ] - cond = any([_ in code for _ in combination]) - assert cond, code + +def test_inner_loops(): + + def imperative_loop(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + B = hcl.placeholder((10, 32), "B") + def kernel(A, B): + C = hcl.compute(A.shape, lambda *args : 0, "C") + with hcl.Stage("stage"): + with hcl.for_(0, 10, name="i") as i: + with hcl.for_(0, 32, name="j") as j: + B[i, j] = A[i, j] + B[i, j] + C[i, j] = 2 * B[i, j] + return C + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A, B], kernel) + + stage = kernel.stage + s.to(stage, target.xcel, axis=1) + code = str(hcl.lower(s)) + pattern = "test({}, {}, {}, {})" + combination = [ pattern.format(*_) + for _ in list(permutations(["A", "B", "C", "i"])) ] + cond = any([_ in code for _ in combination]) + assert cond, code + + def declarative_loop(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + def kernel(A): + C = hcl.compute(A.shape, lambda *args : A[args] * 4, "C") + return C + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A], kernel) + s.to(kernel.C, target.xcel, axis=1) + code = str(hcl.lower(s)) + assert "test(C, A, args)" in code + + def inner_loop_tile(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + def kernel(A): + C = hcl.compute(A.shape, lambda *args : A[args] * 4, "C") + return C + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A], kernel) + + stage = kernel.C + yo, yi = s[stage].split(stage.axis[0], factor=3) + xo, xi = s[stage].split(stage.axis[1], factor=3) + s.to(kernel.C, target.xcel, axis=1) + code = str(hcl.lower(s)) + assert "test(args.outer, C, A)" in code + + imperative_loop() + declarative_loop() + inner_loop_tile() + def test_kernel(): hcl.init() @@ -85,6 +163,7 @@ def mul(B, C): assert "c_buf_1.write" in code assert "c_buf_1.read" in code + def test_inter_stage(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") @@ -103,6 +182,7 @@ def kernel(A, B): assert "C.pipe1.write" in code assert "C.pipe1.read" in code + def test_extern_op_multicast(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") @@ -126,6 +206,7 @@ def kernel(A, B): assert "C.pipe2.write" in code assert "C.pipe2.read" in code + def test_kernel_multicast(): hcl.init() A = hcl.placeholder((10, 32), "A") @@ -156,6 +237,7 @@ def mul(A, C): # print(code) # assert "test(A.channel, D.channel)" in code + def test_mixed_stream(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") @@ -179,12 +261,124 @@ def kernel(A, B): assert "C.pipe1.write" in code assert "C.pipe1.read" in code + +def test_fork_join(): + + def inter_stage_fork(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + B = hcl.placeholder((10, 32), "B") + + def kernel(A, B): + C = hcl.compute(A.shape, lambda i, j: A[i,j] + B[i,j], "C") + D = hcl.compute(C.shape, lambda i, j: C[i,j] + 1, "D") + E = hcl.compute(C.shape, lambda i, j: C[i,j] * 2, "E") + return D, E + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A, B], kernel) + s.fork(kernel.C, [kernel.D, kernel.E]) + code = str(hcl.lower(s)) + assert "C.pipe1.write" in code + assert "C.pipe1.read" in code + assert "C.pipe2.write" in code + + def inter_stage_join(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + B = hcl.placeholder((10, 32), "B") + + def kernel(A, B): + C = hcl.compute(A.shape, lambda i, j: 0, "C") + hcl.update(C, lambda i, j: A[i,j] + 1, "s1") + hcl.update(C, lambda i, j: B[i,j] * 2, "s2") + return hcl.compute(C.shape, lambda *args: C[args] + 3, "ret") + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A, B], kernel) + s.join([kernel.s1.C, kernel.s2.C], kernel.ret.C) + code = str(hcl.lower(s)) + assert "C.pipe1.read" in code + assert "C.pipe2.write" in code + + inter_stage_fork() + inter_stage_join() + +def test_kernel_duplicate(): + + def extract_subgraph(combine=False): + hcl.init() + A = hcl.placeholder((10, 32), "A") + B = hcl.placeholder((10, 32), "B") + + def kernel(A, B): + C = hcl.compute(A.shape, lambda i, j: 0, "C") + hcl.update(C, lambda i, j: A[i,j] + 1, "s1") + hcl.update(C, lambda i, j: B[i,j] * 2, "s2") + return hcl.compute(C.shape, lambda *args: C[args] + 3, "ret") + + target = hcl.platform.aws_f1 + s = hcl.create_schedule([A, B], kernel) + + A_, B_ = s.to([A, B], target.xcel) + ret_ = s.to(kernel.ret, target.host) + + if combine == True: + # merge the channel stages into + s[A_].compute_at(s[B_], 1) + s[B_].compute_at(s[kernel.C], 1) + + # merge stages from top to bottom + s[kernel.C].compute_at(s[kernel.s1], kernel.s1.axis[1]) + s[kernel.s1].compute_at(s[kernel.s2], kernel.s2.axis[1]) + s[kernel.s2].compute_at(s[kernel.ret], kernel.ret.axis[1]) + + nodes = s.subgraph(inputs=[A_, B_], outputs=[ret_]) + code = str(hcl.lower(s)) + + extract_subgraph(True) + + +def test_custom_device(): + + def custom_target(): + hcl.init() + A = hcl.placeholder((10, 32), "A") + B = hcl.placeholder((10, 32), "B") + + def kernel(A, B): + C = hcl.compute(A.shape, lambda i, j: A[i,j] + B[i,j], "C") + D = hcl.compute(C.shape, lambda i, j: C[i,j] + 1, "D") + return D + + config = { + "host" : hcl.dev.cpu("intel", "e5"), + "xcel" : [ + hcl.dev.fpga("xilinx", "xcvu19p") + ] + } + + p = hcl.platform.custom(config) + s = hcl.create_schedule([A, B], kernel) + s.to(A, p.xcel.hbm[0]) + s.to(B, p.xcel.hbm[1]) + s.to(kernel.D, p.host) + p.config(compile="vitis", mode="debug", backend="vhls") + code = hcl.build(s, p) + assert "MAX_HBM_BANKCOUNT" in code + + custom_target() + + if __name__ == '__main__': test_placeholders() test_extern_ops() - test_imperative_loops() + test_inner_loops() test_kernel() test_inter_stage() test_extern_op_multicast() - # test_kernel_multicast() + test_kernel_multicast() test_mixed_stream() + test_fork_join() + test_kernel_duplicate() + test_custom_device() diff --git a/tvm/HalideIR/src/ir/Expr.h b/tvm/HalideIR/src/ir/Expr.h index 48f07d891..c21d323ad 100644 --- a/tvm/HalideIR/src/ir/Expr.h +++ b/tvm/HalideIR/src/ir/Expr.h @@ -97,7 +97,9 @@ enum class IRNodeType : int { /** for stencil analysis **/ Stencil, /** for external module **/ - ExternModule + ExternModule, + /** for debuggin **/ + Print }; /** The abstract base classes for a node in the Halide IR. */ @@ -309,9 +311,8 @@ enum class PartitionType : int { /** An enum describing the stream type */ enum class StreamType : int { - Channel = 0, - Pipe = 1, - FIFO = 2 + FIFO = 0, + DoubleBuffer = 1 }; /** An enum class for device type */ @@ -321,6 +322,13 @@ enum class DeviceType : int { devGPU = 2 }; +/* An enum class for storage type*/ +enum class StorageType : int { + devDRAM = 0, + devHBM = 1, + devPLRAM = 2 +}; + /** A reference-counted handle to a statement node. */ struct Stmt : public IRHandle { Stmt() : IRHandle() {} diff --git a/tvm/HalideIR/src/ir/IR.cpp b/tvm/HalideIR/src/ir/IR.cpp index 6bf0021af..ea20f2645 100644 --- a/tvm/HalideIR/src/ir/IR.cpp +++ b/tvm/HalideIR/src/ir/IR.cpp @@ -29,9 +29,9 @@ Expr IntImm::make(Type t, int64_t value) { // << "IntImm must be 8, 16, 32, or 64-bit\n"; // Normalize the value by dropping the high bits - value <<= (64 - t.bits()); + //value <<= (64 - t.bits()); // Then sign-extending to get them back - value >>= (64 - t.bits()); + //value >>= (64 - t.bits()); std::shared_ptr node = std::make_shared(); node->type = t; @@ -46,8 +46,8 @@ Expr UIntImm::make(Type t, uint64_t value) { // << "UIntImm must be 1, 8, 16, 32, or 64-bit\n"; // Normalize the value by dropping the high bits - value <<= (64 - t.bits()); - value >>= (64 - t.bits()); + //value <<= (64 - t.bits()); + //value >>= (64 - t.bits()); std::shared_ptr node = std::make_shared(); node->type = t; @@ -695,7 +695,8 @@ Expr Quantize::make(Expr body, Expr bitwidth) { Stmt KernelDef::make(Array args, Array> arg_shapes, Array arg_types, Array arg_tensors, Stmt body, Expr ret_void, - Type ret_type, std::string name, Array channels) { + Type ret_type, std::string name, + Array> channels) { internal_assert(arg_shapes.size() == arg_types.size()) << "KernelDef of unmatched args\n"; for (size_t i = 0; i < args.size(); i++) { internal_assert(args[i].defined()) << "KernelDef of undefined arg\n"; @@ -911,6 +912,17 @@ Stmt ExternModule::make(std::string attr_key, Expr value, Stmt body, return Stmt(node); } +Stmt Print::make(Array values, std::string format) { + for (size_t i = 0; i < values.size(); i++) { + internal_assert(values[i].defined()) << "Print of undefined value\n"; + } + + std::shared_ptr node = std::make_shared(); + node->values = std::move(values); + node->format = std::move(format); + return Stmt(node); +} + namespace { // Helper function to determine if a sequence of indices is a @@ -1012,6 +1024,7 @@ template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const Stencil *)this, s); } template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const StreamStmt *)this, s); } template<> void ExprNode::accept(IRVisitor *v, const Expr &e) const { v->visit((const StreamExpr *)this, e); } +template<> void StmtNode::accept(IRVisitor *v, const Stmt &s) const { v->visit((const Print *)this, s); } Call::ConstString Call::debug_to_file = "debug_to_file"; Call::ConstString Call::reinterpret = "reinterpret"; diff --git a/tvm/HalideIR/src/ir/IR.h b/tvm/HalideIR/src/ir/IR.h index a5cc7147e..ef04da8ab 100644 --- a/tvm/HalideIR/src/ir/IR.h +++ b/tvm/HalideIR/src/ir/IR.h @@ -1050,23 +1050,22 @@ struct Quantize : public ExprNode { /** The imperative function definition */ struct KernelDef : public StmtNode { Array args; - Array> arg_shapes; + Array > arg_shapes; Array arg_types; Array arg_tensors; Stmt body; Expr ret_void; Type ret_type; std::string name; - // args to stream data - Array channels; + Array > channels; EXPORT static Stmt make(Array args, - Array> arg_shapes, + Array > arg_shapes, Array arg_types, Array arg_tensors, Stmt body, Expr ret_void, Type ret_type, std::string name, - Array channels); + Array > channels); void VisitAttrs(IR::AttrVisitor* v) final { v -> Visit("args", &args); @@ -1314,6 +1313,21 @@ struct ExternModule : public StmtNode { static constexpr const char* _type_key = "ExternModule"; }; +struct Print : public StmtNode { + Array values; + std::string format; + + EXPORT static Stmt make(Array values, std::string format); + + void VisitAttrs(IR::AttrVisitor* v) final { + v -> Visit("values", &values); + v -> Visit("format", &format); + } + + static const IRNodeType _type_info = IRNodeType::Print; + static constexpr const char* _type_key = "Print"; +}; + } // inline functions diff --git a/tvm/HalideIR/src/ir/IRMutator.cpp b/tvm/HalideIR/src/ir/IRMutator.cpp index 82904a62b..639850225 100644 --- a/tvm/HalideIR/src/ir/IRMutator.cpp +++ b/tvm/HalideIR/src/ir/IRMutator.cpp @@ -602,6 +602,25 @@ void IRMutator::visit(const Stencil *op, const Stmt &s) { } } +void IRMutator::visit(const Print *op, const Stmt &s) { + vector new_values(op->values.size()); + bool changed = false; + + for (size_t i = 0; i < op->values.size(); i++) { + Expr old_value = op->values[i]; + Expr new_value = mutate(old_value); + if (!new_value.same_as(old_value)) changed = true; + new_values[i] = new_value; + } + + if (!changed) { + stmt = s; + } + else { + stmt = Print::make(new_values, op->format); + } +} + Stmt IRGraphMutator::mutate(Stmt s) { auto iter = stmt_replacements.find(s); if (iter != stmt_replacements.end()) { diff --git a/tvm/HalideIR/src/ir/IRMutator.h b/tvm/HalideIR/src/ir/IRMutator.h index aba20df9f..2c8847a9b 100644 --- a/tvm/HalideIR/src/ir/IRMutator.h +++ b/tvm/HalideIR/src/ir/IRMutator.h @@ -102,6 +102,7 @@ class IRMutator : public IRVisitor { EXPORT virtual void visit(const Stencil *, const Stmt &); EXPORT virtual void visit(const StreamExpr *, const Expr &); EXPORT virtual void visit(const StreamStmt *, const Stmt &); + EXPORT virtual void visit(const Print *, const Stmt &); }; diff --git a/tvm/HalideIR/src/ir/IRPrinter.cpp b/tvm/HalideIR/src/ir/IRPrinter.cpp index f8cd8259f..626e58d7f 100644 --- a/tvm/HalideIR/src/ir/IRPrinter.cpp +++ b/tvm/HalideIR/src/ir/IRPrinter.cpp @@ -888,5 +888,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "}\n"; }); +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const Print *op, IRPrinter* p) { + p->do_indent(); + p->stream << "print:"; + for (size_t i = 0; i < op->values.size(); i++) { + p->stream << " "; + p->print(op->values[i]); + } + p->stream << "\n"; +}); + } } diff --git a/tvm/HalideIR/src/ir/IRVisitor.cpp b/tvm/HalideIR/src/ir/IRVisitor.cpp index 39f94113e..9d94a6e97 100644 --- a/tvm/HalideIR/src/ir/IRVisitor.cpp +++ b/tvm/HalideIR/src/ir/IRVisitor.cpp @@ -318,6 +318,13 @@ void IRVisitor::visit(const Stencil *op, const Stmt &) { op->body.accept(this); } +void IRVisitor::visit(const Print *op, const Stmt &) { + for (size_t i = 0; i < op->values.size(); i++) { + op->values[i].accept(this); + } +} + + void IRGraphVisitor::include(const Expr &e) { if (visited.count(e.get())) { return; @@ -635,5 +642,11 @@ void IRGraphVisitor::visit(const Stencil *op, const Stmt &) { include(op->body); } +void IRGraphVisitor::visit(const Print *op, const Stmt &) { + for (size_t i = 0; i < op->values.size(); i++) { + include(op->values[i]); + } +} + } } diff --git a/tvm/HalideIR/src/ir/IRVisitor.h b/tvm/HalideIR/src/ir/IRVisitor.h index 384598bb6..ab2820bba 100644 --- a/tvm/HalideIR/src/ir/IRVisitor.h +++ b/tvm/HalideIR/src/ir/IRVisitor.h @@ -82,6 +82,7 @@ class IRVisitor { EXPORT virtual void visit(const Stencil *, const Stmt &); EXPORT virtual void visit(const StreamStmt *, const Stmt &); EXPORT virtual void visit(const StreamExpr *, const Expr &); + EXPORT virtual void visit(const Print *, const Stmt &); }; /** A base class for algorithms that walk recursively over the IR @@ -164,6 +165,7 @@ class IRGraphVisitor : public IRVisitor { EXPORT virtual void visit(const Stencil *, const Stmt &); EXPORT virtual void visit(const StreamExpr *, const Expr &); EXPORT virtual void visit(const StreamStmt *, const Stmt &); + EXPORT virtual void visit(const Print *, const Stmt &); // @} }; diff --git a/tvm/include/tvm/ir.h b/tvm/include/tvm/ir.h index dd9a2ceac..872488d2c 100644 --- a/tvm/include/tvm/ir.h +++ b/tvm/include/tvm/ir.h @@ -23,6 +23,7 @@ using Halide::Internal::ForType; using Halide::Internal::PartitionType; using Halide::Internal::StreamType; using Halide::Internal::DeviceType; +using Halide::Internal::StorageType; using Halide::DeviceAPI; // Node container for CommReducer @@ -514,6 +515,7 @@ using Halide::Internal::Reuse; using Halide::Internal::Partition; using Halide::Internal::Stencil; using Halide::Internal::ExternModule; +using Halide::Internal::Print; // ir functions using Halide::Internal::is_const_power_of_two_integer; diff --git a/tvm/include/tvm/ir_functor_ext.h b/tvm/include/tvm/ir_functor_ext.h index 149382158..aa2956473 100644 --- a/tvm/include/tvm/ir_functor_ext.h +++ b/tvm/include/tvm/ir_functor_ext.h @@ -254,6 +254,7 @@ class StmtFunctor { virtual R VisitStmt_(const Reuse* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Partition* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Stencil* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Print* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Node* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); return R(); @@ -287,6 +288,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(Reuse); IR_STMT_FUNCTOR_DISPATCH(Partition); IR_STMT_FUNCTOR_DISPATCH(Stencil); + IR_STMT_FUNCTOR_DISPATCH(Print); return vtable; } }; diff --git a/tvm/include/tvm/ir_mutator.h b/tvm/include/tvm/ir_mutator.h index 3ad2ace3d..8e706ca48 100644 --- a/tvm/include/tvm/ir_mutator.h +++ b/tvm/include/tvm/ir_mutator.h @@ -79,6 +79,7 @@ class TVM_DLL IRMutator { virtual Stmt Mutate_(const Partition* op, const Stmt& s); virtual Stmt Mutate_(const Stencil* op, const Stmt& s); virtual Stmt Mutate_(const StreamStmt* op, const Stmt& s); + virtual Stmt Mutate_(const Print* op, const Stmt& s); virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& e); diff --git a/tvm/include/tvm/ir_visitor.h b/tvm/include/tvm/ir_visitor.h index 87a424f16..a8ca8af00 100644 --- a/tvm/include/tvm/ir_visitor.h +++ b/tvm/include/tvm/ir_visitor.h @@ -140,6 +140,7 @@ class TVM_DLL IRVisitor { virtual void Visit_(const Reuse* op); virtual void Visit_(const Partition* op); virtual void Visit_(const Stencil* op); + virtual void Visit_(const Print* op); }; /*! diff --git a/tvm/include/tvm/schedule.h b/tvm/include/tvm/schedule.h index 24768ca09..a74f4d114 100644 --- a/tvm/include/tvm/schedule.h +++ b/tvm/include/tvm/schedule.h @@ -357,6 +357,12 @@ class Schedule : public NodeRef { IterVar axis, std::string name); + EXPORT void join_to(const Tensor& target, + Stage source, + Stage destiny, + ir::StreamType stream_type, + int channel_depth); + EXPORT void to_stage(const Tensor& target, Stage dest, int arg_pos, @@ -371,10 +377,11 @@ class Schedule : public NodeRef { int occur_index); EXPORT Tensor move_to(const Tensor& target, + Stage parent, ir::DeviceType device_type, ir::StreamType stream_type, int channel_depth, - int occurrence); + Array dev_ports); EXPORT void stream_to(const Tensor& target, Stage dest, diff --git a/tvm/src/api/api_ir.cc b/tvm/src/api/api_ir.cc index 1795365c9..03307fc52 100644 --- a/tvm/src/api/api_ir.cc +++ b/tvm/src/api/api_ir.cc @@ -252,6 +252,7 @@ REGISTER_MAKE2(While); REGISTER_MAKE2(Reuse); REGISTER_MAKE6(Stencil); REGISTER_MAKE5(ExternModule); +REGISTER_MAKE2(Print); } // namespace ir } // namespace TVM diff --git a/tvm/src/api/api_lang.cc b/tvm/src/api/api_lang.cc index 3f0555f53..6af04f301 100644 --- a/tvm/src/api/api_lang.cc +++ b/tvm/src/api/api_lang.cc @@ -472,10 +472,10 @@ TVM_REGISTER_API("_ScheduleMoveToStage") TVM_REGISTER_API("_ScheduleMove") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = args[0].operator Schedule() - .move_to(args[1], - static_cast(args[2].operator int()), - static_cast(args[3].operator int()), - args[4], args[5]); + .move_to(args[1], args[2], + static_cast(args[3].operator int()), + static_cast(args[4].operator int()), + args[5], args[6]); }); TVM_REGISTER_API("_ScheduleInStageMove") @@ -495,6 +495,14 @@ TVM_REGISTER_API("_ScheduleStream") args[6], args[7]); }); +TVM_REGISTER_API("_ScheduleJoin") + .set_body([](TVMArgs args, TVMRetValue *ret) { + args[0].operator Schedule() + .join_to(args[1], args[2], args[3], + static_cast(args[4].operator int()), + args[5]); + }); + TVM_REGISTER_API("_ScheduleReshape") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Schedule().reshape(args[1], args[2]); diff --git a/tvm/src/codegen/build_common.cc b/tvm/src/codegen/build_common.cc index 43aa70de0..aeef63e96 100644 --- a/tvm/src/codegen/build_common.cc +++ b/tvm/src/codegen/build_common.cc @@ -68,8 +68,11 @@ class SimModuleNode final : public ModuleNode { LOG(FATAL) << "The function should take in " << func_->args.size() << " inputs but get " << args.size(); - // check whether init needed - bool init = true; + bool init = true; // check whether init needed + bool empty = false; // whether kernel is empty + if (dev_.find_first_not_of(" \t\n") + == std::string::npos) empty = true; + if (shmids.size() > 0) { init = false; // requires mem update CHECK(shmids.size() == (unsigned)args.size()) @@ -84,7 +87,7 @@ class SimModuleNode final : public ModuleNode { system("rm -rf project; mkdir project"); GenHostCode(args, shmids, arg_types, func_, - platform_, host_, arg_names_); + platform_, host_, arg_names_, empty); GenKernelCode(dev_, platform_, options_["backend"]); // copy files and compile tp binary @@ -93,7 +96,7 @@ class SimModuleNode final : public ModuleNode { CHECK(options_.count("mode")) << "mode mot set"; auto mode = options_["mode"]; auto backend = options_["backend"]; - (*f)(platform_, mode, backend, cfg_).operator std::string(); + (*f)(platform_, mode, backend, empty, cfg_).operator std::string(); } } @@ -122,7 +125,7 @@ class SimModuleNode final : public ModuleNode { if (const auto* f = Registry::Get("tvm_callback_exec_evaluate")) { std::string code; std::string mode = options_["mode"]; - code = (*f)(platform_, mode).operator std::string(); + code = (*f)(platform_, mode, empty).operator std::string(); LOG(CLEAN) << "Execution complete \n"; } diff --git a/tvm/src/codegen/build_util.cc b/tvm/src/codegen/build_util.cc index e76762f01..2ac30fab0 100644 --- a/tvm/src/codegen/build_util.cc +++ b/tvm/src/codegen/build_util.cc @@ -391,6 +391,7 @@ void GenKernelCode(std::string& test_file, } else if (platform == "vitis") { stream << "#include \n"; stream << "#include \n"; + stream << "#include \n"; } stream << test_file; @@ -478,7 +479,8 @@ void GenHostCode(TVMArgs& args, const std::vector& arg_types, LoweredFunc lowered_func,std::string platform, std::string host_code, - std::vector arg_names) { + std::vector arg_names, + bool kernel_is_empty) { int indent = 0; std::ofstream stream; stream.open("project/host.cpp"); @@ -540,8 +542,9 @@ void GenHostCode(TVMArgs& args, stream << "\n"; } - if (platform == "sdaccel" || platform == "vitis") { - stream << R"( + if (!kernel_is_empty) { + if (platform == "sdaccel" || platform == "vitis") { + stream << R"( if (argc != 2) { std::cout << "Usage: " << argv[0] << " " << std::endl; return EXIT_FAILURE; @@ -567,8 +570,8 @@ void GenHostCode(TVMArgs& args, )"; - } else if (platform == "aocl") { - stream << R"( + } else if (platform == "aocl") { + stream << R"( cl_int status; cl_uint numDevices = 0; @@ -617,9 +620,10 @@ void GenHostCode(TVMArgs& args, CHECK(status); )"; - } + } - stream << "\n"; + stream << "\n"; + } PrintIndent(stream, indent); stream << "// compute and kernel call from host"; stream << code << "\n"; diff --git a/tvm/src/codegen/build_util.h b/tvm/src/codegen/build_util.h index ad5679732..5da8142cc 100644 --- a/tvm/src/codegen/build_util.h +++ b/tvm/src/codegen/build_util.h @@ -57,7 +57,8 @@ void GenHostCode(TVMArgs& args, LoweredFunc func, std::string platform, std::string host_code, - std::vector arg_names); + std::vector arg_names, + bool kernel_is_empty); } // namespace runtime } // namespace TVM #endif // TVM_CODEGEN_BUILD_HELPER_H_ diff --git a/tvm/src/codegen/codegen_c.cc b/tvm/src/codegen/codegen_c.cc index d048bd427..0b70d4a21 100644 --- a/tvm/src/codegen/codegen_c.cc +++ b/tvm/src/codegen/codegen_c.cc @@ -79,6 +79,7 @@ void CodeGenC::Init(bool output_ssa) { } void CodeGenC::InitFuncState(LoweredFunc f) { + alloc_set_.clear(); alloc_storage_scope_.clear(); handle_data_type_.clear(); var_shape_map_.clear(); @@ -134,7 +135,8 @@ std::string CodeGenC::GetConfig() { } std::string CodeGenC::GetHost() { - return this->stream.str(); + return decl_stream.str() + + this->stream.str(); } std::string CodeGenC::GetDevice() { @@ -627,10 +629,45 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) os << "("; PrintExpr(op->args[0], os); os << " ? "; - PrintExpr(op->args[1], os); - os << " : "; - PrintExpr(op->args[2], os); - os << ")"; + // type casting when mismatching + auto& v1 = op->args[1]; + auto& v2 = op->args[2]; + bool cast_value = false; + if (v1.as() || v1.as() || v1.as()) { + if (auto var = v2.as()) { + cast_value = true; + Type type = handle_data_type_[var->buffer_var.get()]; + std::stringstream value; + this->PrintExpr(v1, value); + os << "(("; + this->PrintType(type, os); + os << ")" << value.str() << ")"; + + os << " : "; + PrintExpr(op->args[2], os); + os << ")"; + } + } else if (v2.as() || v2.as() || v2.as()) { + if (auto var = v1.as()) { + cast_value = true; + PrintExpr(op->args[1], os); + os << " : "; + + Type type = handle_data_type_[var->buffer_var.get()]; + std::stringstream value; + this->PrintExpr(v2, value); + os << "(("; + this->PrintType(type, os); + os << ")" << value.str() << ")"; + os << ")"; + } + } + if (!cast_value) { + PrintExpr(op->args[1], os); + os << " : "; + PrintExpr(op->args[2], os); + os << ")"; + } } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as(); CHECK(op->args.size() == 1 && l); @@ -899,6 +936,7 @@ void CodeGenC::VisitStmt_(const LetStmt* op) { } else if (value.find("data") != std::string::npos || value.substr(0, 3) == "arg") { arg_names.push_back(vid); + alloc_set_.insert(vid); } PrintStmt(op->body); } @@ -1104,7 +1142,6 @@ void CodeGenC::VisitStmt_(const KernelDef* op) { this->stream.clear(); this->stream << save.str(); RestoreFuncState(f); - alloc_set.clear(); } void CodeGenC::VisitStmt_(const KernelStmt *op) { @@ -1145,11 +1182,13 @@ void CodeGenC::VisitStmt_(const Partition* op) { void CodeGenC::SaveFuncState(LoweredFunc f) { // clear save info copy + alloc_set_save.clear(); alloc_storage_scope_save.clear(); handle_data_type_save.clear(); var_shape_map_save.clear(); range_save.clear(); // backup func info and clear + alloc_set_save = alloc_set_; alloc_storage_scope_save = alloc_storage_scope_; handle_data_type_save = handle_data_type_; var_shape_map_save = var_shape_map_; @@ -1159,6 +1198,7 @@ void CodeGenC::SaveFuncState(LoweredFunc f) { void CodeGenC::RestoreFuncState(LoweredFunc f) { this->InitFuncState(f); + alloc_set_ = alloc_set_save; alloc_storage_scope_ = alloc_storage_scope_save; handle_data_type_ = handle_data_type_save; var_shape_map_ = var_shape_map_save; diff --git a/tvm/src/codegen/codegen_c.h b/tvm/src/codegen/codegen_c.h index 6b95ca4ec..316ee3d7d 100644 --- a/tvm/src/codegen/codegen_c.h +++ b/tvm/src/codegen/codegen_c.h @@ -188,15 +188,16 @@ class CodeGenC : std::map > var_shape_map_; std::unordered_map range_; str2tupleMap map_arg_type_; + // allocated buffer names + std::unordered_set alloc_set_; // save for kernel std::map > var_shape_map_save; std::unordered_map range_save; + std::unordered_set alloc_set_save; // top function argument names std::vector arg_names; - // record allocated buffer names - std::unordered_set alloc_set; protected: void SaveFuncState(LoweredFunc f); diff --git a/tvm/src/codegen/hlsc/codegen_vhls.cc b/tvm/src/codegen/hlsc/codegen_vhls.cc index bdb0b581f..d259d0608 100644 --- a/tvm/src/codegen/hlsc/codegen_vhls.cc +++ b/tvm/src/codegen/hlsc/codegen_vhls.cc @@ -109,6 +109,22 @@ void CodeGenVivadoHLS::VisitStmt_(const Store* op) { } } +void CodeGenVivadoHLS::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) + if ((op->call_type == Call::Extern || + op->call_type == Call::PureExtern) && op->name == "sqrtf") { + os << "sqrt("; + for (size_t i = 0; i < op->args.size(); i++) { + this->PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + void CodeGenVivadoHLS::VisitStmt_(const Allocate* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); @@ -138,7 +154,7 @@ void CodeGenVivadoHLS::VisitStmt_(const Allocate* op) { } // not allocate buffer for channel or moved data - if (!(ptr_mode && alloc_set.find(vid) != alloc_set.end())) { + if (!(ptr_mode && alloc_set_.find(vid) != alloc_set_.end())) { this->PrintIndent(); // allocate stream channels @@ -353,17 +369,12 @@ void CodeGenVivadoHLS::VisitStmt_(const StreamStmt* op) { std::string vid = GetVarID(op->buffer_var.get()); switch (op->stream_type) { case StreamType::FIFO: - PrintIndent(); - stream << "#pragma HLS stream variable=" - << vid << " depth=" << op->depth << "\n"; - break; - case StreamType::Channel: PrintIndent(); stream << vid << ".write("; PrintExpr(op->value, stream); stream << ");\n"; break; - case StreamType::Pipe: + case StreamType::DoubleBuffer: PrintIndent(); stream << vid << " << "; PrintExpr(op->value, stream); @@ -402,10 +413,6 @@ void CodeGenVivadoHLS::VisitStmt_(const KernelStmt *op) { } } for (size_t i = 0; i < op->args.size(); i++) { - if (arg_info.find(i) != arg_info.end()) { - if (arg_info[i] == 0 && !sdsoc_mode) - stream << "fd_"; - } PrintExpr(op->args[i], stream); if (i < op->args.size() - 1) stream << ", "; } @@ -432,15 +439,31 @@ void CodeGenVivadoHLS::VisitStmt_(const KernelDef* op) { // collect argument information std::unordered_map arg_info; - for (size_t i = 0; i < op->channels.size(); i=i+2) { - auto pos = op->channels[i].as()->value; - auto idx = op->channels[i+1].as()->value; + for (size_t i = 0; i < op->channels.size(); i++) { + auto info = op->channels[i]; + auto pos = info[0].as()->value; + auto idx = info[1].as()->value; if (idx > 0) arg_info[pos] = idx; } // print kernel function if (op->name.find("test") != std::string::npos) { + // extract the memory port information + std::unordered_map> mem_mapping; + CHECK(op->channels.size() == op->args.size()); + for (size_t i = 0; i < op->channels.size();i++) { + auto info = op->channels[i]; + CHECK(info.size() == 6); + auto pos = info[0].as()->value; + // auto channel = info[1].as()->value; + // auto depth = info[2].as()->value; + // auto is_sender = info[3].as()->value; + int mem = info[4].as()->value; + int port = info[5].as()->value; + mem_mapping[pos] = {mem, port}; + } + // used as OpenCL kernel if (ptr_mode) { int extern_scope = BeginScope(); @@ -457,8 +480,14 @@ void CodeGenVivadoHLS::VisitStmt_(const KernelDef* op) { CHECK(vid.find("_channel")) << vid << " not a channel"; vid.replace(vid.find("_channel"), 8, ""); - alloc_set.insert(vid); - alloc_set.insert(vid + "_new"); + + // handle output-update-in-kernel case + if (vid.find("_update") != std::string::npos) { + vid.replace(vid.find("_update"), 7, ""); + } + + alloc_set_.insert(vid); + alloc_set_.insert(vid + "_new"); kernel_args.push_back(vid); if (i != 0) stream << ", "; @@ -484,9 +513,12 @@ void CodeGenVivadoHLS::VisitStmt_(const KernelDef* op) { continue; } else { PrintIndent(); + CHECK(mem_mapping.count(i)); + CHECK(mem_mapping.at(i).size() == 2); + auto port = mem_mapping[i][1]; stream << "#pragma HLS INTERFACE m_axi port=" << kernel_args[i] << " " - << "offset=slave bundle=gmem" << i << "\n"; + << "offset=slave bundle=gmem" << port << "\n"; } } // block-level control interface diff --git a/tvm/src/codegen/hlsc/codegen_vhls.h b/tvm/src/codegen/hlsc/codegen_vhls.h index 8ab63f41f..14ee15352 100644 --- a/tvm/src/codegen/hlsc/codegen_vhls.h +++ b/tvm/src/codegen/hlsc/codegen_vhls.h @@ -24,6 +24,7 @@ class CodeGenVivadoHLS final : public CodeGenHLSC { void VisitExpr_(const GetBit* op, std::ostream& os) override; void VisitExpr_(const GetSlice* op, std::ostream& os) override; void VisitExpr_(const StreamExpr* op, std::ostream& os) override; + void VisitExpr_(const Call *op, std::ostream& os) override; void VisitStmt_(const Allocate* op) override; void VisitStmt_(const Store* op) override; diff --git a/tvm/src/codegen/llvm/codegen_llvm.cc b/tvm/src/codegen/llvm/codegen_llvm.cc index 6677721f5..b70b118d4 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.cc +++ b/tvm/src/codegen/llvm/codegen_llvm.cc @@ -43,6 +43,7 @@ void CodeGenLLVM::Init(const std::string& module_name, t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(); t_int_ = llvm::Type::getInt32Ty(*ctx_); t_char_ = llvm::Type::getInt8Ty(*ctx_); + t_char_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(); t_int8_ = llvm::Type::getInt8Ty(*ctx_); t_int16_ = llvm::Type::getInt16Ty(*ctx_); t_int32_ = llvm::Type::getInt32Ty(*ctx_); @@ -1414,6 +1415,41 @@ void CodeGenLLVM::VisitStmt_(const Stencil* op) { this->VisitStmt(op->body); } +void CodeGenLLVM::VisitStmt_(const Print* op) { + std::vector values; + std::vector types; + std::vector llvm_types; + for (size_t i = 0; i < op->values.size(); i++) { + Expr v = op->values[i]; + values.push_back(MakeValue(v)); + types.push_back(v.type()); + if (v.type().is_int() || v.type().is_uint()) { + llvm_types.push_back(t_int64_); + } else { + llvm_types.push_back(llvm::Type::getDoubleTy(*ctx_)); + } + } + llvm::FunctionType* call_ftype = llvm::FunctionType::get(t_int_, true); +#if TVM_LLVM_VERSION <= 60 + llvm::Function* printf_call = llvm::cast(module_->getOrInsertFunction("printf", call_ftype)); +#else + llvm::Function* printf_call = llvm::cast(module_->getOrInsertFunction("printf", call_ftype).getCallee()); +#endif + std::vector printf_args; + std::string format = op->format; + printf_args.push_back(builder_->CreateGlobalStringPtr(format)); + for (size_t i = 0; i < op->values.size(); i++) { + if (types[i].is_int() || types[i].is_uint()) { + llvm::Value* ivalue = CreateCast(types[i], Int(64), values[i]); + printf_args.push_back(ivalue); + } else { // fixed or float + llvm::Value* fvalue = CreateCast(types[i], Float(64), values[i]); + printf_args.push_back(fvalue); + } + } + builder_->CreateCall(printf_call, printf_args); +} + } // namespace codegen } // namespace TVM #endif // TVM_LLVM_VERSION diff --git a/tvm/src/codegen/llvm/codegen_llvm.h b/tvm/src/codegen/llvm/codegen_llvm.h index b3dcdb612..388183de6 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.h +++ b/tvm/src/codegen/llvm/codegen_llvm.h @@ -137,6 +137,7 @@ class CodeGenLLVM : void VisitStmt_(const While* op) override; void VisitStmt_(const Partition* op) override {}; void VisitStmt_(const Stencil* op) override; + void VisitStmt_(const Print* op) override; protected: /*! \brief The storage information */ @@ -240,6 +241,7 @@ class CodeGenLLVM : llvm::PointerType* t_void_p_{nullptr}; llvm::Type* t_int_{nullptr}; llvm::Type* t_char_{nullptr}; + llvm::PointerType* t_char_p_{nullptr}; llvm::Type* t_int8_{nullptr}; llvm::Type* t_int16_{nullptr}; llvm::Type* t_int32_{nullptr}; diff --git a/tvm/src/codegen/llvm/llvm_module.cc b/tvm/src/codegen/llvm/llvm_module.cc index 1ceda2dd1..46f6f465f 100644 --- a/tvm/src/codegen/llvm/llvm_module.cc +++ b/tvm/src/codegen/llvm/llvm_module.cc @@ -68,25 +68,57 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { +#if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(mptr_); +#else + std::unique_ptr m = llvm::CloneModule(*mptr_); +#endif llvm::legacy::PassManager pass; CHECK(tm_); +#if TVM_LLVM_VERSION <= 60 CHECK(tm_->addPassesToEmitFile( pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; +#elif TVM_LLVM_VERSION <= 90 + CHECK(tm_->addPassesToEmitFile( + pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; +#else + CHECK(tm_->addPassesToEmitFile( + pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; +#endif pass.run(*m); } else if (fmt == "s" || fmt == "asm") { +#if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(mptr_); +#else + std::unique_ptr m = llvm::CloneModule(*mptr_); +#endif llvm::legacy::PassManager pass; CHECK(tm_); +#if TVM_LLVM_VERSION <= 60 CHECK(tm_->addPassesToEmitFile( pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; +#elif TVM_LLVM_VERSION <= 90 + CHECK(tm_->addPassesToEmitFile( + pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#else + CHECK(tm_->addPassesToEmitFile( + pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#endif pass.run(*m); } else if (fmt == "ll") { mptr_->print(dest, nullptr); } else if (fmt == "bc") { +#if TVM_LLVM_VERSION <= 60 llvm::WriteBitcodeToFile(mptr_, dest); +#else + llvm::WriteBitcodeToFile(*mptr_, dest); +#endif } else { LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'"<< format << "\'"; diff --git a/tvm/src/codegen/merlinc/codeanalys_merlinc.cc b/tvm/src/codegen/merlinc/codeanalys_merlinc.cc index b1afb7800..526368b37 100644 --- a/tvm/src/codegen/merlinc/codeanalys_merlinc.cc +++ b/tvm/src/codegen/merlinc/codeanalys_merlinc.cc @@ -246,7 +246,6 @@ void CodeAnalysMerlinC::PrintStorageSync(const Call* op) { // NOLINT(*) } void CodeAnalysMerlinC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - CHECK_EQ(scope, "global"); } std::string CodeAnalysMerlinC::GetType(Type t) { // NOLINT(*) diff --git a/tvm/src/codegen/opencl/codegen_aocl.cc b/tvm/src/codegen/opencl/codegen_aocl.cc index a5e977f4c..29d673b0a 100644 --- a/tvm/src/codegen/opencl/codegen_aocl.cc +++ b/tvm/src/codegen/opencl/codegen_aocl.cc @@ -204,7 +204,7 @@ void CodeGenAOCL::VisitStmt_(const Allocate* op) { } // not allocate buffer for channel or moved data - if (alloc_set.find(vid) == alloc_set.end()) { + if (alloc_set_.find(vid) == alloc_set_.end()) { this->PrintIndent(); // allocate stream channels @@ -324,16 +324,12 @@ void CodeGenAOCL::VisitExpr_(const StreamExpr* op, std::ostream& os) { i++; } switch (op->stream_type) { - case StreamType::Channel: - os << "read_channel_intel("; - os << vid << ")"; - break; - case StreamType::Pipe: + case StreamType::DoubleBuffer: os << "read_pipe("; break; case StreamType::FIFO: - // buffered channel - os << "fifo"; + os << "read_channel_intel("; + os << vid << ")"; break; } } @@ -361,9 +357,10 @@ void CodeGenAOCL::VisitStmt_(const KernelDef* op) { // streamed arg position to channel index std::unordered_map stream_args; - for (size_t j = 0; j < op->channels.size(); j=j+2) { - int pos = op->channels[j].as()->value; - int idx = op->channels[j+1].as()->value; + for (size_t j = 0; j < op->channels.size(); j=j++) { + auto info = op->channels[j]; + int pos = info[0].as()->value; + int idx = info[1].as()->value; stream_args[pos] = idx; } @@ -375,8 +372,8 @@ void CodeGenAOCL::VisitStmt_(const KernelDef* op) { // for top kernel functions if (vid.find("_channel")) { vid.replace(vid.find("_channel"), 8, ""); - alloc_set.insert(vid); - alloc_set.insert(vid + "_new"); + alloc_set_.insert(vid); + alloc_set_.insert(vid + "_new"); } if (stream_args.count(i)) { @@ -456,17 +453,14 @@ void CodeGenAOCL::VisitStmt_(const StreamStmt* op) { i++; } switch (op->stream_type) { - case StreamType::Channel: + case StreamType::FIFO: stream << "write_channel_intel("; stream << vid << ", "; break; - case StreamType::Pipe: + case StreamType::DoubleBuffer: stream << "write_pipe("; stream << vid << ", "; break; - case StreamType::FIFO: - stream << "fifo("; - break; } PrintExpr(op->value, stream); stream << ");\n"; diff --git a/tvm/src/codegen/opencl/codegen_aocl_host.cc b/tvm/src/codegen/opencl/codegen_aocl_host.cc index f2904c8fe..c33b5936f 100644 --- a/tvm/src/codegen/opencl/codegen_aocl_host.cc +++ b/tvm/src/codegen/opencl/codegen_aocl_host.cc @@ -199,7 +199,7 @@ void CodeGenAOCLHost::VisitStmt_(const Allocate* op) { // skip if buffer allocated in host scope } else if (vid.find("_channel") != std::string::npos) { vid.replace(vid.find("_channel"), 8, ""); - if (alloc_set.find(vid) != alloc_set.end()) { + if (alloc_set_.find(vid) != alloc_set_.end()) { not_alloc = true; } } @@ -207,7 +207,7 @@ void CodeGenAOCLHost::VisitStmt_(const Allocate* op) { // not allocate for moved data if (!not_alloc) { PrintType(op->type, stream); - alloc_set.insert(vid); + alloc_set_.insert(vid); stream << ' '<< vid; if (constant_size > 1) {// Transfer length one array to scalar stream << "["; diff --git a/tvm/src/codegen/opencl/codegen_xocl.cc b/tvm/src/codegen/opencl/codegen_xocl.cc index 8666f7c62..bea52fb8a 100644 --- a/tvm/src/codegen/opencl/codegen_xocl.cc +++ b/tvm/src/codegen/opencl/codegen_xocl.cc @@ -198,17 +198,6 @@ void CodeGenXOCL::VisitStmt_(const StreamStmt* op) { std::string vid = GetVarID(op->buffer_var.get()); PrintIndent(); stream << vid; - switch (op->stream_type) { - case StreamType::Channel: - stream << "[channel]"; - break; - case StreamType::FIFO: - stream << "[fifo]"; - break; - case StreamType::Pipe: - stream << "[pipe]"; - break; - } stream << ".write"; PrintExpr(op->value, stream); stream << ";\n"; diff --git a/tvm/src/codegen/opencl/codegen_xocl_host.cc b/tvm/src/codegen/opencl/codegen_xocl_host.cc index 3505d032d..1b3c6a75a 100644 --- a/tvm/src/codegen/opencl/codegen_xocl_host.cc +++ b/tvm/src/codegen/opencl/codegen_xocl_host.cc @@ -19,19 +19,7 @@ void CodeGenXOCLHost::AddFunction(LoweredFunc f, } void CodeGenXOCLHost::PrintType(Type t, std::ostream& os) { - if (t.is_uint() || t.is_int() || t.is_fixed() || t.is_ufixed()) { - if (t.is_uint()) { - os << "ap_uint<" << t.bits() << ">"; - } else if (t.is_int()) { - os << "ap_int<" << t.bits() << ">"; - } else if (t.is_ufixed()) { - os << "ap_ufixed<" << t.bits() << ", " << t.bits() - t.fracs() << ">"; - } else { - os << "ap_fixed<" << t.bits() << ", " << t.bits() - t.fracs() << ">"; - } - } else { - CodeGenC::PrintType(t, os); - } + CodeGenC::PrintType(t, os); } std::string CodeGenXOCLHost::GetBufferRef(Type t, const Variable* buffer, Expr index) { @@ -195,7 +183,16 @@ void CodeGenXOCLHost::VisitStmt_(const Allocate* op) { // skip if buffer allocated in host scope } else if (vid.find("_channel") != std::string::npos) { vid.replace(vid.find("_channel"), 8, ""); - if (alloc_set.find(vid) != alloc_set.end()) { + + // handle output-update-in-kernel case + if (vid.find("_update") != std::string::npos) { + auto name = var_idmap_[op->buffer_var.get()]; + name.replace(name.find("_update"), 7, ""); + vid.replace(vid.find("_update"), 7, ""); + var_idmap_[op->buffer_var.get()] = name; + } + + if (alloc_set_.find(vid) != alloc_set_.end()) { not_alloc = true; } else { for (auto& name : arg_names) { @@ -207,7 +204,7 @@ void CodeGenXOCLHost::VisitStmt_(const Allocate* op) { // not allocate for moved data if (!not_alloc) { PrintType(op->type, stream); - alloc_set.insert(vid); + alloc_set_.insert(vid); stream << ' '<< vid; if (constant_size > 1) {// Transfer length one array to scalar stream << "["; @@ -230,8 +227,16 @@ void CodeGenXOCLHost::VisitStmt_(const Allocate* op) { void CodeGenXOCLHost::VisitStmt_(const KernelStmt* op) { std::string name = op->name; // extract annotation information + std::unordered_map> mem_mapping; + CHECK(op->annotate_values.size() == 3 * op->args.size()); + for (size_t i = 0; i < op->args.size(); i++) { + int pos = op->annotate_values[3*i+0].as()->value; + int mem = op->annotate_values[3*i+1].as()->value; + int port = op->annotate_values[3*i+2].as()->value; + mem_mapping[pos] = {mem, port}; + } + // initialize buffers and opencl kernel - int mem_count = 0; if (name.find("test") != std::string::npos) { // create kernels @@ -258,22 +263,72 @@ void CodeGenXOCLHost::VisitStmt_(const KernelStmt* op) { arg_name.replace(arg_name.find("_channel"), 8, ""); kernel_args.push_back(arg_name); + // check buffer types + CHECK(mem_mapping.count(k)); + CHECK(mem_mapping.at(k).size() == 2); + auto type = static_cast(mem_mapping[k][0]); + unsigned int port = mem_mapping[k][1]; PrintIndent(); - stream << "cl::Buffer buffer_" - << arg_name - << "(context, " - << "CL_MEM_USE_HOST_PTR | CL_MEM_READ_WRITE, " - << "sizeof("; - PrintType(handle_data_type_[v], stream); - stream << ")*"; - for (size_t i = 0; i < shape.size(); i++) { - if (i != 0) stream << "*"; - stream << shape[i]; - } - stream << ", " << arg_name - << ", &err);\n"; - mem_count += 1; + if (type == StorageType::devDRAM) { + stream << "cl::Buffer buffer_" + << arg_name + << "(context, " + << "CL_MEM_USE_HOST_PTR | CL_MEM_READ_WRITE, " + << "sizeof("; + PrintType(handle_data_type_[v], stream); + stream << ")*"; + for (size_t i = 0; i < shape.size(); i++) { + if (i != 0) stream << "*"; + stream << shape[i]; + } + + stream << ", " << arg_name + << ", &err);\n"; + + // high bandwidth memory + } else if (type == StorageType::devHBM) { + if (decl_stream.str().find("HBM") == std::string::npos) { + decl_stream << R"( +#define MAX_HBM_BANKCOUNT 32 +#define BANK(n) n | XCL_MEM_TOPOLOGY +const int bank[MAX_HBM_BANKCOUNT] = { + BANK(0), BANK(1), BANK(2), BANK(3), BANK(4), + BANK(5), BANK(6), BANK(7), BANK(8), BANK(9), + BANK(10), BANK(11), BANK(12), BANK(13), BANK(14), + BANK(15), BANK(16), BANK(17), BANK(18), BANK(19), + BANK(20), BANK(21), BANK(22), BANK(23), BANK(24), + BANK(25), BANK(26), BANK(27), BANK(28), BANK(29), + BANK(30), BANK(31) +}; +)"; + // create tcl script + cfg_stream << "[connectivity]\n"; + } + auto name = "BufExt_" + arg_name; + // create external mem pointer + stream << "cl_mem_ext_ptr_t " << name << ";\n"; + stream << " " << name << ".flags = bank[" << port << "];\n"; + stream << " " << name << ".parameter = 0;\n"; + stream << " " << name << ".obj = &" << arg_name << "[0];\n"; + PrintIndent(); + stream << "cl::Buffer buffer_" + << arg_name + << "(context, " + << "CL_MEM_EXT_PTR_XILINX | " + << "CL_MEM_USE_HOST_PTR | CL_MEM_READ_WRITE, " + << "sizeof("; + PrintType(handle_data_type_[v], stream); + stream << ")*"; + for (size_t i = 0; i < shape.size(); i++) { + if (i != 0) stream << "*"; + stream << shape[i]; + } + stream << ", &" << name << ", &err);\n\n"; + // assign memory channel ports + cfg_stream << "sp=" << op->name << "." + << arg_name << ":HBM[" << port << "]\n"; + } } // set kernel arguments diff --git a/tvm/src/lang/ir.cc b/tvm/src/lang/ir.cc index 5994a5dd8..90dd6445c 100644 --- a/tvm/src/lang/ir.cc +++ b/tvm/src/lang/ir.cc @@ -158,6 +158,7 @@ TVM_REGISTER_NODE_TYPE(Reuse); TVM_REGISTER_NODE_TYPE(Partition); TVM_REGISTER_NODE_TYPE(Stencil); TVM_REGISTER_NODE_TYPE(ExternModule); +TVM_REGISTER_NODE_TYPE(Print); } // namespace ir } // namespace TVM diff --git a/tvm/src/pass/ir_mutator.cc b/tvm/src/pass/ir_mutator.cc index 38d4de635..aa6ca64e2 100644 --- a/tvm/src/pass/ir_mutator.cc +++ b/tvm/src/pass/ir_mutator.cc @@ -410,6 +410,16 @@ Stmt IRMutator::Mutate_(const Stencil *op, const Stmt &s) { } } +Stmt IRMutator::Mutate_(const Print *op, const Stmt &s) { + auto new_values = MutateArray(op->values, this); + + if (op->values.same_as(new_values)) { + return s; + } else { + return Print::make(new_values, op->format); + } +} + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt) @@ -434,7 +444,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(While) .DISPATCH_TO_MUTATE_STMT(Reuse) .DISPATCH_TO_MUTATE_STMT(Partition) -.DISPATCH_TO_MUTATE_STMT(Stencil); +.DISPATCH_TO_MUTATE_STMT(Stencil) +.DISPATCH_TO_MUTATE_STMT(Print); // Mutate Expr diff --git a/tvm/src/pass/ir_visitor.cc b/tvm/src/pass/ir_visitor.cc index 8c549f56c..104de36ef 100644 --- a/tvm/src/pass/ir_visitor.cc +++ b/tvm/src/pass/ir_visitor.cc @@ -285,6 +285,12 @@ void IRVisitor::Visit_(const Stencil *op) { this->Visit(op->body); } +void IRVisitor::Visit_(const Print *op) { + for (size_t i = 0; i < op->values.size(); i++) { + this->Visit(op->values[i]); + } +} + #define DEFINE_OP_NO_VISIT_(OP) \ void IRVisitor::Visit_(const OP* op) {} @@ -358,7 +364,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Reuse) .DISPATCH_TO_VISIT(Partition) .DISPATCH_TO_VISIT(Stencil) -.DISPATCH_TO_VISIT(ExternModule); +.DISPATCH_TO_VISIT(ExternModule) +.DISPATCH_TO_VISIT(Print); } // namespace ir } // namespace TVM diff --git a/tvm/src/pass/storage_flatten.cc b/tvm/src/pass/storage_flatten.cc index 4e9d4da15..9f208147c 100644 --- a/tvm/src/pass/storage_flatten.cc +++ b/tvm/src/pass/storage_flatten.cc @@ -148,7 +148,8 @@ class StorageFlattener : public IRMutator { Stmt Mutate_(const Realize* op, const Stmt& s) final { TensorKey key{op->func, op->value_index}; if (buf_map_.count(key)) { - CHECK(buf_map_.at(key).external); + // CHECK(buf_map_.at(key).external) + // << key.f << " not in external buffer bindings"; return this->Mutate(op->body); } else { // create a buffer entry @@ -451,6 +452,7 @@ class StorageFlattener : public IRMutator { return this->Mutate(op->body); } TensorKey key{tensor->op, tensor->value_index}; + if (!buf_map_.count(key)) return this->Mutate(op->body); CHECK(buf_map_.count(key)) << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index; const BufferEntry& be = buf_map_.at(key); diff --git a/tvm/src/pass/stream_inference.cc b/tvm/src/pass/stream_inference.cc index 6a57ecfb2..79f91ff5a 100644 --- a/tvm/src/pass/stream_inference.cc +++ b/tvm/src/pass/stream_inference.cc @@ -243,6 +243,85 @@ class AccessCollector : public ir::IRMutator { } }; +// create streaming channels across loop iterations +class LoopbackMutator : public ir::IRMutator { + public: + explicit LoopbackMutator( + const VarExpr& target_buf, const Array& shape, + const std::unordered_map& range, + Type type) + : target_buf_(target_buf), shape_(shape), + range_(range), type_(type) {} + + // FIXME: buffer mismatch + Stmt Mutate_(const Store* op, const Stmt& s) { + if (op->buffer_var->name_hint == target_buf_->name_hint) { + if (store_count == 0) { + store_count += 1; + CHECK(!temp_.defined()); + temp_ = VarExpr("temp_" + target_buf_->name_hint); + auto index = IntImm::make(Int(32), 0); + Expr load_expr = Load::make(type_, + temp_, index, op->predicate); + save_stmt = Store::make(op->buffer_var, + load_expr, op->index, op->predicate); + + Stmt stmt = Store::make(temp_, op->value, index, op->predicate); + stmt = Allocate::make(temp_, type_, Array(), + make_const(Bool(type_.lanes()), true), stmt); + stmt = AttrStmt::make(temp_, attr::storage_scope, + StringImm::make("local"), stmt); + return stmt; + + } else { + store_count += 1; + auto index = IntImm::make(Int(32), 0); + return Store::make(temp_, op->value, index, op->predicate); + } + } + return IRMutator::Mutate_(op, s); + } + + Expr Mutate_(const Load* op, const Expr& e) { + if (op->buffer_var->name_hint == target_buf_->name_hint) { + if (store_count > 0) { + auto index = IntImm::make(Int(32), 0); + return Load::make(op->type, temp_, index, op->predicate); + } + } + return e; + } + + // create stream array + Stmt Mutate_(const For* op, const Stmt& s) { + + if (op->body.as() == nullptr) { + Stmt stmt = this->Mutate(op->body); + stmt = Block::make(stmt, save_stmt); + return For::make( + op->loop_var, op->min, op->extent, op->for_type, + op->device_api, stmt, op->annotate_keys, + op->annotate_values); + + } else { + Stmt stmt = this->Mutate(op->body); + return For::make( + op->loop_var, op->min, op->extent, op->for_type, + op->device_api, stmt, op->annotate_keys, + op->annotate_values); + } + } + + private: + const VarExpr& target_buf_; + const Array& shape_; + const std::unordered_map& range_; + Type type_; + VarExpr temp_; + int store_count{0}; + Stmt save_stmt; +}; + /*! * \brief An IRVisitor to collect information * of undefined variables @@ -419,6 +498,18 @@ class StreamAnalyzer final : public IRMutator { Stmt Mutate_(const StreamStmt* op, const Stmt& s) final { + // TODO add config info to ir node + if (auto val = op->value.as()) { + if (val->value == "config") { + CHECK(op->annotate_values.size() == 2); + Array dev_port(op->annotate_values); + auto buffer = op->buffer_var.as(); + CHECK(buffer != nullptr); + mem_ports[buffer->name] = dev_port; + return Evaluate::make(0); + } + } + if (auto buf = op->buffer_var.as()) { std::string name = buf->name; VarExpr buf_var(bind_buffer_map_[name].node_); @@ -454,6 +545,7 @@ class StreamAnalyzer final : public IRMutator { if (!new_var_nodes.count(var.get())) { new_vars.push_back(var); new_var_nodes.insert(var.get()); + return; } } auto it = use_count_.find(var.get()); @@ -482,6 +574,8 @@ class StreamAnalyzer final : public IRMutator { std::unordered_set new_var_nodes; std::unordered_map> shape_; std::unordered_map dtype_; + // extract memory interface information + std::unordered_map> mem_ports; }; @@ -491,8 +585,8 @@ class StreamMutator : public IRMutator { Stmt Mutate_(const KernelDef *op, const Stmt& s) final { // check the kernel channels - CHECK(op->channels.size() % 4 == 0) - << "wrong index number in channels"; + CHECK(op->channels.size() <= op->args.size()) + << "conflicting entries in op->channels"; // TODO: match buffer to extract graph for (auto& arg : op->args) { std::string name = arg.get()->name_hint; @@ -515,9 +609,11 @@ class StreamMutator : public IRMutator { } // insert (position, channel idx) into map - for (size_t i = 0; i < op->channels.size(); i+=4) { - auto pos = op->channels[i].as()->value; - auto idx = op->channels[i+1].as()->value; + for (size_t i = 0; i < op->channels.size(); i++) { + Array info = op->channels[i]; + CHECK(info.size() == 6); + auto pos = info[0].as()->value; + auto idx = info[1].as()->value; kernel_arg_map[op->name].push_back(pos); kernel_arg_map[op->name].push_back(idx); kernel_channel_map[op->name].insert(idx); @@ -823,6 +919,58 @@ class InfoUpdater final : public IRMutator { const std::vector& marked_buffer_; }; +// create local copy and sync with data copy +class MultiLoadMutator : public IRMutator { + public: + explicit MultiLoadMutator( + std::string& target, + std::vector& channels, Type type) + : target_(target), channels_(channels), type_(type) {} + + Stmt Mutate(Stmt stmt) final { + Stmt ret = IRMutator::Mutate(stmt); + if (found && !alloc) { + for (auto& channel : channels_) { + auto stream_expr = StreamExpr::make(type_, + VarExpr(channel.node_), StreamType::FIFO, + 1, Array(), Array()); + + auto store = Store::make(temp_, + stream_expr, Expr(0), const_true()); + ret = Block::make(store, ret); + } + ret = Allocate::make(temp_, type_, Array(), + make_const(Bool(type_.lanes()), true), ret); + ret = AttrStmt::make(temp_, attr::storage_scope, + StringImm::make("local"), ret); + alloc = true; + } + return ret; + } + + Expr Mutate_(const Load *op, const Expr& e) final { + Expr index = op->index; + std::string target_name = op->buffer_var.get()->name_hint; + + Stmt stmt; + if (target_name == target_) { + found = true; + temp_ = VarExpr("temp_" + target_); + return Load::make(op->type, temp_, index, op->predicate); + } else { + return Load::make(op->type, op->buffer_var, index, op->predicate); + } + } + + private: + std::string& target_; + std::vector& channels_; + Type type_; + VarExpr temp_; + bool found{false}; + bool alloc{false}; +}; + // create local copy and multiple streaming channels class MultiCastMutator : public IRMutator { public: @@ -841,7 +989,7 @@ class MultiCastMutator : public IRMutator { for (auto& channel : channels_) { auto stream_stmt = StreamStmt::make( VarExpr(channel.node_), temp, - StreamType::Channel, 1, Array(), Array()); + StreamType::FIFO, 1, Array(), Array()); stmt = Block::make(stmt, stream_stmt); } stmt = Allocate::make(temp, type_, Array(), @@ -925,10 +1073,23 @@ class StmtGrpReplacer final : public IRMutator { } body = Substitute(body, subst); + // create buffers if api_args used in kernel body + auto undefs = UndefinedVars(body, Array()); + for (auto& var : undefs) { + if (var->name_hint.find(".channel") == std::string::npos) { + auto name = var->name_hint; + Type type = dtype_[name]; + Array shape = shape_[name]; + body = Allocate::make(var, type, shape, + make_const(Bool(type.lanes()), true), body); + body = AttrStmt::make(var, attr::storage_scope, + StringImm::make("global"), body); + } + } auto kernel = KernelDef::make(new_vars, shapes, types, Array(), body, UIntImm::make(UInt(1), 1), - UInt(32), "test", Array()); + UInt(32), "test", Array>()); kernel_defs_.push_back(kernel); Stmt stmt = KernelStmt::make(func_call_args, "test"); @@ -958,7 +1119,22 @@ class StmtGrpReplacer final : public IRMutator { auto target = VarExpr(op->node.node_)->name_hint; int index = op->value.as()->value; - CHECK(index) << "invalid attr value " << op->value; + + // self-loopback + if (index == 0) { + Stmt stmt = this->Mutate(op->body); + Type dtype = dtype_[target]; + Array shape = shape_[target]; + auto range_ = CollectIterRange(op->body); + + // replace with local temp + auto target_buf = VarExpr(op->node.node_); + LoopbackMutator mutator( + target_buf, shape, range_, dtype); + stmt = mutator.Mutate(stmt); + return stmt; + } + bool data_load = index < 0 ? false : true; if (index < 0) index = -1 * index; @@ -976,12 +1152,13 @@ class StmtGrpReplacer final : public IRMutator { // check wrapped attr stmt in multi-cast if (map[new_target].size() > 0) { - CHECK(!new_load); + // CHECK(!new_load) << "only support multiple writing in nest attrs"; for (auto& v : map[new_target]) - CHECK(!v.data_load) - << "cannot support loading " << new_target - << " from multiple sources"; + if (v.data_load) { + LOG(WARNING) << "joining target tensor " << new_target; + } } + if (new_index < 0) new_index = -1 * new_index; map[new_target].push_back({new_index, new_load}); } @@ -1006,7 +1183,7 @@ class StmtGrpReplacer final : public IRMutator { << "cannot find channel buffer " << name; auto channel_buf = VarExpr(channel_map_[name].node_); LoadToStreamExprConverter mutator( - target, StreamType::Channel, + target, StreamType::FIFO, channel_buf, 1, index, shape, range_); return mutator.Mutate(body); @@ -1015,7 +1192,7 @@ class StmtGrpReplacer final : public IRMutator { channel_map_[name] = channel_buf; StoreToStreamStmtConverter mutator( - target, StreamType::Channel, + target, StreamType::FIFO, channel_buf, 1, index, shape, range_); Stmt stmt = mutator.Mutate(body); @@ -1029,20 +1206,33 @@ class StmtGrpReplacer final : public IRMutator { return stmt; } - } else { // multi-cast + } else { target = kv.first; std::vector channels; // create channel buffers + size_t load_count = 0; for (auto& v : kv.second) { + if (v.data_load) load_count += 1; std::string name = target + ".pipe" + std::to_string(v.index); auto channel_buf = VarExpr(name); channel_map_[name] = channel_buf; channels.push_back(channel_buf); } - MultiCastMutator mutator(target, channels, dtype); - Stmt stmt = mutator.Mutate(body); + Stmt stmt; + // multi-casting data + if (load_count == 0) { + MultiCastMutator mutator(target, channels, dtype); + stmt = mutator.Mutate(body); + // multi-loading data + } else if (load_count == kv.second.size()){ + MultiLoadMutator mutator(target, channels, dtype); + stmt = mutator.Mutate(body); + } + + // allocate channel buffers + CHECK(stmt.defined()); for (auto& channel : channels) { stmt = Allocate::make( VarExpr(channel.node_), dtype, shape, @@ -1055,7 +1245,6 @@ class StmtGrpReplacer final : public IRMutator { return stmt; } } - } } return IRMutator::Mutate_(op, s); @@ -1088,29 +1277,56 @@ class KernelAnnotator final : public IRMutator { public: KernelAnnotator( std::unordered_map> map, + std::unordered_map> mem_ports, Array& api_args) : - arg_scope_map_(map) {} + arg_scope_map_(map), mem_ports_(mem_ports) {} Stmt Mutate_(const Allocate* op, const Stmt& s) { Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as(); std::string target_name = op->buffer_var.get()->name_hint; - if (target_name == "test") return op->body; + if (target_name == "test") { + return this->Mutate(op->body); + } return stmt; } Stmt Mutate_(const KernelDef *op, const Stmt& s) final { Stmt body = this->Mutate(op->body); - Array channels = op->channels; + Array> channels = op->channels; + + // insert annotation for top function + if (op->name == "test") { + int count = 0; + for (auto& arg : op->args) { + auto name = arg->name_hint; + // skip inner loop movement case + if (!mem_ports_.count(name)) { + LOG(INFO) << "device function within loop"; + break; + } + auto dev_port = mem_ports_[name]; + CHECK(dev_port.size() == 2); + // pos, channel index, depth, is_sedner, dev_type, mem_port + Array info = {count, -1, -1, -1, dev_port[0], dev_port[1]}; + count = count + 1; + channels.push_back(info); + } + return KernelDef::make( + op->args, op->arg_shapes, op->arg_types, + op->arg_tensors, body, op->ret_void, + op->ret_type, op->name, channels); + } // mutate kernel def body - CHECK(channels.size() % 4 == 0); if (channels.size() > 0) { - for (size_t i = 0; i < channels.size(); i+=4) { - auto pos = op->channels[i].as()->value; - auto channel = op->channels[i+1].as()->value; - auto depth = op->channels[i+2].as()->value; - auto is_sender = op->channels[i+3].as()->value; + for (size_t i = 0; i < channels.size(); i++) { + auto info = channels[i]; + CHECK(info.size() == 6); + auto pos = info[0].as()->value; + auto channel = info[1].as()->value; + auto depth = info[2].as()->value; + auto is_sender = info[3].as()->value; // create shared channel buffer VarExpr channel_buf; @@ -1136,10 +1352,14 @@ class KernelAnnotator final : public IRMutator { for (size_t i = 0; i < op->args.size(); i++) { if (set.find(i) != set.end()) { // position, channel index and depth - channels.push_back(IntImm::make(Int(32), i)); - channels.push_back(IntImm::make(Int(32), -1)); - channels.push_back(IntImm::make(Int(32), -1)); - channels.push_back(IntImm::make(Int(32), -1)); + Array info_new; + info_new.push_back(IntImm::make(Int(32), i)); + info_new.push_back(IntImm::make(Int(32), -1)); + info_new.push_back(IntImm::make(Int(32), -1)); + info_new.push_back(IntImm::make(Int(32), -1)); + info_new.push_back(IntImm::make(Int(32), -1)); + info_new.push_back(IntImm::make(Int(32), -1)); + channels.push_back(info_new); } } } @@ -1149,10 +1369,39 @@ class KernelAnnotator final : public IRMutator { op->ret_type, op->name, channels); } + // attach atributes to kernel function calls + Stmt Mutate_(const KernelStmt* op, const Stmt& s) final { + if (op->name == "test") { + int count = 0; + Array keys, values; + for (auto& arg : op->args) { + auto name = arg.as()->name_hint; + // skip inner loop movement case + if (!mem_ports_.count(name)) { + LOG(INFO) << "device function within loop"; + break; + } + auto dev_port = mem_ports_[name]; + CHECK(dev_port.size() == 2); + // pos, channel index, depth, is_sedner, dev_type, mem_port + keys.push_back(StringImm::make("pos")); + values.push_back(IntImm::make(Int(32), count)); + keys.push_back(StringImm::make("mem")); + values.push_back(dev_port[0]); + keys.push_back(StringImm::make("port")); + values.push_back(dev_port[1]); + count = count + 1; + } + return KernelStmt::make(op->args, op->name, keys, values); + } + return IRMutator::Mutate_(op, s); + } + private: std::unordered_map> arg_scope_map_; std::unordered_map channel_map_; + std::unordered_map> mem_ports_; // mutate kernel def body Stmt KernelRebuild(const VarExpr& channel_buf, @@ -1178,7 +1427,7 @@ class KernelAnnotator final : public IRMutator { } else if (is_sender == 1) { if (ac.reg_store && ac.store_num == 1) { StoreToStreamStmtConverter mutator( - target, StreamType::Channel, + target, StreamType::FIFO, channel_buf, depth, index, shape, range_); stmt = mutator.Mutate(body); @@ -1196,7 +1445,7 @@ class KernelAnnotator final : public IRMutator { VarExpr buf_var(ac.store_var.node_); stmt = BufferInserter( body, shape, buf_var, channel_buf, false, - StreamType::Channel, depth); + StreamType::FIFO, depth); } else { LOG(FATAL) << "target variable " << target << " not found; " @@ -1208,7 +1457,7 @@ class KernelAnnotator final : public IRMutator { if (ac.reg_load && ac.load_num == 1) { LoadToStreamExprConverter mutator( - target, StreamType::Channel, + target, StreamType::FIFO, channel_buf, depth, index, shape, range_); stmt = mutator.Mutate(body); @@ -1225,7 +1474,7 @@ class KernelAnnotator final : public IRMutator { VarExpr buf_var(ac.load_var.node_); stmt = BufferInserter( body, shape, buf_var, channel_buf, true, - StreamType::Channel, depth); + StreamType::FIFO, depth); } else { LOG(FATAL) << "target variable " << target << " not found; " @@ -1250,11 +1499,66 @@ class KernelAnnotator final : public IRMutator { } }; +// replace the mismatched buffers +class BufferReplacer final : public IRMutator { + public: + BufferReplacer( + std::unordered_map& bind_buffer_map, + Array& undefined_vars) + : bind_buffer_map_(bind_buffer_map), + undefined_vars_(undefined_vars) {} + + Stmt Mutate_(const Allocate* op, const Stmt& s) { + auto name = op->buffer_var->name_hint; + CHECK(bind_buffer_map_.count(name)) << name; + CHECK(bind_buffer_map_[name].get() == op->buffer_var.get()); + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + return stmt; + } + + Stmt Mutate_(const Store* op, const Stmt& s) { + auto name = op->buffer_var->name_hint; + CHECK(bind_buffer_map_.count(name)) << name; + if (bind_buffer_map_[name].get() != op->buffer_var.get()) { + Expr index = op->index; + Expr value = this->Mutate(op->value); + auto new_buf = VarExpr(bind_buffer_map_[name].node_); + CHECK(bind_buffer_map_[name].get() == new_buf.get()); + return Store::make(new_buf, value, index, op->predicate); + } + Stmt stmt = IRMutator::Mutate_(op, s); + return stmt; + } + + Expr Mutate_(const Load* op, const Expr& e) { + auto name = op->buffer_var->name_hint; + CHECK(bind_buffer_map_.count(name)) << name; + if (bind_buffer_map_[name].get() != op->buffer_var.get()) { + Expr index = op->index; + auto new_buf = VarExpr(bind_buffer_map_[name].node_); + CHECK(bind_buffer_map_[name].get() == new_buf.get()); + return Load::make(op->type, new_buf, index, op->predicate); + } + return IRMutator::Mutate_(op, e); + } + + private: + std::unordered_map& bind_buffer_map_; + Array& undefined_vars_; +}; + Stmt InferStream(Stmt stmt, Array api_args) { StreamAnalyzer analyzer(api_args); stmt = analyzer.Mutate(stmt); + // FIXME: var buffer binding error + if (analyzer.undefined_.size() > 0 ) { + stmt = BufferReplacer(analyzer.bind_buffer_map_, + analyzer.undefined_).Mutate(stmt); + } + StreamMutator mutator; stmt = mutator.Mutate(stmt); @@ -1273,7 +1577,7 @@ Stmt InferStream(Stmt stmt, // mark kernel def with storage scope stmt = KernelAnnotator(analyzer.kernel_arg_scope_, - api_args).Mutate(stmt); + analyzer.mem_ports, api_args).Mutate(stmt); return stmt; } diff --git a/tvm/src/schedule/schedule_dataflow_rewrite.cc b/tvm/src/schedule/schedule_dataflow_rewrite.cc index 7cf666a9d..f419ae1a1 100644 --- a/tvm/src/schedule/schedule_dataflow_rewrite.cc +++ b/tvm/src/schedule/schedule_dataflow_rewrite.cc @@ -37,14 +37,8 @@ class InStageMover : public ir::IRMutator { Stmt Mutate_(const For* op, const Stmt& s) { if (counter == index_) { - Stmt attr_stmt = AttrStmt::make( - VarExpr(), - attr::device_scope, - scope_, - op->body); - return For::make( - op->loop_var, op->min, op->extent, op->for_type, op->device_api, - attr_stmt, op->annotate_keys, op->annotate_values); + return AttrStmt::make( + VarExpr(), attr::device_scope, scope_, s); } else { counter += 1; return For::make( @@ -273,14 +267,19 @@ class InfoUpdater final : public IRMutator { channel_index_(channel_index), is_sender_(is_sender) { } + // add information into kernel def Stmt Mutate_(const KernelDef* op, const Stmt& s) { - Array arr = op->channels; - CHECK(op->channels.size() % 4 == 0) - << "(pos, channel index, depth) pair number mismatch"; - arr.push_back(IntImm::make(Int(32), arg_pos_)); - arr.push_back(IntImm::make(Int(32), channel_index_)); - arr.push_back(IntImm::make(Int(32), channel_depth_)); - arr.push_back(IntImm::make(Int(32), is_sender_)); + Array> arr = op->channels; + CHECK(op->channels.size() <= op->args.size()); + // (pos, channel index, depth, memory, port) pair + Array info; + info.push_back(IntImm::make(Int(32), arg_pos_)); + info.push_back(IntImm::make(Int(32), channel_index_)); + info.push_back(IntImm::make(Int(32), channel_depth_)); + info.push_back(IntImm::make(Int(32), is_sender_)); + info.push_back(IntImm::make(Int(32), -1)); // storage dev + info.push_back(IntImm::make(Int(32), -1)); // storage port + arr.push_back(info); return KernelDef::make(op->args, op->arg_shapes, op->arg_types, op->arg_tensors, op->body, op->ret_void, @@ -377,33 +376,48 @@ void Schedule::stream_to(const Tensor& target, // inter-stage data movement if (stream_pos.size() == 0) { - VarExpr node(target_buffer->data.node_); - - // create common channel buffer - InfoUpdater::channelCount += 1; - auto ch_index = InfoUpdater::channelCount; - Stmt dest_body = AttrStmt::make( - node, - attr::device_scope, - IntImm::make(Int(32), ch_index), - destOp->body); - dest->op = ExternOpNode::make(destOp->name, destOp->tag, - destOp->axis, destOp->inputs, - destOp->input_placeholders, - destOp->output_placeholders, - dest_body); - - Stmt src_body = AttrStmt::make( - node, - attr::device_scope, - IntImm::make(Int(32), -1 * ch_index), - srcOp->body); - source->op = ExternOpNode::make(srcOp->name, srcOp->tag, - srcOp->axis, srcOp->inputs, - srcOp->input_placeholders, - srcOp->output_placeholders, - src_body); + if (destOp == srcOp) { + // mutate loop body (attr_value indicates self-loop) + VarExpr node(target_buffer->data.node_); + Stmt dest_body = AttrStmt::make( + node, + attr::device_scope, + IntImm::make(Int(32), 0), + destOp->body); + dest->op = ExternOpNode::make(destOp->name, destOp->tag, + destOp->axis, destOp->inputs, + destOp->input_placeholders, + destOp->output_placeholders, + dest_body); + } else { + // create common channel buffer + VarExpr node(target_buffer->data.node_); + InfoUpdater::channelCount += 1; + auto ch_index = InfoUpdater::channelCount; + + Stmt dest_body = AttrStmt::make( + node, + attr::device_scope, + IntImm::make(Int(32), ch_index), + destOp->body); + dest->op = ExternOpNode::make(destOp->name, destOp->tag, + destOp->axis, destOp->inputs, + destOp->input_placeholders, + destOp->output_placeholders, + dest_body); + + Stmt src_body = AttrStmt::make( + node, + attr::device_scope, + IntImm::make(Int(32), -1 * ch_index), + srcOp->body); + source->op = ExternOpNode::make(srcOp->name, srcOp->tag, + srcOp->axis, srcOp->inputs, + srcOp->input_placeholders, + srcOp->output_placeholders, + src_body); + } } else { // streaming between kernel defs CHECK(stream_pos.size() == 2) << "missing pos index"; @@ -483,6 +497,7 @@ void Schedule::stage_move( } CHECK(scope.defined()) << "unsopport device "; const ExternOpNode* op = parent->op.as(); + CHECK(op) << parent << " not a extern op"; Stmt body = InStageMover(scope, occur_index).Mutate(op->body); @@ -497,12 +512,82 @@ void Schedule::stage_move( body); } +// annotate the tensor to be joined +void Schedule::join_to(const Tensor& target, + Stage source, + Stage dest, + StreamType stream_type, + int channel_depth) { + + Stage target_stage = (*this)[target]; + size_t num_stage = (*this)->stages.size(); + Buffer target_buffer; + + const PlaceholderOpNode* op = target_stage->op.as(); + bool is_placeholder = op ? true : false; + if (is_placeholder) { + for (size_t i = 0; i < num_stage; i++) { + Stage s = (*this)->stages[i]; + if (const ExternOpNode* op = s->op.as()) { + for (size_t j = 0; j < op->inputs.size(); j++) { + if (target == op->inputs[j]) { + target_buffer = op->input_placeholders[j]; + } + } + } + } + } else { // mark device scope of consumers & update kernel stmts + const ExternOpNode* op = target_stage->op.as(); + target_buffer = op->output_placeholders[0]; + } + + CHECK(source.defined()); + const ExternOpNode* src_op = source->op.as(); + CHECK(src_op) << "cannot join placeholder stage " << source; + + InfoUpdater::channelCount += 1; + auto index = InfoUpdater::channelCount; + + CHECK(target_buffer.defined()); + VarExpr node(target_buffer->data.node_); + + if (dest.defined()) { + // insert attr into collector op + const ExternOpNode* dest_op = dest->op.as(); + CHECK(dest_op) << "cannot join to placeholder stage " << dest; + Stmt body = dest_op->body; + + Stmt dest_body = AttrStmt::make( + node, + attr::device_scope, + IntImm::make(Int(32), index), + dest_op->body); + dest->op = ExternOpNode::make(dest_op->name, dest_op->tag, + dest_op->axis, dest_op->inputs, + dest_op->input_placeholders, + dest_op->output_placeholders, + dest_body); + + } else { // create result collector stage + + } + Stmt src_body = AttrStmt::make( + node, + attr::device_scope, + IntImm::make(Int(32), -1 * index), + src_op->body); + source->op = ExternOpNode::make( + src_op->name, src_op->tag, src_op->axis, src_op->inputs, + src_op->input_placeholders, src_op->output_placeholders, src_body); +} + // move data to device Tensor Schedule::move_to(const Tensor& target, + Stage parent, DeviceType device_type, StreamType stream_type, int channel_depth, - int occurrence) { + Array dev_ports) { Stage target_stage = (*this)[target]; std::vector consumers; size_t num_stage = (*this)->stages.size(); @@ -510,9 +595,16 @@ Tensor Schedule::move_to(const Tensor& target, ArrayNode* stages = (*this)->stages.CopyOnWrite(); Buffer target_buffer; + // parse the memory module interface + CHECK(dev_ports.size() == 2); + auto dev_type = dev_ports[0].as()->value; + auto mem_port = dev_ports[1].as()->value; + // StorageType dev = static_cast(dev_type); + // create producer and consumer stages for placeholder const PlaceholderOpNode* op = target_stage->op.as(); bool is_placeholder = op ? true : false; + if (is_placeholder) { min_pos = 0; for (size_t i = 0; i < num_stage; i++) { @@ -531,31 +623,34 @@ Tensor Schedule::move_to(const Tensor& target, min_pos = FindNodeRef(stages, target_stage) + 1; const ExternOpNode* op = target_stage->op.as(); target_buffer = op->output_placeholders[0]; - int use_count = 1; // use count for (size_t i = 0; i < num_stage; i++) { Stage s = (*this)->stages[i]; if (const ExternOpNode* stage_op = s->op.as()) { for (size_t j = 0; j < stage_op->inputs.size(); j++) { if (op->output_placeholders[0] == stage_op->input_placeholders[j]) { + consumers.push_back(s); + } + } + } + } + } - // find out the last usage of target tensor - if (occurrence == 0) { - min_pos = i + 1; - consumers.push_back(s); - break; - - // udpate minimal pos until hit occurrence - } else if (use_count < occurrence) { - min_pos = i + 1; - use_count += 1; - break; - - // add stages after hitting the boundary - } else if (occurrence > 1) { - consumers.push_back(s); - break; - } + if (parent.defined()) { // stream modified tensor + target_stage = parent; + min_pos = FindNodeRef(stages, parent) + 1; + const ExternOpNode* op = parent->op.as(); + CHECK(op) << parent << " not a extern op"; + CHECK(target_buffer.defined()) + << " not found buffer for target tensor"; + consumers.clear(); + for (size_t i = 0; i < num_stage; i++) { + Stage s = (*this)->stages[i]; + if (const ExternOpNode* stage_op = s->op.as()) { + for (size_t j = 0; j < stage_op->inputs.size(); j++) { + if (op->output_placeholders[0] == + stage_op->input_placeholders[j]) { + consumers.push_back(s); } } } @@ -566,8 +661,9 @@ Tensor Schedule::move_to(const Tensor& target, Array consumer_inputs; Array consumer_input_placeholders; Array consumer_output_placeholders; - std::string consumer_name = target_buffer->name + ".channel"; - // to be binded with the (channel) tensor + std::string consumer_name = target->op->name + ".channel"; + if (parent.defined()) consumer_name = target->op->name + ".update.channel"; + Buffer channel_buffer = BufferNode::make( Var(consumer_name, Handle()), target->dtype, @@ -576,38 +672,58 @@ Tensor Schedule::move_to(const Tensor& target, Expr(), consumer_name, "", 0, 0); - // input target tensor and output channel buffer - consumer_inputs.push_back(target); - consumer_input_placeholders.push_back(target_buffer); + + if (!parent.defined()) { + consumer_inputs.push_back(target); + consumer_input_placeholders.push_back(target_buffer); + } else { + const ExternOpNode* prt = parent->op.as(); + CHECK(prt) << "stage " << parent << " not extern op"; + consumer_inputs.push_back(parent->op.output(0)); + consumer_input_placeholders.push_back(prt->output_placeholders[0]); + } consumer_output_placeholders.push_back(channel_buffer); // create statement index + Array consumer_axis; std::vector csm_indices; std::vector csm_loop_vars; for (size_t i = 0; i < target->shape.size(); i++) { - VarExpr iter(target_buffer->name + std::to_string(i)); + VarExpr iter(target->op->name + std::to_string(i)); csm_indices.push_back(iter); csm_loop_vars.push_back(iter); + IterVar inner = IterVarNode::make( + Range(0, target->shape[i]), Var(iter.node_), kDataPar); + consumer_axis.push_back(inner); } - Expr csm_index = FlattenIndices(csm_indices, target->shape); - Expr load_expr = Load::make(target->dtype, - VarExpr(target_buffer.node_), - csm_index, - UIntImm::make(UInt(1), 1)); - Stmt consumer_body = StreamStmt::make(VarExpr(channel_buffer.node_), - load_expr, stream_type, channel_depth); - Array consumer_axis; - for (size_t j = 0; j < target->shape.size(); j++) { - auto iter = csm_loop_vars[j]; - consumer_axis.push_back(IterVarNode::make( - Range(0, target->shape[j]), Var(iter.node_), kDataPar)); - consumer_body = For::make( - VarExpr(iter.node_), - 0, target->shape[j], - ForType::Serial, - DeviceAPI::None, - consumer_body); + Expr csm_index = FlattenIndices(csm_indices, target->shape); + Expr load_expr = Load::make(target->dtype, VarExpr(target_buffer.node_), + csm_index, UIntImm::make(UInt(1), 1)); + + Stmt consumer_body = StreamStmt::make( + VarExpr(channel_buffer.node_), + load_expr, stream_type, channel_depth); + + // mark dev and port information + Array mark_keys, mark_vals; + mark_keys.push_back(StringImm::make("dev")); + mark_keys.push_back(StringImm::make("port")); + mark_vals.push_back(IntImm::make(Int(32), dev_type)); + mark_vals.push_back(IntImm::make(Int(32), mem_port)); + Stmt info = StreamStmt::make(VarExpr(channel_buffer.node_), + Expr("config"), StreamType::FIFO, 0, mark_keys, mark_vals); + consumer_body = Block::make(info, consumer_body); + + // make for loops for sender side + for (int j = target->shape.size()-1; j >= 0; j--) { + auto iter = csm_loop_vars[j]; + auto inner = consumer_axis[j]; + // inner loop scope attr stmt + consumer_body = AttrStmt::make(inner, attr::loop_scope, + inner->var, consumer_body); + consumer_body = For::make(VarExpr(iter.node_), 0, target->shape[j], + ForType::Serial, DeviceAPI::None, consumer_body); } // create new stage and return stream tensors @@ -632,7 +748,8 @@ Tensor Schedule::move_to(const Tensor& target, Array producer_output_placeholders; // new buffer copy of original data - std::string producer_name = target_buffer->name + ".new"; + std::string producer_name = target->op->name + ".new"; + if (parent.defined()) producer_name = target->op->name + ".update.new"; Buffer output_buffer = BufferNode::make( Var(producer_name, Handle()), target->dtype, @@ -649,10 +766,14 @@ Tensor Schedule::move_to(const Tensor& target, // create for loops for tensor init std::vector indices; std::vector loop_vars; + Array producer_axis; for (size_t i = 0; i < target->shape.size(); i++) { - VarExpr iter(target_buffer->name + std::to_string(i)); + VarExpr iter(target->op->name + std::to_string(i)); indices.push_back(iter); loop_vars.push_back(iter); + IterVar inner = IterVarNode::make( + Range(0, target->shape[i]), Var(iter.node_), kDataPar); + producer_axis.push_back(inner); } Expr index = FlattenIndices(indices, target->shape); // streaming producer tensor reading from channel @@ -663,23 +784,19 @@ Tensor Schedule::move_to(const Tensor& target, Stmt for_stmt = Store::make(VarExpr(output_buffer.node_), stream, index, UIntImm::make(UInt(1), 1)); - Array producer_axis; - for (size_t j = 0; j < target->shape.size(); j++) { - auto iter = loop_vars[j]; - producer_axis.push_back(IterVarNode::make( - Range(0, target->shape[j]), Var(iter.node_), kDataPar)); - for_stmt = For::make( - VarExpr(iter.node_), - 0, target->shape[j], - ForType::Serial, - DeviceAPI::None, - for_stmt); + for (int j = target->shape.size()-1; j >= 0; j--) { + auto iter = loop_vars[j]; + auto inner = producer_axis[j]; + // inner loop scope attr stmt + for_stmt = AttrStmt::make(inner, attr::loop_scope, inner->var, for_stmt); + for_stmt = For::make(VarExpr(iter.node_), 0, target->shape[j], + ForType::Serial, DeviceAPI::None, for_stmt); } Stmt body = for_stmt; // same buffer under different device scoep Tensor producer = ExternOpNode::make( - target_buffer->name + ".new", + producer_name, "", producer_axis, producer_inputs, @@ -687,12 +804,12 @@ Tensor Schedule::move_to(const Tensor& target, producer_output_placeholders, body).output(0); - // recv stage creation + return tensor Stage producer_stage = Stage(producer->op); producer_stage->device_type = static_cast(device_type); size_t pos = FindNodeRef(stages, consumer_stage); stages->data.insert(stages->data.begin() + pos, producer_stage.node_); (*this)->stage_map.Set(producer->op, producer_stage); + // add producer as output stage if output moved to host if (target_stage->is_output && static_cast(device_type) == DeviceType::devHost) { @@ -702,28 +819,38 @@ Tensor Schedule::move_to(const Tensor& target, } // update consumer stages with new tensor and buffer - // std::unordered_map vsub; - // vsub[target_buffer->data.as()] = output_buffer->data; std::unordered_map vsub; std::unordered_map vsub2newvar; vsub[target] = producer; vsub2newvar[target_buffer->data.as()] = output_buffer; + for (Stage s : consumers) { CHECK(s->op.as()); Operation repl_op = s->op->ReplaceInputs(s->op, vsub); - CHECK(!repl_op.same_as(s->op)) - << "Cannot find " << target - << " in the inputs of " << s->op; + + // udpate stage not having orginal tensor input auto op = repl_op.as(); Stmt repl_body = LoadReplacer(vsub2newvar).Mutate(op->body); + + Array new_inputs; + Array new_input_placeholders; + if (parent.defined()) { + new_inputs.push_back(producer); + new_input_placeholders.push_back(output_buffer); + } else { + new_inputs = op->inputs; + new_input_placeholders = op->input_placeholders; + } + s->op = ExternOpNode::make( op->name, op->tag, op->axis, - op->inputs, - op->input_placeholders, + new_inputs, + new_input_placeholders, op->output_placeholders, repl_body); + (*this)->stage_map.Set(s->op, s); } producer_stage->group = target_stage->group; if (producer_stage->group.defined()) { @@ -924,7 +1051,18 @@ void Schedule::reshape(const Tensor& target, Array new_shape) { Stage target_stage = (*this)[target]; const ExternOpNode* op = target_stage->op.as(); Buffer target_buffer = op->output_placeholders[0]; - // TODO: check the #elem is the same for both shapes + // check the #elem is the same for both shapes + size_t size = 1, origin = 1; + for (auto& dim : new_shape) { + CHECK(dim.as()) << dim << " must be a positive integrer"; + size *= dim.as()->value; + } + for (auto& dim : target_buffer->shape) { + CHECK(dim.as()) << dim << " must be a positive integrer"; + origin *= dim.as()->value; + } + CHECK_EQ(origin, size) + << "new shape must have same element number as original shape"; target_buffer->shape = new_shape; } diff --git a/tvm/src/schedule/schedule_lang.cc b/tvm/src/schedule/schedule_lang.cc index 624c159a1..c48cfdaa4 100644 --- a/tvm/src/schedule/schedule_lang.cc +++ b/tvm/src/schedule/schedule_lang.cc @@ -237,7 +237,6 @@ void ComputeAt(StageNode* producer, Stmt producer_stmt = producer_op->body; Stmt consumer_stmt = consumer_op->body; Buffer producer_buf = producer_op->output_placeholders[0]; - Array reuse_shape; std::unordered_map sub; Stmt new_stmt = PerformComputeAt(producer_stmt, consumer_stmt, producer_buf, var, attach_level, sub); producer->op = ExternOpNode::make(producer_op->name, diff --git a/tvm/src/schedule/schedule_ops.cc b/tvm/src/schedule/schedule_ops.cc index 8156844f5..503151bec 100644 --- a/tvm/src/schedule/schedule_ops.cc +++ b/tvm/src/schedule/schedule_ops.cc @@ -343,13 +343,14 @@ Stmt ScheduleOps( CHECK_EQ(g->leaf_iter_vars.size(), 0U); } // reverse the post DFS order. + Array not_found_stages; for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; CHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops"; CHECK(s->op.defined()); // no need to specify place holder op. - if (auto op = s->op.as()) continue; + if (s->op.as()) continue; // Remove grouping sugar, get the real attach spec. Stage attach_spec = s.GetAttachSpec(); @@ -368,11 +369,20 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectAttach mutator(s, attach_spec, dom_map, sch); body = mutator.Mutate(body); - CHECK(mutator.found_attach) + if (!mutator.found_attach) { + not_found_stages.push_back(s); + LOG(WARNING) << "did not find attachment point for " << s << " in " - << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar - << ", body:\n" - << body; + << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar; + } + + } + } + if (not_found_stages.size() > 0) { + for (auto s : not_found_stages) { + Stage attach_spec = s.GetAttachSpec(); + InjectAttach mutator(s, attach_spec, dom_map, sch); + body = mutator.Mutate(body); } } SchedulePostProc post_proc; diff --git a/tvm/src/schedule/schedule_reorder.cc b/tvm/src/schedule/schedule_reorder.cc index b741bb545..26c765403 100644 --- a/tvm/src/schedule/schedule_reorder.cc +++ b/tvm/src/schedule/schedule_reorder.cc @@ -18,17 +18,97 @@ namespace schedule { using namespace ir; +void TraceExternMods(const Array& roots, + const ReadGraph& g, + std::unordered_map>& extern_mods) { + std::unordered_set visited; + std::vector stack; + stack.push_back(roots[0]); + while (!stack.empty()) { + Operation op = stack.back(); + stack.pop_back(); + + CHECK(g.count(op)) << "not found " << op; + if (auto extern_op = op.as()) { + if (extern_op->body.as()) { + // LOG(INFO) << extern_op->body; + for (const auto& t : g.at(op)) { + extern_mods[op].insert(t->op->name); + if (g.count(t->op) && t->op->name.find(".new") == std::string::npos) { + for (auto& pt : g.at(t->op)) + extern_mods[op].insert(pt->op->name); + } + } + } + } + for (const auto& t : g.at(op)) { + if (t->op.defined()) { + if (visited.count(t->op.get()) == 0) { + visited.insert(t->op.get()); + stack.push_back(t->op); + } + } + } + } +} + +// create dfs post ordered attch attr stmt +Stmt AttachScopeReorder(Array& post_order, + std::vector& merged_ops) { + Stmt body; + Stmt no_op = Evaluate::make(0); + CHECK(post_order.size() > 0); + + for (int i = post_order.size() - 1; i >= 0; i--) { + auto& op = post_order[i]; + if (auto extern_op = op.as()) { + Buffer buf = extern_op->output_placeholders[0]; + if (extern_op->name == "_top") { + continue; + } + if (!body.defined()) { + body = AttrStmt::make(VarExpr(buf.node_), + attr::attach_scope, StringImm::make("_top"), no_op); + } else { + body = AttrStmt::make(VarExpr(buf.node_), + attr::attach_scope, StringImm::make("_top"), body); + } + if (extern_op->name.find(".new") != std::string::npos) { + CHECK_GT(i-1, 0) << "wrong op ordering fonud"; + if (post_order[i-1]->name.find(".new") == std::string::npos) { + // LOG(INFO) << "insert attachment before " << extern_op->name; + for (auto& sub_op : merged_ops) { + auto sub_ext_op = sub_op.as(); + Buffer sub_buf = sub_ext_op->output_placeholders[0]; + CHECK(body.defined()); + body = AttrStmt::make(VarExpr(sub_buf.node_), + attr::attach_scope, StringImm::make("_top"), body); + } + } + } + } + } + CHECK(body.defined()); + return body; +} + std::unordered_set ExtractAncestors(Operation root, const ReadGraph& g) { std::vector stack; std::unordered_set visited; std::unordered_set ops; stack.push_back(root); visited.insert(root.get()); - // for (auto& kv : g) { - // LOG(INFO) << "------------"; - // LOG(INFO) << kv.first; - // for (auto& t : kv.second) LOG(INFO) << t->op; - // } + + // print read graph + if (false) { + LOG(INFO) << "---------------------"; + for (auto& kv : g) { + LOG(INFO) << "------------"; + LOG(INFO) << kv.first; + for (auto& t : kv.second) LOG(INFO) << t->op; + } + } while (!stack.empty()) { Operation op = stack.back(); @@ -106,7 +186,8 @@ std::vector ExtractSubGraph( } inputs.push_back(input); outputs.push_back(output); - // LOG(INFO) << input << ":" << output; + CHECK(input.size() > 0) + << "cannot found boundary for output " << output; // GetSubGraph(RemapTensor(sch, output), // RemapTensor(sch, input), true); } @@ -120,6 +201,7 @@ std::vector ExtractSubGraph( visited.insert(t->op.get()); } + CHECK(!stack.empty()); std::unordered_set shared; while (!stack.empty()) { Operation op = stack.back(); @@ -139,7 +221,9 @@ std::vector ExtractSubGraph( if (t->op.defined()) { if (visited.count(t->op.get()) == 0) { visited.insert(t->op.get()); - if (!reach_bound) stack.push_back(t->op); + if (!reach_bound) { + stack.push_back(t->op); + } } else { // visited ancestor shared.insert(t->op.get()); } @@ -163,14 +247,18 @@ std::vector ExtractSubGraph( } } - for (Operation op : output_ops) { - if (auto extern_op = op.as()) { - for (auto& buffer : extern_op->output_placeholders) { - aggregate->output_placeholders.push_back(buffer); - } - } - } + // for (Operation op : output_ops) { + // if (auto extern_op = op.as()) { + // for (auto& buffer : extern_op->output_placeholders) { + // aggregate->output_placeholders.push_back(buffer); + // } + // } + // } + Buffer aggregate_buffer = BufferNode::make(Var(aggregate->name, Handle()), + Int(32), Array(), Array(), Expr(), aggregate->name, "", 0, 0); + aggregate->output_placeholders.push_back(aggregate_buffer); + // rearrange the op in subgraph CHECK(subgraph.size() > 0); std::vector new_subgraph; size_t op_count = 0; @@ -188,26 +276,50 @@ std::vector ExtractSubGraph( } } - std::unordered_map inserted; + // find the updated tensors in extern mod subgraph + // insert the extern mod into aggregate node + std::unordered_map inserted; + std::unordered_map> op2modifed; // for(auto op : new_subgraph) LOG(INFO) << op; - // for(auto kv : atts_map) { - // inserted[kv.first] = false; - // LOG(INFO) << kv.first << ":------------"; - // for (auto& k : kv.second) LOG(INFO) << k; - // } + + // remove unrelated ops + std::unordered_set nodes; + for(auto op : new_subgraph) { + nodes.insert(op->name); + } + for(auto& kv : atts_map) { + inserted[kv.first] = 0; + // LOG(INFO) << kv.first << ":------------"; + // for (auto& k : kv.second) LOG(INFO) << k; + for (auto& v : kv.second) { + if (nodes.find(v) != nodes.end()) { + if (v.find(".new") == std::string::npos) + op2modifed[kv.first].insert(v); + } + } + } + Stmt body = Evaluate::make(0); for (Operation op : new_subgraph) { CHECK(op.as()) << op; if (auto extern_op = op.as()) { + // insert standalone subgraph op + CHECK(extern_op->output_placeholders.size()); + Buffer out_buf = extern_op->output_placeholders[0]; + Stmt attr = AttrStmt::make(VarExpr(out_buf.node_), + "attach_scope", StringImm::make("test"), no_op); + body = Block::make(body, attr); + // check if subgraph op in extern module inputs - bool in_extern_mod = false; - for (auto& kv : atts_map) { + // the extern module acts as upadter of these ops + for (auto& kv : op2modifed) { if (kv.second.count(op->name) && op->name.find(".new") == std::string::npos) { - in_extern_mod = true; - if (!inserted[kv.first]) { - inserted[kv.first] = true; + + inserted[kv.first] += 1; + // insert extern op after dependent stages + if (inserted[kv.first] == (signed)kv.second.size()) { auto mod_op = kv.first.as(); Buffer mod_buf = mod_op->output_placeholders[0]; // LOG(INFO) << "insert " << kv.first << ":" << mod_buf; @@ -217,15 +329,6 @@ std::vector ExtractSubGraph( } } } - - // insert standalone subgraph op - if (!in_extern_mod) { - CHECK(extern_op->output_placeholders.size()); - Buffer out_buf = extern_op->output_placeholders[0]; - Stmt attr = AttrStmt::make(VarExpr(out_buf.node_), - "attach_scope", StringImm::make("test"), no_op); - body = Block::make(body, attr); - } } } @@ -246,6 +349,7 @@ std::vector ExtractSubGraph( } aggregate->body = AttrStmt::make( VarExpr(), attr::device_scope, scope, body); + // LOG(INFO) << aggregate->body; merged_ops.push_back(Operation(aggregate)); return new_subgraph; } @@ -253,20 +357,56 @@ std::vector ExtractSubGraph( static int bound_index = 0; // extract the bounded op arrays from subgraph root +// needed to add to extracted subgrapg ( since subgraph +// does not capture the ops in extern module ) void PostDFSBoundary(const Operation& op, const ReadGraph& g, std::unordered_set* visited, - Array* bounded_ops) { + Array* post_order, + Array* bounded_ops, + std::unordered_map>& extern_mods, + std::unordered_set& sub_ops) { if (visited->count(op)) return; visited->insert(op); - CHECK(op.as()); + // CHECK(op.as()) << op; for (const auto& t : g.at(op)) { - if (op->name.find(".new") == std::string::npos) - PostDFSBoundary(t->op, g, visited, bounded_ops); + PostDFSBoundary(t->op, g, visited, post_order, + bounded_ops, extern_mods, sub_ops); } + + if (op.as()) { + post_order->push_back(op); + return; + } + // record ops before .new ops - bounded_ops->push_back(op); + bool in_ext_mod = false; + if (op.as()->body.as()) in_ext_mod = true; + for (auto& kv : extern_mods) { + // the op required to be a child stage + if ((kv.second.find(op->name) != kv.second.end()) && + // ignore the moved tensor (part of test stage) + (op->name.find(".new") == std::string::npos) && + // should be part of subgraph + (sub_ops.find(op->name) == sub_ops.end())) + in_ext_mod = true; + } + + if (in_ext_mod) + bounded_ops->push_back(op); + + // record ops outside subgraph + if ((sub_ops.find(op->name) == sub_ops.end()) && + (!op.as()->body.as())) { + for (auto& kv : extern_mods) { + // the op should not be a child stage of extern modules + if (kv.second.find(op->name) == kv.second.end()) { + post_order->push_back(op); + } + } + } } // schedule the ops with subgraphs @@ -311,31 +451,23 @@ Array PostDFSSplit( std::unordered_map> extern_mods; - // check the external module for (Operation op : roots) { dev[op.get()] = DeviceType::devHost; - if (auto extern_op = op.as()) { - if (extern_op->body.as()) { - CHECK(g.count(op)) << "not found " << op; - for (const auto& t : g.at(op)) { - extern_mods[op].insert(t->op->name); - // TODO: find a better abstraction - if (g.count(t->op) && t->op->name.find(".new") == std::string::npos) { - for (auto& pt : g.at(t->op)) - extern_mods[op].insert(pt->op->name); - } - } - } - } } + // check the external module + TraceExternMods(roots, g, extern_mods); + for (Stage stage : sch->stages) { + // LOG(INFO) << stage->op + // << ":" << static_cast(stage->device_type); if (dev.count(stage->op.get())) CHECK(dev[stage->op.get()] == DeviceType::devHost) << "output " << stage << " should be placed on host scope"; dev[stage->op.get()] = stage->device_type; - if (stage->device_type != DeviceType::devHost) + if (stage->device_type != DeviceType::devHost) { boundary.insert(boundary.begin(), stage->op); + } } bound_index = 0; @@ -356,19 +488,32 @@ Array PostDFSSplit( Array post_order; Array bounded_ops; for (Operation op : roots) { - if (extern_mods.count(op)) { + if (extern_mods.size() > 0) { // create op array of extern module (from .new to super stage root) - // return inner ops inside extern module (must be bounded by .new ops) + // i.e. inner ops inside extern module (must be bounded by .new ops) + // the result is returned in bounded_ops + bool dev_scope = false; - for (auto& input : extern_mods.at(op)) { - if (input.find(".new") != std::string::npos) - dev_scope = true; + // TODO: consider multiple extern modules + for (auto& kv : extern_mods) { + for (auto& input : kv.second) { + if (input.find(".new") != std::string::npos) + dev_scope = true; + } } if (dev_scope) { + LOG(INFO) << "inputs in device scope"; std::unordered_set visited_ops; - PostDFSBoundary(op, g, &visited_ops, &bounded_ops); - } else { // in host scope, for sim + // extract all stages in extern module for updating logic + std::unordered_set sub_ops; + for (auto& op : subgraph) sub_ops.insert(op->name); + PostDFSBoundary(op, g, &visited_ops, &post_order, + &bounded_ops, extern_mods, sub_ops); + // LOG(INFO) << bounded_ops; + + } else { + LOG(WARNING) << "input tensors of IP core on host scope (sim mode only)"; PostDFSSplit(op, g, &visited, &post_order, dev, subgraph); } } else { @@ -379,11 +524,17 @@ Array PostDFSSplit( } // op array index to insert subgraph - if (bound_index > 0) { + // for (auto& op : subgraph) LOG(INFO) << op; + bool inserted = false; + if (merged_ops.size() > 0) { Array results; for (size_t k = 0; k < post_order.size(); k++) { - // scope switching right after index-th last op - if (k == post_order.size() - (bound_index - 1)) { + // fix: insert right before the first .new + // if (k == post_order.size() - (bound_index - 1)) + auto sname = post_order[k]->name; + if (!inserted && sname.find(".new") != std::string::npos) { + inserted = true; + // LOG(INFO) << "insert beofre " << post_order[k]; if (extern_mods.size() == 0) { for (auto& sub_op : subgraph) @@ -391,32 +542,55 @@ Array PostDFSSplit( for (auto& sub_op : merged_ops) results.push_back(sub_op); - // replace the modfied ops with extern module + // replace the modfied tensor ops with extern module // i.e. ops in the keys of corresponding module } else { + + // missing graph info in extern module CHECK(bounded_ops.size() > 0); for (Operation op : bounded_ops) { results.push_back(op); } + for (auto& sub_op : subgraph) { - bool found_in_module = false; - for (auto& kv : extern_mods) { - if (kv.second.count(sub_op->name)) { - found_in_module = true; - } - } - if (!found_in_module) - results.push_back(sub_op); + results.push_back(sub_op); + // bool found_in_module = false; + // for (auto& kv : extern_mods) { + // if (kv.second.count(sub_op->name)) { + // found_in_module = true; + // } + // } + // if (!found_in_module) + // results.push_back(sub_op); } + for (auto& sub_op : merged_ops) results.push_back(sub_op); } } Operation op = post_order[k]; + + // fix: re-arrange attr stmt inside + if (op->name == "_top") { + Stmt no_op = Evaluate::make(0); + std::shared_ptr new_op = + std::make_shared(); + new_op->name = op->name; + // top op input / output buffers + auto extern_op = op.as(); + CHECK(extern_op) << "invalid _top op node"; + new_op->inputs = std::move(extern_op->inputs); + new_op->input_placeholders = std::move(extern_op->input_placeholders); + new_op->output_placeholders = std::move(extern_op->output_placeholders); + // rearrange attachment scope attr inside _top body + new_op->body = AttachScopeReorder(post_order, merged_ops); + op = Operation(new_op); + } results.push_back(op); } CHECK(results.size() >= sch->stages.size()) - << "schedule op array error " << results; + << "missing ops in result. size " << results.size() << ":" << sch->stages.size() + << results; return results; } @@ -445,6 +619,8 @@ Schedule ScopePartition(const Schedule& sch) { visited.insert(op.get()); } + // for (auto& kv : sch->stage_map) + // LOG(INFO) << kv.first << ":" << kv.second; while (!stack.empty()) { Operation op = stack.back(); stack.pop_back(); @@ -507,6 +683,14 @@ Schedule ScopePartition(const Schedule& sch) { std::make_shared(*s.operator->()); scopy = Stage(snode); smap[s] = scopy; + // replace stage op body for _top + if (scopy->op->name == "_top") { + // LOG(INFO) << scopy; + // LOG(INFO) << op.as()->body; + // LOG(INFO) << scopy->op.as()->body; + scopy = Stage(op); + n->stage_map.Set(op, scopy); + } n->stages.push_back(scopy); } }