Skip to content

Commit

Permalink
Merge pull request Xilinx#1076 from Xilinx/bugfix/rtl_mvau
Browse files Browse the repository at this point in the history
RTL MVAU cross-lane accumulation overflow fix
  • Loading branch information
auphelia authored May 15, 2024
2 parents 43fc12b + 963a38d commit ae87807
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 22 deletions.
21 changes: 14 additions & 7 deletions finn-rtllib/mvu/mvu_4sx4u.sv
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,12 @@ module mvu_4sx4u #(
// Count leaves reachable from each node
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1}; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop

uwire signed [ACCU_WIDTH -1:0] up4;
uwire signed [ACCU_WIDTH -8:0] hi4[3];
uwire [$clog2(SIMD)+7:0] lo4[3];
uwire signed [ACCU_WIDTH-1:0] up4;
uwire signed [$clog2(2**(ACCU_WIDTH-8)+SIMD):0] hi4[3]; // min LO_WIDTH=7
uwire [$clog2(SIMD)+7 :0] lo4[3]; // max LO_WIDTH=8
for(genvar i = 0; i < 4; i++) begin
localparam int unsigned LO_WIDTH = D[i+1] - D[i];
localparam int unsigned HI_WIDTH = ACCU_WIDTH - LO_WIDTH;
localparam int unsigned HI_WIDTH = 1 + $clog2(2**(ACCU_WIDTH-LO_WIDTH-1)+SIMD);

// Conclusive high part accumulation
if(i >= PE_REM && i < 3) begin : genHi
Expand All @@ -469,15 +469,22 @@ module mvu_4sx4u #(
logic signed [HI_WIDTH-1:0] Hi4 = 0;
always_ff @(posedge clk) begin
if(rst) Hi4 <= 0;
else if(en) Hi4 <= (L[4]? 0 : Hi4) + $signed(tree[0]);
else if(en) begin
automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]);
assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin
$error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH);
$stop;
end
Hi4 <= h;
end
end
assign hi4[i] = Hi4;
end : genHi
else if (i < 3) begin : genHiZero
assign hi4[i] = '0;
end : genHiZero

// Conclusive low part accumulation
// Conclusive low part accumulation (all unsigned arithmetic)
if(i >= PE_REM) begin : blkLo
// Adder Tree across all SIMD low contributions
localparam int unsigned ROOT_WIDTH = $clog2(1 + SIMD*(2**LO_WIDTH-1));
Expand All @@ -486,7 +493,7 @@ module mvu_4sx4u #(
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
localparam int unsigned NODE_WIDTH = $clog2(1 + LEAVE_LOAD[n]*(2**LO_WIDTH-1));
uwire [NODE_WIDTH-1:0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]);
uwire [NODE_WIDTH-1:0] s = tree[2*n+1] + tree[2*n+2];
assign tree[n] = s;
end

Expand Down
42 changes: 33 additions & 9 deletions finn-rtllib/mvu/mvu_8sx8u_dsp48.sv
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ module mvu_8sx8u_dsp48 #(
localparam int unsigned PE_END = PE < 2*(c+1)? PE : 2*(c+1);
localparam int unsigned PE_REM = 2*(c+1) - PE_END;

uwire [57:0] p3[SIMD];
uwire [47:0] p3[SIMD];
uwire signed [ 1:0] h3[SIMD];
for(genvar s = 0; s < SIMD; s++) begin : genSIMD

Expand Down Expand Up @@ -447,13 +447,30 @@ module mvu_8sx8u_dsp48 #(
// Count leaves reachable from each node
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 0}; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop

uwire signed [ACCU_WIDTH -1:0] up4;
uwire signed [ACCU_WIDTH -SINGLE_PROD_WIDTH:0] hi4;
uwire [$clog2(SIMD)+SINGLE_PROD_WIDTH-1:0] lo4;
// Range of Cross-lane Contribution Tracked in Hi4
/*
* - Assumption: ACCU_WIDTH bounds right lane value at any point in time.
* - The value x beyond the lane boundary is hence bounded by:
* -2^(w-1) <= x <= 2^(w-1)-1 with w = ACCU_WIDTH - D[1]
* - This value decomposes into the tracked overflow h and the overflow l
* from the low SIMD lane reduction with:
* 0 <= l <= SIMD
* - From x = l + h follows:
* h = x - l
* -2^(w-1) - SIMD <= h <= 2^(w-1)-1
* - This required bit width of the two's complement representation of this
* signed value is determined by its lower bound to be at least:
* 1 + $clog2(2^(w-1)+SIMD)
*/
localparam int unsigned HI_WIDTH = 1 + $clog2(2**(ACCU_WIDTH-D[1]-1)+SIMD);

uwire signed [ACCU_WIDTH -1:0] up4;
uwire signed [HI_WIDTH -1:0] hi4;
uwire [$clog2(SIMD)+D[1]-1:0] lo4;

// Conclusive high part accumulation
if(PE_REM == 0) begin : genHi
localparam int unsigned HI_WIDTH = ACCU_WIDTH - D[1];

// Adder Tree across all SIMD high contributions, each from [-1:1]
uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s];
Expand All @@ -466,8 +483,15 @@ module mvu_8sx8u_dsp48 #(
// High Sideband Accumulation
logic signed [HI_WIDTH-1:0] Hi4 = 0;
always_ff @(posedge clk) begin
if(rst) Hi4 <= 0;
else if(en) Hi4 <= (L[4]? 0 : Hi4) + $signed(tree[0]);
if(rst) Hi4 <= 0;
else if(en) begin
automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]);
assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin
$error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH);
$stop;
end
Hi4 <= h;
end
end
assign hi4 = Hi4;
end : genHi
Expand All @@ -479,14 +503,14 @@ module mvu_8sx8u_dsp48 #(
localparam int unsigned LO_WIDTH = D[i+1] - D[i];
// Conclusive low part accumulation
if(i >= PE_REM) begin : blkLo
// Adder Tree across all SIMD low contributions
// Adder Tree across all SIMD low contributions (all unsigned arithmetic)
localparam int unsigned ROOT_WIDTH = $clog2(1 + SIMD*(2**LO_WIDTH-1));
uwire [2*SIMD-2:0][ROOT_WIDTH-1:0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][D[i]+:LO_WIDTH];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
localparam int unsigned NODE_WIDTH = $clog2(1 + LEAVE_LOAD[n]*(2**LO_WIDTH-1));
uwire [NODE_WIDTH-1:0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]);
uwire [NODE_WIDTH-1:0] s = tree[2*n+1] + tree[2*n+2];
assign tree[n] = s;
end

Expand Down
11 changes: 5 additions & 6 deletions finn-rtllib/mvu/tb/mvu_axi_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ module mvu_axi_tb();
// Matrix & parallelism config
localparam bit IS_MVU = 1;
localparam string COMPUTE_CORE = "mvu_4sx4u";
localparam int unsigned MW = 120;
localparam int unsigned MH = 40;
localparam int unsigned SIMD = 20;
localparam int unsigned PE = 10;
localparam int unsigned MW = 96;
localparam int unsigned MH = 32;
localparam int unsigned SIMD = 48;
localparam int unsigned PE = 16;
localparam int unsigned SEGMENTLEN = 2.0;
localparam bit FORCE_BEHAVIORAL = 1;
localparam bit M_REG_LUT = 1;
// Bit-width config
localparam int unsigned ACTIVATION_WIDTH = 4;
localparam int unsigned WEIGHT_WIDTH = 4;
localparam int unsigned ACCU_WIDTH = ACTIVATION_WIDTH+WEIGHT_WIDTH+$clog2(MW);
localparam bit SIGNED_ACTIVATIONS = 0;
localparam bit SIGNED_ACTIVATIONS = 1;
// Simulation constants
localparam int unsigned NF = MH/PE;
localparam int unsigned SF = MW/SIMD;
Expand Down Expand Up @@ -142,7 +142,6 @@ module mvu_axi_tb();

// Function to compute golden output
// a: [SF][SIMD-1:0][ACTIVATION_WIDTH-1:0]
// a: [SF][SIMD-1:0][ACTIVATION_WIDTH-1:0]
// a: [SF][PE*SIMD-1:0][ACTIVATION_WIDTH-1:0]
// w: [NF][SF][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0]
typedef logic signed [PE-1:0][ACCU_WIDTH-1:0] output_t;
Expand Down

0 comments on commit ae87807

Please sign in to comment.