From 22fc4bb02c69d145f29c6b3aa0485b22fb175328 Mon Sep 17 00:00:00 2001 From: Ayaka Yorihiro <36107281+ayakayorihiro@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:56:44 -0400 Subject: [PATCH] Attribute to retain cells during optimization (#2303) I added a new boolean attribute that indicates which cells (and uses of cells) should be preserved during optimizations. I check for the attribute in the `dead-cell-removal` pass, so the pass doesn't optimize those cells when it would otherwise do so. (@ekiwi and I got to the bottom of the errors I was experiencing yesterday) The attribute is currently called `@protected`, but if anyone has suggestions about what the attribute should be named that would be helpful! ## Example The test file `tests/passes/dead-cell-removal-protected.futil` differs from `tests/passes/dead-cell-removal.futil` by a single line, where we add `@protected` to unused register `unused_reg`: ``` $ diff tests/passes/dead-cell-removal.futil tests/passes/dead-cell-removal-protected.futil 30c30 < unused_reg = std_reg(32); --- > @protected unused_reg = std_reg(32); ``` The expect files showing the result of running `dead-cell-removal` indicate that the pass retained `unused_reg` because of the `@protected` attribute: ``` $ diff tests/passes/dead-cell-removal.expect tests/passes/dead-cell-removal-protected.expect 26a27 > @protected unused_reg = std_reg(32); ``` --- calyx-frontend/src/attribute.rs | 3 ++ calyx-opt/src/passes/cell_share.rs | 6 ++- calyx-opt/src/passes/dead_cell_removal.rs | 3 +- docs/lang/attributes.md | 4 ++ tests/passes/cell-share/protected.expect | 31 ++++++++++++++ tests/passes/cell-share/protected.futil | 35 ++++++++++++++++ .../passes/dead-cell-removal-protected.expect | 41 +++++++++++++++++++ .../passes/dead-cell-removal-protected.futil | 41 +++++++++++++++++++ 8 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 tests/passes/cell-share/protected.expect create mode 100644 tests/passes/cell-share/protected.futil create mode 100644 tests/passes/dead-cell-removal-protected.expect create mode 100644 tests/passes/dead-cell-removal-protected.futil diff --git a/calyx-frontend/src/attribute.rs b/calyx-frontend/src/attribute.rs index b860005dac..16b2a2d768 100644 --- a/calyx-frontend/src/attribute.rs +++ b/calyx-frontend/src/attribute.rs @@ -76,6 +76,9 @@ pub enum BoolAttr { #[strum(serialize = "fast")] /// https://github.com/calyxir/calyx/issues/1828 Fast, + #[strum(serialize = "protected")] + /// Indicate that the cell should not be removed or shared during optimization. + Protected, } impl From for Attribute { diff --git a/calyx-opt/src/passes/cell_share.rs b/calyx-opt/src/passes/cell_share.rs index 80c202c987..228f2e9dc4 100644 --- a/calyx-opt/src/passes/cell_share.rs +++ b/calyx-opt/src/passes/cell_share.rs @@ -5,8 +5,8 @@ use crate::analysis::{ use crate::traversal::{ Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor, }; -use calyx_ir::rewriter; use calyx_ir::{self as ir}; +use calyx_ir::{rewriter, BoolAttr}; use calyx_utils::{CalyxResult, OutputFile}; use itertools::Itertools; use serde_json::{json, Value}; @@ -242,6 +242,10 @@ impl CellShare { if self.cont_ref_cells.contains(&cell.name()) { return false; } + // Cells that have @protected cannot be shared (even if they have share/state_share attributes) + if cell.attributes.has(BoolAttr::Protected) { + return false; + } if let Some(ref name) = cell.type_name() { self.state_shareable.contains(name) || self.shareable.contains(name) } else { diff --git a/calyx-opt/src/passes/dead_cell_removal.rs b/calyx-opt/src/passes/dead_cell_removal.rs index ceefddb94e..5ac58c8bbd 100644 --- a/calyx-opt/src/passes/dead_cell_removal.rs +++ b/calyx-opt/src/passes/dead_cell_removal.rs @@ -129,13 +129,14 @@ impl Visitor for DeadCellRemoval { _sigs: &ir::LibrarySignatures, _comps: &[ir::Component], ) -> VisResult { - // Add @external cells and ref cells. + // Add @external cells, @protected cells and ref cells. self.all_reads.extend( comp.cells .iter() .filter(|c| { let cell = c.borrow(); cell.attributes.get(ir::BoolAttr::External).is_some() + || cell.attributes.has(ir::BoolAttr::Protected) || cell.is_reference() }) .map(|c| c.borrow().name()), diff --git a/docs/lang/attributes.md b/docs/lang/attributes.md index 8b368a7425..d7e11bc50d 100644 --- a/docs/lang/attributes.md +++ b/docs/lang/attributes.md @@ -247,3 +247,7 @@ as before. [externalize]: https://docs.rs/calyx-opt/latest/calyx_opt/passes/struct.Externalize.html [promotable]: #promotable(n) [interval]: #interval(n) + +### `@protected` + +Marks that the cell should not be removed or shared during optimization. \ No newline at end of file diff --git a/tests/passes/cell-share/protected.expect b/tests/passes/cell-share/protected.expect new file mode 100644 index 0000000000..ed52865172 --- /dev/null +++ b/tests/passes/cell-share/protected.expect @@ -0,0 +1,31 @@ +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + @protected add0 = std_add(32); + add1 = std_add(32); + x_0 = std_reg(32); + } + wires { + group upd0 { + add0.left = x_0.out; + add0.right = 32'd1; + x_0.in = add0.out; + x_0.write_en = 1'd1; + upd0[done] = x_0.done ? 1'd1; + } + group upd1 { + add1.left = x_0.out; + add1.right = 32'd1; + x_0.in = add1.out; + x_0.write_en = 1'd1; + upd1[done] = x_0.done ? 1'd1; + } + } + control { + seq { + upd0; + upd1; + } + } +} diff --git a/tests/passes/cell-share/protected.futil b/tests/passes/cell-share/protected.futil new file mode 100644 index 0000000000..f67777a862 --- /dev/null +++ b/tests/passes/cell-share/protected.futil @@ -0,0 +1,35 @@ +//-p cell-share -p remove-ids + +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; + +// testing that @protected overrides @share +component main() -> () { + cells { + @protected add0 = std_add(32); + add1 = std_add(32); + x_0 = std_reg(32); + } + wires { + group upd0 { + add0.left = x_0.out; + add0.right = 32'd1; + x_0.in = add0.out; + x_0.write_en = 1'd1; + upd0[done] = x_0.done ? 1'd1; + } + group upd1 { + add1.left = x_0.out; + add1.right = 32'd1; + x_0.in = add1.out; + x_0.write_en = 1'd1; + upd1[done] = x_0.done ? 1'd1; + } + } + control { + seq { + upd0; + upd1; + } + } +} diff --git a/tests/passes/dead-cell-removal-protected.expect b/tests/passes/dead-cell-removal-protected.expect new file mode 100644 index 0000000000..817ce5d27a --- /dev/null +++ b/tests/passes/dead-cell-removal-protected.expect @@ -0,0 +1,41 @@ +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; +component add(left: 32, right: 32, @go go: 1, @clk clk: 1, @reset reset: 1) -> (out: 32, @done done: 1) { + cells { + adder = std_add(32); + outpt = std_reg(32); + } + wires { + group do_add { + adder.left = left; + adder.right = right; + outpt.in = adder.out; + outpt.write_en = 1'd1; + do_add[done] = outpt.done; + } + } + control { + seq { + do_add; + } + } +} +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (out: 32, @done done: 1) { + cells { + used_reg = std_reg(32); + used_le = std_le(1); + @protected unused_reg = std_reg(32); + my_add = add(); + add_input = std_reg(32); + } + wires { + used_reg.in = used_le.out ? 32'd10; + out = used_reg.out; + } + control { + invoke my_add( + left = add_input.out, + right = add_input.out + )(); + } +} diff --git a/tests/passes/dead-cell-removal-protected.futil b/tests/passes/dead-cell-removal-protected.futil new file mode 100644 index 0000000000..5eb53caff3 --- /dev/null +++ b/tests/passes/dead-cell-removal-protected.futil @@ -0,0 +1,41 @@ +// -p dead-cell-removal +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; + +component add(left: 32, right: 32) -> (out: 32) { + cells { + adder = std_add(32); + outpt = std_reg(32); + } + wires { + group do_add { + adder.left = left; + adder.right = right; + outpt.in = adder.out; + outpt.write_en = 1'd1; + do_add[done] = outpt.done; + } + } + control { + seq { + do_add; + } + } +} + +component main() -> (out: 32) { + cells { + used_reg = std_reg(32); + used_le = std_le(1); + @protected unused_reg = std_reg(32); + my_add = add(); + add_input = std_reg(32); + } + wires { + used_reg.in = used_le.out ? 32'd10; + out = used_reg.out; + } + control { + invoke my_add(left = add_input.out, right = add_input.out)(); + } +}