Skip to content

Commit

Permalink
Add comb_mem_d2 and seq_mem_d2 to eDSL (#2099)
Browse files Browse the repository at this point in the history
* Add comb_mem_d2 and seq_mem_d2

* Fix 'is_comb' errors

* Use both comb_mem and seq_mem in matrix.py

* Add documentation
  • Loading branch information
polybeandip authored Jun 6, 2024
1 parent a831b04 commit 8ab0bda
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 13 deletions.
99 changes: 91 additions & 8 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ def comb_mem_d1(
name, ast.Stdlib.comb_mem_d1(bitwidth, len, idx_size), is_external, is_ref
)

def comb_mem_d2(
self,
name: str,
bitwidth: int,
len0: int,
len1: int,
idx_size0: int,
idx_size1: int,
is_external: bool = False,
is_ref: bool = False,
) -> CellBuilder:
"""Generate a StdMemD2 cell."""
self.prog.import_("primitives/memories/comb.futil")
return self.cell(
name, ast.Stdlib.comb_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), is_external, is_ref
)

def seq_mem_d1(
self,
name: str,
Expand All @@ -409,6 +426,23 @@ def seq_mem_d1(
name, ast.Stdlib.seq_mem_d1(bitwidth, len, idx_size), is_external, is_ref
)

def seq_mem_d2(
self,
name: str,
bitwidth: int,
len0: int,
len1: int,
idx_size0: int,
idx_size1: int,
is_external: bool = False,
is_ref: bool = False,
) -> CellBuilder:
"""Generate a SeqMemD2 cell."""
self.prog.import_("primitives/memories/seq.futil")
return self.cell(
name, ast.Stdlib.seq_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), is_external, is_ref
)

def binary(
self,
operation: str,
Expand Down Expand Up @@ -751,11 +785,12 @@ def reg_store(self, reg, val, groupname=None):
reg_grp.done = reg.done
return reg_grp

def mem_load_d1(self, mem, i, reg, groupname, is_comb=False):
def mem_load_d1(self, mem, i, reg, groupname):
"""Inserts wiring into `self` to perform `reg := mem[i]`,
where `mem` is a seq_d1 memory or a comb_mem_d1 memory (if `is_comb` is True)
where `mem` is a seq_d1 memory or a comb_mem_d1 memory
"""
assert mem.is_seq_mem_d1() if not is_comb else mem.is_comb_mem_d1()
assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1()
is_comb = mem.is_comb_mem_d1()
with self.group(groupname) as load_grp:
mem.addr0 = i
if is_comb:
Expand All @@ -768,11 +803,31 @@ def mem_load_d1(self, mem, i, reg, groupname, is_comb=False):
load_grp.done = reg.done
return load_grp

def mem_store_d1(self, mem, i, val, groupname, is_comb=False):
def mem_load_d2(self, mem, i, j, reg, groupname):
"""Inserts wiring into `self` to perform `reg := mem[i]`,
where `mem` is a seq_d2 memory or a comb_mem_d2 memory
"""
assert mem.is_seq_mem_d2() or mem.is_comb_mem_d2()
is_comb = mem.is_comb_mem_d2()
with self.group(groupname) as load_grp:
mem.addr0 = i
mem.addr1 = j
if is_comb:
reg.write_en = 1
reg.in_ = mem.read_data
else:
mem.content_en = 1
reg.write_en = mem.done @ 1
reg.in_ = mem.done @ mem.read_data
load_grp.done = reg.done
return load_grp

def mem_store_d1(self, mem, i, val, groupname):
"""Inserts wiring into `self` to perform `mem[i] := val`,
where `mem` is a seq_d1 memory or a comb_mem_d1 memory (if `is_comb` is True)
where `mem` is a seq_d1 memory or a comb_mem_d1 memory
"""
assert mem.is_seq_mem_d1() if not is_comb else mem.is_comb_mem_d1()
assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1()
is_comb = mem.is_comb_mem_d1()
with self.group(groupname) as store_grp:
mem.addr0 = i
mem.write_en = 1
Expand All @@ -782,6 +837,22 @@ def mem_store_d1(self, mem, i, val, groupname, is_comb=False):
mem.content_en = 1
return store_grp

def mem_store_d2(self, mem, i, j, val, groupname):
"""Inserts wiring into `self` to perform `mem[i] := val`,
where `mem` is a seq_d2 memory or a comb_mem_d2 memory
"""
assert mem.is_seq_mem_d2() or mem.is_comb_mem_d2()
is_comb = mem.is_comb_mem_d2()
with self.group(groupname) as store_grp:
mem.addr0 = i
mem.addr1 = j
mem.write_en = 1
mem.write_data = val
store_grp.done = mem.done
if not is_comb:
mem.content_en = 1
return store_grp

def mem_load_to_mem(self, mem, i, ans, j, groupname):
"""Inserts wiring into `self` to perform `ans[j] := mem[i]`,
where `mem` and `ans` are both comb_mem_d1 memories.
Expand Down Expand Up @@ -1237,10 +1308,18 @@ def is_comb_mem_d1(self) -> bool:
"""Check if the cell is a StdMemD1 cell."""
return self.is_primitive("comb_mem_d1")

def is_comb_mem_d2(self) -> bool:
"""Check if the cell is a StdMemD2 cell."""
return self.is_primitive("comb_mem_d2")

def is_seq_mem_d1(self) -> bool:
"""Check if the cell is a SeqMemD1 cell."""
return self.is_primitive("seq_mem_d1")

def is_seq_mem_d2(self) -> bool:
"""Check if the cell is a SeqMemD2 cell."""
return self.is_primitive("seq_mem_d2")

def infer_width_reg(self) -> int:
"""Infer the width of a register. That is, the width of `reg.in`."""
assert self._cell.comp.id == "std_reg", "Cell is not a register"
Expand Down Expand Up @@ -1273,14 +1352,18 @@ def infer_width(self, port_name) -> int:
):
if port_name in ("left", "right"):
return inst.args[0]
if prim in ("comb_mem_d1", "seq_mem_d1"):
if prim in ("comb_mem_d1", "seq_mem_d1", "comb_mem_d2", "seq_mem_d2"):
if port_name == "write_en":
return 1
if "d2" in prim and port_name == "addr0":
return inst.args[3]
if "d2" in prim and port_name == "addr1":
return inst.args[4]
if port_name == "addr0":
return inst.args[2]
if port_name == "in":
return inst.args[0]
if prim == "seq_mem_d1" and port_name == "content_en":
if "seq_mem" in prim and port_name == "content_en":
return 1
if prim in (
"std_mult_pipe",
Expand Down
8 changes: 4 additions & 4 deletions calyx-py/calyx/gen_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,9 @@ def build_base_not_e(degree, width, int_width, is_signed) -> Program:
ret = main.comb_mem_d1("ret", width, 1, 1, is_external=True)
f = main.comp_instance("f", "fp_pow_full")

read_base = main.mem_load_d1(b, 0, base_reg, "read_base", is_comb=True)
read_exp = main.mem_load_d1(x, 0, exp_reg, "read_exp", is_comb=True)
write_to_memory = main.mem_store_d1(ret, 0, f.out, "write_to_memory", is_comb=True)
read_base = main.mem_load_d1(b, 0, base_reg, "read_base")
read_exp = main.mem_load_d1(x, 0, exp_reg, "read_exp")
write_to_memory = main.mem_store_d1(ret, 0, f.out, "write_to_memory")

main.control += [
read_base,
Expand Down Expand Up @@ -741,7 +741,7 @@ def build_base_is_e(degree, width, int_width, is_signed) -> Program:
t.write_en = 1
init.done = t.done

write_to_memory = main.mem_store_d1(ret, 0, e.out, "write_to_memory", is_comb=True)
write_to_memory = main.mem_store_d1(ret, 0, e.out, "write_to_memory")

main.control += [
init,
Expand Down
101 changes: 101 additions & 0 deletions calyx-py/test/correctness/matrix.data
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"A": {
"data": [
[
1,
0,
4,
6
],
[
2,
5,
0,
3
],
[
1,
2,
3,
5
],
[
2,
1,
2,
3
]
],
"format": {
"numeric_type": "bitnum",
"is_signed": false,
"width": 32
}
},
"B": {
"data": [
[
1,
0,
4,
6
],
[
2,
5,
0,
3
],
[
1,
2,
3,
5
],
[
2,
1,
2,
3
]
],
"format": {
"numeric_type": "bitnum",
"is_signed": false,
"width": 32
}
},
"C": {
"data": [
[
0,
0,
0,
0
],
[
0,
0,
0,
0
],
[
0,
0,
0,
0
],
[
0,
0,
0,
0
]
],
"format": {
"numeric_type": "bitnum",
"is_signed": false,
"width": 32
}
}
}
80 changes: 80 additions & 0 deletions calyx-py/test/correctness/matrix.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"A": [
[
1,
0,
4,
6
],
[
2,
5,
0,
3
],
[
1,
2,
3,
5
],
[
2,
1,
2,
3
]
],
"B": [
[
1,
0,
4,
6
],
[
2,
5,
0,
3
],
[
1,
2,
3,
5
],
[
2,
1,
2,
3
]
],
"C": [
[
17,
14,
28,
44
],
[
18,
28,
14,
36
],
[
18,
21,
23,
42
],
[
12,
12,
20,
34
]
]
}
Loading

0 comments on commit 8ab0bda

Please sign in to comment.