diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index 5999ccd857..efbd9f2311 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -344,6 +344,18 @@ def sub(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdSub cell.""" return self.binary("sub", size, name, signed) + def div_pipe( + self, size: int, name: str = None, signed: bool = False + ) -> CellBuilder: + """Generate a Div_Pipe cell.""" + return self.binary("div_pipe", size, name, signed) + + def mult_pipe( + self, size: int, name: str = None, signed: bool = False + ) -> CellBuilder: + """Generate a Mult_Pipe cell.""" + return self.binary("mult_pipe", size, name, signed) + def gt(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdGt cell.""" return self.binary("gt", size, name, signed) @@ -466,6 +478,28 @@ def binary_use(self, left, right, cell, groupname=None): cell.right = right return CellAndGroup(cell, comb_group) + def binary_use_names(self, cellname, leftname, rightname, groupname=None): + """Accepts the name of a cell that performs some computation on two values. + Accepts the names of cells that contain those two values. + Creates a group that wires up the cell with those values. + Returns the group created. + + group `groupname` { + `cellname`.left = `leftname`.out; + `cellname`.right = `rightname`.out; + `groupname`.go = 1; + `groupname`.done = `cellname`.done; + } + """ + cell = self.get_cell(cellname) + groupname = groupname or f"{cellname}_group" + with self.group(groupname) as group: + cell.left = self.get_cell(leftname).out + cell.right = self.get_cell(rightname).out + cell.go = HI + group.done = cell.done + return group + def try_infer_width(self, width, left, right): """If `width` is None, try to infer it from `left` or `right`. If that fails, raise an error. diff --git a/frontends/ntt-pipeline/gen-ntt-pipeline.py b/frontends/ntt-pipeline/gen-ntt-pipeline.py index 8c66a9ba01..913ae42486 100755 --- a/frontends/ntt-pipeline/gen-ntt-pipeline.py +++ b/frontends/ntt-pipeline/gen-ntt-pipeline.py @@ -3,7 +3,6 @@ from prettytable import PrettyTable import numpy as np import calyx.py_ast as ast -from calyx.py_ast import CompVar, Cell, Stdlib import calyx.builder as cb from calyx.utils import bits_needed @@ -28,7 +27,7 @@ def reduce_parallel_control_pass(component: ast.Component, N: int, input_size: i ... """ assert ( - N is not None and 0 < N < input_size and (not (N & (N - 1))) + N and 0 < N < input_size and (not (N & (N - 1))) ), f"""N: {N} should be a power of two within bounds (0, {input_size}).""" reduced_controls = [] @@ -168,15 +167,12 @@ def fresh_comp_index(op): def mul_group(comp: cb.ComponentBuilder, stage, mul_tuple): mul_index, k, phi_index = mul_tuple - - mul = comp.get_cell(f"mult_pipe{mul_index}") - phi = comp.get_cell(f"phi{phi_index}") - reg = comp.get_cell(f"r{k}") - with comp.group(f"s{stage}_mul{mul_index}") as g: - mul.left = phi.out - mul.right = reg.out - mul.go = 1 - g.done = mul.done + comp.binary_use_names( + f"mult_pipe{mul_index}", + f"phi{phi_index}", + f"r{k}", + f"s{stage}_mul{mul_index}", + ) def op_mod_group(comp: cb.ComponentBuilder, stage, row, operations_tuple): lhs, op, mul_index = operations_tuple @@ -201,10 +197,7 @@ def op_mod_group(comp: cb.ComponentBuilder, stage, row, operations_tuple): def precursor_group(comp: cb.ComponentBuilder, row): r = comp.get_cell(f"r{row}") A = comp.get_cell(f"A{row}") - with comp.group(f"precursor_{row}") as g: - r.in_ = A.out - r.write_en = 1 - g.done = r.done + comp.reg_store(r, A.out, f"precursor_{row}") def preamble_group(comp: cb.ComponentBuilder, row): reg = comp.get_cell(f"r{row}") @@ -223,70 +216,24 @@ def preamble_group(comp: cb.ComponentBuilder, row): def epilogue_group(comp: cb.ComponentBuilder, row): input = comp.get_cell("a") A = comp.get_cell(f"A{row}") - with comp.group(f"epilogue_{row}") as epilogue: - input.addr0 = row - input.write_en = 1 - input.write_data = A.out - epilogue.done = input.done - - def cells(): - input = CompVar("a") - phis = CompVar("phis") - memories = [ - Cell( - input, Stdlib.comb_mem_d1(input_bitwidth, n, bitwidth), is_external=True - ), - Cell( - phis, Stdlib.comb_mem_d1(input_bitwidth, n, bitwidth), is_external=True - ), - ] - r_regs = [ - Cell(CompVar(f"r{r}"), Stdlib.register(input_bitwidth)) for r in range(n) - ] - A_regs = [ - Cell(CompVar(f"A{r}"), Stdlib.register(input_bitwidth)) for r in range(n) - ] - mul_regs = [ - Cell(CompVar(f"mul{i}"), Stdlib.register(input_bitwidth)) - for i in range(n // 2) - ] - phi_regs = [ - Cell(CompVar(f"phi{r}"), Stdlib.register(input_bitwidth)) for r in range(n) - ] - mod_pipes = [ - Cell( - CompVar(f"mod_pipe{r}"), - Stdlib.op("div_pipe", input_bitwidth, signed=True), - ) - for r in range(n) - ] - mult_pipes = [ - Cell( - CompVar(f"mult_pipe{i}"), - Stdlib.op("mult_pipe", input_bitwidth, signed=True), - ) - for i in range(n // 2) - ] - adds = [ - Cell(CompVar(f"add{i}"), Stdlib.op("add", input_bitwidth, signed=True)) - for i in range(n // 2) - ] - subs = [ - Cell(CompVar(f"sub{i}"), Stdlib.op("sub", input_bitwidth, signed=True)) - for i in range(n // 2) - ] - - return ( - memories - + r_regs - + A_regs - + mul_regs - + phi_regs - + mod_pipes - + mult_pipes - + adds - + subs - ) + comp.mem_store_comb_mem_d1(input, row, A.out, f"epilogue_{row}") + + def insert_cells(comp: cb.ComponentBuilder): + # memories + comp.comb_mem_d1("a", input_bitwidth, n, bitwidth, is_external=True) + comp.comb_mem_d1("phis", input_bitwidth, n, bitwidth, is_external=True) + + for r in range(n): + comp.reg(input_bitwidth, f"r{r}") # r_regs + comp.reg(input_bitwidth, f"A{r}") # A_regs + comp.reg(input_bitwidth, f"phi{r}") # phi_regs + comp.div_pipe(input_bitwidth, f"mod_pipe{r}", signed=True) # mod_pipes + + for i in range(n // 2): + comp.reg(input_bitwidth, f"mult{i}") # mul_regs + comp.mult_pipe(input_bitwidth, f"mult_pipe{i}", signed=True) # mult_pipes + comp.add(input_bitwidth, f"add{i}", signed=True) # adds + comp.sub(input_bitwidth, f"sub{i}", signed=True) # subs def wires(main: cb.ComponentBuilder): for r in range(n): @@ -325,9 +272,8 @@ def control(): pp_table(operations, multiplies, n, num_stages) prog = cb.Builder() - prog.import_("primitives/binary_operators.futil") - prog.import_("primitives/memories/comb.futil") - main = prog.component("main", cells()) + main = prog.component("main") + insert_cells(main) wires(main) main.component.controls = control() return prog.program diff --git a/tests/frontend/ntt-pipeline/ntt-4-reduced-2.expect b/tests/frontend/ntt-pipeline/ntt-4-reduced-2.expect index 7f93296aeb..afbf82d6cc 100644 --- a/tests/frontend/ntt-pipeline/ntt-4-reduced-2.expect +++ b/tests/frontend/ntt-pipeline/ntt-4-reduced-2.expect @@ -7,35 +7,35 @@ // | 3 | a[1] - a[3] * phis[1] | a[2] - a[3] * phis[3] | // +---+-----------------------+-----------------------+ import "primitives/core.futil"; -import "primitives/binary_operators.futil"; import "primitives/memories/comb.futil"; +import "primitives/binary_operators.futil"; component main() -> () { cells { @external a = comb_mem_d1(32, 4, 3); @external phis = comb_mem_d1(32, 4, 3); r0 = std_reg(32); - r1 = std_reg(32); - r2 = std_reg(32); - r3 = std_reg(32); A0 = std_reg(32); - A1 = std_reg(32); - A2 = std_reg(32); - A3 = std_reg(32); - mul0 = std_reg(32); - mul1 = std_reg(32); phi0 = std_reg(32); - phi1 = std_reg(32); - phi2 = std_reg(32); - phi3 = std_reg(32); mod_pipe0 = std_sdiv_pipe(32); + r1 = std_reg(32); + A1 = std_reg(32); + phi1 = std_reg(32); mod_pipe1 = std_sdiv_pipe(32); + r2 = std_reg(32); + A2 = std_reg(32); + phi2 = std_reg(32); mod_pipe2 = std_sdiv_pipe(32); + r3 = std_reg(32); + A3 = std_reg(32); + phi3 = std_reg(32); mod_pipe3 = std_sdiv_pipe(32); + mult0 = std_reg(32); mult_pipe0 = std_smult_pipe(32); - mult_pipe1 = std_smult_pipe(32); add0 = std_sadd(32); - add1 = std_sadd(32); sub0 = std_ssub(32); + mult1 = std_reg(32); + mult_pipe1 = std_smult_pipe(32); + add1 = std_sadd(32); sub1 = std_ssub(32); } wires { diff --git a/tests/frontend/ntt-pipeline/ntt-4.expect b/tests/frontend/ntt-pipeline/ntt-4.expect index e4ada83fa1..5f23d7d0b8 100644 --- a/tests/frontend/ntt-pipeline/ntt-4.expect +++ b/tests/frontend/ntt-pipeline/ntt-4.expect @@ -7,35 +7,35 @@ // | 3 | a[1] - a[3] * phis[1] | a[2] - a[3] * phis[3] | // +---+-----------------------+-----------------------+ import "primitives/core.futil"; -import "primitives/binary_operators.futil"; import "primitives/memories/comb.futil"; +import "primitives/binary_operators.futil"; component main() -> () { cells { @external a = comb_mem_d1(32, 4, 3); @external phis = comb_mem_d1(32, 4, 3); r0 = std_reg(32); - r1 = std_reg(32); - r2 = std_reg(32); - r3 = std_reg(32); A0 = std_reg(32); - A1 = std_reg(32); - A2 = std_reg(32); - A3 = std_reg(32); - mul0 = std_reg(32); - mul1 = std_reg(32); phi0 = std_reg(32); - phi1 = std_reg(32); - phi2 = std_reg(32); - phi3 = std_reg(32); mod_pipe0 = std_sdiv_pipe(32); + r1 = std_reg(32); + A1 = std_reg(32); + phi1 = std_reg(32); mod_pipe1 = std_sdiv_pipe(32); + r2 = std_reg(32); + A2 = std_reg(32); + phi2 = std_reg(32); mod_pipe2 = std_sdiv_pipe(32); + r3 = std_reg(32); + A3 = std_reg(32); + phi3 = std_reg(32); mod_pipe3 = std_sdiv_pipe(32); + mult0 = std_reg(32); mult_pipe0 = std_smult_pipe(32); - mult_pipe1 = std_smult_pipe(32); add0 = std_sadd(32); - add1 = std_sadd(32); sub0 = std_ssub(32); + mult1 = std_reg(32); + mult_pipe1 = std_smult_pipe(32); + add1 = std_sadd(32); sub1 = std_ssub(32); } wires {