diff --git a/aeppl/printing.py b/aeppl/printing.py index c27541be..287919e8 100644 --- a/aeppl/printing.py +++ b/aeppl/printing.py @@ -380,7 +380,7 @@ def process_shape_info(cls, output: Variable, pstate: Optional[PrinterStateType] try: old_precedence = getattr(pstate, "precedence", None) pstate.precedence = new_precedence - _s_i_out = shape_feature.get_shape(output, i) + _s_i_out = shape_feature.get_shape(pstate.fgraph, output, i) if not isinstance(_s_i_out, (Constant, TensorVariable)): s_i_out = pstate.pprinter.process(_s_i_out, pstate) diff --git a/aeppl/scan.py b/aeppl/scan.py index bc649e99..eeae1c2f 100644 --- a/aeppl/scan.py +++ b/aeppl/scan.py @@ -439,7 +439,8 @@ def update_scan_value_vars( # graph, so we use the shape feature to (hopefully) get the shape # without the entire `Scan` itself. full_out_shape = tuple( - fgraph.shape_feature.get_shape(full_out, i) for i in range(full_out.ndim) + fgraph.shape_feature.get_shape(fgraph, full_out, i) + for i in range(full_out.ndim) ) new_val_var = at.empty(full_out_shape, dtype=full_out.dtype)