diff --git a/hazardflow-designs/src/gemmini/execute/mod.rs b/hazardflow-designs/src/gemmini/execute/mod.rs index a20319f..f5dd31e 100644 --- a/hazardflow-designs/src/gemmini/execute/mod.rs +++ b/hazardflow-designs/src/gemmini/execute/mod.rs @@ -501,11 +501,12 @@ struct SramReadRespInfos { } /// generate inputs for mesh_with_delays +#[allow(clippy::type_complexity)] fn mesh_inputs( cntl: I, { Dep::Helpful }>, spad_resps: [Vr; SP_BANKS], acc_resps: [Vr; ACC_BANKS], -) -> (Vr, { Dep::Helpful }>, I, { Dep::Helpful }>) { +) -> (Vr<(A, B, D)>, I, { Dep::Helpful }>) { let (mesh_req, cntl) = cntl.map_resolver_inner::<(TagsInProgress, ())>(|er| er.0).lfork(); let req: I, { Dep::Helpful }> = mesh_req.map(|cntl| MeshReq { @@ -516,14 +517,12 @@ fn mesh_inputs( rows: cntl.c_shape.map(|c| c.rows).unwrap_or(0.into_u()), cols: cntl.c_shape.map(|c| c.cols).unwrap_or(0.into_u()), }, - pe_control: PeControl { - dataflow: cntl.dataflow, - propagate: if cntl.prop { Propagate::Reg2 } else { Propagate::Reg1 }, - shift: cntl.shift, - }, + dataflow: cntl.dataflow, + propagate_flip: cntl.prop, + shift: cntl.shift, transpose_a: cntl.transpose_a, transpose_bd: cntl.transpose_bd, - flush: 0.into_u(), + flush: false, }); let (info, zeros) = cntl @@ -761,12 +760,12 @@ fn mesh_inputs( let b_data = [b_data, b_zero].merge(); let d_data = [d_data, d_zero].merge(); - let data = [ + let data = ( a_data.filter_map(|p| p).reg_fwd(true), b_data.filter_map(|p| p).reg_fwd(true), d_data.filter_map(|p| p).reg_fwd(true), - ] - .join_vr(); + ) + .join_vr(); (data, req.reg_fwd(true)) } @@ -1558,16 +1557,14 @@ where let MeshControlSignals { counters, signals, .. } = p; MeshReq { - pe_control: PeControl { - dataflow: signals.dataflow, - propagate: if counters.in_prop_flush { Propagate::Reg2 } else { Propagate::Reg1 }, - shift: signals.shift, - }, + dataflow: signals.dataflow, + propagate_flip: counters.in_prop_flush, + shift: signals.shift, transpose_a: false, transpose_bd: false, total_rows: BLOCK_SIZE.into_u(), tag: MeshTag { rob_id: None, addr: LocalAddr::garbage(), rows: 0.into_u(), cols: 0.into_u() }, - flush: 1.into_u(), + flush: true, } }) .map_resolver_inner(|_| ()); diff --git a/hazardflow-designs/src/gemmini/execute/systolic_array/mesh_with_delays.rs b/hazardflow-designs/src/gemmini/execute/systolic_array/mesh_with_delays.rs index b97fab7..bee50a4 100644 --- a/hazardflow-designs/src/gemmini/execute/systolic_array/mesh_with_delays.rs +++ b/hazardflow-designs/src/gemmini/execute/systolic_array/mesh_with_delays.rs @@ -41,301 +41,265 @@ pub enum TransposeFlag { /// Mesh tag #[derive(Debug, Clone, Copy)] pub struct MeshTag { - /// rob_id + /// ROB ID. pub rob_id: HOption>, - /// local_addr + /// SRAM write address. pub addr: LocalAddr, - /// rows + /// Number of rows. pub rows: U<{ clog2(BLOCK_SIZE + 1) }>, - /// cols + /// Number of cols. pub cols: U<{ clog2(BLOCK_SIZE + 1) }>, } -impl MeshTag { - /// Generate garbage tag. - pub fn get_garbage_tag() -> Self { - let garbage_addr = LocalAddr::from(GARBAGE_ADDR.into_u()); - Self { rob_id: None, addr: garbage_addr, rows: 0.into_u(), cols: 0.into_u() } +impl Default for MeshTag { + /// Returns garbage tag. + fn default() -> Self { + Self { rob_id: None, addr: LocalAddr::from(GARBAGE_ADDR.into_u()), rows: 0.into_u(), cols: 0.into_u() } } } /// Request signals to the mesh. #[derive(Debug, Clone, Copy)] pub struct MeshReq { - /// pe_control - pub pe_control: PeControl, - /// a_transpose + /// Dataflow value used in the PE. + pub dataflow: Dataflow, + /// Indicates whether the propagate value should be flipped. + pub propagate_flip: bool, + /// Shift value used in the PE. + pub shift: U<{ clog2(ACC_BITS) }>, + /// Indicates that `A` should be transposed, used to invoke a transposer. pub transpose_a: bool, - /// bd_transpos + /// Indicates that either `B` or `D` should be transposed, used to invoke a transposer. pub transpose_bd: bool, - /// total_rows + /// Specifies the number of rows in the matmul operation. pub total_rows: U<{ clog2(BLOCK_SIZE + 1) }>, - /// tag + /// Tag. pub tag: MeshTag, - /// flush - pub flush: U<2>, + /// Indicates whether the request represents a flush. + pub flush: bool, } /// Response signals from the mesh. #[derive(Debug, Clone, Copy)] pub struct MeshResp { - /// total_rows + /// Specifies the number of rows in the matmul operation. pub total_rows: U<{ clog2(BLOCK_SIZE + 1) }>, - /// tag + /// Tag. pub tag: MeshTag, - /// last + /// Indicates that the row represents the last row. pub last: bool, - /// data + /// Data. pub data: Array, MESH_COLS>, } -/// Helper type to update configurations. +/// Matmul operation configuration. #[derive(Debug, Default, Clone, Copy)] pub struct Config { /// Matmul ID. pub matmul_id: U, /// Propagation. - pub in_prop: bool, + pub propagate: Propagate, } -/// Helper type to manage fire_counter and flush_counter. -#[derive(Debug, Default, Clone, Copy)] -struct Counter { - fire_counter: U<{ clog2(BLOCK_SIZE) }>, - flush_counter: U<2>, +impl Config { + /// Creates a new configuration. + pub fn new(matmul_id: U, propagate: Propagate) -> Self { + Self { matmul_id, propagate } + } + + /// Returns the updated global configuration based on the incoming request. + /// + /// For more details, see Section 2.3.1 of the assignment documentation. + /// + /// # Arguments + /// + /// - `self`: The current configuration state. + /// - `propagate_flip`: A boolean indicating whether to toggle the propagate value in processing elements (PEs). + pub fn update(self, propagate_flip: bool) -> Self { + todo!("assignment 6") + } } +/// Wrapper type of request and configuration. #[derive(Debug, Clone, Copy)] -struct ReqExtended { - req: MeshReq, - config: Config, +pub struct ReqExtended { + /// Mesh request. + pub req: MeshReq, + /// Matmul operation configuration. + pub config: Config, } -/// Helper function -/// -/// Updates the global configuration of `mesh_delay`. -/// This function should be used in a combinator that manages internal state. -/// -/// # Arguments +/// Manages two FIFOs containing the mesh tag and total rows, and returns the metadata transferred from the FIFO. /// -/// * `prop` - Propagation information from `MeshReq`. -/// * `config` - Previously stored configuration state. +/// These metadata are used to store the Mesh output in the SRAM (Scratchpad + Accumulator). /// -/// # Returns -/// -/// New configuration state. -/// -/// # Behavior -/// -/// For more details, refer to section 2.3.1 of the assignment description. -fn update_config(prop: Propagate, config: Config) -> Config { - todo!("assignment 6") -} - -/// Helper function -/// -/// Increase fire counter whenever all data (a, b, and d) comes in. -/// This function should be used in a `fsm_egress`. +/// For more details, see Section 2.3.8 of the assignment documentation. /// /// # Arguments /// -/// * `req_ext` - Request and updated configuration. -/// * `counter` - Previously stored conter state. -/// -/// # Returns -/// -/// * `(ReqExtended, bool)` - request and 1-bit signal indicating last fire. -/// * `Counter` - Updated counter state. -/// * `bool` - The 1-bit signal indicating whether `fsm_egress` is ready to receieve new request or not. -/// -/// # Behavior -/// -/// For more details, refer to section 2.3.2 of the assignment description. -#[allow(clippy::type_complexity)] -fn update_counter(req_ext: ReqExtended, counter: Counter) -> ((ReqExtended, bool), Counter, bool) { - let req = req_ext.req; - - let last_fire: bool = todo!("Is this last fire?"); - let s_next: Counter = todo!("Calculate next fire_counter and flush_counter"); - - // `is_last` indicates fsm_egress is ready to receive next payload. - let is_last: bool = todo!("is_last for WS dataflow") || todo!("is_last for OS dataflow"); - - ((req_ext, last_fire), s_next, is_last) -} - -/// Helper function to manage two fifo. -/// -/// Manage `tag_fifo` and `total_rows_fifo`. -/// The `tag` and `total_rows` are metadata to store results in SRAM (Scratchpad or Accumulator). -/// -/// # Arguments -/// -/// * `req` - Request which contains `tag` and `total_rows`. -/// Resolver signal which contains active computation in systolic array. -/// * `control` - control signal which contains `mesh_id` and `last` signals. -/// -/// # Returns -/// -/// * `MeshTag` - Tag for the just finished computation in systolic array. -/// * `U<{ clog2(BLOCK_SIZE + 1) }>` - Total rows for the just finished computation in systolic array. -/// -/// # Behavior -/// -/// For more details, refer to section 2.3.6 of the assignment description. +/// - `req`: A request containing metadata. It sends the tags in the `tags_fifo` as the resolver. +/// - `control`: Control signals from the Mesh output. fn fifos( req: I, { Dep::Helpful }>, control: Valid, -) -> Valid<(MeshTag, U<{ clog2(BLOCK_SIZE + 1) }>)> { +) -> (Valid, Valid>) { // Duplicate control signal and request, because we need to address two fifo. - let (control_tag, control_row) = control.lfork(); - - // Refer to section 2.3.8 (1) - let req: I, { Dep::Helpful }> = todo!("filter flush operation"); - let (req_tagq, req_rowq) = req.map_resolver_inner::<(TagsInProgress, ())>(|(tags, _)| tags).lfork(); - - // Refer to section 2.3.8 (2) - // Calculate future `matmul_id` when the computation is completed. - let req_tag: I, ReqExtended), TagsInProgress>, { Dep::Helpful }> = - todo!("Calculate matmul_id_of_output"); - let req_row: Vr<(U, ReqExtended)> = todo!("Calculate matmul id of current"); - - // Refer to section 2.3.8 (3) - // Convert the resolver type to use `fifo` family combinator - let req_tag: I, MeshTag), ((), FifoS<(U, MeshTag), FIFO_LENGTH>)>, { Dep::Helpful }> = + let (control_to_tag_fifo, control_to_total_rows_fifo) = control.lfork(); + + // Section 2.3.8 (1) Filter out flush request. + let req: I, { Dep::Helpful }> = todo!("assignment 6"); + + // Section 2.3.8 (2) Calculate future `matmul_id`. + let (tag, total_rows) = req + .map_resolver_inner::<(TagsInProgress, ())>(|(tags, _)| tags) + .map(|ReqExtended { req, config }| { + let tag_id = todo!("assignment 6"); + let total_rows_id = todo!("assignment 6"); + ((tag_id, req.tag), (total_rows_id, req.total_rows)) + }) + .unzip(); + + // Section 2.3.8 (3) Convert resolver type and calculate `TagsInProgress`. + let tag: I, MeshTag), ((), FifoS<(U, MeshTag), FIFO_LENGTH>)>, { Dep::Helpful }> = todo!("Caculate the resolver signal `TagsInProgress` here"); - let req_row = todo!("Refer `req_tag` and convert resolver type"); + let total_rows = total_rows.map_resolver_inner::<((), FifoS<(U, U<5>), FIFO_LENGTH>)>(|_| ()); // FIFO - let tag_fifo = req_tag.multi_headed_transparent_fifo(); - let row_fifo = todo!("Use this -> req_row.multi_headed_transparent_fifo();"); - - // Refer to the section 2.3.8 (4) - // We need PeColControl signal. - let tag = (tag_fifo, control_tag).join(); - let total_rows = todo!("Use this -> (row_fifo, control_row).join();"); - - let tag: Valid = todo!("Carefully read condition for popping and transferring data."); - let total_rows: Valid> = todo!("Carefully read condition for popping and transferring data."); - - // Refer to the section 2.3.8 (5) - // Return metadata - (tag, total_rows).zip_any_valid().map(|(tag, total_rows)| { - let tag: MeshTag = todo!("If tag is invalid payload, replace it with garbage tag signal"); - let total_rows: U<5> = todo!("If total_rows in invalid payload, replace it with garbage total_rows signal"); - - (tag, total_rows) - }) + let tag_fifo = tag.multi_headed_transparent_fifo().filter_map(|p| p.head()); + let total_rows_fifo = total_rows.multi_headed_transparent_fifo().filter_map(|p| p.head()); + + // Section 2.3.8 (4) Pop one element and get metadata. + let tag = (tag_fifo, control_to_tag_fifo) + .join() + .map_resolver_inner_with_p::<()>(|ip, _| { + let pop: bool = todo!("assignment 6"); + ((), if pop { 1.into_u() } else { 0.into_u() }) + }) + .filter_map::(|((head_id, tag), mesh_out_control)| { + let transfer: bool = todo!("assignment 6"); + if transfer { + Some(tag) + } else { + None + } + }); + let total_rows = (total_rows_fifo, control_to_total_rows_fifo) + .join() + .map_resolver_inner_with_p::<()>(|ip, _| { + let pop: bool = todo!("assignment 6"); + ((), if pop { 1.into_u() } else { 0.into_u() }) + }) + .filter_map::>(|((head_id, total_rows), mesh_out_control)| { + let transfer: bool = todo!("assignment 6"); + if transfer { + Some(total_rows) + } else { + None + } + }); + + // NOTE: Converting to the valid interface is safe as there are no longer any hazards. + (tag.always_into_valid(), total_rows.always_into_valid()) } -/// Helper function to invoke transposer. -/// -/// Transpose 0 to 1 matrix. -/// -/// # Arguments +/// Invokes a Transposer. /// -/// * `data` - It contains request and 3 matrices. The request contains `dataflow`, `transpose_a` and `transpose_bd`. +/// Returns three matrices, with at most one matrix is transposed. /// -/// # Returns +/// For more details, see Section 2.3.5 of the assignment documentation. /// -/// Three matrices. Either one is transposed, or none are transposed. -/// -/// # Behavior +/// # Arguments /// -/// For more details, refer to section 2.3.5 of the assignment description. -/// The figure would be pretty helpful! -fn transpose(data: Valid<(MeshReq, A, B, D)>) -> (Valid, Valid, Valid) { - // You need to attach selector bit to use `branch` combinator later. - // Selector bit for whether the matrix should be tranposed or not. - // Refer to section 2.3.5 (1). - let a_with_sel: Valid<(A, BoundedU<2>)> = todo!("attach selector bit for A matrix"); - let b_with_sel: Valid<(B, BoundedU<2>)> = todo!("attach selector bit for B matrix"); - let d_with_sel: Valid<(D, BoundedU<2>)> = todo!("attach selector bit for D matrix"); - - // Perform branch based on selector bit. - // Refer to section 2.3.5 (2) - let [a, a_transpose]: [Valid; 2] = todo!("branch from a_with_sel"); - let [b, b_transpose]: [Valid; 2] = todo!("branch from b_with_sel"); - let [d, d_transpose]: [Valid; 2] = todo!("branch from d_with_sel"); - - // Attach tag where transposed data come from. - // Refer to section 2.3.5 (3) - let a_transpose: Valid<(TransposeFlag, A)> = todo!("attach tag"); - let b_transpose: Valid<(TransposeFlag, B)> = todo!("attach tag"); - let d_transpose: Valid<(TransposeFlag, D)> = - todo!("attach tag. It should be reversed before going into transposer!"); - - // Get transpose_target. - // section 2.3.5 (4) - let flag: Valid = todo!("Which matrices among A, B, and D go into the Transposer?"); - // Valid doesn't mean that A should be tranposed. Actually the type of A, B, and D are the same. - let transpose_target: Valid = todo!("Select matrix among a, b, and d"); +/// - `data`: It contains request and 3 matrices. The request contains `dataflow`, `transpose_a` and `transpose_bd`. +fn transpose(data: Valid<(MeshReq, A, B, D)>) -> Valid<(A, B, D)> { + // Section 2.3.5 (1) Attach selector. + let (a_with_sel, b_with_sel, d_with_sel): ( + Valid<(A, BoundedU<2>)>, + Valid<(B, BoundedU<2>)>, + Valid<(D, BoundedU<2>)>, + ) = todo!("assignment 6"); + + // Section 2.3.5 (2) Branch interface. + let [a, a_transpose] = a_with_sel.branch(); + let [b, b_transpose] = b_with_sel.branch(); + let [d, d_transpose] = d_with_sel.branch(); + + let a_transpose = a_transpose.map(|p| (TransposeFlag::A, p)); + let b_transpose = b_transpose.map(|p| (TransposeFlag::B, p)); + let d_transpose = d_transpose.map(|p| (TransposeFlag::D, p.reverse())); + + // Section 2.3.5 (3) Select transpose target. + // NOTE: `Valid` does not mean that `A` should be transposed, actually the types `A`, `B`, and `D` are the same. + let (flag, transpose_target): (Valid, Valid) = todo!("assignment 6"); let transposed = transpose_target.map(|vec| vec.concat()).comb(transposer_ffi); - // Identify which matrix is transposed among A, B, or D. - // section 2.3.5 (5) + // Section 2.3.5 (4) Identify which matrix is transposed among A, B, or D. let [a_transposed, b_transposed, d_transposed]: [Valid; 3] = - (flag, transposed).join_valid().map(|(flag, arr)| todo!("which matrix?")).branch(); + (flag, transposed).join_valid().map(|(flag, arr)| todo!("assignment 6")).branch(); - // Select one among `a` and `a_transposed` - // Section 2.3.5 (6) - let a: Valid = todo!("Select one among `a` and `a_transposed`"); - let b: Valid = todo!("Select one among `b` and `b_transposed`"); - let d: Valid = todo!("Select one among `d` and `d_transposed`. It should be reversed before going out."); + // Section 2.3.5 (5) Select one among `X` and `X_transposed`. + let a = [a_transposed, a].merge(); + let b = [b_transposed, b].merge(); + let d = [d_transposed.map(|p| p.reverse()), d].merge(); - (a, b, d) + (a, b, d).join_valid() } /// Mesh with delays. pub fn mesh_with_delays( - data: Vr>, + data: Vr<(A, B, D)>, req: I, { Dep::Helpful }>, ) -> Valid { - // Section 2.3.1 update configurations. - // Update configuration and fork interfaces. - let req: I, { Dep::Helpful }> = todo!("update configuration. from req"); + // Section 2.3.1 Update Configurations. + let req: I, { Dep::Helpful }> = todo!("assignment 6"); let (mesh_req, fifo_req) = req.map_resolver_inner::<((), TagsInProgress)>(|(_, tags)| tags).lfork(); - // Section 2.3.2 manage request buffer - let mesh_req: Vr<(ReqExtended, bool)> = todo!("Request buffer. from mesh_req"); - - // Section 2.3.3 perform a branch with either flush_req or matmul_req - // It could be (flush_req, matmul_req) - let [flush_req, matmul_req]: [Vr<(ReqExtended, bool)>; 2] = todo!("Branch request. from mesh_req"); - - let (matmul_req, req_and_data) = (matmul_req, data).join_vr().always_into_valid().lfork(); - - // Section 2.3.4 merge `flush_req` and `matmul_req`. - // Only one request between `flush_req` and `matmul_req` can be transferred to mesh. - let matmul_req = matmul_req.map(|(req, _)| req); - let flush_req = flush_req.always_into_valid(); - let req: Valid<(MeshReq, Config, bool)> = todo!("Merged Req"); - - // Section 2.3.5 Invoke transposer. - let (a, b, d): (Valid, Valid, Valid) = - todo!("Use matmul_req_and_data. Get these Valid interfaces from `transpose` helper function."); - - // Preprocessing for mesh input such as, - // Applying ShiftRegister and interface type conversion - let (in_left, in_top) = (a, b, d, req).comb(mesh_i).comb(shift_i); - - let mesh_output = mesh_ffi(in_left, in_top); - - // Preprocessing for mesh output such as, - // Applying ShiftRegister and interface type conversion - let (output_data, output_config) = mesh_output.comb(shift_o).comb(mesh_o); - let (fifo_control, last) = output_config.lfork(); - - // FIFO - let metadata = fifos(fifo_req, fifo_control); - - let output = (output_data, last.map(|p| p.last)).join_valid(); - - // Refer to the section 2.3.9 - // Return `MeshResp` - todo!("Use `metadata` and `output.`") + // Section 2.3.2 Request Buffer. + let mesh_req: Vr<(ReqExtended, bool)> = mesh_req.fsm_egress::<(ReqExtended, bool), U<{ clog2(BLOCK_SIZE) }>>( + U::default(), + true, + true, + |req_ext, counter| todo!("assignment 6"), + ); + + // Section 2.3.3 Branch Request. + let [mesh_req_flush, mesh_req_matmul]: [Vr<(ReqExtended, bool)>; 2] = todo!("assignment 6"); + + // NOTE: Converting to the valid interface is safe as there are no longer any hazards. + let mesh_req_flush = mesh_req_flush.always_into_valid(); + let (mesh_req_matmul, mesh_data) = (mesh_req_matmul, data) + .join_vr() + .always_into_valid() + .map(|((req_ext, last), (a, b, d))| { + let matmul_req = (req_ext, last); + let mesh_data = (req_ext.req, a, b, d); + (matmul_req, mesh_data) + }) + .unzip(); + + // Section 2.3.4 Merging Requests. + let mesh_req: Valid<(ReqExtended, bool)> = todo!("assignment 6"); + + // Section 2.3.5 Invoke a Transposer. + let mesh_data_transposed = mesh_data.comb(transpose); + + // Section 2.3.6 Mesh IO type conversion + Section 2.3.7 Shift. + let mesh_out = (mesh_data_transposed, mesh_req) + .comb(preprocess_type) + .comb(preprocess_shift) + .comb(move |(in_row, in_col)| mesh_ffi(in_row, in_col)) + .comb(postprocess_shift) + .comb(postprocess_type); + + let (mesh_out, mesh_out_control) = mesh_out.map(|p| (p, p.1)).unzip(); + + // Section 2.3.8 FIFO. + let (tag, total_rows) = fifos(fifo_req, mesh_out_control); + + // Section 2.3.9 Return Mesh Response. + (tag, total_rows, mesh_out).zip_any_valid().filter_map(|(tag, total_rows, mesh_out)| todo!("assignment 6")) } /// Mesh with delays. @@ -346,6 +310,6 @@ pub fn mesh_with_delays_default( d: Vr, req: I, { Dep::Helpful }>, ) -> Valid { - let data = [a, b, d].join_vr(); + let data = (a, b, d).join_vr(); mesh_with_delays(data, req) } diff --git a/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs b/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs index 982d394..fafd8a6 100644 --- a/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs +++ b/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs @@ -5,7 +5,7 @@ use super::*; /// Bit width of the register type. -const ACC_BITS: usize = 32; +pub const ACC_BITS: usize = 32; /// PE row data signals. #[derive(Debug, Clone, Copy)] diff --git a/hazardflow-designs/src/gemmini/execute/systolic_array/utils.rs b/hazardflow-designs/src/gemmini/execute/systolic_array/utils.rs index aeb80e8..95ab909 100644 --- a/hazardflow-designs/src/gemmini/execute/systolic_array/utils.rs +++ b/hazardflow-designs/src/gemmini/execute/systolic_array/utils.rs @@ -9,71 +9,65 @@ use super::*; macro_rules! shift_reg { ($first: ident, $( $x:ident ),*) => {{ [ [$first], $( - [ $x.shift_reg_fwd::<{${index()} + 1}>() ] + [ $x.shift_reg_fwd::<{ ${index()} + 1 }>() ] ), *] }}; (($fx:ident, $fy: ident), $(($x:ident, $y:ident)), *) => {{ [ [($fx, $fy)], $( - [ ($x.shift_reg_fwd::<{${index()} + 1}>(), $y.shift_reg_fwd::<{${index()} + 1}>()) ] + [ ($x.shift_reg_fwd::<{ ${index()} + 1 }>(), $y.shift_reg_fwd::<{ ${index()} + 1 }>()) ] ), *] }}; } -macro_rules! shift_reg_reverse { +macro_rules! shift_reg_rev { ($($x:ident),* ; $last:ident) => {{ - [ $( [ $x.shift_reg_fwd::<{TOTAL_ROWS - 1 - ${index()}}>() ] + [ $( [ $x.shift_reg_fwd::<{ TOTAL_ROWS - 1 - ${index()} }>() ] ),*, [ $last ] ] }}; ($(($x:ident, $y:ident)),* ; ($lx:ident, $ly:ident)) => {{ - [ $( [ ($x.shift_reg_fwd::<{TOTAL_ROWS - 1 - ${index()}}>(), $y.shift_reg_fwd::<{TOTAL_ROWS - 1 - ${index()}}>()) ] + [ $( [ ($x.shift_reg_fwd::<{ TOTAL_ROWS - 1 - ${index()} }>(), $y.shift_reg_fwd::<{ TOTAL_ROWS - 1 - ${index()} }>()) ] ),*, [ ($lx, $ly) ] ] }}; } /// Shift input interface. -pub fn shift_i((in_left, in_top): (MeshRowData, MeshColData)) -> (MeshRowData, MeshColData) { - let [[in_left0], [in_left1], [in_left2], [in_left3], [in_left4], [in_left5], [in_left6], [in_left7], [in_left8], [in_left9], [in_left10], [in_left11], [in_left12], [in_left13], [in_left14], [in_left15]] = - in_left; - let [[(t0d, t0c)], [(t1d, t1c)], [(t2d, t2c)], [(t3d, t3c)], [(t4d, t4c)], [(t5d, t5c)], [(t6d, t6c)], [(t7d, t7c)], [(t8d, t8c)], [(t9d, t9c)], [(t10d, t10c)], [(t11d, t11c)], [(t12d, t12c)], [(t13d, t13c)], [(t14d, t14c)], [(t15d, t15c)]] = - in_top; +pub fn preprocess_shift((in_row, in_col): (MeshRowData, MeshColData)) -> (MeshRowData, MeshColData) { + let [[r0], [r1], [r2], [r3], [r4], [r5], [r6], [r7], [r8], [r9], [r10], [r11], [r12], [r13], [r14], [r15]] = in_row; + let [[(c0d, c0c)], [(c1d, c1c)], [(c2d, c2c)], [(c3d, c3c)], [(c4d, c4c)], [(c5d, c5c)], [(c6d, c6c)], [(c7d, c7c)], [(c8d, c8c)], [(c9d, c9c)], [(c10d, c10c)], [(c11d, c11c)], [(c12d, c12c)], [(c13d, c13c)], [(c14d, c14c)], [(c15d, c15c)]] = + in_col; ( + shift_reg!(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15), shift_reg!( - in_left0, in_left1, in_left2, in_left3, in_left4, in_left5, in_left6, in_left7, in_left8, in_left9, - in_left10, in_left11, in_left12, in_left13, in_left14, in_left15 - ), - shift_reg!( - (t0d, t0c), - (t1d, t1c), - (t2d, t2c), - (t3d, t3c), - (t4d, t4c), - (t5d, t5c), - (t6d, t6c), - (t7d, t7c), - (t8d, t8c), - (t9d, t9c), - (t10d, t10c), - (t11d, t11c), - (t12d, t12c), - (t13d, t13c), - (t14d, t14c), - (t15d, t15c) + (c0d, c0c), + (c1d, c1c), + (c2d, c2c), + (c3d, c3c), + (c4d, c4c), + (c5d, c5c), + (c6d, c6c), + (c7d, c7c), + (c8d, c8c), + (c9d, c9c), + (c10d, c10c), + (c11d, c11c), + (c12d, c12c), + (c13d, c13c), + (c14d, c14c), + (c15d, c15c) ), ) } /// Shift output interface. -pub fn shift_o((row_output, col_output): (MeshRowData, MeshColData)) -> (MeshRowData, MeshColData) { - let [[row0], [row1], [row2], [row3], [row4], [row5], [row6], [row7], [row8], [row9], [row10], [row11], [row12], [row13], [row14], [row15]] = - row_output; +pub fn postprocess_shift((out_row, out_col): (MeshRowData, MeshColData)) -> (MeshRowData, MeshColData) { + let [[r0], [r1], [r2], [r3], [r4], [r5], [r6], [r7], [r8], [r9], [r10], [r11], [r12], [r13], [r14], [r15]] = + out_row; let [[(c0d, c0c)], [(c1d, c1c)], [(c2d, c2c)], [(c3d, c3c)], [(c4d, c4c)], [(c5d, c5c)], [(c6d, c6c)], [(c7d, c7c)], [(c8d, c8c)], [(c9d, c9c)], [(c10d, c10c)], [(c11d, c11c)], [(c12d, c12c)], [(c13d, c13c)], [(c14d, c14c)], [(c15d, c15c)]] = - col_output; + out_col; ( - shift_reg_reverse!( - row0, row1, row2, row3, row4, row5, row6, row7, row8, row9, row10, row11, row12, row13, row14; row15 - ), - shift_reg_reverse!( + shift_reg_rev!(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14; r15), + shift_reg_rev!( (c0d, c0c), (c1d, c1c), (c2d, c2c), @@ -96,17 +90,19 @@ pub fn shift_o((row_output, col_output): (MeshRowData, MeshColData)) -> (MeshRow /// Interface type conversion. #[allow(clippy::type_complexity)] -pub fn mesh_i( - (a, b, d, req): (Valid, Valid, Valid, Valid<(MeshReq, Config, bool)>), -) -> (MeshRowData, MeshColData) { +pub fn preprocess_type((data, req): (Valid<(A, B, D)>, Valid<(ReqExtended, bool)>)) -> (MeshRowData, MeshColData) { // # Safety // // All the input and output interfaces are `Valid` type. unsafe { - (a, b, d, req).fsm::<(MeshRowData, MeshColData), ()>((), |(a_in, b_in, d_in, req_in), _, ()| { + (data, req).fsm::<(MeshRowData, MeshColData), ()>((), |(data_in, req_in), _, ()| { let default_row = None::.repeat::<1>().repeat::(); let default_col = (None::, None::).repeat::<1>().repeat::(); + let a_in = data_in.map(|p| p.0); + let b_in = data_in.map(|p| p.1); + let d_in = data_in.map(|p| p.2); + let col_in = match (b_in.zip(d_in), req_in) { (Some(bd), Some(req)) => Some((Some(bd), Some(req))), (None, Some(req)) => Some((None, Some(req))), @@ -114,8 +110,8 @@ pub fn mesh_i( }; let in_left = col_in.map_or(default_row, |_| { - if let Some(a_in) = a_in { - a_in.map(|tile_v| tile_v.map(|v| Some(PeRowData { a: v }))) + if let Some(mesh_row) = a_in { + mesh_row.map(|tile_row| tile_row.map(|a| Some(PeRowData { a }))) } else { range::().map(|_| Some(PeRowData { a: S::from(0.into_u::()) }).repeat::<1>()) } @@ -123,13 +119,9 @@ pub fn mesh_i( let in_top = col_in.map_or(default_col, |(bd, mesh_req)| { // Reqeust is always valid due to the match statement above. - let (bd, (req, config, last_fire)) = mesh_req.map(|req| (bd, req)).unwrap(); - let pe_control = PeControl { - dataflow: req.pe_control.dataflow, - propagate: if config.in_prop { Propagate::Reg1 } else { Propagate::Reg2 }, - shift: req.pe_control.shift, - }; - let column_control = Some(PeColControl { control: pe_control, id: config.matmul_id, last: last_fire }); + let (bd, (ReqExtended { req, config }, last)) = mesh_req.map(|req| (bd, req)).unwrap(); + let pe_control = PeControl { dataflow: req.dataflow, propagate: config.propagate, shift: req.shift }; + let column_control = Some(PeColControl { control: pe_control, id: config.matmul_id, last }); if let Some((b, d)) = bd { b.zip(d).map(|(b, d)| { @@ -148,20 +140,20 @@ pub fn mesh_i( } }); - ((in_left, in_top), ((), (), (), ()), ()) + ((in_left, in_top), ((), ()), ()) }) } } /// Interface type conversion. -pub fn mesh_o( - (row_output, col_output): (MeshRowData, MeshColData), -) -> (Valid, MESH_COLS>>, Valid) { +pub fn postprocess_type( + (out_row, out_col): (MeshRowData, MeshColData), +) -> Valid<(Array, MESH_COLS>, PeColControl)> { // # Safety // // All the input and output interfaces are `Valid` type. unsafe { - (row_output, col_output).fsm::<(Valid, MESH_COLS>>, Valid), ()>( + (out_row, out_col).fsm::, MESH_COLS>, PeColControl)>, ()>( (), |(_, col_data), _, ()| { let out_valid = col_data[0][0].0.is_some(); @@ -176,12 +168,11 @@ pub fn mesh_o( let matmul_result = if dataflow_os { out_c } else { out_b }; - let matmul_result = if out_valid { Some(matmul_result) } else { None }; - let output_control = if out_valid { col_data[0][0].1 } else { None }; - + let ep = if out_valid { Some((matmul_result, col_data[0][0].1.unwrap())) } else { None }; let ir0 = ().repeat::<1>().repeat::(); let ir1 = ((), ()).repeat::<1>().repeat::(); - ((matmul_result, output_control), (ir0, ir1), ()) + + (ep, (ir0, ir1), ()) }, ) } diff --git a/hazardflow-designs/src/std/combinators/fifo.rs b/hazardflow-designs/src/std/combinators/fifo.rs index 29533cc..90cf123 100644 --- a/hazardflow-designs/src/std/combinators/fifo.rs +++ b/hazardflow-designs/src/std/combinators/fifo.rs @@ -35,6 +35,15 @@ where [(); clog2(N + 1)]:, [(); clog2(N) + 1]:, { + /// Returns the head of the FIFO. + pub fn head(self) -> HOption

{ + if self.len == 0.into_u() { + None + } else { + Some(self.inner[self.raddr]) + } + } + /// Returns inner elements with valid bit. pub fn inner_with_valid(self) -> Array, N> { range::().map(|i| { diff --git a/hazardflow-designs/src/std/combinators/unzip.rs b/hazardflow-designs/src/std/combinators/unzip.rs index 3805f76..1a2a8c4 100644 --- a/hazardflow-designs/src/std/combinators/unzip.rs +++ b/hazardflow-designs/src/std/combinators/unzip.rs @@ -146,18 +146,14 @@ impl I` | `(HOption, HOption)` | /// | **Bwd** | `Ready<(R1, R2)>` | `(Ready, Ready)` | #[allow(clippy::type_complexity)] - pub fn unzip(self) -> (I, { Dep::Demanding }>, I, { Dep::Demanding }>) { + pub fn unzip(self) -> (I, D>, I, D>) { unsafe { - Interface::fsm(self, (), |ip, er: (Ready, Ready), ()| { - let ready = er.0.ready && er.1.ready; - let ep = if ready && ip.is_some() { - let (p1, p2) = ip.unwrap(); - (Some(p1), Some(p2)) - } else { - (None, None) - }; - let ir = Ready::new(ready, (er.0.inner, er.1.inner)); - (ep, ir, ()) + Interface::fsm::<(I, D>, I, D>), ()>(self, (), |ip, (er1, er2), ()| { + let ep1 = if er2.ready { ip.map(|(p, _)| p) } else { None }; + let ep2 = if er1.ready { ip.map(|(_, p)| p) } else { None }; + let ir = Ready::new(er1.ready && er2.ready, (er1.inner, er2.inner)); + + ((ep1, ep2), ir, ()) }) } } @@ -238,7 +234,7 @@ impl Vr<(P1, P2), D> { /// | :-------: | ------------------- | ---------------------------- | /// | **Fwd** | `HOption<(P1, P2)>` | `(HOption, HOption)` | /// | **Bwd** | `Ready<()>` | `(Ready<()>, Ready<()>)` | - pub fn unzip(self) -> (Vr, Vr) { + pub fn unzip(self) -> (Vr, Vr) { self.map_resolver::<((), ())>(|_| ()).unzip() } } diff --git a/scripts/gemmini/unit_tests/mesh_with_delays/test_mesh_with_delays.py b/scripts/gemmini/unit_tests/mesh_with_delays/test_mesh_with_delays.py index 0cb6f9a..96ac09c 100644 --- a/scripts/gemmini/unit_tests/mesh_with_delays/test_mesh_with_delays.py +++ b/scripts/gemmini/unit_tests/mesh_with_delays/test_mesh_with_delays.py @@ -75,9 +75,9 @@ def decompose_data(data: BinaryValue, width): "InpCtrl", signals=[ "payload_discriminant", - "payload_Some_0_pe_control_dataflow_discriminant", - "payload_Some_0_pe_control_propagate_discriminant", - "payload_Some_0_pe_control_shift", + "payload_Some_0_dataflow_discriminant", + "payload_Some_0_propagate_flip", + "payload_Some_0_shift", "payload_Some_0_transpose_a", "payload_Some_0_transpose_bd", "payload_Some_0_total_rows", @@ -104,10 +104,10 @@ def os_flush_request(propagate, shift): return InpCtrlTransaction( payload_Some_0_transpose_a=False, payload_Some_0_transpose_bd=False, - payload_Some_0_flush=1, - payload_Some_0_pe_control_dataflow_discriminant=OS, - payload_Some_0_pe_control_propagate_discriminant=propagate, - payload_Some_0_pe_control_shift=shift, + payload_Some_0_flush=True, + payload_Some_0_dataflow_discriminant=OS, + payload_Some_0_propagate_flip=propagate, + payload_Some_0_shift=shift, payload_Some_0_tag_addr_accumulate=False, payload_Some_0_tag_addr_data=0, payload_Some_0_tag_addr_garbage=False, @@ -127,10 +127,10 @@ def req_with_none_rob_id(mode, transpose_a, transpose_bd, propagate): return InpCtrlTransaction( payload_Some_0_transpose_a=transpose_a, payload_Some_0_transpose_bd=transpose_bd, - payload_Some_0_flush=0, - payload_Some_0_pe_control_dataflow_discriminant=mode, - payload_Some_0_pe_control_propagate_discriminant=propagate, - payload_Some_0_pe_control_shift=0, + payload_Some_0_flush=False, + payload_Some_0_dataflow_discriminant=mode, + payload_Some_0_propagate_flip=propagate, + payload_Some_0_shift=0, payload_Some_0_total_rows=16, payload_Some_0_tag_rob_id_discriminant=False, payload_Some_0_tag_rob_id_Some_0=0, @@ -148,9 +148,9 @@ def req_with_none_rob_id(mode, transpose_a, transpose_bd, propagate): def req_with_rob_id(mode, transpose_a, transpose_bd, propagate): return InpCtrlTransaction( - payload_Some_0_pe_control_dataflow_discriminant=mode, - payload_Some_0_pe_control_propagate_discriminant=propagate, - payload_Some_0_pe_control_shift=0, + payload_Some_0_dataflow_discriminant=mode, + payload_Some_0_propagate_flip=propagate, + payload_Some_0_shift=0, payload_Some_0_transpose_a=transpose_a, payload_Some_0_transpose_bd=transpose_bd, payload_Some_0_total_rows=16, @@ -165,7 +165,7 @@ def req_with_rob_id(mode, transpose_a, transpose_bd, propagate): payload_Some_0_tag_addr_data=0, payload_Some_0_tag_rows=16, payload_Some_0_tag_cols=16, - payload_Some_0_flush=0, + payload_Some_0_flush=False, ) @@ -458,11 +458,7 @@ async def ws_transpose_b(dut): await tb.in_b_data_req.send(InpDataTransaction(payload_Some_0_0=0)) # When running WS dataflow, weight should be sent in reverse order # due to the way the weight is preloaded in the PEs. - await tb.in_d_data_req.send( - InpDataTransaction( - payload_Some_0_0=concatenate_data(weight[15 - i], 8), - ) - ) + await tb.in_d_data_req.send(InpDataTransaction(payload_Some_0_0=concatenate_data(weight[15 - i], 8))) # 2. Wait until the data is loaded await tb.in_ctrl_req.send(req_with_rob_id(WS, False, True, REG2)) @@ -474,12 +470,8 @@ async def ws_transpose_b(dut): # 3. Send activation and bias data await tb.in_ctrl_req.send(req_with_none_rob_id(WS, False, True, REG1)) for i in range(16): - await tb.in_a_data_req.send( - InpDataTransaction(payload_Some_0_0=concatenate_data(activation[i], 8)) - ) - await tb.in_b_data_req.send( - InpDataTransaction(payload_Some_0_0=concatenate_data(bias[i], 8)) - ) + await tb.in_a_data_req.send(InpDataTransaction(payload_Some_0_0=concatenate_data(activation[i], 8))) + await tb.in_b_data_req.send(InpDataTransaction(payload_Some_0_0=concatenate_data(bias[i], 8))) await tb.in_d_data_req.send(InpDataTransaction(payload_Some_0_0=0)) output_data = [] @@ -538,21 +530,9 @@ async def os_no_transpose(dut): # 1. Preload bias data await tb.in_ctrl_req.send(req_with_rob_id(OS, False, False, REG2)) for i in range(16): - await tb.in_a_data_req.send( - InpDataTransaction( - payload_Some_0_0=concatenate_data(activation[i], 8), - ) - ) - await tb.in_b_data_req.send( - InpDataTransaction( - payload_Some_0_0=0, - ) - ) - await tb.in_d_data_req.send( - InpDataTransaction( - payload_Some_0_0=concatenate_data(bias[15 - i], 8), - ) - ) + await tb.in_a_data_req.send(InpDataTransaction(payload_Some_0_0=concatenate_data(activation[i], 8))) + await tb.in_b_data_req.send(InpDataTransaction(payload_Some_0_0=0)) + await tb.in_d_data_req.send(InpDataTransaction(payload_Some_0_0=concatenate_data(bias[15 - i], 8))) await tb.in_ctrl_req.send(req_with_none_rob_id(OS, False, False, REG1)) for i in range(16):