diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 17a3ce9165..9b51f0593d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -60,11 +60,36 @@ solve_triangular, ) +from pytensor.tensor.slinalg import BlockDiagonal logger = logging.getLogger(__name__) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) +from pytensor.tensor.slinalg import BlockDiagonal +from pytensor.graph import Apply + +def fuse_blockdiagonal(node): + # Only process if this node is a BlockDiagonal + if not isinstance(node.owner.op, BlockDiagonal): + return node + + new_inputs = [] + changed = False + for inp in node.owner.inputs: + # If input is itself a BlockDiagonal, flatten its inputs + if inp.owner and isinstance(inp.owner.op, BlockDiagonal): + new_inputs.extend(inp.owner.inputs) + changed = True + else: + new_inputs.append(inp) + + if changed: + # Return a new fused BlockDiagonal with all inputs + return BlockDiagonal(len(new_inputs))(*new_inputs) + return node + + def is_matrix_transpose(x: TensorVariable) -> bool: """Check if a variable corresponds to a transpose of the last two axes""" node = x.owner diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 515120e446..d426f1a039 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -43,7 +43,50 @@ from tests import unittest_tools as utt from tests.test_rop import break_op +from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal + +def test_nested_blockdiag_fusion(): + # Create matrix variables + x = pt.matrix("x") + y = pt.matrix("y") + z = pt.matrix("z") + + # Nested BlockDiagonal + inner = BlockDiagonal(2)(x, y) + outer = BlockDiagonal(2)(inner, z) + + # Count number of BlockDiagonal ops before fusion + nodes_before = ancestors([outer]) + initial_count = sum( + 1 for node in nodes_before + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" + + # Apply the rewrite + fused = fuse_blockdiagonal(outer) + + # Count number of BlockDiagonal ops after fusion + nodes_after = ancestors([fused]) + fused_count = sum( + 1 for node in nodes_after + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + ) + assert fused_count == 1, "Nested BlockDiagonal ops were not fused" + + # Check that all original inputs are preserved + fused_inputs = [ + inp + for node in ancestors([fused]) + if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) + for inp in node.owner.inputs + ] + assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" + + + + def test_matrix_inverse_rop_lop(): rtol = 1e-7 if config.floatX == "float64" else 1e-5 mx = matrix("mx")