Skip to content

Commit

Permalink
[When] Split when blocks with cycles (#1322)
Browse files Browse the repository at this point in the history
Adds a pass that splits when blocks with a cyclic dependency.

Fixes #1269.
  • Loading branch information
rsetaluri committed Sep 23, 2023
1 parent e2fc445 commit 4475c36
Show file tree
Hide file tree
Showing 12 changed files with 353 additions and 71 deletions.
13 changes: 0 additions & 13 deletions magma/backend/mlir/build_magma_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,24 +176,11 @@ def _visit_inputs(
)


def _check_for_when_cycles(graph: Graph):
"""Temporary guard against https://github.com/phanrahan/magma/issues/1248"""
for node in graph.nodes():
if not iswhen(node):
continue
for predecessor in graph.predecessors(node):
if isinstance(type(node), Register): # registers break cycles
break
if predecessor is node:
raise MlirWhenCycleError()


def build_magma_graph(
ckt: DefineCircuitKind,
opts: BuildMagmaGrahOpts = BuildMagmaGrahOpts()) -> Graph:
ctx = ModuleContext(Graph(), opts)
_visit_inputs(ctx, ckt, opts.flatten_all_tuples)
for inst in ckt.instances:
_visit_inputs(ctx, inst, opts.flatten_all_tuples)
_check_for_when_cycles(ctx.graph)
return ctx.graph
74 changes: 74 additions & 0 deletions magma/passes/split_when_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from magma.when import (
when,
elsewhen,
otherwise,
WhenBlock,
ElseWhenBlock,
OtherwiseBlock
)


def _check_to_split(value, outputs, to_split):
source = value.trace()
for elem in source.root_iter():
if any(elem is output for output in outputs):
to_split.append(source)
return
if value.has_children() and value.has_elaborated_children():
for _, child in value.enumerate_children():
_check_to_split(child, outputs, to_split)


def _get_builder_ports(builder, names):
"""Filter out removed ports."""
return [getattr(builder, name) for name in names if hasattr(builder, name)]


def _find_values_to_split(builder):
"""Detect output values that feed into inputs."""
to_split = []
outputs = _get_builder_ports(builder, builder.output_to_name.values())
inputs = _get_builder_ports(builder, builder.input_to_name.values())
for value in inputs:
_check_to_split(value, outputs, to_split)
return to_split


def _emit_new_when_assign(value, driver_map, curr_block):
"""Reconstruct when logic in new set of blocks."""
if isinstance(curr_block, WhenBlock):
new_block = when(curr_block._info.condition)
elif isinstance(curr_block, ElseWhenBlock):
new_block = elsewhen(curr_block._info.condition)
elif isinstance(curr_block, OtherwiseBlock):
new_block = otherwise()
with new_block:
if curr_block in driver_map:
for driver in driver_map[curr_block]:
value @= driver
for child in curr_block.children():
_emit_new_when_assign(value, driver_map, child)
for _elsewhen in curr_block.elsewhen_blocks():
_emit_new_when_assign(value, driver_map, _elsewhen)
if curr_block.otherwise_block:
_emit_new_when_assign(value, driver_map, curr_block.otherwise_block)


def _build_driver_map(drivee):
# driver_map stores the driver for each context that drivee is driven.
driver_map = {}
for ctx in drivee._wired_when_contexts:
wires = ctx.get_conditional_wires_for_drivee(drivee)
driver_map[ctx] = (wire.driver for wire in wires)
return driver_map


def split_when_cycles(builder, defn):
to_split = _find_values_to_split(builder)
for value in to_split:
driving = value.driving()
driver_map = _build_driver_map(driving[0])
for drivee in driving:
drivee.unwire()
root = next(iter(driver_map.keys())).root
_emit_new_when_assign(drivee, driver_map, root)
11 changes: 10 additions & 1 deletion magma/passes/when.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from magma.passes.passes import DefinitionPass, pass_lambda
from magma.passes.split_when_utils import split_when_cycles
from magma.primitives.when import WhenBuilder
from magma.when import WhenBlock


class WhenPass(DefinitionPass):
def __call__(self, definition):
with definition.open():
for builder in definition._context_._builders:
for builder in list(definition._context_._builders):
if isinstance(builder, WhenBuilder):
self.process_when_builder(builder, definition)

Expand All @@ -20,6 +22,11 @@ def process_when_builder(self, builder, defn):
builder.emit_when_assertions()


class SplitCycles(WhenPass):
def process_when_builder(self, builder, defn):
split_when_cycles(builder, defn)


class FinalizeWhens(WhenPass):
def process_when_builder(self, builder, defn):
assert not builder._finalized
Expand All @@ -28,11 +35,13 @@ def process_when_builder(self, builder, defn):

infer_latch_pass = pass_lambda(InferLatches)
emit_when_assert_pass = pass_lambda(EmitWhenAsserts)
split_cycles_pass = pass_lambda(SplitCycles)
finalize_when_pass = pass_lambda(FinalizeWhens)


def run_when_passes(main, emit_when_assertions: bool = False):
infer_latch_pass(main)
if emit_when_assertions:
emit_when_assert_pass(main)
split_cycles_pass(main)
finalize_when_pass(main)
4 changes: 4 additions & 0 deletions magma/primitives/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def output_to_index(self) -> Dict[Type, int]:
def output_to_name(self):
return self._output_to_name

@property
def input_to_name(self):
return self._input_to_name

def check_existing_derived_ref(self, value, value_to_name, value_to_index):
"""If value is a child of an array or tuple that has already been added,
we return the child of the existing value, rather than adding a new
Expand Down
60 changes: 30 additions & 30 deletions magma/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ class ConditionalWire:


class _BlockBase(contextlib.AbstractContextManager):
def __init__(self, parent: Optional['_WhenBlock']):
def __init__(self, parent: Optional['WhenBlock']):
self._parent = parent
self._children = list()
self._conditional_wires = list()
self._default_drivers = dict()

def spawn(self, info: '_WhenBlockInfo') -> '_WhenBlock':
child = _WhenBlock(self, info)
def spawn(self, info: 'WhenBlockInfo') -> 'WhenBlock':
child = WhenBlock(self, info)
self._children.append(child)
return child

Expand Down Expand Up @@ -108,11 +108,11 @@ def add_conditional_wire(self, i, o):

@property
@abc.abstractmethod
def root(self) -> '_WhenBlock':
def root(self) -> 'WhenBlock':
raise NotImplementedError()

@abc.abstractmethod
def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'):
def new_elsewhen_block(self, info: 'ElseWhenBlockInfo'):
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -241,22 +241,22 @@ def get_all_blocks(


@dataclasses.dataclass
class _WhenBlockInfo:
class WhenBlockInfo:
# NOTE(leonardt): We can't use Bit here because of circular dependency, so
# instead we enforce it in the public constructors (e.g. m.when).
condition: Any


@dataclasses.dataclass
class _ElseWhenBlockInfo(_WhenBlockInfo):
class ElseWhenBlockInfo(WhenBlockInfo):
pass


class _WhenBlock(_BlockBase):
class WhenBlock(_BlockBase):
def __init__(
self,
parent: Optional[_BlockBase],
info: _WhenBlockInfo,
info: WhenBlockInfo,
debug_info: Optional[DebugInfo] = None,
):
super().__init__(parent)
Expand All @@ -271,15 +271,15 @@ def __init__(
if debug_info is not None:
self._builder.debug_info = debug_info

def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'):
block = _ElseWhenBlock(self, info)
def new_elsewhen_block(self, info: 'ElseWhenBlockInfo'):
block = ElseWhenBlock(self, info)
self._elsewhens.append(block)
return block

def new_otherwise_block(self):
if self._otherwise is not None:
raise WhenSyntaxError()
block = _OtherwiseBlock(self)
block = OtherwiseBlock(self)
self._otherwise = block
return block

Expand All @@ -290,7 +290,7 @@ def builder(self) -> Any:
return self._builder

@property
def root(self) -> '_WhenBlock':
def root(self) -> 'WhenBlock':
if self._parent is None:
return self
return self._parent.root
Expand All @@ -316,19 +316,19 @@ def _get_exit_block(self):
return self._parent


class _ElseWhenBlock(_BlockBase):
def __init__(self, parent: Optional[_BlockBase], info: _ElseWhenBlockInfo):
class ElseWhenBlock(_BlockBase):
def __init__(self, parent: Optional[_BlockBase], info: ElseWhenBlockInfo):
super().__init__(parent)
self._info = info

def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'):
def new_elsewhen_block(self, info: 'ElseWhenBlockInfo'):
return self._parent.new_elsewhen_block(info)

def new_otherwise_block(self):
return self._parent.new_otherwise_block()

@property
def root(self) -> _WhenBlock:
def root(self) -> WhenBlock:
return self._parent.root

@property
Expand All @@ -346,15 +346,15 @@ def _get_exit_block(self):
return self._parent._get_exit_block()


class _OtherwiseBlock(_BlockBase):
class OtherwiseBlock(_BlockBase):
def new_elsewhen_block(self, _):
raise ElsewhenWithoutPrecedingWhenError()

def new_otherwise_block(self):
raise OtherwiseWithoutPrecedingWhenError()

@property
def root(self) -> _WhenBlock:
def root(self) -> WhenBlock:
return self._parent.root

@property
Expand Down Expand Up @@ -434,18 +434,18 @@ def _reset_context():


def _get_else_block(block: _BlockBase) -> Optional[_BlockBase]:
if isinstance(block, _OtherwiseBlock):
if isinstance(block, OtherwiseBlock):
return None
if isinstance(block, _ElseWhenBlock):
# TODO(rsetaluri): We could augment _ElseWhenBlock to keep track of its
if isinstance(block, ElseWhenBlock):
# TODO(rsetaluri): We could augment ElseWhenBlock to keep track of its
# index in the parent block, or alternatively keep an explicit pointer
# to the next else/otherwise block in each _ElseWhenBlock.
# to the next else/otherwise block in each ElseWhenBlock.
parent_elsewhen_blocks = list(block._parent.elsewhen_blocks())
index = parent_elsewhen_blocks.index(block) + 1
if index == len(parent_elsewhen_blocks):
return block._parent.otherwise_block
return parent_elsewhen_blocks[index]
if isinstance(block, _WhenBlock):
if isinstance(block, WhenBlock):
try:
else_block = next(block.elsewhen_blocks())
except StopIteration:
Expand All @@ -464,9 +464,9 @@ def _get_then_ops(block: _BlockBase) -> Iterable[_Op]:


def _get_else_ops(block: _BlockBase) -> Iterable[_Op]:
if isinstance(block, _OtherwiseBlock):
if isinstance(block, OtherwiseBlock):
return _get_then_ops(block)
if isinstance(block, _ElseWhenBlock):
if isinstance(block, ElseWhenBlock):
return (block,)
raise TypeError(block)

Expand Down Expand Up @@ -535,7 +535,7 @@ def _get_assignees_and_latches(ops: Iterable[_Op]) -> Tuple[Set, Set]:


def find_inferred_latches(block: _BlockBase) -> Set:
if not (isinstance(block, _WhenBlock) and block.root is block):
if not (isinstance(block, WhenBlock) and block.root is block):
raise TypeError("Can only find inferred latches on root when block")
ops = tuple(block.default_drivers()) + (block,)
_, latches = _get_assignees_and_latches(ops)
Expand Down Expand Up @@ -651,17 +651,17 @@ def emit_when_assertions(block, builder, precond=None,
def when(cond):
if not isinstance(cond, _bit_type()):
raise TypeError(f"Invalid when cond {cond}, should be Bit")
info = _WhenBlockInfo(cond)
info = WhenBlockInfo(cond)
curr_block = _get_curr_block()
if curr_block is None:
return _WhenBlock(None, info, get_debug_info(3))
return WhenBlock(None, info, get_debug_info(3))
return curr_block.spawn(info)


def elsewhen(cond):
if not isinstance(cond, _bit_type()):
raise TypeError(f"Invalid elsewhen cond {cond}, should be Bit")
info = _ElseWhenBlockInfo(cond)
info = ElseWhenBlockInfo(cond)
prev_block = _get_prev_block()
if prev_block is None:
raise ElsewhenWithoutPrecedingWhenError()
Expand Down
3 changes: 3 additions & 0 deletions magma/wire_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def has_elaborated_children(self):
def _enumerate_children(self):
raise NotImplementedError()

def enumerate_children(self):
return self._enumerate_children()

def _get_conditional_drivee_info(self):
"""
* wired_when_contexts: list of contexts in which this value appears as
Expand Down
25 changes: 25 additions & 0 deletions tests/gold/test_when_alwcomb_order.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Generated by CIRCT firtool-1.48.0-34-g7018fb13b
module test_when_alwcomb_order(
input [7:0] I,
input S,
output [7:0] O
);

reg [7:0] _GEN;
always_comb begin
if (S)
_GEN = I;
else
_GEN = ~I;
end // always_comb
reg [7:0] _GEN_0;
always_comb begin
_GEN_0 = _GEN;
if (S) begin
end
else
_GEN_0 = ~_GEN;
end // always_comb
assign O = _GEN_0;
endmodule

Loading

0 comments on commit 4475c36

Please sign in to comment.