diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index a928073a15..5509c59280 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -446,6 +446,24 @@ def comb_mem_d2( is_external, is_ref, ) + + def comb_mem_dn( + self, + name: str, + bitwidth: int, + lens: List[int], + idx_size: int, + is_external: bool = False, + is_ref: bool = False, + ) -> CellBuilder: + """Generate a StdMemD1 cell that abstracts to an n-dimensional memory.""" + self.prog.import_("primitives/memories/comb.futil") + prod = 1 + for l in lens: + prod *= l + return self.cell( + name, ast.Stdlib.comb_mem_d1(bitwidth, prod, idx_size), is_external, is_ref + ) def seq_mem_d1( self, @@ -482,6 +500,24 @@ def seq_mem_d2( is_ref, ) + def seq_mem_dn( + self, + name: str, + bitwidth: int, + lens : List[int], + idx_size: int, + is_external: bool = False, + is_ref: bool=False + ) -> CellBuilder: + """Generate a SeqMemD1 cell that abstracts to an n-dimensional memory.""" + self.prog.import_("primitives/memories/seq.futil") + prod = 1 + for l in lens: + prod *= l + return self.cell( + name, ast.Stdlib.seq_mem_d1(bitwidth, prod, idx_size), is_external, is_ref + ) + def binary( self, operation: str, @@ -898,6 +934,48 @@ def mem_latch_d2(self, mem, i, j, groupname): latch_grp.done = mem.done return latch_grp + def flatten_idx(self, dims, indices): + """Translate an n-dimensional index into a corresponding 1d index""" + assert len(dims) == len(indices) + i = len(indices) - 1 + prod = 1 + total = indices[-1] + while i > 0: + prod *= dims[i] + total += prod * indices[i-1] + i -= 1 + return total + + def mem_load_dn(self, mem, dims, indices, reg, groupname): + """Inserts wiring into `self` to perform `reg := mem[i1][i2]...[in]`, + where `mem` is a seq_dn memory or a comb_mem_d1 memory + """ + assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1() + is_comb = mem.is_comb_mem_d1() + with self.group(groupname) as load_grp: + mem.addr0 = self.flatten_idx(dims, indices) + if is_comb: + reg.write_en = 1 + reg.in_ = mem.read_data + else: + mem.content_en = 1 + reg.write_en = mem.done @ 1 + reg.in_ = mem.done @ mem.read_data + load_grp.done = reg.done + return load_grp + + def mem_latch_dn(self, mem, dims, indices, groupname): + """Inserts wiring into `self` to latch `mem[i]`, + where `mem` is a seq_mem_d1 memory. + A user can later read `mem.out` and get the latched value. + """ + assert mem.is_seq_mem_d1() + with self.group(groupname) as latch_grp: + mem.addr0 = self.flatten_idx(dims, indices) + mem.content_en = HI + latch_grp.done = mem.done + return latch_grp + def mem_store_d1(self, mem, i, val, groupname): """Inserts wiring into `self` to perform `mem[i] := val`, where `mem` is a seq_d1 memory or a comb_mem_d1 memory @@ -928,6 +1006,21 @@ def mem_store_d2(self, mem, i, j, val, groupname): if not is_comb: mem.content_en = 1 return store_grp + + def mem_store_dn(self, mem, dims, indices, val, groupname): + """Inserts wiring into `self` to perform `mem[i] := val`, + where `mem` is a seq_d2 memory or a comb_mem_d2 memory + """ + assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1() + is_comb = mem.is_comb_mem_d1() + with self.group(groupname) as store_grp: + mem.addr0 = self.flatten_idx(dims, indices) + mem.write_en = 1 + mem.write_data = val + store_grp.done = mem.done + if not is_comb: + mem.content_en = 1 + return store_grp def mem_load_to_mem(self, mem, i, ans, j, groupname): """Inserts wiring into `self` to perform `ans[j] := mem[i]`,