Skip to content

Commit

Permalink
[Primitives] Add multiport memory (#1254)
Browse files Browse the repository at this point in the history
* [Primitives] Add multiport memory primitive

* [MLIR] Add support for multiport memory

---------

Co-authored-by: rsetaluri <[email protected]>
  • Loading branch information
leonardt and rsetaluri committed Mar 31, 2023
1 parent df5b4bb commit 2edf15c
Show file tree
Hide file tree
Showing 11 changed files with 309 additions and 4 deletions.
24 changes: 21 additions & 3 deletions magma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,28 @@ def set_mantle_target(t):
as_bits, from_bits
)
from magma.primitives import (
LUT, Mux, mux, dict_lookup, list_lookup,
Register, get_slice, set_slice, slice,
Memory, set_index, register, Wire,
# LUT primitives.
LUT,
# Mux primitives.
Mux,
mux,
# Lookup primitives.
dict_lookup,
list_lookup,
# Slice getter/setter primitives.
slice,
get_slice,
set_slice,
set_index,
# Memory primitives.
Memory,
MultiportMemory,
# Register primitives.
AbstractRegister,
Register,
register,
# Wire primitives.
Wire,
)

from magma.types import (BitPattern, Valid, ReadyValid, Consumer, Producer,
Expand Down
45 changes: 45 additions & 0 deletions magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
make_mem_reg,
make_mem_read,
emit_conditional_assign,
emit_conditional_assigns,
make_index_op,
collect_multiport_memory_operands,
make_multiport_memory_index_ops,
make_multiport_memory_read_ops,
)
from magma.backend.mlir.mlir import (
MlirType, MlirValue, MlirSymbol, MlirAttribute, MlirBlock, push_block,
Expand Down Expand Up @@ -58,6 +62,7 @@
has_default_linked_module,
get_default_linked_module,
)
from magma.primitives.multiport_memory import MultiportMemory
from magma.primitives.mux import Mux
from magma.primitives.register import Register
from magma.primitives.when import iswhen
Expand Down Expand Up @@ -674,6 +679,44 @@ def visit_array_slice(self, module: ModuleWrapper) -> bool:
def visit_when(self, module: ModuleWrapper) -> bool:
return WhenCompiler(self, module).compile()

@wrap_with_not_implemented_error
def visit_multiport_memory(self, module: ModuleWrapper) -> bool:
inst = module.module
defn = type(inst)
clk = module.operands[0]
read_ports_out = module.results
read_port_len = 1 + defn.has_read_enable
read_ports = collect_multiport_memory_operands(
module.operands, 1, defn.num_read_ports, read_port_len
)
write_ports = collect_multiport_memory_operands(
module.operands,
1 + defn.num_read_ports * read_port_len,
defn.num_write_ports,
3
)
elt_type = hw.InOutType(magma_type_to_mlir_type(defn.T))
mem = make_mem_reg(self._ctx, inst.name, defn.height, elt_type.T)
read_results = make_multiport_memory_index_ops(
self._ctx, read_ports, mem
)
read_targets = make_multiport_memory_read_ops(
self._ctx, read_results, read_ports, read_ports_out
)
write_results = make_multiport_memory_index_ops(
self._ctx, write_ports, mem
)
write_targets = (
(write_results[i], *write_ports[i][1:3])
for i in range(len(write_results))
)
always = sv.AlwaysFFOp(operands=[clk], clock_edge="posedge").body_block
with push_block(always):
emit_conditional_assigns(write_targets)
if len(read_targets):
emit_conditional_assigns(read_targets)
return True

@wrap_with_not_implemented_error
def visit_primitive(self, module: ModuleWrapper) -> bool:
inst = module.module
Expand All @@ -687,6 +730,8 @@ def visit_primitive(self, module: ModuleWrapper) -> bool:
return self.visit_coreir_primitive(module)
if defn.coreir_lib == "commonlib":
return self.visit_commonlib_primitive(module)
if isinstance(defn, MultiportMemory):
return self.visit_multiport_memory(module)
if isinstance(defn, InlineVerilogExpression):
assert len(module.operands) == 0
assert len(module.results) > 0
Expand Down
44 changes: 44 additions & 0 deletions magma/backend/mlir/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,47 @@ def make_index_op(ctx, value, idx):
result = ctx.new_value(hw.InOutType(value.type.T.T))
sv.ArrayIndexInOutOp(operands=[value, idx], results=[result])
return result


def collect_multiport_memory_operands(
operands, start_idx, num_ports, num_operands_per_port
):
"""Collect flat list of operands for read or write ports.
start_idx: offset in `operands` list to start from
num_ports: number of ports to collect
num_operands_per_port: number of operands per port
(e.g. 3 for waddr, wdata, wen)
"""
port_operands = []
curr_idx = start_idx
for i in range(num_ports):
port_operands.append(
operands[curr_idx:curr_idx + num_operands_per_port]
)
curr_idx += num_operands_per_port
return port_operands


def make_multiport_memory_index_ops(ctx, ports, mem):
"""For each port, emit an array index op and return a list of results."""
return [make_index_op(ctx, mem, port[0]) for port in ports]


def make_multiport_memory_read_ops(
ctx, read_results, read_ports, read_ports_out
):
"""If ren, emit an intermediate register to hold read value."""
read_targets = []
for i, (target, port) in enumerate(zip(read_results, read_ports)):
has_en = len(port) == 2
read_reg, read_temp = make_mem_read(
ctx, target, read_ports_out[i], has_en, f"read_reg_{i}"
)
if has_en:
read_targets.append((read_reg, read_temp, port[1]))
return read_targets


def emit_conditional_assigns(targets):
for target, data, en in targets:
emit_conditional_assign(target, data, en)
1 change: 1 addition & 0 deletions magma/primitives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from magma.primitives.lut import LUT
from magma.primitives.memory import Memory
from magma.primitives.multiport_memory import MultiportMemory
from magma.primitives.mux import Mux, mux, dict_lookup, list_lookup
from magma.primitives.register import Register, register, AbstractRegister
from magma.primitives.slice import get_slice, set_slice, slice
Expand Down
46 changes: 46 additions & 0 deletions magma/primitives/multiport_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from magma.bits import Bits
from magma.bitutils import clog2
from magma.clock import Enable
from magma.clock_io import ClockIO
from magma.generator import Generator2
from magma.interface import IO
from magma.t import In, Out, Kind


class MultiportMemory(Generator2):
def __init__(
self,
height: int,
T: Kind,
num_read_ports: int = 1,
num_write_ports: int = 1,
has_read_enable: bool = False
):
if num_read_ports < 1:
raise ValueError("At least one read port is required")
if num_write_ports < 0:
raise ValueError("Number of write ports must be non-negative")

self.num_read_ports = num_read_ports
self.has_read_enable = has_read_enable
self.num_write_ports = num_write_ports
self.T = T
self.height = height

addr_width = clog2(height)

self.io = ClockIO()
for i in range(num_read_ports):
self.io += IO(**{
f"RADDR_{i}": In(Bits[addr_width]),
f"RDATA_{i}": Out(T)
})
if has_read_enable:
self.io += IO(**{f"RE_{i}": In(Enable)})
for i in range(num_write_ports):
self.io += IO(**{
f"WADDR_{i}": In(Bits[addr_width]),
f"WDATA_{i}": In(T),
f"WE_{i}": In(Enable)
})
self.primitive = True
41 changes: 40 additions & 1 deletion tests/test_backend/test_mlir/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ class simple_memory_wrapper(m.Circuit):
)


m.passes.clock.WireClockPass(simple_memory_wrapper).run()


class sync_memory_wrapper(m.Circuit):
T = m.Bits[12]
height = 128
Expand All @@ -533,7 +536,43 @@ class sync_memory_wrapper(m.Circuit):
)


m.passes.clock.WireClockPass(simple_memory_wrapper).run()
def _make_multiport_memory(has_re: bool):

class _test_multiport_memory(m.Circuit):
name = "multiport_memory" + ("_re" if has_re else "")
io = m.IO(
raddr_0=m.In(m.Bits[2]),
rdata_0=m.Out(m.UInt[5]),
raddr_1=m.In(m.Bits[2]),
rdata_1=m.Out(m.UInt[5]),
waddr_0=m.In(m.Bits[2]),
wdata_0=m.In(m.UInt[5]),
we_0=m.In(m.Enable),
waddr_1=m.In(m.Bits[2]),
wdata_1=m.In(m.UInt[5]),
we_1=m.In(m.Enable),
clk=m.In(m.Clock)
)
if has_re:
for i in range(2):
io += m.IO(**{f"re_{i}": m.In(m.Enable)})

mem = m.MultiportMemory(4, m.UInt[5], 2, 2, has_re)()

for i in range(2):
m.wire(getattr(mem, f"RADDR_{i}"), getattr(io, f"raddr_{i}"))
m.wire(getattr(mem, f"RDATA_{i}"), getattr(io, f"rdata_{i}"))
if has_re:
m.wire(getattr(mem, f"RE_{i}"), getattr(io, f"re_{i}"))
m.wire(getattr(mem, f"WADDR_{i}"), getattr(io, f"waddr_{i}"))
m.wire(getattr(mem, f"WDATA_{i}"), getattr(io, f"wdata_{i}"))
m.wire(getattr(mem, f"WE_{i}"), getattr(io, f"we_{i}"))

return _test_multiport_memory


multiport_memory = _make_multiport_memory(has_re=False)
multiport_memory_re = _make_multiport_memory(has_re=True)


class simple_undriven_instances(m.Circuit):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_backend/test_mlir/golds/multiport_memory.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none"} {
hw.module @multiport_memory(%raddr_0: i2, %raddr_1: i2, %waddr_0: i2, %wdata_0: i5, %we_0: i1, %waddr_1: i2, %wdata_1: i5, %we_1: i1, %clk: i1) -> (rdata_0: i5, rdata_1: i5) {
%2 = sv.reg name "MultiportMemory_inst0" : !hw.inout<!hw.array<4xi5>>
%3 = sv.array_index_inout %2[%raddr_0] : !hw.inout<!hw.array<4xi5>>, i2
%4 = sv.array_index_inout %2[%raddr_1] : !hw.inout<!hw.array<4xi5>>, i2
%0 = sv.read_inout %3 : !hw.inout<i5>
%1 = sv.read_inout %4 : !hw.inout<i5>
%5 = sv.array_index_inout %2[%waddr_0] : !hw.inout<!hw.array<4xi5>>, i2
%6 = sv.array_index_inout %2[%waddr_1] : !hw.inout<!hw.array<4xi5>>, i2
sv.alwaysff(posedge %clk) {
sv.if %we_0 {
sv.passign %5, %wdata_0 : i5
}
sv.if %we_1 {
sv.passign %6, %wdata_1 : i5
}
}
hw.output %0, %1 : i5, i5
}
}
26 changes: 26 additions & 0 deletions tests/test_backend/test_mlir/golds/multiport_memory.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Generated by CIRCT circtorg-0.0.0-1773-g7abbc4313
module multiport_memory(
input [1:0] raddr_0,
raddr_1,
waddr_0,
input [4:0] wdata_0,
input we_0,
input [1:0] waddr_1,
input [4:0] wdata_1,
input we_1,
clk,
output [4:0] rdata_0,
rdata_1
);

reg [3:0][4:0] MultiportMemory_inst0;
always_ff @(posedge clk) begin
if (we_0)
MultiportMemory_inst0[waddr_0] <= wdata_0;
if (we_1)
MultiportMemory_inst0[waddr_1] <= wdata_1;
end // always_ff @(posedge)
assign rdata_0 = MultiportMemory_inst0[raddr_0];
assign rdata_1 = MultiportMemory_inst0[raddr_1];
endmodule

30 changes: 30 additions & 0 deletions tests/test_backend/test_mlir/golds/multiport_memory_re.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none"} {
hw.module @multiport_memory_re(%raddr_0: i2, %raddr_1: i2, %waddr_0: i2, %wdata_0: i5, %we_0: i1, %waddr_1: i2, %wdata_1: i5, %we_1: i1, %clk: i1, %re_0: i1, %re_1: i1) -> (rdata_0: i5, rdata_1: i5) {
%2 = sv.reg name "MultiportMemory_inst0" : !hw.inout<!hw.array<4xi5>>
%3 = sv.array_index_inout %2[%raddr_0] : !hw.inout<!hw.array<4xi5>>, i2
%4 = sv.array_index_inout %2[%raddr_1] : !hw.inout<!hw.array<4xi5>>, i2
%5 = sv.read_inout %3 : !hw.inout<i5>
%6 = sv.reg name "read_reg_0" : !hw.inout<i5>
%0 = sv.read_inout %6 : !hw.inout<i5>
%7 = sv.read_inout %4 : !hw.inout<i5>
%8 = sv.reg name "read_reg_1" : !hw.inout<i5>
%1 = sv.read_inout %8 : !hw.inout<i5>
%9 = sv.array_index_inout %2[%waddr_0] : !hw.inout<!hw.array<4xi5>>, i2
%10 = sv.array_index_inout %2[%waddr_1] : !hw.inout<!hw.array<4xi5>>, i2
sv.alwaysff(posedge %clk) {
sv.if %we_0 {
sv.passign %9, %wdata_0 : i5
}
sv.if %we_1 {
sv.passign %10, %wdata_1 : i5
}
sv.if %re_0 {
sv.passign %6, %5 : i5
}
sv.if %re_1 {
sv.passign %8, %7 : i5
}
}
hw.output %0, %1 : i5, i5
}
}
34 changes: 34 additions & 0 deletions tests/test_backend/test_mlir/golds/multiport_memory_re.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Generated by CIRCT circtorg-0.0.0-1773-g7abbc4313
module multiport_memory_re(
input [1:0] raddr_0,
raddr_1,
waddr_0,
input [4:0] wdata_0,
input we_0,
input [1:0] waddr_1,
input [4:0] wdata_1,
input we_1,
clk,
re_0,
re_1,
output [4:0] rdata_0,
rdata_1
);

reg [3:0][4:0] MultiportMemory_inst0;
reg [4:0] read_reg_0;
reg [4:0] read_reg_1;
always_ff @(posedge clk) begin
if (we_0)
MultiportMemory_inst0[waddr_0] <= wdata_0;
if (we_1)
MultiportMemory_inst0[waddr_1] <= wdata_1;
if (re_0)
read_reg_0 <= MultiportMemory_inst0[raddr_0];
if (re_1)
read_reg_1 <= MultiportMemory_inst0[raddr_1];
end // always_ff @(posedge)
assign rdata_0 = read_reg_0;
assign rdata_1 = read_reg_1;
endmodule

2 changes: 2 additions & 0 deletions tests/test_backend/test_mlir/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def get_local_examples() -> List[DefineCircuitKind]:
examples.complex_undriven,
examples.simple_memory_wrapper,
examples.sync_memory_wrapper,
examples.multiport_memory,
examples.multiport_memory_re,
examples.simple_undriven_instances,
examples.simple_neg,
examples.simple_array_slice,
Expand Down

0 comments on commit 2edf15c

Please sign in to comment.