Skip to content

Commit

Permalink
Revert "[When] Split when blocks with cycles (#1309)"
Browse files Browse the repository at this point in the history
This reverts commit 512c7e1.
  • Loading branch information
rsetaluri committed Sep 23, 2023
1 parent 512c7e1 commit ad05e7e
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 344 deletions.
13 changes: 13 additions & 0 deletions magma/backend/mlir/build_magma_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,24 @@ def _visit_inputs(
)


def _check_for_when_cycles(graph: Graph):
"""Temporary guard against https://github.com/phanrahan/magma/issues/1248"""
for node in graph.nodes():
if not iswhen(node):
continue
for predecessor in graph.predecessors(node):
if isinstance(type(node), Register): # registers break cycles
break
if predecessor is node:
raise MlirWhenCycleError()


def build_magma_graph(
ckt: DefineCircuitKind,
opts: BuildMagmaGrahOpts = BuildMagmaGrahOpts()) -> Graph:
ctx = ModuleContext(Graph(), opts)
_visit_inputs(ctx, ckt, opts.flatten_all_tuples)
for inst in ckt.instances:
_visit_inputs(ctx, inst, opts.flatten_all_tuples)
_check_for_when_cycles(ctx.graph)
return ctx.graph
74 changes: 0 additions & 74 deletions magma/passes/split_when_utils.py

This file was deleted.

11 changes: 1 addition & 10 deletions magma/passes/when.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from magma.passes.passes import DefinitionPass, pass_lambda
from magma.passes.split_when_utils import split_when_cycles
from magma.primitives.when import WhenBuilder
from magma.when import WhenBlock


class WhenPass(DefinitionPass):
def __call__(self, definition):
with definition.open():
for builder in list(definition._context_._builders):
for builder in definition._context_._builders:
if isinstance(builder, WhenBuilder):
self.process_when_builder(builder, definition)

Expand All @@ -26,11 +24,6 @@ def process_when_builder(self, builder, defn):
builder.emit_when_assertions(self._flatten_all_tuples)


class SplitCycles(WhenPass):
def process_when_builder(self, builder, defn):
split_when_cycles(builder, defn)


class FinalizeWhens(WhenPass):
def process_when_builder(self, builder, defn):
assert not builder._finalized
Expand All @@ -39,7 +32,6 @@ def process_when_builder(self, builder, defn):

infer_latch_pass = pass_lambda(InferLatches)
emit_when_assert_pass = pass_lambda(EmitWhenAsserts)
split_cycles_pass = pass_lambda(SplitCycles)
finalize_when_pass = pass_lambda(FinalizeWhens)


Expand All @@ -51,5 +43,4 @@ def run_when_passes(
infer_latch_pass(main)
if emit_when_assertions:
emit_when_assert_pass(main, flatten_all_tuples)
split_cycles_pass(main)
finalize_when_pass(main)
4 changes: 0 additions & 4 deletions magma/primitives/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ def output_to_index(self) -> Dict[Type, int]:
def output_to_name(self):
return self._output_to_name

@property
def input_to_name(self):
return self._input_to_name

def check_existing_derived_ref(self, value, value_to_name, value_to_index):
"""If value is a child of an array or tuple that has already been added,
we return the child of the existing value, rather than adding a new
Expand Down
42 changes: 21 additions & 21 deletions magma/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ class ConditionalWire:


class _BlockBase(contextlib.AbstractContextManager):
def __init__(self, parent: Optional['WhenBlock']):
def __init__(self, parent: Optional['_WhenBlock']):
self._parent = parent
self._children = list()
self._conditional_wires = list()
self._default_drivers = dict()

def spawn(self, info: '_WhenBlockInfo') -> 'WhenBlock':
child = WhenBlock(self, info)
def spawn(self, info: '_WhenBlockInfo') -> '_WhenBlock':
child = _WhenBlock(self, info)
self._children.append(child)
return child

Expand Down Expand Up @@ -108,7 +108,7 @@ def add_conditional_wire(self, i, o):

@property
@abc.abstractmethod
def root(self) -> 'WhenBlock':
def root(self) -> '_WhenBlock':
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -252,7 +252,7 @@ class _ElseWhenBlockInfo(_WhenBlockInfo):
pass


class WhenBlock(_BlockBase):
class _WhenBlock(_BlockBase):
def __init__(
self,
parent: Optional[_BlockBase],
Expand All @@ -272,14 +272,14 @@ def __init__(
self._builder.debug_info = debug_info

def new_elsewhen_block(self, info: '_ElseWhenBlockInfo'):
block = ElseWhenBlock(self, info)
block = _ElseWhenBlock(self, info)
self._elsewhens.append(block)
return block

def new_otherwise_block(self):
if self._otherwise is not None:
raise WhenSyntaxError()
block = OtherwiseBlock(self)
block = _OtherwiseBlock(self)
self._otherwise = block
return block

Expand All @@ -290,7 +290,7 @@ def builder(self) -> Any:
return self._builder

@property
def root(self) -> 'WhenBlock':
def root(self) -> '_WhenBlock':
if self._parent is None:
return self
return self._parent.root
Expand All @@ -316,7 +316,7 @@ def _get_exit_block(self):
return self._parent


class ElseWhenBlock(_BlockBase):
class _ElseWhenBlock(_BlockBase):
def __init__(self, parent: Optional[_BlockBase], info: _ElseWhenBlockInfo):
super().__init__(parent)
self._info = info
Expand All @@ -328,7 +328,7 @@ def new_otherwise_block(self):
return self._parent.new_otherwise_block()

@property
def root(self) -> WhenBlock:
def root(self) -> _WhenBlock:
return self._parent.root

@property
Expand All @@ -346,15 +346,15 @@ def _get_exit_block(self):
return self._parent._get_exit_block()


class OtherwiseBlock(_BlockBase):
class _OtherwiseBlock(_BlockBase):
def new_elsewhen_block(self, _):
raise ElsewhenWithoutPrecedingWhenError()

def new_otherwise_block(self):
raise OtherwiseWithoutPrecedingWhenError()

@property
def root(self) -> WhenBlock:
def root(self) -> _WhenBlock:
return self._parent.root

@property
Expand Down Expand Up @@ -434,18 +434,18 @@ def _reset_context():


def _get_else_block(block: _BlockBase) -> Optional[_BlockBase]:
if isinstance(block, OtherwiseBlock):
if isinstance(block, _OtherwiseBlock):
return None
if isinstance(block, ElseWhenBlock):
# TODO(rsetaluri): We could augment ElseWhenBlock to keep track of its
if isinstance(block, _ElseWhenBlock):
# TODO(rsetaluri): We could augment _ElseWhenBlock to keep track of its
# index in the parent block, or alternatively keep an explicit pointer
# to the next else/otherwise block in each ElseWhenBlock.
# to the next else/otherwise block in each _ElseWhenBlock.
parent_elsewhen_blocks = list(block._parent.elsewhen_blocks())
index = parent_elsewhen_blocks.index(block) + 1
if index == len(parent_elsewhen_blocks):
return block._parent.otherwise_block
return parent_elsewhen_blocks[index]
if isinstance(block, WhenBlock):
if isinstance(block, _WhenBlock):
try:
else_block = next(block.elsewhen_blocks())
except StopIteration:
Expand All @@ -464,9 +464,9 @@ def _get_then_ops(block: _BlockBase) -> Iterable[_Op]:


def _get_else_ops(block: _BlockBase) -> Iterable[_Op]:
if isinstance(block, OtherwiseBlock):
if isinstance(block, _OtherwiseBlock):
return _get_then_ops(block)
if isinstance(block, ElseWhenBlock):
if isinstance(block, _ElseWhenBlock):
return (block,)
raise TypeError(block)

Expand Down Expand Up @@ -535,7 +535,7 @@ def _get_assignees_and_latches(ops: Iterable[_Op]) -> Tuple[Set, Set]:


def find_inferred_latches(block: _BlockBase) -> Set:
if not (isinstance(block, WhenBlock) and block.root is block):
if not (isinstance(block, _WhenBlock) and block.root is block):
raise TypeError("Can only find inferred latches on root when block")
ops = tuple(block.default_drivers()) + (block,)
_, latches = _get_assignees_and_latches(ops)
Expand Down Expand Up @@ -676,7 +676,7 @@ def when(cond):
info = _WhenBlockInfo(cond)
curr_block = _get_curr_block()
if curr_block is None:
return WhenBlock(None, info, get_debug_info(3))
return _WhenBlock(None, info, get_debug_info(3))
return curr_block.spawn(info)


Expand Down
3 changes: 0 additions & 3 deletions magma/wire_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,6 @@ def has_elaborated_children(self):
def _enumerate_children(self):
raise NotImplementedError()

def enumerate_children(self):
return self._enumerate_children()

def _get_conditional_drivee_info(self):
"""
* wired_when_contexts: list of contexts in which this value appears as
Expand Down
25 changes: 0 additions & 25 deletions tests/gold/test_when_alwcomb_order.v

This file was deleted.

34 changes: 0 additions & 34 deletions tests/gold/test_when_alwcomb_order_complex.v

This file was deleted.

Loading

0 comments on commit ad05e7e

Please sign in to comment.