diff --git a/hls4ml/templates/pynq/patches/resnet_axi.cpp b/hls4ml/templates/pynq/patches/resnet_axi.cpp deleted file mode 100644 index 4d4495e788..0000000000 --- a/hls4ml/templates/pynq/patches/resnet_axi.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "resnet_axi.h" - -void resnet_axi( - input_axi_t in[N_IN], - output_axi_t out[N_OUT] - ){ - - #pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS - #pragma HLS INTERFACE m_axi depth=N_IN port=in offset=slave bundle=IN_BUS - #pragma HLS INTERFACE m_axi depth=N_OUT port=out offset=slave bundle=OUT_BUS - - unsigned short in_size = 0; - unsigned short out_size = 0; - - hls::stream in_local("input_1"); - hls::stream out_local("output_1"); - - #pragma HLS STREAM variable=in_local depth=N_IN - #pragma HLS STREAM variable=out_local depth=N_OUT - - for(unsigned i = 0; i < N_IN / input_t::size; ++i) { - input_t ctype; - #pragma HLS DATA_PACK variable=ctype - for(unsigned j = 0; j < input_t::size; j++) { - ap_ufixed<16,8> tmp = in[i * input_t::size + j]; // store 8 bit input in a larger temp variable - ap_ufixed<8,0> tmp2 = tmp >> 8; // shift right by 8 (div by 256) and select only the decimal of the larger temp variable - ctype[j] = typename input_t::value_type(tmp2); - } - in_local.write(ctype); - } - - resnet(in_local, out_local, in_size, out_size); - - for(unsigned i = 0; i < N_OUT / layer11_t::size; ++i) { - layer11_t ctype = out_local.read(); - for(unsigned j = 0; j < layer11_t::size; j++) { - out[i * layer11_t::size + j] = output_axi_t(ctype[j]); - } - } -} diff --git a/hls4ml/templates/pynq/patches/resnet_axi.h b/hls4ml/templates/pynq/patches/resnet_axi.h deleted file mode 100644 index d1c2f7b62a..0000000000 --- a/hls4ml/templates/pynq/patches/resnet_axi.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef RESNET_AXI_H_ -#define RESNET_AXI_H_ - -#include "resnet.h" - -static const unsigned N_IN = 3072; -static const unsigned N_OUT = 10; -typedef ap_uint<8> input_axi_t; -typedef ap_fixed<8,6> output_axi_t; - -void resnet_axi( - input_axi_t in[N_IN], - output_axi_t out[N_OUT] - ); -#endif diff --git a/hls4ml/templates/pynq/patches/resnet_bridge.cpp b/hls4ml/templates/pynq/patches/resnet_bridge.cpp deleted file mode 100644 index c1885ea266..0000000000 --- a/hls4ml/templates/pynq/patches/resnet_bridge.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#ifndef RESNET_BRIDGE_H_ -#define RESNET_BRIDGE_H_ - -#include "firmware/resnet.h" -#include "firmware/resnet_axi.h" -#include "firmware/nnet_utils/nnet_helpers.h" -#include -#include - -namespace nnet { - bool trace_enabled = false; - std::map *trace_outputs = NULL; - size_t trace_type_size = sizeof(double); -} - -extern "C" { - -struct trace_data { - const char *name; - void *data; -}; - -void allocate_trace_storage(size_t element_size) { - nnet::trace_enabled = true; - nnet::trace_outputs = new std::map; - nnet::trace_type_size = element_size; -} - -void free_trace_storage() { - for (std::map::iterator i = nnet::trace_outputs->begin(); i != nnet::trace_outputs->end(); i++) { - void *ptr = i->second; - free(ptr); - } - nnet::trace_outputs->clear(); - delete nnet::trace_outputs; - nnet::trace_outputs = NULL; - nnet::trace_enabled = false; -} - -void collect_trace_output(struct trace_data *c_trace_outputs) { - int ii = 0; - for (std::map::iterator i = nnet::trace_outputs->begin(); i != nnet::trace_outputs->end(); i++) { - c_trace_outputs[ii].name = i->first.c_str(); - c_trace_outputs[ii].data = i->second; - ii++; - } -} - -// Wrapper of top level function for Python bridge -void resnet_float( - float input_1[N_INPUT_1_1*N_INPUT_2_1*N_INPUT_3_1], - float layer11_out[N_LAYER_11], - unsigned short &const_size_in_1, - unsigned short &const_size_out_1 -) { - - input_axi_t input_1_ap[N_INPUT_1_1*N_INPUT_2_1*N_INPUT_3_1]; - nnet::convert_data(input_1, input_1_ap); - - output_axi_t layer11_out_ap[N_LAYER_11]; - - resnet_axi(input_1_ap, layer11_out_ap); - - nnet::convert_data(layer11_out_ap, layer11_out); -} - -void resnet_double( - double input_1[N_INPUT_1_1*N_INPUT_2_1*N_INPUT_3_1], - double layer11_out[N_LAYER_11], - unsigned short &const_size_in_1, - unsigned short &const_size_out_1 -) { - input_axi_t input_1_ap[N_INPUT_1_1*N_INPUT_2_1*N_INPUT_3_1]; - nnet::convert_data(input_1, input_1_ap); - - output_axi_t layer11_out_ap[N_LAYER_11]; - - resnet_axi(input_1_ap, layer11_out_ap); - - nnet::convert_data(layer11_out_ap, layer11_out); -} - -} - -#endif diff --git a/hls4ml/templates/pynq/patches/resnet_test.cpp b/hls4ml/templates/pynq/patches/resnet_test.cpp deleted file mode 100644 index a6f735e698..0000000000 --- a/hls4ml/templates/pynq/patches/resnet_test.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "firmware/resnet_axi.h" -#include "firmware/nnet_utils/nnet_helpers.h" - -#define CHECKPOINT 5000 - -namespace nnet { - bool trace_enabled = true; - std::map *trace_outputs = NULL; - size_t trace_type_size = sizeof(double); -} - -int main(int argc, char **argv) -{ - //load input data from text file - std::ifstream fin("tb_data/tb_input_features.dat"); - //load predictions from text file - std::ifstream fpr("tb_data/tb_output_predictions.dat"); - -#ifdef RTL_SIM - std::string RESULTS_LOG = "tb_data/rtl_cosim_results.log"; -#else - std::string RESULTS_LOG = "tb_data/csim_results.log"; -#endif - std::ofstream fout(RESULTS_LOG); - - std::string iline; - std::string pline; - int e = 0; - - if (fin.is_open() && fpr.is_open()) { - while ( std::getline(fin,iline) && std::getline (fpr,pline) ) { - if (e % CHECKPOINT == 0) std::cout << "Processing input " << e << std::endl; - char* cstr=const_cast(iline.c_str()); - char* current; - std::vector in; - current=strtok(cstr," "); - while(current!=NULL) { - in.push_back(atof(current)); - current=strtok(NULL," "); - } - cstr=const_cast(pline.c_str()); - std::vector pr; - current=strtok(cstr," "); - while(current!=NULL) { - pr.push_back(atof(current)); - current=strtok(NULL," "); - } - - - //hls-fpga-machine-learning insert data - input_axi_t inputs[N_IN]; - nnet::copy_data, 0, N_IN>(in, inputs); // Copy floating point values in the temp array - - output_axi_t outputs[N_OUT]; - - //hls-fpga-machine-learning insert top-level-function - resnet_axi(inputs,outputs); - - if (e % CHECKPOINT == 0) { - std::cout << "Predictions" << std::endl; - //hls-fpga-machine-learning insert predictions - for(int i = 0; i < N_OUT; i++) { - std::cout << pr[i] << " "; - } - std::cout << std::endl; - std::cout << "Quantized predictions" << std::endl; - //hls-fpga-machine-learning insert quantized - nnet::print_result(outputs, std::cout, true); - } - e++; - - //hls-fpga-machine-learning insert tb-output - nnet::print_result(outputs, fout); - - } - fin.close(); - fpr.close(); - } else { - std::cout << "INFO: Unable to open input/predictions file, using default input." << std::endl; - //hls-fpga-machine-learning insert zero - input_axi_t inputs[N_IN]; - nnet::fill_zero(inputs); - - //hls-fpga-machine-learning insert top-level-function - output_axi_t outputs[N_OUT]; - resnet_axi(inputs,outputs); - - //hls-fpga-machine-learning insert output - nnet::print_result(outputs, std::cout, true); - - //hls-fpga-machine-learning insert tb-output - nnet::print_result(outputs, fout); - } - - fout.close(); - std::cout << "INFO: Saved inference results to file: " << RESULTS_LOG << std::endl; - - return 0; -} diff --git a/hls4ml/writer/pynq_writer.py b/hls4ml/writer/pynq_writer.py index dd0acd4411..a58472bd23 100644 --- a/hls4ml/writer/pynq_writer.py +++ b/hls4ml/writer/pynq_writer.py @@ -35,7 +35,7 @@ def write_axi_wrapper(self, model): inp = model_inputs[0] out = model_outputs[0] inp_axi_t = self.next_axi_type(inp.type.precision) - out_axi_t = self.next_axi_type(inp.type.precision) + out_axi_t = self.next_axi_type(out.type.precision) indent = ' ' @@ -158,7 +158,7 @@ def write_axi_wrapper(self, model): def modify_build_script(self, model): ''' - Modify the build_prj.tcl script to add the extra wrapper files and set the top function + Modify the build_prj.tcl and build_lib.sh scripts to add the extra wrapper files and set the top function ''' filedir = os.path.dirname(os.path.abspath(__file__)) oldfile = '{}/build_prj.tcl'.format(model.config.get_output_dir()) @@ -179,79 +179,141 @@ def modify_build_script(self, model): fout.close() os.rename(newfile, oldfile) - def write_board_script(self, model): + ################### + # build_lib.sh + ################### + + f = open(os.path.join(filedir,'../templates/pynq/build_lib.sh'),'r') + fout = open('{}/build_lib.sh'.format(model.config.get_output_dir()),'w') + + for line in f.readlines(): + line = line.replace('myproject', model.config.get_project_name()) + line = line.replace('mystamp', model.config.get_config_value('Stamp')) + + fout.write(line) + f.close() + fout.close() + + def apply_patches(self, model): ''' - Write the tcl scripts to create a Vivado IPI project for the Pynq + Apply patches. ''' filedir = os.path.dirname(os.path.abspath(__file__)) - copyfile(os.path.join(filedir,'../templates/pynq/pynq_design.tcl'), '{}/pynq_design.tcl'.format(model.config.get_output_dir())) - f = open('{}/project.tcl'.format(model.config.get_output_dir()),'w') - f.write('variable myproject\n') - f.write('set myproject "{}"\n'.format(model.config.get_project_name())) - - def write_build_script(self, model): - ################### - # build_prj.tcl - ################### + indent = ' ' - filedir = os.path.dirname(os.path.abspath(__file__)) + ################### + # patch myproject_axi.h + ################### + oldfile = '{}/firmware/{}_axi.h'.format(model.config.get_output_dir(), model.config.get_project_name()) + newfile = '{}/firmware/{}_axi_patch.h'.format(model.config.get_output_dir(), model.config.get_project_name()) - f = open(os.path.join(filedir,'../templates/vivado/build_prj.tcl'),'r') - fout = open('{}/build_prj.tcl'.format(model.config.get_output_dir()),'w') + f = open(oldfile,'r') + fout = open(newfile, 'w') for line in f.readlines(): + if 'typedef' in line and 'input_axi_t;' in line: + # hardcoded ap_uint<8> input + newline = 'typedef ap_uint<8> input_axi_t;\n' + else: + newline = line + fout.write(newline) - line = line.replace('myproject',model.config.get_project_name()) + f.close() + fout.close() + os.rename(newfile, oldfile) - if 'set_part {xcku115-flvb2104-2-i}' in line: - line = 'set_part {{{}}}\n'.format(model.config.get_config_value('XilinxPart')) - elif 'create_clock -period 5 -name default' in line: - line = 'create_clock -period {} -name default\n'.format(model.config.get_config_value('ClockPeriod')) + ################### + # patch myproject_axi.cpp + ################### + oldfile = '{}/firmware/{}_axi.cpp'.format(model.config.get_output_dir(), model.config.get_project_name()) + newfile = '{}/firmware/{}_axi_patch.cpp'.format(model.config.get_output_dir(), model.config.get_project_name()) + + f = open(oldfile,'r') + fout = open(newfile, 'w') + + for line in f.readlines(): + if 'ctype[j] = typename input_t::value_type' in line: + # these lines are hardcoded to do the bitshift by 256 + newline = indent + indent + indent + 'ap_ufixed<16,8> tmp = in[i * input_t::size + j]; // store 8 bit input in a larger temp variable\n' + newline += indent + indent + indent + 'ctype[j] = typename input_t::value_type(tmp >> 8); // shift right by 8 (div by 256) and select only the decimal of the larger temp variable\n' + else: + newline = line + fout.write(newline) - fout.write(line) f.close() fout.close() - + os.rename(newfile, oldfile) ################### - # vivado_synth.tcl + # patch myproject_test.cpp ################### + oldfile = '{}/{}_test.cpp'.format(model.config.get_output_dir(), model.config.get_project_name()) + newfile = '{}/{}_test_patch.cpp'.format(model.config.get_output_dir(), model.config.get_project_name()) + + f = open(oldfile,'r') + fout = open(newfile, 'w') + + inp = model.get_input_variables()[0] + out = model.get_output_variables()[0] - f = open(os.path.join(filedir,'../templates/vivado/vivado_synth.tcl'),'r') - fout = open('{}/vivado_synth.tcl'.format(model.config.get_output_dir()),'w') for line in f.readlines(): - line = line.replace('myproject', model.config.get_project_name()) - if '-part' in line: - line = 'synth_design -top {} -part {}\n'.format(model.config.get_project_name(), model.config.get_config_value('XilinxPart')) + if '{}.h'.format(model.config.get_project_name()) in line: + newline = line.replace('{}.h'.format(model.config.get_project_name()), '{}_axi.h'.format(model.config.get_project_name())) + elif self.variable_definition_cpp(model, inp) in line: + newline = line.replace(self.variable_definition_cpp(model, inp), 'input_axi_t inputs[N_IN]') + elif self.variable_definition_cpp(model, out) in line: + newline = line.replace(self.variable_definition_cpp(model, out), 'output_axi_t outputs[N_OUT]') + elif 'unsigned short' in line: + newline = '' + elif '{}('.format(model.config.get_project_name()) in line: + indent_amount = line.split(model.config.get_project_name())[0] + newline = indent_amount + '{}_axi(inputs,outputs);\n'.format(model.config.get_project_name()) + elif inp.size_cpp() in line or inp.cppname in line or inp.type.name in line: + newline = line.replace(inp.size_cpp(),'N_IN').replace(inp.cppname, 'inputs').replace(inp.type.name, 'input_axi_t') + elif out.size_cpp() in line or out.cppname in line or out.type.name in line: + newline = line.replace(out.size_cpp(),'N_OUT').replace(out.cppname, 'outputs').replace(out.type.name, 'output_axi_t') + else: + newline = line + fout.write(newline) - fout.write(line) f.close() fout.close() + os.rename(newfile, oldfile) ################### - # build_lib.sh + # patch myproject_bridge.cpp ################### + oldfile = '{}/{}_bridge.cpp'.format(model.config.get_output_dir(), model.config.get_project_name()) + newfile = '{}/{}_bridge_patch.cpp'.format(model.config.get_output_dir(), model.config.get_project_name()) - f = open(os.path.join(filedir,'../templates/pynq/build_lib.sh'),'r') - fout = open('{}/build_lib.sh'.format(model.config.get_output_dir()),'w') + f = open(oldfile,'r') + fout = open(newfile, 'w') + + inp = model.get_input_variables()[0] + out = model.get_output_variables()[0] for line in f.readlines(): - line = line.replace('myproject', model.config.get_project_name()) - line = line.replace('mystamp', model.config.get_config_value('Stamp')) + if '{}.h'.format(model.config.get_project_name()) in line: + newline = line.replace('{}.h'.format(model.config.get_project_name()), '{}_axi.h'.format(model.config.get_project_name())) + elif self.variable_definition_cpp(model, inp, name_suffix='_ap') in line: + newline = line.replace(self.variable_definition_cpp(model, inp, name_suffix='_ap'), 'input_axi_t {}_ap[N_IN]'.format(inp.cppname)) + elif self.variable_definition_cpp(model, out, name_suffix='_ap') in line: + newline = line.replace(self.variable_definition_cpp(model, out, name_suffix='_ap'), 'output_axi_t {}_ap[N_OUT]'.format(out.cppname)) + elif '{}('.format(model.config.get_project_name()) in line: + indent_amount = line.split(model.config.get_project_name())[0] + newline = indent_amount + '{}_axi({}_ap,{}_ap);\n'.format(model.config.get_project_name(), inp.cppname,out.cppname) + elif inp.size_cpp() in line or inp.cppname in line or inp.type.name in line: + newline = line.replace(inp.size_cpp(),'N_IN').replace(inp.type.name, 'input_axi_t') + elif out.size_cpp() in line or out.cppname in line or out.type.name in line: + newline = line.replace(out.size_cpp(),'N_OUT').replace(out.type.name, 'output_axi_t') + else: + newline = line + fout.write(newline) - fout.write(line) f.close() fout.close() - - def apply_patches(self, model): - - filedir = os.path.dirname(os.path.abspath(__file__)) - if model.config.get_project_name() == 'resnet': - copyfile(os.path.join(filedir,'../templates/pynq/patches/resnet_axi.h'), '{}/firmware/resnet_axi.h'.format(model.config.get_output_dir())) - copyfile(os.path.join(filedir,'../templates/pynq/patches/resnet_axi.cpp'), '{}/firmware/resnet_axi.cpp'.format(model.config.get_output_dir())) - copyfile(os.path.join(filedir,'../templates/pynq/patches/resnet_bridge.cpp'), '{}/resnet_bridge.cpp'.format(model.config.get_output_dir())) - copyfile(os.path.join(filedir,'../templates/pynq/patches/resnet_test.cpp'), '{}/resnet_test.cpp'.format(model.config.get_output_dir())) + os.rename(newfile, oldfile) def write_hls(self, model): ''' @@ -260,6 +322,6 @@ def write_hls(self, model): super(PynqWriter, self).write_hls(model) self.write_axi_wrapper(model) self.modify_build_script(model) - self.write_board_script(model) - self.apply_patches(model) + if model.config.get_config_value('ApplyPatches'): + self.apply_patches(model)