From 2edf15cc758480ff431f01ade13afdef414c0e88 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Fri, 31 Mar 2023 12:31:55 -0700 Subject: [PATCH] [Primitives] Add multiport memory (#1254) * [Primitives] Add multiport memory primitive * [MLIR] Add support for multiport memory --------- Co-authored-by: rsetaluri --- magma/__init__.py | 24 ++++++++-- magma/backend/mlir/hardware_module.py | 45 ++++++++++++++++++ magma/backend/mlir/mem_utils.py | 44 ++++++++++++++++++ magma/primitives/__init__.py | 1 + magma/primitives/multiport_memory.py | 46 +++++++++++++++++++ tests/test_backend/test_mlir/examples.py | 41 ++++++++++++++++- .../test_mlir/golds/multiport_memory.mlir | 20 ++++++++ .../test_mlir/golds/multiport_memory.v | 26 +++++++++++ .../test_mlir/golds/multiport_memory_re.mlir | 30 ++++++++++++ .../test_mlir/golds/multiport_memory_re.v | 34 ++++++++++++++ tests/test_backend/test_mlir/test_utils.py | 2 + 11 files changed, 309 insertions(+), 4 deletions(-) create mode 100644 magma/primitives/multiport_memory.py create mode 100644 tests/test_backend/test_mlir/golds/multiport_memory.mlir create mode 100644 tests/test_backend/test_mlir/golds/multiport_memory.v create mode 100644 tests/test_backend/test_mlir/golds/multiport_memory_re.mlir create mode 100644 tests/test_backend/test_mlir/golds/multiport_memory_re.v diff --git a/magma/__init__.py b/magma/__init__.py index f617dcadb..573328f87 100644 --- a/magma/__init__.py +++ b/magma/__init__.py @@ -102,10 +102,28 @@ def set_mantle_target(t): as_bits, from_bits ) from magma.primitives import ( - LUT, Mux, mux, dict_lookup, list_lookup, - Register, get_slice, set_slice, slice, - Memory, set_index, register, Wire, + # LUT primitives. + LUT, + # Mux primitives. + Mux, + mux, + # Lookup primitives. + dict_lookup, + list_lookup, + # Slice getter/setter primitives. + slice, + get_slice, + set_slice, + set_index, + # Memory primitives. + Memory, + MultiportMemory, + # Register primitives. AbstractRegister, + Register, + register, + # Wire primitives. + Wire, ) from magma.types import (BitPattern, Valid, ReadyValid, Consumer, Producer, diff --git a/magma/backend/mlir/hardware_module.py b/magma/backend/mlir/hardware_module.py index 610923772..6dd3343ef 100644 --- a/magma/backend/mlir/hardware_module.py +++ b/magma/backend/mlir/hardware_module.py @@ -30,7 +30,11 @@ make_mem_reg, make_mem_read, emit_conditional_assign, + emit_conditional_assigns, make_index_op, + collect_multiport_memory_operands, + make_multiport_memory_index_ops, + make_multiport_memory_read_ops, ) from magma.backend.mlir.mlir import ( MlirType, MlirValue, MlirSymbol, MlirAttribute, MlirBlock, push_block, @@ -58,6 +62,7 @@ has_default_linked_module, get_default_linked_module, ) +from magma.primitives.multiport_memory import MultiportMemory from magma.primitives.mux import Mux from magma.primitives.register import Register from magma.primitives.when import iswhen @@ -674,6 +679,44 @@ def visit_array_slice(self, module: ModuleWrapper) -> bool: def visit_when(self, module: ModuleWrapper) -> bool: return WhenCompiler(self, module).compile() + @wrap_with_not_implemented_error + def visit_multiport_memory(self, module: ModuleWrapper) -> bool: + inst = module.module + defn = type(inst) + clk = module.operands[0] + read_ports_out = module.results + read_port_len = 1 + defn.has_read_enable + read_ports = collect_multiport_memory_operands( + module.operands, 1, defn.num_read_ports, read_port_len + ) + write_ports = collect_multiport_memory_operands( + module.operands, + 1 + defn.num_read_ports * read_port_len, + defn.num_write_ports, + 3 + ) + elt_type = hw.InOutType(magma_type_to_mlir_type(defn.T)) + mem = make_mem_reg(self._ctx, inst.name, defn.height, elt_type.T) + read_results = make_multiport_memory_index_ops( + self._ctx, read_ports, mem + ) + read_targets = make_multiport_memory_read_ops( + self._ctx, read_results, read_ports, read_ports_out + ) + write_results = make_multiport_memory_index_ops( + self._ctx, write_ports, mem + ) + write_targets = ( + (write_results[i], *write_ports[i][1:3]) + for i in range(len(write_results)) + ) + always = sv.AlwaysFFOp(operands=[clk], clock_edge="posedge").body_block + with push_block(always): + emit_conditional_assigns(write_targets) + if len(read_targets): + emit_conditional_assigns(read_targets) + return True + @wrap_with_not_implemented_error def visit_primitive(self, module: ModuleWrapper) -> bool: inst = module.module @@ -687,6 +730,8 @@ def visit_primitive(self, module: ModuleWrapper) -> bool: return self.visit_coreir_primitive(module) if defn.coreir_lib == "commonlib": return self.visit_commonlib_primitive(module) + if isinstance(defn, MultiportMemory): + return self.visit_multiport_memory(module) if isinstance(defn, InlineVerilogExpression): assert len(module.operands) == 0 assert len(module.results) > 0 diff --git a/magma/backend/mlir/mem_utils.py b/magma/backend/mlir/mem_utils.py index 35374d62e..b2d3438ec 100644 --- a/magma/backend/mlir/mem_utils.py +++ b/magma/backend/mlir/mem_utils.py @@ -31,3 +31,47 @@ def make_index_op(ctx, value, idx): result = ctx.new_value(hw.InOutType(value.type.T.T)) sv.ArrayIndexInOutOp(operands=[value, idx], results=[result]) return result + + +def collect_multiport_memory_operands( + operands, start_idx, num_ports, num_operands_per_port +): + """Collect flat list of operands for read or write ports. + start_idx: offset in `operands` list to start from + num_ports: number of ports to collect + num_operands_per_port: number of operands per port + (e.g. 3 for waddr, wdata, wen) + """ + port_operands = [] + curr_idx = start_idx + for i in range(num_ports): + port_operands.append( + operands[curr_idx:curr_idx + num_operands_per_port] + ) + curr_idx += num_operands_per_port + return port_operands + + +def make_multiport_memory_index_ops(ctx, ports, mem): + """For each port, emit an array index op and return a list of results.""" + return [make_index_op(ctx, mem, port[0]) for port in ports] + + +def make_multiport_memory_read_ops( + ctx, read_results, read_ports, read_ports_out +): + """If ren, emit an intermediate register to hold read value.""" + read_targets = [] + for i, (target, port) in enumerate(zip(read_results, read_ports)): + has_en = len(port) == 2 + read_reg, read_temp = make_mem_read( + ctx, target, read_ports_out[i], has_en, f"read_reg_{i}" + ) + if has_en: + read_targets.append((read_reg, read_temp, port[1])) + return read_targets + + +def emit_conditional_assigns(targets): + for target, data, en in targets: + emit_conditional_assign(target, data, en) diff --git a/magma/primitives/__init__.py b/magma/primitives/__init__.py index 665bdd412..c15a6ef91 100644 --- a/magma/primitives/__init__.py +++ b/magma/primitives/__init__.py @@ -1,5 +1,6 @@ from magma.primitives.lut import LUT from magma.primitives.memory import Memory +from magma.primitives.multiport_memory import MultiportMemory from magma.primitives.mux import Mux, mux, dict_lookup, list_lookup from magma.primitives.register import Register, register, AbstractRegister from magma.primitives.slice import get_slice, set_slice, slice diff --git a/magma/primitives/multiport_memory.py b/magma/primitives/multiport_memory.py new file mode 100644 index 000000000..48b2d6bc3 --- /dev/null +++ b/magma/primitives/multiport_memory.py @@ -0,0 +1,46 @@ +from magma.bits import Bits +from magma.bitutils import clog2 +from magma.clock import Enable +from magma.clock_io import ClockIO +from magma.generator import Generator2 +from magma.interface import IO +from magma.t import In, Out, Kind + + +class MultiportMemory(Generator2): + def __init__( + self, + height: int, + T: Kind, + num_read_ports: int = 1, + num_write_ports: int = 1, + has_read_enable: bool = False + ): + if num_read_ports < 1: + raise ValueError("At least one read port is required") + if num_write_ports < 0: + raise ValueError("Number of write ports must be non-negative") + + self.num_read_ports = num_read_ports + self.has_read_enable = has_read_enable + self.num_write_ports = num_write_ports + self.T = T + self.height = height + + addr_width = clog2(height) + + self.io = ClockIO() + for i in range(num_read_ports): + self.io += IO(**{ + f"RADDR_{i}": In(Bits[addr_width]), + f"RDATA_{i}": Out(T) + }) + if has_read_enable: + self.io += IO(**{f"RE_{i}": In(Enable)}) + for i in range(num_write_ports): + self.io += IO(**{ + f"WADDR_{i}": In(Bits[addr_width]), + f"WDATA_{i}": In(T), + f"WE_{i}": In(Enable) + }) + self.primitive = True diff --git a/tests/test_backend/test_mlir/examples.py b/tests/test_backend/test_mlir/examples.py index 319dac980..5e4df092e 100644 --- a/tests/test_backend/test_mlir/examples.py +++ b/tests/test_backend/test_mlir/examples.py @@ -511,6 +511,9 @@ class simple_memory_wrapper(m.Circuit): ) +m.passes.clock.WireClockPass(simple_memory_wrapper).run() + + class sync_memory_wrapper(m.Circuit): T = m.Bits[12] height = 128 @@ -533,7 +536,43 @@ class sync_memory_wrapper(m.Circuit): ) -m.passes.clock.WireClockPass(simple_memory_wrapper).run() +def _make_multiport_memory(has_re: bool): + + class _test_multiport_memory(m.Circuit): + name = "multiport_memory" + ("_re" if has_re else "") + io = m.IO( + raddr_0=m.In(m.Bits[2]), + rdata_0=m.Out(m.UInt[5]), + raddr_1=m.In(m.Bits[2]), + rdata_1=m.Out(m.UInt[5]), + waddr_0=m.In(m.Bits[2]), + wdata_0=m.In(m.UInt[5]), + we_0=m.In(m.Enable), + waddr_1=m.In(m.Bits[2]), + wdata_1=m.In(m.UInt[5]), + we_1=m.In(m.Enable), + clk=m.In(m.Clock) + ) + if has_re: + for i in range(2): + io += m.IO(**{f"re_{i}": m.In(m.Enable)}) + + mem = m.MultiportMemory(4, m.UInt[5], 2, 2, has_re)() + + for i in range(2): + m.wire(getattr(mem, f"RADDR_{i}"), getattr(io, f"raddr_{i}")) + m.wire(getattr(mem, f"RDATA_{i}"), getattr(io, f"rdata_{i}")) + if has_re: + m.wire(getattr(mem, f"RE_{i}"), getattr(io, f"re_{i}")) + m.wire(getattr(mem, f"WADDR_{i}"), getattr(io, f"waddr_{i}")) + m.wire(getattr(mem, f"WDATA_{i}"), getattr(io, f"wdata_{i}")) + m.wire(getattr(mem, f"WE_{i}"), getattr(io, f"we_{i}")) + + return _test_multiport_memory + + +multiport_memory = _make_multiport_memory(has_re=False) +multiport_memory_re = _make_multiport_memory(has_re=True) class simple_undriven_instances(m.Circuit): diff --git a/tests/test_backend/test_mlir/golds/multiport_memory.mlir b/tests/test_backend/test_mlir/golds/multiport_memory.mlir new file mode 100644 index 000000000..8888abaf4 --- /dev/null +++ b/tests/test_backend/test_mlir/golds/multiport_memory.mlir @@ -0,0 +1,20 @@ +module attributes {circt.loweringOptions = "locationInfoStyle=none"} { + hw.module @multiport_memory(%raddr_0: i2, %raddr_1: i2, %waddr_0: i2, %wdata_0: i5, %we_0: i1, %waddr_1: i2, %wdata_1: i5, %we_1: i1, %clk: i1) -> (rdata_0: i5, rdata_1: i5) { + %2 = sv.reg name "MultiportMemory_inst0" : !hw.inout> + %3 = sv.array_index_inout %2[%raddr_0] : !hw.inout>, i2 + %4 = sv.array_index_inout %2[%raddr_1] : !hw.inout>, i2 + %0 = sv.read_inout %3 : !hw.inout + %1 = sv.read_inout %4 : !hw.inout + %5 = sv.array_index_inout %2[%waddr_0] : !hw.inout>, i2 + %6 = sv.array_index_inout %2[%waddr_1] : !hw.inout>, i2 + sv.alwaysff(posedge %clk) { + sv.if %we_0 { + sv.passign %5, %wdata_0 : i5 + } + sv.if %we_1 { + sv.passign %6, %wdata_1 : i5 + } + } + hw.output %0, %1 : i5, i5 + } +} diff --git a/tests/test_backend/test_mlir/golds/multiport_memory.v b/tests/test_backend/test_mlir/golds/multiport_memory.v new file mode 100644 index 000000000..5bc36e8d1 --- /dev/null +++ b/tests/test_backend/test_mlir/golds/multiport_memory.v @@ -0,0 +1,26 @@ +// Generated by CIRCT circtorg-0.0.0-1773-g7abbc4313 +module multiport_memory( + input [1:0] raddr_0, + raddr_1, + waddr_0, + input [4:0] wdata_0, + input we_0, + input [1:0] waddr_1, + input [4:0] wdata_1, + input we_1, + clk, + output [4:0] rdata_0, + rdata_1 +); + + reg [3:0][4:0] MultiportMemory_inst0; + always_ff @(posedge clk) begin + if (we_0) + MultiportMemory_inst0[waddr_0] <= wdata_0; + if (we_1) + MultiportMemory_inst0[waddr_1] <= wdata_1; + end // always_ff @(posedge) + assign rdata_0 = MultiportMemory_inst0[raddr_0]; + assign rdata_1 = MultiportMemory_inst0[raddr_1]; +endmodule + diff --git a/tests/test_backend/test_mlir/golds/multiport_memory_re.mlir b/tests/test_backend/test_mlir/golds/multiport_memory_re.mlir new file mode 100644 index 000000000..9cb685356 --- /dev/null +++ b/tests/test_backend/test_mlir/golds/multiport_memory_re.mlir @@ -0,0 +1,30 @@ +module attributes {circt.loweringOptions = "locationInfoStyle=none"} { + hw.module @multiport_memory_re(%raddr_0: i2, %raddr_1: i2, %waddr_0: i2, %wdata_0: i5, %we_0: i1, %waddr_1: i2, %wdata_1: i5, %we_1: i1, %clk: i1, %re_0: i1, %re_1: i1) -> (rdata_0: i5, rdata_1: i5) { + %2 = sv.reg name "MultiportMemory_inst0" : !hw.inout> + %3 = sv.array_index_inout %2[%raddr_0] : !hw.inout>, i2 + %4 = sv.array_index_inout %2[%raddr_1] : !hw.inout>, i2 + %5 = sv.read_inout %3 : !hw.inout + %6 = sv.reg name "read_reg_0" : !hw.inout + %0 = sv.read_inout %6 : !hw.inout + %7 = sv.read_inout %4 : !hw.inout + %8 = sv.reg name "read_reg_1" : !hw.inout + %1 = sv.read_inout %8 : !hw.inout + %9 = sv.array_index_inout %2[%waddr_0] : !hw.inout>, i2 + %10 = sv.array_index_inout %2[%waddr_1] : !hw.inout>, i2 + sv.alwaysff(posedge %clk) { + sv.if %we_0 { + sv.passign %9, %wdata_0 : i5 + } + sv.if %we_1 { + sv.passign %10, %wdata_1 : i5 + } + sv.if %re_0 { + sv.passign %6, %5 : i5 + } + sv.if %re_1 { + sv.passign %8, %7 : i5 + } + } + hw.output %0, %1 : i5, i5 + } +} diff --git a/tests/test_backend/test_mlir/golds/multiport_memory_re.v b/tests/test_backend/test_mlir/golds/multiport_memory_re.v new file mode 100644 index 000000000..ba4ea6f65 --- /dev/null +++ b/tests/test_backend/test_mlir/golds/multiport_memory_re.v @@ -0,0 +1,34 @@ +// Generated by CIRCT circtorg-0.0.0-1773-g7abbc4313 +module multiport_memory_re( + input [1:0] raddr_0, + raddr_1, + waddr_0, + input [4:0] wdata_0, + input we_0, + input [1:0] waddr_1, + input [4:0] wdata_1, + input we_1, + clk, + re_0, + re_1, + output [4:0] rdata_0, + rdata_1 +); + + reg [3:0][4:0] MultiportMemory_inst0; + reg [4:0] read_reg_0; + reg [4:0] read_reg_1; + always_ff @(posedge clk) begin + if (we_0) + MultiportMemory_inst0[waddr_0] <= wdata_0; + if (we_1) + MultiportMemory_inst0[waddr_1] <= wdata_1; + if (re_0) + read_reg_0 <= MultiportMemory_inst0[raddr_0]; + if (re_1) + read_reg_1 <= MultiportMemory_inst0[raddr_1]; + end // always_ff @(posedge) + assign rdata_0 = read_reg_0; + assign rdata_1 = read_reg_1; +endmodule + diff --git a/tests/test_backend/test_mlir/test_utils.py b/tests/test_backend/test_mlir/test_utils.py index 1a250a1c9..1f8189ab9 100644 --- a/tests/test_backend/test_mlir/test_utils.py +++ b/tests/test_backend/test_mlir/test_utils.py @@ -190,6 +190,8 @@ def get_local_examples() -> List[DefineCircuitKind]: examples.complex_undriven, examples.simple_memory_wrapper, examples.sync_memory_wrapper, + examples.multiport_memory, + examples.multiport_memory_re, examples.simple_undriven_instances, examples.simple_neg, examples.simple_array_slice,