diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index 3c2eea81eb..55c0835249 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -408,7 +408,10 @@ def comb_mem_d2( """Generate a StdMemD2 cell.""" self.prog.import_("primitives/memories/comb.futil") return self.cell( - name, ast.Stdlib.comb_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), is_external, is_ref + name, + ast.Stdlib.comb_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), + is_external, + is_ref, ) def seq_mem_d1( @@ -440,7 +443,10 @@ def seq_mem_d2( """Generate a SeqMemD2 cell.""" self.prog.import_("primitives/memories/seq.futil") return self.cell( - name, ast.Stdlib.seq_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), is_external, is_ref + name, + ast.Stdlib.seq_mem_d2(bitwidth, len0, len1, idx_size0, idx_size1), + is_external, + is_ref, ) def binary( @@ -787,7 +793,7 @@ def reg_store(self, reg, val, groupname=None): def mem_load_d1(self, mem, i, reg, groupname): """Inserts wiring into `self` to perform `reg := mem[i]`, - where `mem` is a seq_d1 memory or a comb_mem_d1 memory + where `mem` is a seq_d1 memory or a comb_mem_d1 memory """ assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1() is_comb = mem.is_comb_mem_d1() @@ -805,7 +811,7 @@ def mem_load_d1(self, mem, i, reg, groupname): def mem_load_d2(self, mem, i, j, reg, groupname): """Inserts wiring into `self` to perform `reg := mem[i]`, - where `mem` is a seq_d2 memory or a comb_mem_d2 memory + where `mem` is a seq_d2 memory or a comb_mem_d2 memory """ assert mem.is_seq_mem_d2() or mem.is_comb_mem_d2() is_comb = mem.is_comb_mem_d2() @@ -824,7 +830,7 @@ def mem_load_d2(self, mem, i, j, reg, groupname): def mem_store_d1(self, mem, i, val, groupname): """Inserts wiring into `self` to perform `mem[i] := val`, - where `mem` is a seq_d1 memory or a comb_mem_d1 memory + where `mem` is a seq_d1 memory or a comb_mem_d1 memory """ assert mem.is_seq_mem_d1() or mem.is_comb_mem_d1() is_comb = mem.is_comb_mem_d1() @@ -839,7 +845,7 @@ def mem_store_d1(self, mem, i, val, groupname): def mem_store_d2(self, mem, i, j, val, groupname): """Inserts wiring into `self` to perform `mem[i] := val`, - where `mem` is a seq_d2 memory or a comb_mem_d2 memory + where `mem` is a seq_d2 memory or a comb_mem_d2 memory """ assert mem.is_seq_mem_d2() or mem.is_comb_mem_d2() is_comb = mem.is_comb_mem_d2() @@ -1092,6 +1098,16 @@ def invoke(cell: CellBuilder, **kwargs) -> ast.Invoke: The keyword arguments should have the form `in_*`, `out_*`, or `ref_*`, where `*` is the name of an input port, output port, or ref cell on the invoked cell. """ + + def try_infer_width(x): + width = cell.infer_width(x) + if not width: + raise WidthInferenceError( + f"Could not infer width of input '{x}' when invoking cell '{cell.name}'. " + "Consider using `const(width, value)` instead of `value`." + ) + return width + return ast.Invoke( cell._cell.id, [ @@ -1099,7 +1115,7 @@ def invoke(cell: CellBuilder, **kwargs) -> ast.Invoke: k[3:], ( ( - const(cell.infer_width(k[3:]), v).expr + const(try_infer_width(k[3:]), v).expr if isinstance(v, int) else ExprBuilder.unwrap(v) )