From ef81ae5b3cf3e7c8f07cc8318ce9be772c15f182 Mon Sep 17 00:00:00 2001 From: Parth Sarkar <93054509+parthsarkar17@users.noreply.github.com> Date: Sun, 9 Jun 2024 23:51:08 -0400 Subject: [PATCH] new dynamic ohe branch due to weird git stuff (#2104) * new dynamic ohe branch due to weird git stuff * created one-hot-query factorization * formatting fix * tracked down transition guard issue --- .../src/passes/top_down_compile_control.rs | 272 ++++++++++++++---- runt.toml | 21 ++ tests/passes/tdcc-ohe/par.expect | 73 +++++ tests/passes/tdcc-ohe/par.futil | 40 +++ 4 files changed, 358 insertions(+), 48 deletions(-) create mode 100644 tests/passes/tdcc-ohe/par.expect create mode 100644 tests/passes/tdcc-ohe/par.futil diff --git a/calyx-opt/src/passes/top_down_compile_control.rs b/calyx-opt/src/passes/top_down_compile_control.rs index 872604233a..453993b16e 100644 --- a/calyx-opt/src/passes/top_down_compile_control.rs +++ b/calyx-opt/src/passes/top_down_compile_control.rs @@ -209,6 +209,11 @@ fn compute_unique_ids(con: &mut ir::Control, cur_state: u64) -> u64 { } } +enum Encoding { + Binary, + OneHot, +} + /// Represents the dyanmic execution schedule of a control program. struct Schedule<'b, 'a: 'b> { /// A mapping from groups to corresponding FSM state ids @@ -325,15 +330,55 @@ impl<'b, 'a> Schedule<'b, 'a> { }); } + /// Queries the FSM by building a new slicer and corresponding assignments if + /// the query hasn't yet been made. If this query has been made before, it + /// reuses the old query. Returns a new guard representing the query. + fn build_one_hot_query( + builder: &mut ir::Builder, + used_slicers: &mut HashMap>, + fsm: &ir::RRC, + signal_on: &ir::RRC, + state: &u64, + fsm_size: &u64, + ) -> ir::Guard { + match used_slicers.get(state) { + None => { + // construct slicer for this bit query + structure!( + builder; + let slicer = prim std_bit_slice(*fsm_size, *state, *state, 1); + ); + // build wire from fsm to slicer + let fsm_to_slicer = builder.build_assignment( + slicer.borrow().get("in"), + fsm.borrow().get("out"), + ir::Guard::True, + ); + // add continuous assignments to slicer + builder.component.continuous_assignments.push(fsm_to_slicer); + // create a guard representing when to allow next-state transition + let state_guard = guard!(slicer["out"] == signal_on["out"]); + used_slicers.insert(*state, slicer); + state_guard + } + Some(slicer) => { + let state_guard = guard!(slicer["out"] == signal_on["out"]); + state_guard + } + } + } + /// Implement a given [Schedule] and return the name of the [ir::Group] that /// implements it. fn realize_schedule( self, dump_fsm: bool, fsm_groups: &mut HashSet, + one_hot_cutoff: u64, ) -> RRC { self.validate(); + // build tdcc group let group = self.builder.add_group("tdcc"); if dump_fsm { self.display(format!( @@ -343,16 +388,40 @@ impl<'b, 'a> Schedule<'b, 'a> { )); } + // calculate fsm size and encoding let final_state = self.last_state(); - let fsm_size = get_bit_width_from( - final_state + 1, /* represent 0..final_state */ - ); - structure!(self.builder; - let fsm = prim std_reg(fsm_size); - let signal_on = constant(1, 1); - let last_state = constant(final_state, fsm_size); - let first_state = constant(0, fsm_size); - ); + let encoding = if final_state <= one_hot_cutoff { + Encoding::OneHot + } else { + Encoding::Binary + }; + + // build necessary primitives dependent on encoding + let signal_on = self.builder.add_constant(1, 1); + let (fsm, first_state, last_state_opt, fsm_size) = match encoding { + Encoding::Binary => { + let fsm_size = get_bit_width_from( + final_state + 1, /* represent 0..final_state */ + ); + structure!(self.builder; + let fsm = prim std_reg(fsm_size); + let last_state = constant(final_state, fsm_size); + let first_state = constant(0, fsm_size); + ); + (fsm, first_state, Some(last_state), fsm_size) + } + Encoding::OneHot => { + let fsm_size = final_state + 1; /* represent 0..final_state */ + + let fsm = self.builder.add_primitive( + "fsm", + "init_one_reg", + &[fsm_size], + ); + let first_state = self.builder.add_constant(1, fsm_size); + (fsm, first_state, None, fsm_size) + } + }; // Add last state to JSON info let mut states = self.groups_to_states.iter().cloned().collect_vec(); @@ -369,33 +438,78 @@ impl<'b, 'a> Schedule<'b, 'a> { states, })); - // Enable assignments + // keep track of used slicers if using one hot encoding + let mut used_slicers = HashMap::new(); + + // enable assignments group.borrow_mut().assignments.extend( self.enables .into_iter() .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) - .flat_map(|(state, mut assigns)| { - let state_const = - self.builder.add_constant(state, fsm_size); - let state_guard = guard!(fsm["out"] == state_const["out"]); - assigns.iter_mut().for_each(|asgn| { - asgn.guard.update(|g| g.and(state_guard.clone())) - }); - assigns + .flat_map(|(state, mut assigns)| match encoding { + Encoding::Binary => { + let state_const = + self.builder.add_constant(state, fsm_size); + let state_guard = + guard!(fsm["out"] == state_const["out"]); + assigns.iter_mut().for_each(|asgn| { + asgn.guard.update(|g| g.and(state_guard.clone())) + }); + assigns + } + Encoding::OneHot => { + let state_guard = Self::build_one_hot_query( + self.builder, + &mut used_slicers, + &fsm, + &signal_on, + &state, + &fsm_size, + ); + assigns.iter_mut().for_each(|asgn| { + asgn.guard.update(|g| g.and(state_guard.clone())) + }); + assigns + } }), ); - // Transition assignments + // transition assignments group.borrow_mut().assignments.extend( self.transitions.into_iter().flat_map(|(s, e, guard)| { - structure!(self.builder; - let end_const = constant(e, fsm_size); - let start_const = constant(s, fsm_size); - ); + let (end_const, trans_guard) = match encoding { + Encoding::Binary => { + structure!(self.builder; + let end_const = constant(e, fsm_size); + let start_const = constant(s, fsm_size); + ); + let trans_guard = + guard!((fsm["out"] == start_const["out"]) & guard); + + (end_const, trans_guard) + } + Encoding::OneHot => { + let end_constant_value = u64::pow( + 2, + e.try_into().expect("failed to convert to u32"), + ); + + let trans_guard = Self::build_one_hot_query( + self.builder, + &mut used_slicers, + &fsm, + &signal_on, + &s, + &fsm_size, + ); + let end_const = self + .builder + .add_constant(end_constant_value, fsm_size); + + (end_const, trans_guard.and(guard)) + } + }; let ec_borrow = end_const.borrow(); - let trans_guard = - guard!((fsm["out"] == start_const["out"]) & guard); - vec![ self.builder.build_assignment( fsm.borrow().get("in"), @@ -411,20 +525,54 @@ impl<'b, 'a> Schedule<'b, 'a> { }), ); - // Done condition for group - let last_guard = guard!(fsm["out"] == last_state["out"]); - let done_assign = self.builder.build_assignment( - group.borrow().get("done"), - signal_on.borrow().get("out"), - last_guard.clone(), - ); - group.borrow_mut().assignments.push(done_assign); + // done condition for group + let reset_fsm = match last_state_opt { + // binary branch; only binary needs last state constant + Some(last_state) => { + let last_guard = guard!(fsm["out"] == last_state["out"]); + let done_assign = self.builder.build_assignment( + group.borrow().get("done"), + signal_on.borrow().get("out"), + last_guard.clone(), + ); + group.borrow_mut().assignments.push(done_assign); - // Cleanup: Add a transition from last state to the first state. - let reset_fsm = build_assignments!(self.builder; - fsm["in"] = last_guard ? first_state["out"]; - fsm["write_en"] = last_guard ? signal_on["out"]; - ); + // Cleanup: Add a transition from last state to the first state. + let reset_fsm = build_assignments!(self.builder; + fsm["in"] = last_guard ? first_state["out"]; + fsm["write_en"] = last_guard ? signal_on["out"]; + ); + + reset_fsm.to_vec() + } + + // ohe branch does not need last state constant + None => { + let last_guard = Self::build_one_hot_query( + self.builder, + &mut used_slicers, + &fsm, + &signal_on, + &final_state, + &fsm_size, + ); + let done_assign = self.builder.build_assignment( + group.borrow().get("done"), + signal_on.borrow().get("out"), + last_guard.clone(), + ); + group.borrow_mut().assignments.push(done_assign); + // Cleanup: Add a transition from last state to the first state. + let reset_fsm = build_assignments!(self.builder; + fsm["in"] = last_guard ? first_state["out"]; + fsm["write_en"] = last_guard ? signal_on["out"]; + ); + + reset_fsm.to_vec() + } + }; + + // extend with conditions to set fsm to initial state self.builder .component .continuous_assignments @@ -846,6 +994,9 @@ pub struct TopDownCompileControl { early_transitions: bool, /// Bookkeeping for FSM ids for groups across all FSMs in the program fsm_groups: HashSet, + /// How many states the dynamic FSM must have before we pick binary encoding over + /// one-hot + one_hot_cutoff: u64, } impl ConstructVisitor for TopDownCompileControl { @@ -860,6 +1011,9 @@ impl ConstructVisitor for TopDownCompileControl { dump_fsm_json: opts[&"dump-fsm-json"].not_null_outstream(), early_transitions: opts[&"early-transitions"].bool(), fsm_groups: HashSet::new(), + one_hot_cutoff: opts[&"one-hot-cutoff"] + .pos_num() + .expect("requires non-negative OHE cutoff parameter"), }) } @@ -897,6 +1051,12 @@ impl Named for TopDownCompileControl { ParseVal::Bool(false), PassOpt::parse_bool, ), + PassOpt::new( + "one-hot-cutoff", + "The threshold at and below which a one-hot encoding is used for dynamic group scheduling", + ParseVal::Num(0), + PassOpt::parse_num, + ), ] } } @@ -957,8 +1117,11 @@ impl Visitor for TopDownCompileControl { let mut sch = Schedule::from(&mut builder); sch.calculate_states_seq(s, self.early_transitions)?; // Compile schedule and return the group. - let seq_group = - sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups); + let seq_group = sch.realize_schedule( + self.dump_fsm, + &mut self.fsm_groups, + self.one_hot_cutoff, + ); // Add NODE_ID to compiled group. let mut en = ir::Control::enable(seq_group); @@ -984,8 +1147,11 @@ impl Visitor for TopDownCompileControl { // Compile schedule and return the group. sch.calculate_states_if(i, self.early_transitions)?; - let if_group = - sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups); + let if_group = sch.realize_schedule( + self.dump_fsm, + &mut self.fsm_groups, + self.one_hot_cutoff, + ); // Add NODE_ID to compiled group. let mut en = ir::Control::enable(if_group); @@ -1011,8 +1177,11 @@ impl Visitor for TopDownCompileControl { sch.calculate_states_while(w, self.early_transitions)?; // Compile schedule and return the group. - let if_group = - sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups); + let if_group = sch.realize_schedule( + self.dump_fsm, + &mut self.fsm_groups, + self.one_hot_cutoff, + ); // Add NODE_ID to compiled group. let mut en = ir::Control::enable(if_group); @@ -1060,7 +1229,11 @@ impl Visitor for TopDownCompileControl { _ => { let mut sch = Schedule::from(&mut builder); sch.calculate_states(con, self.early_transitions)?; - sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups) + sch.realize_schedule( + self.dump_fsm, + &mut self.fsm_groups, + self.one_hot_cutoff, + ) } }; @@ -1131,8 +1304,11 @@ impl Visitor for TopDownCompileControl { let mut sch = Schedule::from(&mut builder); // Add assignments for the final states sch.calculate_states(&control.borrow(), self.early_transitions)?; - let comp_group = - sch.realize_schedule(self.dump_fsm, &mut self.fsm_groups); + let comp_group = sch.realize_schedule( + self.dump_fsm, + &mut self.fsm_groups, + self.one_hot_cutoff, + ); if let Some(json_out_file) = &self.dump_fsm_json { let _ = serde_json::to_writer_pretty( json_out_file.get_write(), diff --git a/runt.toml b/runt.toml index a5b2622e18..d66bdebb00 100644 --- a/runt.toml +++ b/runt.toml @@ -212,6 +212,27 @@ fud exec --from calyx --to jq \ """ timeout = 120 +[[tests]] +name = "correctness dynamic one-hot encoding" +paths = [ + "tests/correctness/*.futil", + "tests/correctness/ref-cells/*.futil", + "tests/correctness/sync/*.futil", + "tests/correctness/static-interface/*.futil", +] +cmd = """ +fud exec --from calyx --to jq \ + --through verilog \ + --through dat \ + -s calyx.exec './target/debug/calyx' \ + -s calyx.flags '-x tdcc:one-hot-cutoff=500 -d static-promotion' \ + -s verilog.cycle_limit 500 \ + -s verilog.data {}.data \ + -s jq.expr ".memories" \ + {} -q +""" +timeout = 120 + [[tests]] name = "correctness static timing" paths = [ diff --git a/tests/passes/tdcc-ohe/par.expect b/tests/passes/tdcc-ohe/par.expect new file mode 100644 index 0000000000..81f84d8e1f --- /dev/null +++ b/tests/passes/tdcc-ohe/par.expect @@ -0,0 +1,73 @@ +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + @generated pd = std_reg(1); + @generated pd0 = std_reg(1); + @generated pd1 = std_reg(1); + @generated fsm = init_one_reg(4); + @generated slicer = std_bit_slice(4, 0, 0, 1); + @generated slicer0 = std_bit_slice(4, 1, 1, 1); + @generated slicer1 = std_bit_slice(4, 2, 2, 1); + @generated slicer2 = std_bit_slice(4, 3, 3, 1); + } + wires { + group A { + a.in = 2'd0; + a.write_en = 1'd1; + A[done] = a.done; + } + group B { + b.in = 2'd1; + b.write_en = 1'd1; + B[done] = b.done; + } + group C { + c.in = 2'd2; + c.write_en = 1'd1; + C[done] = c.done; + } + group par0 { + A[go] = !(pd.out | A[done]) ? 1'd1; + pd.in = A[done] ? 1'd1; + pd.write_en = A[done] ? 1'd1; + B[go] = !(pd0.out | B[done]) ? 1'd1; + pd0.in = B[done] ? 1'd1; + pd0.write_en = B[done] ? 1'd1; + C[go] = !(pd1.out | C[done]) ? 1'd1; + pd1.in = C[done] ? 1'd1; + pd1.write_en = C[done] ? 1'd1; + par0[done] = pd.out & pd0.out & pd1.out ? 1'd1; + } + group tdcc { + A[go] = !A[done] & slicer.out == 1'd1 ? 1'd1; + par0[go] = !par0[done] & slicer0.out == 1'd1 ? 1'd1; + B[go] = !B[done] & slicer1.out == 1'd1 ? 1'd1; + fsm.in = slicer.out == 1'd1 & A[done] ? 4'd2; + fsm.write_en = slicer.out == 1'd1 & A[done] ? 1'd1; + fsm.in = slicer0.out == 1'd1 & par0[done] ? 4'd4; + fsm.write_en = slicer0.out == 1'd1 & par0[done] ? 1'd1; + fsm.in = slicer1.out == 1'd1 & B[done] ? 4'd8; + fsm.write_en = slicer1.out == 1'd1 & B[done] ? 1'd1; + tdcc[done] = slicer2.out == 1'd1 ? 1'd1; + } + pd.in = pd.out & pd0.out & pd1.out ? 1'd0; + pd.write_en = pd.out & pd0.out & pd1.out ? 1'd1; + pd0.in = pd.out & pd0.out & pd1.out ? 1'd0; + pd0.write_en = pd.out & pd0.out & pd1.out ? 1'd1; + pd1.in = pd.out & pd0.out & pd1.out ? 1'd0; + pd1.write_en = pd.out & pd0.out & pd1.out ? 1'd1; + slicer.in = fsm.out; + slicer0.in = fsm.out; + slicer1.in = fsm.out; + slicer2.in = fsm.out; + fsm.in = slicer2.out == 1'd1 ? 4'd1; + fsm.write_en = slicer2.out == 1'd1 ? 1'd1; + } + control { + tdcc; + } +} diff --git a/tests/passes/tdcc-ohe/par.futil b/tests/passes/tdcc-ohe/par.futil new file mode 100644 index 0000000000..922b2c62e4 --- /dev/null +++ b/tests/passes/tdcc-ohe/par.futil @@ -0,0 +1,40 @@ +// -x tdcc:one-hot-cutoff=20 -p tdcc + +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; + +component main() -> () { + cells { + a = std_reg(2); + b = std_reg(2); + c = std_reg(2); + } + + wires { + group A { + a.in = 2'd0; + a.write_en = 1'b1; + A[done] = a.done; + } + + group B { + b.in = 2'd1; + b.write_en = 1'b1; + B[done] = b.done; + } + + group C { + c.in = 2'd2; + c.write_en = 1'b1; + C[done] = c.done; + } + } + + control { + seq { + A; + par { A; B; C; } + B; + } + } +}