diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index 852bcc20f3..3c2eea81eb 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -394,6 +394,23 @@ def comb_mem_d1( name, ast.Stdlib.comb_mem_d1(bitwidth, len, idx_size), is_external, is_ref ) + def comb_mem_d2( + self, + name: str, + bitwidth: int, + len0: int, + len1: int, + idx_size0: int, + idx_size1: int, + is_external: bool = False, + is_ref: bool = False, + ) -> CellBuilder: + """Generate a StdMemD2 cell.""" + self.prog.import_("primitives/memories/comb.futil") + return self.cell( + name, ast.Stdlib.comb_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), is_external, is_ref + ) + def seq_mem_d1( self, name: str, @@ -409,6 +426,23 @@ def seq_mem_d1( name, ast.Stdlib.seq_mem_d1(bitwidth, len, idx_size), is_external, is_ref ) + def seq_mem_d2( + self, + name: str, + bitwidth: int, + len0: int, + len1: int, + idx_size0: int, + idx_size1: int, + is_external: bool = False, + is_ref: bool = False, + ) -> CellBuilder: + """Generate a SeqMemD2 cell.""" + self.prog.import_("primitives/memories/seq.futil") + return self.cell( + name, ast.Stdlib.seq_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), is_external, is_ref + ) + def binary( self, operation: str, @@ -751,11 +785,12 @@ def reg_store(self, reg, val, groupname=None): reg_grp.done = reg.done return reg_grp - def mem_load_d1(self, mem, i, reg, groupname, is_comb=False): + def mem_load_d1(self, mem, i, reg, groupname): """Inserts wiring into `self` to perform `reg := mem[i]`, - where `mem` is a seq_d1 memory or a comb_mem_d1 memory (if `is_comb` is True) + where `mem` is a seq_d1 memory or a comb_mem_d1 memory """ - assert mem.is_seq_mem_d1() if not is_comb else mem.is_comb_mem_d1() + assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1() + is_comb = mem.is_comb_mem_d1() with self.group(groupname) as load_grp: mem.addr0 = i if is_comb: @@ -768,11 +803,31 @@ def mem_load_d1(self, mem, i, reg, groupname, is_comb=False): load_grp.done = reg.done return load_grp - def mem_store_d1(self, mem, i, val, groupname, is_comb=False): + def mem_load_d2(self, mem, i, j, reg, groupname): + """Inserts wiring into `self` to perform `reg := mem[i]`, + where `mem` is a seq_d2 memory or a comb_mem_d2 memory + """ + assert mem.is_seq_mem_d2() or mem.is_comb_mem_d2() + is_comb = mem.is_comb_mem_d2() + with self.group(groupname) as load_grp: + mem.addr0 = i + mem.addr1 = j + if is_comb: + reg.write_en = 1 + reg.in_ = mem.read_data + else: + mem.content_en = 1 + reg.write_en = mem.done @ 1 + reg.in_ = mem.done @ mem.read_data + load_grp.done = reg.done + return load_grp + + def mem_store_d1(self, mem, i, val, groupname): """Inserts wiring into `self` to perform `mem[i] := val`, - where `mem` is a seq_d1 memory or a comb_mem_d1 memory (if `is_comb` is True) + where `mem` is a seq_d1 memory or a comb_mem_d1 memory """ - assert mem.is_seq_mem_d1() if not is_comb else mem.is_comb_mem_d1() + assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1() + is_comb = mem.is_comb_mem_d1() with self.group(groupname) as store_grp: mem.addr0 = i mem.write_en = 1 @@ -782,6 +837,22 @@ def mem_store_d1(self, mem, i, val, groupname, is_comb=False): mem.content_en = 1 return store_grp + def mem_store_d2(self, mem, i, j, val, groupname): + """Inserts wiring into `self` to perform `mem[i] := val`, + where `mem` is a seq_d2 memory or a comb_mem_d2 memory + """ + assert mem.is_seq_mem_d2() or mem.is_comb_mem_d2() + is_comb = mem.is_comb_mem_d2() + with self.group(groupname) as store_grp: + mem.addr0 = i + mem.addr1 = j + mem.write_en = 1 + mem.write_data = val + store_grp.done = mem.done + if not is_comb: + mem.content_en = 1 + return store_grp + def mem_load_to_mem(self, mem, i, ans, j, groupname): """Inserts wiring into `self` to perform `ans[j] := mem[i]`, where `mem` and `ans` are both comb_mem_d1 memories. @@ -1237,10 +1308,18 @@ def is_comb_mem_d1(self) -> bool: """Check if the cell is a StdMemD1 cell.""" return self.is_primitive("comb_mem_d1") + def is_comb_mem_d2(self) -> bool: + """Check if the cell is a StdMemD2 cell.""" + return self.is_primitive("comb_mem_d2") + def is_seq_mem_d1(self) -> bool: """Check if the cell is a SeqMemD1 cell.""" return self.is_primitive("seq_mem_d1") + def is_seq_mem_d2(self) -> bool: + """Check if the cell is a SeqMemD2 cell.""" + return self.is_primitive("seq_mem_d2") + def infer_width_reg(self) -> int: """Infer the width of a register. That is, the width of `reg.in`.""" assert self._cell.comp.id == "std_reg", "Cell is not a register" @@ -1273,14 +1352,18 @@ def infer_width(self, port_name) -> int: ): if port_name in ("left", "right"): return inst.args[0] - if prim in ("comb_mem_d1", "seq_mem_d1"): + if prim in ("comb_mem_d1", "seq_mem_d1", "comb_mem_d2", "seq_mem_d2"): if port_name == "write_en": return 1 + if "d2" in prim and port_name == "addr0": + return inst.args[3] + if "d2" in prim and port_name == "addr1": + return inst.args[4] if port_name == "addr0": return inst.args[2] if port_name == "in": return inst.args[0] - if prim == "seq_mem_d1" and port_name == "content_en": + if "seq_mem" in prim and port_name == "content_en": return 1 if prim in ( "std_mult_pipe", diff --git a/calyx-py/calyx/gen_exp.py b/calyx-py/calyx/gen_exp.py index bee7f0c732..560d542c05 100644 --- a/calyx-py/calyx/gen_exp.py +++ b/calyx-py/calyx/gen_exp.py @@ -700,9 +700,9 @@ def build_base_not_e(degree, width, int_width, is_signed) -> Program: ret = main.comb_mem_d1("ret", width, 1, 1, is_external=True) f = main.comp_instance("f", "fp_pow_full") - read_base = main.mem_load_d1(b, 0, base_reg, "read_base", is_comb=True) - read_exp = main.mem_load_d1(x, 0, exp_reg, "read_exp", is_comb=True) - write_to_memory = main.mem_store_d1(ret, 0, f.out, "write_to_memory", is_comb=True) + read_base = main.mem_load_d1(b, 0, base_reg, "read_base") + read_exp = main.mem_load_d1(x, 0, exp_reg, "read_exp") + write_to_memory = main.mem_store_d1(ret, 0, f.out, "write_to_memory") main.control += [ read_base, @@ -741,7 +741,7 @@ def build_base_is_e(degree, width, int_width, is_signed) -> Program: t.write_en = 1 init.done = t.done - write_to_memory = main.mem_store_d1(ret, 0, e.out, "write_to_memory", is_comb=True) + write_to_memory = main.mem_store_d1(ret, 0, e.out, "write_to_memory") main.control += [ init, diff --git a/calyx-py/test/correctness/matrix.data b/calyx-py/test/correctness/matrix.data new file mode 100644 index 0000000000..6711a60878 --- /dev/null +++ b/calyx-py/test/correctness/matrix.data @@ -0,0 +1,101 @@ +{ + "A": { + "data": [ + [ + 1, + 0, + 4, + 6 + ], + [ + 2, + 5, + 0, + 3 + ], + [ + 1, + 2, + 3, + 5 + ], + [ + 2, + 1, + 2, + 3 + ] + ], + "format": { + "numeric_type": "bitnum", + "is_signed": false, + "width": 32 + } + }, + "B": { + "data": [ + [ + 1, + 0, + 4, + 6 + ], + [ + 2, + 5, + 0, + 3 + ], + [ + 1, + 2, + 3, + 5 + ], + [ + 2, + 1, + 2, + 3 + ] + ], + "format": { + "numeric_type": "bitnum", + "is_signed": false, + "width": 32 + } + }, + "C": { + "data": [ + [ + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0 + ] + ], + "format": { + "numeric_type": "bitnum", + "is_signed": false, + "width": 32 + } + } +} \ No newline at end of file diff --git a/calyx-py/test/correctness/matrix.expect b/calyx-py/test/correctness/matrix.expect new file mode 100644 index 0000000000..e34e5805fa --- /dev/null +++ b/calyx-py/test/correctness/matrix.expect @@ -0,0 +1,80 @@ +{ + "A": [ + [ + 1, + 0, + 4, + 6 + ], + [ + 2, + 5, + 0, + 3 + ], + [ + 1, + 2, + 3, + 5 + ], + [ + 2, + 1, + 2, + 3 + ] + ], + "B": [ + [ + 1, + 0, + 4, + 6 + ], + [ + 2, + 5, + 0, + 3 + ], + [ + 1, + 2, + 3, + 5 + ], + [ + 2, + 1, + 2, + 3 + ] + ], + "C": [ + [ + 17, + 14, + 28, + 44 + ], + [ + 18, + 28, + 14, + 36 + ], + [ + 18, + 21, + 23, + 42 + ], + [ + 12, + 12, + 20, + 34 + ] + ] +} diff --git a/calyx-py/test/correctness/matrix.py b/calyx-py/test/correctness/matrix.py new file mode 100644 index 0000000000..d7965e4fbc --- /dev/null +++ b/calyx-py/test/correctness/matrix.py @@ -0,0 +1,107 @@ +import calyx.builder as cb + +def insert_matmul_component(prog, n): + """Inserts the component `matmul` into the program. + + It has: + - one 2d combinational ref memory, A + - two 2d sequential ref memories, B and C + + Interpreting A and B as n x n matrices, matmul computes the matrix product + A*B and writes this into C. + """ + + logn = n.bit_length() + + matmul = prog.component("matmul") + + # matrices + A = matmul.comb_mem_d2("A", 32, n, n, logn, logn, is_ref=True) + B = matmul.seq_mem_d2( "B", 32, n, n, logn, logn, is_ref=True) + C = matmul.seq_mem_d2( "C", 32, n, n, logn, logn, is_ref=True) + + mult = matmul.mult_pipe(32) + add = matmul.add(32) + + acc = matmul.reg(32) + + # iterators: i, j, k ∈ [0, n) + i = matmul.reg(3) + j = matmul.reg(3) + k = matmul.reg(3) + + # matrix entries + a = matmul.reg(32) + b = matmul.reg(32) + + + zero_acc = matmul.reg_store(acc, 0) # acc := 0 + zero_i = matmul.reg_store(i, 0) # i := 0 + zero_j = matmul.reg_store(j, 0) # j := 0 + zero_k = matmul.reg_store(k, 0) # k := 0 + + cond_i = matmul.lt_use(i.out, n) # i < n + cond_j = matmul.lt_use(j.out, n) # j < n + cond_k = matmul.lt_use(k.out, n) # k < n + + read_A = matmul.mem_load_d2(A, i.out, k.out, a, "read_A") # a := A[i][k] + read_B = matmul.mem_load_d2(B, k.out, j.out, b, "read_B") # b := B[k][j] + + # C[i][j] := c + write_C = matmul.mem_store_d2(C, i.out, j.out, acc.out, "write") + + # acc := acc + (a * b) + with matmul.group("upd") as upd: + # compute a*b + mult.go = 1 + mult.left = a.out + mult.right = b.out + + # compute acc + (a*b) + add.left = mult.done @ mult.out + add.right = mult.done @ acc.out + + # store acc + (a*b) in acc + acc.in_ = mult.done @ add.out + acc.write_en = mult.done @ cb.HI + upd.done = mult.done @ acc.done + + matmul.control += [ + zero_i, + cb.while_with(cond_i, + [ + zero_j, + cb.while_with(cond_j, + [ + zero_k, + zero_acc, + cb.while_with(cond_k, [read_A, read_B, upd, matmul.incr(k)]), + write_C, + matmul.incr(j) + ]), + matmul.incr(i) + ]) + ] + + return matmul + +def insert_main(prog): + main = prog.component("main") + + n = 4 + logn = n.bit_length() + + A = main.comb_mem_d2("A", 32, n, n, logn, logn, is_external=True) + B = main.seq_mem_d2( "B", 32, n, n, logn, logn, is_external=True) + C = main.seq_mem_d2( "C", 32, n, n, logn, logn, is_external=True) + + matmul = insert_matmul_component(prog, n) + matmul = main.cell("matmul", matmul) + + main.control += [cb.invoke(matmul, ref_A=A, ref_B=B, ref_C=C)] + +if __name__ == "__main__": + prog = cb.Builder() + insert_main(prog) + prog.program.emit() + diff --git a/frontends/ntt-pipeline/gen-ntt-pipeline.py b/frontends/ntt-pipeline/gen-ntt-pipeline.py index 1d13f9f0ab..237edc5ebe 100755 --- a/frontends/ntt-pipeline/gen-ntt-pipeline.py +++ b/frontends/ntt-pipeline/gen-ntt-pipeline.py @@ -216,7 +216,7 @@ def preamble_group(comp: cb.ComponentBuilder, row): def epilogue_group(comp: cb.ComponentBuilder, row): input = comp.get_cell("a") A = comp.get_cell(f"A{row}") - comp.mem_store_d1(input, row, A.out, f"epilogue_{row}", is_comb=True) + comp.mem_store_d1(input, row, A.out, f"epilogue_{row}") def insert_cells(comp: cb.ComponentBuilder): # memories