Skip to content

Commit

Permalink
Add missing argument to ShapeFeature.get_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 28, 2023
1 parent 1706bb3 commit 4134b22
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aeppl/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def process_shape_info(cls, output: Variable, pstate: Optional[PrinterStateType]
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
_s_i_out = shape_feature.get_shape(output, i)
_s_i_out = shape_feature.get_shape(pstate.fgraph, output, i)

if not isinstance(_s_i_out, (Constant, TensorVariable)):
s_i_out = pstate.pprinter.process(_s_i_out, pstate)
Expand Down
3 changes: 2 additions & 1 deletion aeppl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def update_scan_value_vars(
# graph, so we use the shape feature to (hopefully) get the shape
# without the entire `Scan` itself.
full_out_shape = tuple(
fgraph.shape_feature.get_shape(full_out, i) for i in range(full_out.ndim)
fgraph.shape_feature.get_shape(fgraph, full_out, i)
for i in range(full_out.ndim)
)
new_val_var = at.empty(full_out_shape, dtype=full_out.dtype)

Expand Down

0 comments on commit 4134b22

Please sign in to comment.