Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] when support for flatten_all_tuples #1253

Open
leonardt opened this issue Mar 22, 2023 · 1 comment
Open

[MLIR] when support for flatten_all_tuples #1253

leonardt opened this issue Mar 22, 2023 · 1 comment

Comments

@leonardt
Copy link
Collaborator

--- a/tests/test_when.py
+++ b/tests/test_when.py
@@ -342,7 +342,7 @@ def test_memory(T, bits_to_fault_value):
             io.out @= m.from_bits(T, m.Bits[8](0xFF))
 
     m.compile(f"build/test_when_memory_{T_str}", test_when_memory,
-              output="mlir", flatten_all_tuples=True)
+              output="mlir")
 
     if check_gold(__file__, f"test_when_memory_{T_str}.mlir"):
         return
@@ -377,7 +377,6 @@ def test_memory(T, bits_to_fault_value):
     tester.compile_and_run("verilator", magma_output="mlir-verilog",
                            directory=os.path.join(os.path.dirname(__file__),
                                                   "build"),
-                           magma_opts={"flatten_all_tuples": True},
                            flags=['-Wno-UNUSED'])
 
     update_gold(__file__, f"test_when_memory_{T_str}.mlir")

This first needs a fix in the tuple pipeline

--- a/magma/backend/mlir/magma_ops.py
+++ b/magma/backend/mlir/magma_ops.py
@@ -8,11 +8,11 @@ from magma.digital import DigitalMeta
 from magma.generator import Generator2
 from magma.interface import IO
 from magma.t import In, Out
-from magma.tuple import TupleMeta, ProductMeta
+from magma.tuple import TupleMeta, AnonymousProductMeta


 def _get_tuple_field_type(T: TupleMeta, index: Union[int, str]):
-    if isinstance(T, ProductMeta):
+    if isinstance(T, AnonymousProductMeta):
         return T.field_dict[index]
     index = int(index)
     return T.field_dict[index]

Then we run into

magma/compile.py:104: in compile
    result = compiler.compile()
magma/backend/mlir/mlir_compiler.py:58: in compile
    compile_to_mlir(
magma/backend/mlir/compile_to_mlir.py:24: in compile_to_mlir
    translation_unit.compile()
magma/backend/mlir/translation_unit.py:145: in compile
    hardware_module.compile()
magma/backend/mlir/hardware_module.py:1195: in compile
    self._hw_module = self._compile()
magma/backend/mlir/hardware_module.py:1254: in _compile
    visitor.visit(self._magma_defn_or_decl)
magma/backend/mlir/hardware_module.py:950: in visit
    self._process_magma_module(module)
magma/backend/mlir/hardware_module.py:939: in _process_magma_module
    self.visit(predecessor)
magma/backend/mlir/hardware_module.py:951: in visit
    self.visit_module(ModuleWrapper.make(module, self._ctx))
magma/backend/mlir/common.py:11: in wrapped
    ret = fn(*args, **kwargs)
magma/backend/mlir/hardware_module.py:929: in visit_module
    return self.visit_instance(module)
magma/backend/mlir/common.py:11: in wrapped
    ret = fn(*args, **kwargs)
magma/backend/mlir/hardware_module.py:850: in visit_instance
    return self.visit_primitive(module)
magma/backend/mlir/common.py:11: in wrapped
    ret = fn(*args, **kwargs)
magma/backend/mlir/hardware_module.py:677: in visit_primitive
    return self.visit_when(module)
magma/backend/mlir/common.py:11: in wrapped
    ret = fn(*args, **kwargs)
magma/backend/mlir/hardware_module.py:656: in visit_when
    return WhenCompiler(self, module).compile()
magma/backend/mlir/when_utils.py:282: in compile
    self._process_when_block(self._builder.block)
magma/backend/mlir/when_utils.py:265: in _process_when_block
    self._process_connections(block)
magma/backend/mlir/when_utils.py:246: in _process_connections
    self._make_assignments(connections)
magma/backend/mlir/when_utils.py:235: in _make_assignments
    for wire, value in self._build_wire_map(connections).items():
magma/backend/mlir/when_utils.py:174: in _build_wire_map
    operand = self._get_operand(driver_elt)
magma/backend/mlir/when_utils.py:46: in _get_operand
    return self._operands[self._get_input_index(value)]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <magma.backend.mlir.when_utils.WhenCompiler object at 0x7fa1d1258400>, value = Memory_inst0.RDATA.x

    def _get_input_index(self, value):
>       return self._input_to_index[value]
E       KeyError: Memory_inst0.RDATA.x
@rsetaluri
Copy link
Collaborator

Looked into this issue briefly. From what I can tell, looking at this snippet:

	with m.when(io.en0):
            mem[io.addr0] @= io.data0
            io.out @= mem[io.addr1]
        with m.elsewhen(io.en1):
            mem[io.addr1] @= io.data1
            io.out @= mem[io.addr0]
        with m.otherwise():
            io.out @= m.from_bits(T, m.Bits[8](0xFF))

It seems the otherwise case causes the tuple RDATA to be elaborated resulting in connections for x and y fields separately. Then these connections can not be reconciled with the keys which include the unflattened tuple type in this case.

One option is to make a pass on the when connections to consolidate all aggregate types even if they're elaborated. A potential complication here is that if the when structures are different for different fields.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants