From 4475c366403d4de5ec969d4a5c203ec52739e83d Mon Sep 17 00:00:00 2001 From: rsetaluri Date: Fri, 22 Sep 2023 22:08:05 -0700 Subject: [PATCH] [When] Split when blocks with cycles (#1322) Adds a pass that splits when blocks with a cyclic dependency. Fixes https://github.com/phanrahan/magma/issues/1269. --- magma/backend/mlir/build_magma_graph.py | 13 --- magma/passes/split_when_utils.py | 74 ++++++++++++ magma/passes/when.py | 11 +- magma/primitives/when.py | 4 + magma/when.py | 60 +++++----- magma/wire_container.py | 3 + tests/gold/test_when_alwcomb_order.v | 25 ++++ tests/gold/test_when_alwcomb_order_complex.v | 34 ++++++ tests/gold/test_when_alwcomb_order_nested.v | 30 +++++ tests/gold/test_when_alwcomb_order_nested_2.v | 36 ++++++ tests/gold/test_when_recursive_non_port.mlir | 27 +++++ tests/test_when.py | 107 +++++++++++++----- 12 files changed, 353 insertions(+), 71 deletions(-) create mode 100644 magma/passes/split_when_utils.py create mode 100644 tests/gold/test_when_alwcomb_order.v create mode 100644 tests/gold/test_when_alwcomb_order_complex.v create mode 100644 tests/gold/test_when_alwcomb_order_nested.v create mode 100644 tests/gold/test_when_alwcomb_order_nested_2.v diff --git a/magma/backend/mlir/build_magma_graph.py b/magma/backend/mlir/build_magma_graph.py index fb049a3f0..19778463d 100644 --- a/magma/backend/mlir/build_magma_graph.py +++ b/magma/backend/mlir/build_magma_graph.py @@ -176,18 +176,6 @@ def _visit_inputs( ) -def _check_for_when_cycles(graph: Graph): - """Temporary guard against https://github.com/phanrahan/magma/issues/1248""" - for node in graph.nodes(): - if not iswhen(node): - continue - for predecessor in graph.predecessors(node): - if isinstance(type(node), Register): # registers break cycles - break - if predecessor is node: - raise MlirWhenCycleError() - - def build_magma_graph( ckt: DefineCircuitKind, opts: BuildMagmaGrahOpts = BuildMagmaGrahOpts()) -> Graph: @@ -195,5 +183,4 @@ def build_magma_graph( _visit_inputs(ctx, ckt, opts.flatten_all_tuples) for inst in ckt.instances: _visit_inputs(ctx, inst, opts.flatten_all_tuples) - _check_for_when_cycles(ctx.graph) return ctx.graph diff --git a/magma/passes/split_when_utils.py b/magma/passes/split_when_utils.py new file mode 100644 index 000000000..de94f879e --- /dev/null +++ b/magma/passes/split_when_utils.py @@ -0,0 +1,74 @@ +from magma.when import ( + when, + elsewhen, + otherwise, + WhenBlock, + ElseWhenBlock, + OtherwiseBlock +) + + +def _check_to_split(value, outputs, to_split): + source = value.trace() + for elem in source.root_iter(): + if any(elem is output for output in outputs): + to_split.append(source) + return + if value.has_children() and value.has_elaborated_children(): + for _, child in value.enumerate_children(): + _check_to_split(child, outputs, to_split) + + +def _get_builder_ports(builder, names): + """Filter out removed ports.""" + return [getattr(builder, name) for name in names if hasattr(builder, name)] + + +def _find_values_to_split(builder): + """Detect output values that feed into inputs.""" + to_split = [] + outputs = _get_builder_ports(builder, builder.output_to_name.values()) + inputs = _get_builder_ports(builder, builder.input_to_name.values()) + for value in inputs: + _check_to_split(value, outputs, to_split) + return to_split + + +def _emit_new_when_assign(value, driver_map, curr_block): + """Reconstruct when logic in new set of blocks.""" + if isinstance(curr_block, WhenBlock): + new_block = when(curr_block._info.condition) + elif isinstance(curr_block, ElseWhenBlock): + new_block = elsewhen(curr_block._info.condition) + elif isinstance(curr_block, OtherwiseBlock): + new_block = otherwise() + with new_block: + if curr_block in driver_map: + for driver in driver_map[curr_block]: + value @= driver + for child in curr_block.children(): + _emit_new_when_assign(value, driver_map, child) + for _elsewhen in curr_block.elsewhen_blocks(): + _emit_new_when_assign(value, driver_map, _elsewhen) + if curr_block.otherwise_block: + _emit_new_when_assign(value, driver_map, curr_block.otherwise_block) + + +def _build_driver_map(drivee): + # driver_map stores the driver for each context that drivee is driven. + driver_map = {} + for ctx in drivee._wired_when_contexts: + wires = ctx.get_conditional_wires_for_drivee(drivee) + driver_map[ctx] = (wire.driver for wire in wires) + return driver_map + + +def split_when_cycles(builder, defn): + to_split = _find_values_to_split(builder) + for value in to_split: + driving = value.driving() + driver_map = _build_driver_map(driving[0]) + for drivee in driving: + drivee.unwire() + root = next(iter(driver_map.keys())).root + _emit_new_when_assign(drivee, driver_map, root) diff --git a/magma/passes/when.py b/magma/passes/when.py index b9ccd79d6..1ca3fd74b 100644 --- a/magma/passes/when.py +++ b/magma/passes/when.py @@ -1,11 +1,13 @@ from magma.passes.passes import DefinitionPass, pass_lambda +from magma.passes.split_when_utils import split_when_cycles from magma.primitives.when import WhenBuilder +from magma.when import WhenBlock class WhenPass(DefinitionPass): def __call__(self, definition): with definition.open(): - for builder in definition._context_._builders: + for builder in list(definition._context_._builders): if isinstance(builder, WhenBuilder): self.process_when_builder(builder, definition) @@ -20,6 +22,11 @@ def process_when_builder(self, builder, defn): builder.emit_when_assertions() +class SplitCycles(WhenPass): + def process_when_builder(self, builder, defn): + split_when_cycles(builder, defn) + + class FinalizeWhens(WhenPass): def process_when_builder(self, builder, defn): assert not builder._finalized @@ -28,6 +35,7 @@ def process_when_builder(self, builder, defn): infer_latch_pass = pass_lambda(InferLatches) emit_when_assert_pass = pass_lambda(EmitWhenAsserts) +split_cycles_pass = pass_lambda(SplitCycles) finalize_when_pass = pass_lambda(FinalizeWhens) @@ -35,4 +43,5 @@ def run_when_passes(main, emit_when_assertions: bool = False): infer_latch_pass(main) if emit_when_assertions: emit_when_assert_pass(main) + split_cycles_pass(main) finalize_when_pass(main) diff --git a/magma/primitives/when.py b/magma/primitives/when.py index 4141a0de2..94f44e5c4 100644 --- a/magma/primitives/when.py +++ b/magma/primitives/when.py @@ -168,6 +168,10 @@ def output_to_index(self) -> Dict[Type, int]: def output_to_name(self): return self._output_to_name + @property + def input_to_name(self): + return self._input_to_name + def check_existing_derived_ref(self, value, value_to_name, value_to_index): """If value is a child of an array or tuple that has already been added, we return the child of the existing value, rather than adding a new diff --git a/magma/when.py b/magma/when.py index f3b632655..62f0d6178 100644 --- a/magma/when.py +++ b/magma/when.py @@ -70,14 +70,14 @@ class ConditionalWire: class _BlockBase(contextlib.AbstractContextManager): - def __init__(self, parent: Optional['_WhenBlock']): + def __init__(self, parent: Optional['WhenBlock']): self._parent = parent self._children = list() self._conditional_wires = list() self._default_drivers = dict() - def spawn(self, info: '_WhenBlockInfo') -> '_WhenBlock': - child = _WhenBlock(self, info) + def spawn(self, info: 'WhenBlockInfo') -> 'WhenBlock': + child = WhenBlock(self, info) self._children.append(child) return child @@ -108,11 +108,11 @@ def add_conditional_wire(self, i, o): @property @abc.abstractmethod - def root(self) -> '_WhenBlock': + def root(self) -> 'WhenBlock': raise NotImplementedError() @abc.abstractmethod - def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'): + def new_elsewhen_block(self, info: 'ElseWhenBlockInfo'): raise NotImplementedError() @abc.abstractmethod @@ -241,22 +241,22 @@ def get_all_blocks( @dataclasses.dataclass -class _WhenBlockInfo: +class WhenBlockInfo: # NOTE(leonardt): We can't use Bit here because of circular dependency, so # instead we enforce it in the public constructors (e.g. m.when). condition: Any @dataclasses.dataclass -class _ElseWhenBlockInfo(_WhenBlockInfo): +class ElseWhenBlockInfo(WhenBlockInfo): pass -class _WhenBlock(_BlockBase): +class WhenBlock(_BlockBase): def __init__( self, parent: Optional[_BlockBase], - info: _WhenBlockInfo, + info: WhenBlockInfo, debug_info: Optional[DebugInfo] = None, ): super().__init__(parent) @@ -271,15 +271,15 @@ def __init__( if debug_info is not None: self._builder.debug_info = debug_info - def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'): - block = _ElseWhenBlock(self, info) + def new_elsewhen_block(self, info: 'ElseWhenBlockInfo'): + block = ElseWhenBlock(self, info) self._elsewhens.append(block) return block def new_otherwise_block(self): if self._otherwise is not None: raise WhenSyntaxError() - block = _OtherwiseBlock(self) + block = OtherwiseBlock(self) self._otherwise = block return block @@ -290,7 +290,7 @@ def builder(self) -> Any: return self._builder @property - def root(self) -> '_WhenBlock': + def root(self) -> 'WhenBlock': if self._parent is None: return self return self._parent.root @@ -316,19 +316,19 @@ def _get_exit_block(self): return self._parent -class _ElseWhenBlock(_BlockBase): - def __init__(self, parent: Optional[_BlockBase], info: _ElseWhenBlockInfo): +class ElseWhenBlock(_BlockBase): + def __init__(self, parent: Optional[_BlockBase], info: ElseWhenBlockInfo): super().__init__(parent) self._info = info - def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'): + def new_elsewhen_block(self, info: 'ElseWhenBlockInfo'): return self._parent.new_elsewhen_block(info) def new_otherwise_block(self): return self._parent.new_otherwise_block() @property - def root(self) -> _WhenBlock: + def root(self) -> WhenBlock: return self._parent.root @property @@ -346,7 +346,7 @@ def _get_exit_block(self): return self._parent._get_exit_block() -class _OtherwiseBlock(_BlockBase): +class OtherwiseBlock(_BlockBase): def new_elsewhen_block(self, _): raise ElsewhenWithoutPrecedingWhenError() @@ -354,7 +354,7 @@ def new_otherwise_block(self): raise OtherwiseWithoutPrecedingWhenError() @property - def root(self) -> _WhenBlock: + def root(self) -> WhenBlock: return self._parent.root @property @@ -434,18 +434,18 @@ def _reset_context(): def _get_else_block(block: _BlockBase) -> Optional[_BlockBase]: - if isinstance(block, _OtherwiseBlock): + if isinstance(block, OtherwiseBlock): return None - if isinstance(block, _ElseWhenBlock): - # TODO(rsetaluri): We could augment _ElseWhenBlock to keep track of its + if isinstance(block, ElseWhenBlock): + # TODO(rsetaluri): We could augment ElseWhenBlock to keep track of its # index in the parent block, or alternatively keep an explicit pointer - # to the next else/otherwise block in each _ElseWhenBlock. + # to the next else/otherwise block in each ElseWhenBlock. parent_elsewhen_blocks = list(block._parent.elsewhen_blocks()) index = parent_elsewhen_blocks.index(block) + 1 if index == len(parent_elsewhen_blocks): return block._parent.otherwise_block return parent_elsewhen_blocks[index] - if isinstance(block, _WhenBlock): + if isinstance(block, WhenBlock): try: else_block = next(block.elsewhen_blocks()) except StopIteration: @@ -464,9 +464,9 @@ def _get_then_ops(block: _BlockBase) -> Iterable[_Op]: def _get_else_ops(block: _BlockBase) -> Iterable[_Op]: - if isinstance(block, _OtherwiseBlock): + if isinstance(block, OtherwiseBlock): return _get_then_ops(block) - if isinstance(block, _ElseWhenBlock): + if isinstance(block, ElseWhenBlock): return (block,) raise TypeError(block) @@ -535,7 +535,7 @@ def _get_assignees_and_latches(ops: Iterable[_Op]) -> Tuple[Set, Set]: def find_inferred_latches(block: _BlockBase) -> Set: - if not (isinstance(block, _WhenBlock) and block.root is block): + if not (isinstance(block, WhenBlock) and block.root is block): raise TypeError("Can only find inferred latches on root when block") ops = tuple(block.default_drivers()) + (block,) _, latches = _get_assignees_and_latches(ops) @@ -651,17 +651,17 @@ def emit_when_assertions(block, builder, precond=None, def when(cond): if not isinstance(cond, _bit_type()): raise TypeError(f"Invalid when cond {cond}, should be Bit") - info = _WhenBlockInfo(cond) + info = WhenBlockInfo(cond) curr_block = _get_curr_block() if curr_block is None: - return _WhenBlock(None, info, get_debug_info(3)) + return WhenBlock(None, info, get_debug_info(3)) return curr_block.spawn(info) def elsewhen(cond): if not isinstance(cond, _bit_type()): raise TypeError(f"Invalid elsewhen cond {cond}, should be Bit") - info = _ElseWhenBlockInfo(cond) + info = ElseWhenBlockInfo(cond) prev_block = _get_prev_block() if prev_block is None: raise ElsewhenWithoutPrecedingWhenError() diff --git a/magma/wire_container.py b/magma/wire_container.py index 2ac469d17..89ba36931 100644 --- a/magma/wire_container.py +++ b/magma/wire_container.py @@ -329,6 +329,9 @@ def has_elaborated_children(self): def _enumerate_children(self): raise NotImplementedError() + def enumerate_children(self): + return self._enumerate_children() + def _get_conditional_drivee_info(self): """ * wired_when_contexts: list of contexts in which this value appears as diff --git a/tests/gold/test_when_alwcomb_order.v b/tests/gold/test_when_alwcomb_order.v new file mode 100644 index 000000000..1c846015e --- /dev/null +++ b/tests/gold/test_when_alwcomb_order.v @@ -0,0 +1,25 @@ +// Generated by CIRCT firtool-1.48.0-34-g7018fb13b +module test_when_alwcomb_order( + input [7:0] I, + input S, + output [7:0] O +); + + reg [7:0] _GEN; + always_comb begin + if (S) + _GEN = I; + else + _GEN = ~I; + end // always_comb + reg [7:0] _GEN_0; + always_comb begin + _GEN_0 = _GEN; + if (S) begin + end + else + _GEN_0 = ~_GEN; + end // always_comb + assign O = _GEN_0; +endmodule + diff --git a/tests/gold/test_when_alwcomb_order_complex.v b/tests/gold/test_when_alwcomb_order_complex.v new file mode 100644 index 000000000..20c380f9a --- /dev/null +++ b/tests/gold/test_when_alwcomb_order_complex.v @@ -0,0 +1,34 @@ +// Generated by CIRCT firtool-1.48.0-34-g7018fb13b +module test_when_alwcomb_order_complex( + input [7:0] I, + input [1:0] S, + output [7:0] O +); + + reg [7:0] _GEN; + always_comb begin + if (S[0]) begin + _GEN = I; + if (S[1]) + _GEN = ~I; + end + else if (^S) + _GEN = 8'h0; + else + _GEN = ~I; + end // always_comb + reg [7:0] _GEN_0; + always_comb begin + _GEN_0 = _GEN; + if (S[0]) begin + if (S[1]) + _GEN_0 = _GEN & 8'hDE; + end + else if (^S) begin + end + else + _GEN_0 = ~_GEN; + end // always_comb + assign O = _GEN_0; +endmodule + diff --git a/tests/gold/test_when_alwcomb_order_nested.v b/tests/gold/test_when_alwcomb_order_nested.v new file mode 100644 index 000000000..793e95e37 --- /dev/null +++ b/tests/gold/test_when_alwcomb_order_nested.v @@ -0,0 +1,30 @@ +// Generated by CIRCT firtool-1.48.0-34-g7018fb13b +module test_when_alwcomb_order_nested( + input struct packed {logic x; logic [7:0] y; } I, + input S, + output struct packed {logic x; logic [7:0] y; } O +); + + reg _GEN; + always_comb begin + if (S) + _GEN = I.x; + else + _GEN = ~I.x; + end // always_comb + reg [7:0] _GEN_0; + always_comb begin + if (S) + _GEN_0 = I.y; + else + _GEN_0 = ~I.y; + end // always_comb + struct packed {logic x; logic [7:0] y; } _GEN_1; + always_comb begin + _GEN_1 = I; + if (S) + _GEN_1 = '{x: _GEN, y: _GEN_0}; + end // always_comb + assign O = _GEN_1; +endmodule + diff --git a/tests/gold/test_when_alwcomb_order_nested_2.v b/tests/gold/test_when_alwcomb_order_nested_2.v new file mode 100644 index 000000000..b72e2b6a5 --- /dev/null +++ b/tests/gold/test_when_alwcomb_order_nested_2.v @@ -0,0 +1,36 @@ +// Generated by CIRCT firtool-1.48.0-34-g7018fb13b +module test_when_alwcomb_order_nested_2( + input struct packed {logic x; logic [7:0] y; }[2:0] I, + input S, + output struct packed {logic x; logic [7:0] y; } O +); + + reg _GEN; + always_comb begin + if (S) + _GEN = I[2'h1].x; + else + _GEN = I[2'h2].x; + end // always_comb + reg [7:0] _GEN_0; + always_comb begin + if (S) + _GEN_0 = I[2'h1].y; + else + _GEN_0 = I[2'h2].y; + end // always_comb + reg _GEN_1; + reg [7:0] _GEN_2; + always_comb begin + _GEN_1 = I[2'h0].x; + _GEN_2 = I[2'h0].y; + if (S) begin + end + else begin + _GEN_1 = _GEN; + _GEN_2 = _GEN_0; + end + end // always_comb + assign O = '{x: _GEN_1, y: _GEN_2}; +endmodule + diff --git a/tests/gold/test_when_recursive_non_port.mlir b/tests/gold/test_when_recursive_non_port.mlir index e69de29bb..a85522cc2 100644 --- a/tests/gold/test_when_recursive_non_port.mlir +++ b/tests/gold/test_when_recursive_non_port.mlir @@ -0,0 +1,27 @@ +module attributes {circt.loweringOptions = "locationInfoStyle=none,omitVersionComment"} { + hw.module @test_recursive_non_port(%I: i2, %S: i1) -> (O0: i1, O1: i1) { + %0 = comb.extract %I from 0 : (i2) -> i1 + %1 = comb.extract %I from 1 : (i2) -> i1 + %3 = sv.reg : !hw.inout + %2 = sv.read_inout %3 : !hw.inout + sv.alwayscomb { + sv.if %S { + sv.bpassign %3, %0 : i1 + } else { + sv.bpassign %3, %1 : i1 + } + } + %6 = sv.reg : !hw.inout + %4 = sv.read_inout %6 : !hw.inout + %7 = sv.reg : !hw.inout + %5 = sv.read_inout %7 : !hw.inout + sv.alwayscomb { + sv.if %S { + sv.bpassign %7, %2 : i1 + } else { + sv.bpassign %7, %2 : i1 + } + } + hw.output %2, %5 : i1, i1 + } +} diff --git a/tests/test_when.py b/tests/test_when.py index 1135cd08f..a78832a7d 100644 --- a/tests/test_when.py +++ b/tests/test_when.py @@ -481,19 +481,8 @@ class _Test(m.Circuit): io.O0 @= x basename = "test_when_recursive_non_port" - # TODO(leonardt): Remove this when proper analysis for when cycles is - # added. - # NOTE(leonardt): We disable assert checking for these tests because the - # insertion of the wire module prevents seeing the dependency. In, general - # we do not traverse across instances because we do not determine whether - # the cycle might be broken by a register inside the hierachy. This is - # essentially the classic combinational loop problem. For now, rather - # than attempt to solve the problem for this analysis, we should eventually - # avoid the analysis by appropriately splitting up the when blocks or if - # statements to avoid the cycle issue - with pytest.raises(MlirWhenCycleError): - m.compile(f"build/{basename}", _Test, output="mlir") - assert check_gold(__file__, f"{basename}.mlir") + m.compile(f"build/{basename}", _Test, output="mlir") + assert check_gold(__file__, f"{basename}.mlir") def test_internal_instantiation(): @@ -1738,15 +1727,23 @@ class test_when_emit_asserts_tuple_elab(m.Circuit): assert check_gold(__file__, "test_when_emit_asserts_tuple_elab.mlir") +def _check_or_update(circ): + # We check verilog here because the alwcomb order was "legal" MLIR. + m.compile(f"build/{circ.name}", circ, output="mlir-verilog") + + _file = f"{circ.name}.v" + if check_gold(__file__, _file): + return + verilator_path = os.path.join( + os.path.dirname(__file__), + "build", + _file + ) + assert not os.system(f"verilator --lint-only {verilator_path}") + update_gold(__file__, _file) + + def test_when_alwcomb_order(): - # NOTE(leonardt): We disable assert checking for these tests because the - # insertion of the wire module prevents seeing the dependency. In, general - # we do not traverse across instances because we do not determine whether - # the cycle might be broken by a register inside the hierachy. This is - # essentially the classic combinational loop problem. For now, rather - # than attempt to solve the problem for this analysis, we should eventually - # avoid the analysis by appropriately splitting up the when blocks or if - # statements to avoid the cycle issue class test_when_alwcomb_order(m.Circuit): io = m.IO(I=m.In(m.Bits[8]), S=m.In(m.Bits[1]), O=m.Out(m.Bits[8])) @@ -1759,12 +1756,68 @@ class test_when_alwcomb_order(m.Circuit): x @= ~io.I io.O @= ~x - with pytest.raises(MlirWhenCycleError): - m.compile( - "build/test_when_alwcomb_order", - test_when_alwcomb_order, - output="mlir" - ) + _check_or_update(test_when_alwcomb_order) + + +def test_when_alwcomb_order_complex(): + class test_when_alwcomb_order_complex(m.Circuit): + io = m.IO(I=m.In(m.Bits[8]), S=m.In(m.Bits[2]), O=m.Out(m.Bits[8])) + x = m.Bits[8]() + + io.O @= x + with m.when(io.S[0]): + x @= io.I + with m.when(io.S[1]): + x @= ~io.I + io.O @= x & 0xDE + with m.elsewhen(io.S[0] ^ io.S[1]): + x @= io.I ^ io.I + with m.otherwise(): + io.O @= ~x + x @= ~io.I + + _check_or_update(test_when_alwcomb_order_complex) + + +def test_when_alwcomb_order_nested(): + class T(m.Product): + x = m.Bit + y = m.Bits[8] + + class test_when_alwcomb_order_nested(m.Circuit): + io = m.IO(I=m.In(T), S=m.In(m.Bit), O=m.Out(T)) + x = T() + + io.O @= io.I + with m.when(io.S): + io.O @= x + x.x @= io.I.x + x.y @= io.I.y + with m.otherwise(): + x.x @= ~io.I.x + x.y @= ~io.I.y + + _check_or_update(test_when_alwcomb_order_nested) + + +def test_when_alwcomb_order_nested_2(): + class T(m.Product): + x = m.Bit + y = m.Bits[8] + + class test_when_alwcomb_order_nested_2(m.Circuit): + io = m.IO(I=m.In(m.Array[3, T]), S=m.In(m.Bit), O=m.Out(T)) + x = T() + + io.O @= io.I[0] + with m.when(io.S): + x @= io.I[1] + with m.otherwise(): + io.O.x @= x.x + io.O.y @= x.y + x @= io.I[2] + + _check_or_update(test_when_alwcomb_order_nested_2) # TODO: In this case, we'll generate elaborated assignments, but it should