Skip to content

Commit

Permalink
[When] Add support for unflattened tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
rsetaluri committed Sep 23, 2023
1 parent 4475c36 commit 14196a8
Show file tree
Hide file tree
Showing 41 changed files with 8,541 additions and 216 deletions.
6,299 changes: 6,299 additions & 0 deletions examples/riscv_mini/tests/gold/test_riscv_mini_unflattened_tuples.v

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions examples/riscv_mini/tests/test_unflattened_tuples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import magma as m
from magma.config import config as magma_config
from magma.testing.utils import check_gold

from riscv_mini.tile import Tile


def test_riscv_mini_unflattened_tuples():
magma_config.compile_dir = 'callee_file_dir'
tile = Tile(32)
m.compile(
"build/test_riscv_mini_unflattened_tuples",
tile,
output="mlir-verilog",
disallow_local_variables=True,
disallow_packed_struct_assignments=True
)
assert check_gold(__file__, "test_riscv_mini_unflattened_tuples.v")
2 changes: 2 additions & 0 deletions magma/backend/mlir/compile_to_mlir_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ class CompileToMlirOpts:
explicit_bitcast: bool = False
disallow_expression_inlining_in_ports: bool = False
disallow_local_variables: bool = False
disallow_packed_struct_assignments: bool = False
split_verilog: bool = False
omit_version_comment: bool = True
emit_when_assertions: bool = False
10 changes: 7 additions & 3 deletions magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def make_constant(
hw.ArrayCreateOp(operands=operands, results=[result])
return result
if isinstance(T, TupleMeta):
fields = T.field_dict.items()
fields = list(T.field_dict.items())
value = value if value is not None else {k: None for k, _ in fields}
operands = [self.make_constant(t, value[k]) for k, t in fields]
hw.StructCreateOp(operands=operands, results=[result])
Expand Down Expand Up @@ -342,6 +342,12 @@ def make_array_ref(self, arr: MlirValue, i: Union[int, tuple]) -> MlirValue:
hw.ArrayGetOp(operands=[arr, start], results=[operand])
return operand

@functools.lru_cache()
def make_struct_ref(self, struct: MlirValue, key: str) -> MlirValue:
operand = self.ctx.new_value(struct.type.get_field(key))
hw.StructExtractOp(field=key, operands=[struct], results=[operand])
return operand

def make_concat(self, operands, result):
"""Collect single elements and put them into an array create op,
this allows slices to be concatenated with the surround elements.
Expand Down Expand Up @@ -860,7 +866,6 @@ def visit_magma_xmr_sink(self, module: ModuleWrapper) -> bool:
else:
return True
paths = get_xmr_paths(self._ctx, xmr)
assert len(paths) == len(module.operands)
self._ctx.parent.xmr_paths[xmr] = paths
return True

Expand All @@ -871,7 +876,6 @@ def visit_magma_xmr_source(self, module: ModuleWrapper) -> bool:
assert isinstance(defn, XMRSource)
xmr = defn.value
paths = self._ctx.parent.xmr_paths[xmr]
assert len(paths) == len(module.results)
base = defn.value.parent_view.path()
for result, path in zip(module.results, paths):
in_out = self.ctx.new_value(hw.InOutType(result.type))
Expand Down
8 changes: 7 additions & 1 deletion magma/backend/mlir/hw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, ClassVar, List, Mapping, Optional, Tuple
from typing import ClassVar, List, Optional, Tuple

from magma.backend.mlir.common import default_field
from magma.backend.mlir.mlir import (
Expand Down Expand Up @@ -44,6 +44,12 @@ def emit(self) -> str:
field_str = ", ".join(f"{k}: {t.emit()}" for k, t in self.fields)
return f"!hw.struct<{field_str}>"

def get_field(self, field: str) -> MlirType:
for key, value in self.fields:
if key == field:
return value
raise KeyError(f"Could not find struct key {field} in {self.fields}")


@dataclasses.dataclass(frozen=True)
class InOutType(MlirType):
Expand Down
3 changes: 1 addition & 2 deletions magma/backend/mlir/magma_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def MagmaTupleCreateOp(T: TupleMeta):
assert isinstance(T, TupleMeta)
T = T.undirected_t
name = f"magma_tuple_create_op_{value_or_type_to_string(T)}"
fields = T.field_dict
ports = {f"I{k}": In(t) for k, t in fields.items()}
ports = {f"I{k}": In(t) for k, t in T.field_dict.items()}
ports.update(dict(O=Out(T)))
return InstanceWrapper(name, ports, {})

Expand Down
8 changes: 6 additions & 2 deletions magma/backend/mlir/mlir_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ def suffix(self):
return "mlir"

def run_pre_uniquification_passes(self):
if self.opts.get("flatten_all_tuples", False):
if self._compile_to_mlir_opts.flatten_all_tuples:
elaborate_tuples(self.main)
# NOTE(leonardt): when passes must happen after any
# passes that modify the circuit. This is because passes
# could introduce more conditional logic, or they could
# trigger elaboration on values used in existing coditiona
# logic (which modifies the when builder).
run_when_passes(self.main, self.opts.get("emit_when_assertions", False))
run_when_passes(
self.main,
self._compile_to_mlir_opts.flatten_all_tuples,
self._compile_to_mlir_opts.emit_when_assertions
)

def _run_passes(self):
raise_logs_as_exceptions_pass(self.main)
Expand Down
2 changes: 2 additions & 0 deletions magma/backend/mlir/translation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _set_module_attrs(mlir_module: builtin.ModuleOp, opts: CompileToMlirOpts):
lowering_options.append("disallowLocalVariables")
if opts.omit_version_comment:
lowering_options.append("omitVersionComment")
if opts.disallow_packed_struct_assignments:
lowering_options.append("disallowPackedStructAssignments")
if lowering_options:
mlir_module.attr_dict["circt.loweringOptions"] = builtin.StringAttr(
f"{','.join(lowering_options)}"
Expand Down
113 changes: 75 additions & 38 deletions magma/backend/mlir/when_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,25 @@
from magma.backend.mlir.sv import sv
from magma.common import sort_by_value
from magma.primitives.when import iswhen
from magma.ref import ArrayRef, TupleRef, DerivedRef
from magma.ref import TupleRef, DerivedRef
from magma.value_utils import make_selector


class _IndexBuilder:
"""Stages getitem/getattr calls by appending them to a tuple"""

def __init__(self):
self.index = tuple()

def __getitem__(self, idx):
self.index += (idx,)
return self

def __getattr__(self, idx):
self.index += (idx,)
return self


class WhenCompiler:
def __init__(self, module_visitor, module):
self._module_visitor = module_visitor
Expand Down Expand Up @@ -56,7 +71,10 @@ def _flatten_index_map(self, builder_map):

def _index_map_visit(value):
nonlocal counter, value_to_index
if isinstance(value.name, TupleRef) and value in value_to_index:
if (
self._flatten_all_tuples and
isinstance(value.name, TupleRef) and value in value_to_index
):
# tuples are flattened by the
# `visit_magma_value_or_value_wrapper_by_direction`, so we avoid
# adding them twice which invalidates the count logic
Expand Down Expand Up @@ -87,7 +105,7 @@ def _index_map_visit(value):
return value_to_index

def _make_output_wires(self):
"""Create the mlir values corresponding to each output"""
"""Create the mlir values corresponding to each output."""
wires = [
self._module_visitor.ctx.new_value(hw.InOutType(result.type))
for result in self._module.results
Expand All @@ -110,34 +128,27 @@ def _flatten_value(self, value):

def _get_parent(self, val, collection, to_index):
"""Search ancestor tree until we find either not an Array or we find a
an array that is in the index map"""
an array that is in the index map.
"""
for ref in val.name.root_iter(
stop_if=lambda ref: not isinstance(ref, ArrayRef)
stop_if=lambda ref: not isinstance(ref, DerivedRef)
):
try:
idx = to_index[ref.array]
idx = to_index[ref.parent_value]
except KeyError:
pass # try next parent
else:
return collection[idx], ref.array
return collection[idx], ref.parent_value
return None, None # didn't find parent

def _check_array_child_wire(self, val, collection, to_index):
"""If val is a child of an array in the index map, get the parent wire
(so we add to a collection of drivers for a bulk assign)
"""If val is a child of an array or tuple in the index map, get the
parent wire (so we add to a collection of drivers for a bulk assign).
"""
wire, parent = self._get_parent(val, collection, to_index)
if wire is None:
return None, None

class _IndexBuilder:
def __init__(self):
self.index = tuple()

def __getitem__(self, idx):
self.index += (idx, )
return self

builder = _IndexBuilder()
make_selector(val, stop_at=parent).select(builder)
return wire, builder.index
Expand All @@ -152,7 +163,11 @@ def _make_operand(self, wire, index):
# convert to tuple for hashing
i = (i.start, i.stop, i.step)
with push_block(self._outer_block):
operand = self._module_visitor.make_array_ref(wire, i)
if isinstance(wire.type, (hw.ArrayType, builtin.IntegerType)):
operand = self._module_visitor.make_array_ref(wire, i)
else:
assert isinstance(wire.type, hw.StructType)
operand = self._module_visitor.make_struct_ref(wire, i)
return self._make_operand(operand, index[1:])

def _build_wire_map(self, connections):
Expand Down Expand Up @@ -188,37 +203,57 @@ def _build_wire_map(self, connections):
wire_map[drivee_wire] = operand
return wire_map

def _make_arr_list(self, T):
"""Create a nested list structure matching the dimensions of T, used to
populate the elements of an array create op"""
def _make_recursive_collection(self, T):
"""Create a nested data structure matching T, used to populate the
elements of an array or struct create op.
"""
if isinstance(T, builtin.IntegerType):
return [None for _ in range(T.n)]
if isinstance(T, hw.StructType):
return {k: self._make_recursive_collection(v) for k, v in T.fields}
assert isinstance(T, hw.ArrayType), T
return [self._make_arr_list(T.T) for _ in range(T.dims[0])]
return [self._make_recursive_collection(T.T) for _ in range(T.dims[0])]

def _build_array_value(self, T, value):
"""Unpack the contents of value into a nested list structure"""
# TODO(leonardt): we could use an ndarray here, would simplify indexing
arr = self._make_arr_list(T)
def _populate_recursive_collection(self, T, value):
"""Unpack the contents of value into the corresponding nested structure
create in `_make_recursive_collection`.
"""
arr = self._make_recursive_collection(T)
for idx, elem in value.items():
curr = arr
for i in idx[:-1]: # descend up to last index
curr = curr[i]
curr[idx[-1]] = elem # use last index for setitem
return arr

def _combine_array_assign(self, T, value):
"""Sort drivers by index, use concat or create depending on type"""
def _create_struct_from_collection(self, T, value, result):
value = [
self._create_from_recursive_collection(v, value[k])
for k, v in sorted(T.fields, key=lambda x: x[0])
]
hw.StructCreateOp(
operands=value,
results=[result]
)
return result

def _create_from_recursive_collection(self, T, value):
"""Sort drivers by index, use concat or create depending on type."""
if isinstance(value, MlirValue):
return value # found whole value, no need to combine
if all(x is None for x in value):
return None # found empty value, covered by previous slice
result = self._module_visitor.ctx.new_value(T)
if not isinstance(T, builtin.IntegerType):
# recursive combine children
if isinstance(T, hw.StructType):
return self._create_struct_from_collection(T, value, result)
assert isinstance(T, hw.ArrayType)
assert len(T.dims) == 1, "Expected 1d array"
value = [self._combine_array_assign(T.T, value[i])
for i in range(T.dims[0])]
value = [
self._create_from_recursive_collection(T.T, value[i])
for i in range(T.dims[0])
]
# Filter None elements (indices covered by a previous slice)
value = [x for x in reversed(value) if x is not None]
self._module_visitor.make_concat(value, result)
Expand All @@ -228,14 +263,16 @@ def _make_assignments(self, connections):
"""
* _build_wire_map: contructs mapping from output wire to driver
* _build_array_value,
_combine_array_assign: handle collection elaborated drivers for a bulk
assign
* _populate_recursive_collection,
_create_from_recursive_collection: handle collection of elaborated
drivers for a bulk assign
"""
for wire, value in self._build_wire_map(connections).items():
if isinstance(value, dict):
value = self._build_array_value(wire.type.T, value)
value = self._combine_array_assign(wire.type.T, value)
value = self._populate_recursive_collection(wire.type.T, value)
value = self._create_from_recursive_collection(
wire.type.T, value
)
sv.BPAssignOp(operands=[wire, value])

def _process_connections(self, block):
Expand All @@ -250,10 +287,10 @@ def _process_connections(self, block):
def _process_when_block(self, block):
"""
If no condition, we are in an otherwise case and simply emit the
block body (which is inside a previous IfOp)
block body (which is inside a previous IfOp).
Otherwise, we emit an IfOp with the true body corresponding to this
block, then process the sibilings in the else block
block, then process the sibilings in the else block.
"""
if block.condition is None:
return self._process_connections(block)
Expand All @@ -276,7 +313,7 @@ def _process_when_block(self, block):
return if_op

def compile(self):
"""Emit default drivers then process the when block chain"""
"""Emit default drivers then process the when block chain."""
with push_block(sv.AlwaysCombOp().body_block):
self._make_assignments(self._builder.default_drivers.items())
self._process_when_block(self._builder.block)
Expand Down
12 changes: 7 additions & 5 deletions magma/backend/mlir/xmr_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Tuple

from magma.backend.mlir.magma_common import (
value_or_value_wrapper_to_tree as magma_value_or_value_wrapper_to_tree
value_or_value_wrapper_to_tree as magma_value_or_value_wrapper_to_tree,
tuple_key_to_str
)
from magma.ref import TupleRef, ArrayRef
from magma.view import PortView
Expand Down Expand Up @@ -54,7 +55,7 @@ def _ascend_to_leaf(value, leaves):
ref = value.name
if isinstance(ref, TupleRef):
path = _ascend_to_leaf(ref.tuple, leaves)
return path + [f".{ref.index}"]
return path + [f".{tuple_key_to_str(ref.index, type(ref.tuple))}"]
if isinstance(ref, ArrayRef):
path = _ascend_to_leaf(ref.array, leaves)
return path + [f"[{ref.index}]"]
Expand Down Expand Up @@ -95,21 +96,22 @@ def get_xmr_paths(ctx: 'HardwareModule', xmr: PortView) -> List[Tuple[str]]:
root, flatten_all_tuples=ctx.opts.flatten_all_tuples)
assert tree.has_node(root)

separator = "_" if ctx.opts.flatten_all_tuples else "."
# (1)
if tree.has_node(xmr.port): # visited
path = _get_path_string(tree, xmr.port, "_")
path = _get_path_string(tree, xmr.port, separator)
# (1a)
if tree.out_degree(xmr.port) == 0: # is leaf
return [(path,)]
# (1b)
leaves = _get_leaf_descendants(tree, xmr.port)
leaves = sorted(leaves, key=lambda n: tree.nodes[n]["index"])
return [
(_get_path_string(tree, leaf, "_"),)
(_get_path_string(tree, leaf, separator),)
for leaf in leaves
]
# (2)
leaves = list(_get_leaf_descendants(tree, root, include_self=True))
leaf, *path = _ascend_to_leaf(xmr.port, leaves)
path = _get_path_string(tree, leaf, "_") + "".join(path)
path = _get_path_string(tree, leaf, separator) + "".join(path)
return [(path,)]
Loading

0 comments on commit 14196a8

Please sign in to comment.