-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Backend][hlib] External IPs Integration Support for HeteroCL (#170)
* init * extern module ir * decorator * demo * codegen fix * rtl ip * fft sim works * verify ir * fix test * update interface * clean up * remove blackbox * change name * remove config.ini
- Loading branch information
Showing
44 changed files
with
821 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,3 +22,5 @@ out | |
*.params | ||
*.gz | ||
|
||
# Generated files | ||
project |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from . import op | ||
from . import ip | ||
from . import frontend | ||
from . import utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .fft import single_fft_hls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import heterocl as hcl | ||
import numpy as np | ||
from heterocl import util | ||
from heterocl.tvm import make as _make | ||
from heterocl.tvm import stmt as _stmt | ||
from heterocl.tvm import ir_pass as _pass | ||
from heterocl.tvm._api_internal import _ExternOp | ||
from heterocl.schedule import Schedule, Stage | ||
from heterocl.mutator import Mutator | ||
from collections import OrderedDict | ||
import os | ||
from hlib.op.extern import * | ||
|
||
dtype = hcl.Int() | ||
|
||
@register_extern_ip(vendor="xilinx") | ||
def single_fft_hls(X_real, X_imag, F_real=None, F_imag=None, name=None): | ||
|
||
if name is None: name = "hls::fft<config>" | ||
L = X_real.shape[0] | ||
assert X_real.shape == X_imag.shape | ||
assert np.log2(L) % 1 == 0, "length must be power of 2: " + str(L) | ||
|
||
# functional behavior | ||
with hcl.Stage("ExternModule") as Module: | ||
num_stages = int(np.log2(L)) | ||
bit_width = int(np.log2(L)) | ||
IndexTable = np.zeros((L), dtype='int') | ||
for i in range(L): | ||
b = '{:0{width}b}'.format(i, width=bit_width) | ||
IndexTable[i] = int(b[::-1], 2) | ||
|
||
return_tensors = False | ||
Table = hcl.copy(IndexTable, "table", dtype=hcl.Int()) | ||
if (F_real is None) and (F_imag is None): | ||
return_tensors = True | ||
F_real = hcl.compute((L,), lambda i: X_real[Table[i]], name='F_real') | ||
F_imag = hcl.compute((L,), lambda i: X_imag[Table[i]], name='F_imag') | ||
else: # use passed-in tensors | ||
hcl.update(F_real, lambda i: X_real[Table[i]], name='F_real_update') | ||
hcl.update(F_imag, lambda i: X_imag[Table[i]], name='F_imag_update') | ||
|
||
with hcl.Stage("Out"): | ||
one = hcl.scalar(1, dtype="int32") | ||
with hcl.for_(0, num_stages) as stage: | ||
DFTpts = one[0] << (stage + 1) | ||
numBF = DFTpts / 2 | ||
e = -2 * np.pi / DFTpts | ||
a = hcl.scalar(0) | ||
with hcl.for_(0, numBF) as j: | ||
c = hcl.scalar(hcl.cos(a[0])) | ||
s = hcl.scalar(hcl.sin(a[0])) | ||
a[0] = a[0] + e | ||
with hcl.for_(j, L + DFTpts - 1, DFTpts) as i: | ||
i_lower = i + numBF | ||
temp_r = hcl.scalar(F_real[i_lower] * c - F_imag[i_lower] * s) | ||
temp_i = hcl.scalar(F_imag[i_lower] * c + F_real[i_lower] * s) | ||
F_real[i_lower] = F_real[i] - temp_r[0] | ||
F_imag[i_lower] = F_imag[i] - temp_i[0] | ||
F_real[i] = F_real[i] + temp_r[0] | ||
F_imag[i] = F_imag[i] + temp_i[0] | ||
|
||
dicts = {} | ||
dicts["name"] = name | ||
tensors = [X_real, X_imag, F_real, F_imag] | ||
dicts["args"] = [(_.name, _.dtype) for _ in tensors] | ||
|
||
# declare headers and typedef | ||
dicts["header"] = """ | ||
#include \"hls_fft.h\" | ||
#include <complex> | ||
struct config : hls::ip_fft::params_t { | ||
static const unsigned ordering_opt = hls::ip_fft::natural_order; | ||
static const unsigned config_width = 16; // FFT_CONFIG_WIDTH | ||
}; | ||
typedef std::complex<ap_fixed<16,1>> fxpComplex; | ||
""" | ||
# extern ip function | ||
dicts["func"] = """ | ||
hls::ip_fft::config_t<config> fft_config; | ||
hls::ip_fft::config_t<config> fft_status; | ||
fft_config.setDir(0); | ||
fft_config.setSch(0x2AB); | ||
complex<ap_fixed<16,1>> xn[{}]; | ||
complex<ap_fixed<16,1>> xk[{}]; | ||
for (int i = 0; i < {}; i++) | ||
xn[i] = fxpComplex({}[i], {}[i]); | ||
hls::fft<config>(xn, xk, &fft_config, &fft_status); | ||
for (int i = 0; i < {}; i++) {{ | ||
{}[i] = xk.real(); | ||
{}[i] = xk.imag(); | ||
}} | ||
""".format(L, L, L, X_real.name, X_imag.name, | ||
L, F_real.name, F_imag.name) | ||
|
||
create_extern_module(Module, dicts, ip_type="hls") | ||
if return_tensors: return F_real, F_imag | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from . import math | ||
from . import nn | ||
from . import op | ||
from . import extern |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import heterocl as hcl | ||
import numpy as np | ||
from heterocl import util | ||
from heterocl.tvm import make as _make | ||
from heterocl.tvm import stmt as _stmt | ||
from heterocl.tvm import ir_pass as _pass | ||
from heterocl.tvm._api_internal import _ExternOp | ||
from heterocl.schedule import Schedule, Stage | ||
from heterocl.mutator import Mutator | ||
from collections import OrderedDict | ||
import os | ||
|
||
dtype = hcl.Int() | ||
|
||
class ModuleMarker(Mutator): | ||
""" create extern module used at inner-loop level""" | ||
def __init__(self, axis, info, args): | ||
self.axis = axis | ||
self.info = info | ||
self.args = args | ||
self.count = 0 | ||
self.range_ = {} | ||
self.index_map = {} | ||
|
||
def record_index(self, var, index): | ||
for key, value in self.range_.items(): | ||
if value == 1: | ||
sub = {key : 0} | ||
index = _pass.Simplify(_pass.Substitute(index, sub)) | ||
self.index_map[var] = index | ||
|
||
def mutate_Load(self, node): | ||
buffer_var = self.mutate(node.buffer_var) | ||
index = self.mutate(node.index) | ||
index = util.CastRemover().mutate(index) | ||
self.record_index(buffer_var, index) | ||
|
||
predicate = self.mutate(node.predicate) | ||
return _make.Load(node.dtype, buffer_var, index, predicate) | ||
|
||
def mutate_Store(self, node): | ||
buffer_var = self.mutate(node.buffer_var) | ||
index = self.mutate(node.index) | ||
index = util.CastRemover().mutate(index) | ||
self.record_index(buffer_var, index) | ||
|
||
value = self.mutate(node.value) | ||
predicate = self.mutate(node.predicate) | ||
return _make.Store(buffer_var, value, index, predicate) | ||
|
||
def mutate_For(self, node): | ||
self.count += 1 | ||
loop_var = self.mutate(node.loop_var) | ||
_min = self.mutate(node.min) | ||
extent = self.mutate(node.extent) | ||
self.range_[loop_var] = extent.value - _min.value | ||
body = self.mutate(node.body) | ||
|
||
if (self.count == self.axis): | ||
self.count = 0 | ||
if isinstance(body, _stmt.AttrStmt): | ||
body = body.body | ||
# insert index map | ||
index_map = { k.name : v for k, v in self.index_map.items() } | ||
for i in range(len(self.args)): | ||
self.info["index" + str(i)] = str(index_map[self.args[i]]) | ||
|
||
body = _make.ExternModule( | ||
"rtl", | ||
_make.StringImm("test"), body, | ||
list(self.info.keys()), list(self.info.values())) | ||
|
||
return _make.For(loop_var, _min, extent, node.for_type, node.device_api, body) | ||
|
||
|
||
def register_extern_ip(**attrs): | ||
def with_attrs(f): | ||
for k,v in attrs.items(): | ||
setattr(f, k, v) | ||
return f | ||
return with_attrs | ||
|
||
|
||
# create hls ip invoked within the top function | ||
def create_hls_ip(op, name, args, ip_type="hls", path=None): | ||
# must be called within a superstage | ||
assert Stage._current | ||
curr = Schedule.last_stages[-1] | ||
input_ops = [i._op for i in curr.input_stages] | ||
output_bufs = [curr._buf] | ||
|
||
|
||
# include external ip files | ||
def create_extern_module(op, dicts, ip_type="hls", path=None): | ||
curr = Schedule.last_stages[-1] | ||
input_ops = [i._op for i in curr.input_stages] | ||
input_bufs = [i._buf for i in curr.input_stages] | ||
output_bufs = [curr._buf] | ||
|
||
# input and output arguments | ||
assert "args" in dicts.keys() | ||
annotate_dict = dicts | ||
for name, dtype in dicts["args"]: | ||
annotate_dict["input::" + name] = dtype | ||
del annotate_dict["args"] | ||
|
||
op = op._op.op | ||
assert ip_type in ["rtl", "hls", "host"] | ||
body = _make.ExternModule( | ||
"top", _make.StringImm(ip_type), op.body, | ||
list(annotate_dict.keys()), list(annotate_dict.values())) | ||
|
||
new_op = _ExternOp( | ||
op.name, op.tag, op.axis, | ||
input_ops, input_bufs, output_bufs, body) | ||
curr._op = new_op.output(0) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import heterocl as hcl | ||
import numpy as np | ||
import numpy.testing as tst | ||
import hlib | ||
import os | ||
from itertools import permutations | ||
|
||
dtype = hcl.Float(64) | ||
|
||
def test_fft_hls(): | ||
|
||
def _test_llvm(length): | ||
hcl.init(hcl.Float()) | ||
X_real = hcl.placeholder((length,), name="X_real") | ||
X_imag = hcl.placeholder((length,), name="X_imag") | ||
|
||
def math_func(A, B): | ||
return hlib.ip.single_fft_hls(A, B) | ||
|
||
s = hcl.create_schedule([X_real, X_imag], math_func) | ||
f = hcl.build(s) | ||
|
||
x_real_np = np.random.random((length)) | ||
x_imag_np = np.random.random((length)) | ||
x_np = x_real_np + 1j * x_imag_np | ||
|
||
out_np = np.fft.fft(x_np) | ||
out_real_np = out_np.real | ||
out_imag_np = out_np.imag | ||
|
||
x_real_hcl = hcl.asarray(x_real_np) | ||
x_imag_hcl = hcl.asarray(x_imag_np) | ||
|
||
out_real_hcl = hcl.asarray(np.zeros((length))) | ||
out_imag_hcl = hcl.asarray(np.zeros((length))) | ||
|
||
f(x_real_hcl, x_imag_hcl, out_real_hcl, out_imag_hcl) | ||
|
||
np.testing.assert_allclose(out_real_np, out_real_hcl.asnumpy(), rtol=1e-02, atol=1e-3) | ||
np.testing.assert_allclose(out_imag_np, out_imag_hcl.asnumpy(), rtol=1e-02, atol=1e-3) | ||
|
||
_test_llvm(32) | ||
_test_llvm(512) | ||
_test_llvm(1024) | ||
|
||
def _test_sim(length): | ||
hcl.init(hcl.Float()) | ||
X_real = hcl.placeholder((length,), name="X_real") | ||
X_imag = hcl.placeholder((length,), name="X_imag") | ||
|
||
def math_func(A, B): | ||
real, imag = hlib.ip.single_fft_hls(A, B) | ||
return hcl.compute((length,), lambda x: | ||
hcl.sqrt(real[x] * real[x] + imag[x] * imag[x]), name="abs") | ||
|
||
s = hcl.create_schedule([X_real, X_imag], math_func) | ||
target = hcl.platform.aws_f1 | ||
target.config(compile="vitis", backend="vhls") | ||
s.to([X_real, X_imag], target.xcel) | ||
s.to(math_func.abs, target.host) | ||
ir = str(hcl.lower(s)) | ||
pattern = "test({}.channel, {}.channel, abs.channel)" | ||
combination = [ pattern.format(*_) | ||
for _ in list(permutations(["X_real", "X_imag"])) ] | ||
assert any([_ in ir for _ in combination]) | ||
# f = hcl.build(s, target) | ||
|
||
# x_real_np = np.random.random((length)) | ||
# x_imag_np = np.random.random((length)) | ||
# x_np = x_real_np + 1j * x_imag_np | ||
# | ||
# out_np = np.fft.fft(x_np) | ||
# out_real_np = out_np.real | ||
# out_imag_np = out_np.imag | ||
# | ||
# x_real_hcl = hcl.asarray(x_real_np) | ||
# x_imag_hcl = hcl.asarray(x_imag_np) | ||
# | ||
# out_real_hcl = hcl.asarray(np.zeros((length))) | ||
# out_imag_hcl = hcl.asarray(np.zeros((length))) | ||
|
||
# f(x_real_hcl, x_imag_hcl, out_real_hcl, out_imag_hcl) | ||
|
||
# np.testing.assert_allclose(out_real_np, out_real_hcl.asnumpy(), rtol=1e-02, atol=1e-3) | ||
# np.testing.assert_allclose(out_imag_np, out_imag_hcl.asnumpy(), rtol=1e-02, atol=1e-3) | ||
|
||
_test_sim(32) | ||
_test_sim(512) | ||
_test_sim(1024) | ||
|
||
if __name__ == '__main__': | ||
test_fft_hls() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.