diff --git a/finn-rtllib/thresholding/hdl/thresholding_axi.sv b/finn-rtllib/thresholding/hdl/thresholding_axi.sv index 5c7182b214..39756e5c2b 100644 --- a/finn-rtllib/thresholding/hdl/thresholding_axi.sv +++ b/finn-rtllib/thresholding/hdl/thresholding_axi.sv @@ -39,8 +39,9 @@ *****************************************************************************/ module thresholding_axi #( - int unsigned N, // output precision - int unsigned K, // input/threshold precision + int unsigned N, // output precision + int unsigned WI, // input precision + int unsigned WT, // threshold precision int unsigned C = 1, // Channels int unsigned PE = 1, // Processing Parallelism, requires C = k*PE @@ -96,7 +97,7 @@ module thresholding_axi #( //- AXI Stream - Input -------------- output logic s_axis_tready, input logic s_axis_tvalid, - input logic [((PE*K+7)/8)*8-1:0] s_axis_tdata, + input logic [((PE*WI+7)/8)*8-1:0] s_axis_tdata, //- AXI Stream - Output ------------- input logic m_axis_tready, @@ -109,13 +110,13 @@ module thresholding_axi #( uwire cfg_en; uwire cfg_we; uwire [ADDR_BITS-3:0] cfg_a; - uwire [K -1:0] cfg_d; + uwire [WT -1:0] cfg_d; uwire cfg_rack; - uwire [K -1:0] cfg_q; + uwire [WT -1:0] cfg_q; if(USE_AXILITE) begin uwire [ADDR_BITS-1:0] cfg_a0; - axi4lite_if #(.ADDR_WIDTH(ADDR_BITS), .DATA_WIDTH(32), .IP_DATA_WIDTH(K)) axi ( + axi4lite_if #(.ADDR_WIDTH(ADDR_BITS), .DATA_WIDTH(32), .IP_DATA_WIDTH(WT)) axi ( .aclk(ap_clk), .aresetn(ap_rst_n), .awready(s_axilite_AWREADY), .awvalid(s_axilite_AWVALID), .awaddr(s_axilite_AWADDR), .awprot('x), @@ -143,10 +144,42 @@ module thresholding_axi #( assign cfg_d = 'x; end + //----------------------------------------------------------------------- + // Cast Inputs into Threshold Data Type + uwire [PE-1:0][WT-1:0] idat; + for(genvar pe = 0; pe < PE; pe++) begin + if(WT == WI) begin : genCopy + assign idat[pe] = s_axis_tdata[pe*WI+:WI]; + end : genCopy + else begin + initial begin + if(FPARG) begin + $error("%m: Can't cast floating-point type."); + $finish; + end + end + + if(WT > WI) begin : genWiden + assign idat[pe] = { {(WT-WI){SIGNED? s_axis_tdata[(pe+1)*WI-1] : 1'b0}}, s_axis_tdata[pe*WI+:WI] }; + end : genWiden + else begin : genNarrow + // Saturate for clipping inputs + if(!SIGNED) begin + assign idat[pe] = |s_axis_tdata[pe*WI+WT+:WI-WT]? '1 : s_axis_tdata[pe*WI+:WT]; + end + else begin + assign idat[pe] = + (s_axis_tdata[pe*WI+WT+:WI-WT] == '1) || (s_axis_tdata[pe*WI+WT+:WI-WT] == '0)? s_axis_tdata[pe*WI+:WT] : + {s_axis_tdata[(pe+1)*WI-1], {(WT-1){!s_axis_tdata[(pe+1)*WI-1]}}}; + end + end : genNarrow + end + end + //----------------------------------------------------------------------- // Kernel Implementation thresholding #( - .N(N), .K(K), .C(C), .PE(PE), + .N(N), .K(WT), .C(C), .PE(PE), .SIGNED(SIGNED), .FPARG(FPARG), .BIAS(BIAS), .THRESHOLDS_PATH(THRESHOLDS_PATH), .USE_CONFIG(USE_AXILITE), .DEPTH_TRIGGER_URAM(DEPTH_TRIGGER_URAM), .DEPTH_TRIGGER_BRAM(DEPTH_TRIGGER_BRAM), @@ -157,7 +190,7 @@ module thresholding_axi #( .cfg_en, .cfg_we, .cfg_a, .cfg_d, .cfg_rack, .cfg_q, - .irdy(s_axis_tready), .ivld(s_axis_tvalid), .idat(s_axis_tdata), + .irdy(s_axis_tready), .ivld(s_axis_tvalid), .idat, .ordy(m_axis_tready), .ovld(m_axis_tvalid), .odat(m_axis_tdata) ); diff --git a/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v b/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v index f35db156f6..49a1f2bd8b 100644 --- a/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v +++ b/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v @@ -33,8 +33,9 @@ */ module $MODULE_NAME_AXI_WRAPPER$ #( - parameter N = $N$, // output precision - parameter K = $M$, // input/threshold precision + parameter N = $N$, // output precision + parameter WI = $WI$, // input precision + parameter WT = $WT$, // threshold precision parameter C = $C$, // Channels parameter PE = $PE$, // Processing Parallelism, requires C = k*PE @@ -87,7 +88,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #( //- AXI Stream - Input -------------- output in0_V_TREADY, input in0_V_TVALID, - input [((PE*K+7)/8)*8-1:0] in0_V_TDATA, + input [((PE*WI+7)/8)*8-1:0] in0_V_TDATA, //- AXI Stream - Output ------------- input out_V_TREADY, @@ -96,7 +97,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #( ); thresholding_axi #( - .N(N), .K(K), .C(C), .PE(PE), + .N(N), .WI(WI), .WT(WT), .C(C), .PE(PE), .SIGNED(SIGNED), .FPARG(FPARG), .BIAS(BIAS), diff --git a/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv b/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv index 429fb7776f..cfd875f5c4 100644 --- a/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv +++ b/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv @@ -110,7 +110,7 @@ module thresholding_axi_tb #( uwire ovld; uwire [PE-1:0][N-1:0] odat; - thresholding_axi #(.N(N), .K(K), .C(C), .PE(PE), .SIGNED(0), .USE_AXILITE(1)) dut ( + thresholding_axi #(.N(N), .WI(K), .WT(K), .C(C), .PE(PE), .SIGNED(0), .USE_AXILITE(1)) dut ( .ap_clk(clk), .ap_rst_n(!rst), // Configuration diff --git a/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py index ec875858ff..9584c3ae5f 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py @@ -180,15 +180,15 @@ def prepare_codegen_rtl_values(self, model): # Additionally, increase number of threshold steps to reflect new shape expected_thresholds = 2**o_bitwidth - 1 n_thres_steps = self.get_nodeattr("numSteps") + wdt = self.get_weight_datatype() if expected_thresholds != n_thres_steps: - min_val = DataType[input_data_type].min() + min_val = wdt.min() thresholds = np.insert(thresholds, 0, min_val, axis=1) bias = bias - 1 n_thres_steps += 1 # add dummy dimension as final dimension (that's what gets packed with next call) t_expand = np.expand_dims(thresholds, axis=-1) - wdt = self.get_weight_datatype() bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 4) t_packed = pack_innermost_dim_as_hex_string( t_expand, @@ -242,9 +242,10 @@ def prepare_codegen_rtl_values(self, model): i_bitwidth = DataType[input_data_type].bitwidth() code_gen_dict["$N$"] = [str(o_bitwidth)] # output precision - convert bitwidth to string - code_gen_dict["$M$"] = [ - str(i_bitwidth) - ] # input/threshold precision - convert bitwidth to string + code_gen_dict["$WT$"] = [ + str(wdt.bitwidth()) + ] # threshold precision - convert bitwidth to string + code_gen_dict["$WI$"] = [str(i_bitwidth)] # input precision - convert bitwidth to string code_gen_dict["$C$"] = [str(num_channels)] # number of channels code_gen_dict["$BIAS$"] = [str(bias)] # activation bias value code_gen_dict["$PE$"] = [str(pe)] # requires C = M*PE diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 897d714bf8..e14181b140 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -199,10 +199,16 @@ def apply(self, model): thl_in_shape = model.get_tensor_shape(thl_input) thl_thres_shape = model.get_tensor_shape(thl_threshold) idt = model.get_tensor_datatype(thl_input) - + tdt = model.get_tensor_datatype(thl_threshold) # skip conversion for layers with float input if not idt.is_integer(): continue + assert tdt.is_integer(), ( + node.name + + """: MultiThreshold cannot be converted + because thresholds are float type. Input data type is integer, + please run RoundAndClipThresholds to convert thresholds to integer.""" + ) # check layout of inputs/outputs, and convert if needed # check layout and convert if necessary @@ -253,7 +259,7 @@ def apply(self, model): PE=pe, numSteps=thl_thres_shape[1], inputDataType=idt.name, - weightDataType=idt.name, + weightDataType=tdt.name, outputDataType=odt.name, numInputVectors=list(thl_in_shape[:-1]), ActVal=actval, diff --git a/tests/fpgadataflow/test_fpgadataflow_thresholding.py b/tests/fpgadataflow/test_fpgadataflow_thresholding.py index 88e4247c2a..6501dba33e 100644 --- a/tests/fpgadataflow/test_fpgadataflow_thresholding.py +++ b/tests/fpgadataflow/test_fpgadataflow_thresholding.py @@ -55,7 +55,7 @@ def generate_random_threshold_values( - input_data_type, num_input_channels, num_steps, narrow=False, per_tensor=False + data_type, num_input_channels, num_steps, narrow=False, per_tensor=False ): if per_tensor: num_input_channels = 1 @@ -63,8 +63,8 @@ def generate_random_threshold_values( num_steps -= 1 return np.random.randint( - input_data_type.min(), - input_data_type.max() + 1, + data_type.min(), + data_type.max() + 1, (num_input_channels, num_steps), ).astype(np.float32) @@ -76,6 +76,7 @@ def sort_thresholds_increasing(thresholds): def make_single_multithresholding_modelwrapper( thresholds, input_data_type, + threshold_data_type, output_data_type, activation_bias, num_input_vecs, @@ -115,7 +116,7 @@ def make_single_multithresholding_modelwrapper( model.set_tensor_datatype("inp", input_data_type) model.set_tensor_datatype("outp", output_data_type) - model.set_tensor_datatype("thresh", input_data_type) + model.set_tensor_datatype("thresh", threshold_data_type) model.set_initializer("thresh", thresholds) return model @@ -129,7 +130,15 @@ def make_single_multithresholding_modelwrapper( ], ) @pytest.mark.parametrize("activation", [DataType["INT4"], DataType["BIPOLAR"]]) -@pytest.mark.parametrize("input_data_type", [DataType["INT8"], DataType["UINT8"]]) +@pytest.mark.parametrize( + "idt_tdt_cfg", + [ + (DataType["INT8"], DataType["INT8"]), + (DataType["INT8"], DataType["INT9"]), + (DataType["UINT8"], DataType["UINT8"]), + (DataType["UINT8"], DataType["UINT9"]), + ], +) @pytest.mark.parametrize("fold", [-1, 1, 2]) @pytest.mark.parametrize("narrow", [True, False]) @pytest.mark.parametrize("per_tensor", [True, False]) @@ -143,7 +152,7 @@ def test_fpgadataflow_thresholding( num_input_channels, num_input_vecs, activation, - input_data_type, + idt_tdt_cfg, fold, narrow, per_tensor, @@ -161,6 +170,7 @@ def test_fpgadataflow_thresholding( ) if narrow and activation == DataType["BIPOLAR"]: pytest.skip("Narrow needs to be false with biploar activation.") + input_data_type, threshold_data_type = idt_tdt_cfg num_steps = activation.get_num_possible_values() - 1 if fold == -1: @@ -179,7 +189,7 @@ def test_fpgadataflow_thresholding( # Generate random thresholds and sort in ascending order thresholds = generate_random_threshold_values( - input_data_type, num_input_channels, num_steps, narrow, per_tensor + threshold_data_type, num_input_channels, num_steps, narrow, per_tensor ) # provide non-decreasing/ascending thresholds @@ -189,6 +199,7 @@ def test_fpgadataflow_thresholding( model = make_single_multithresholding_modelwrapper( thresholds, input_data_type, + threshold_data_type, output_data_type, activation_bias, num_input_vecs,