Skip to content

Conversation

eby0303
Copy link

@eby0303 eby0303 commented Oct 16, 2025

Description

This is a draft PR for issue #1593.
I’m setting up the local environment and exploring how to implement a rewrite that fuses nested BlockDiag Ops into a single one.
I’ll update this PR with code once the setup is complete and I have an initial version of the rewrite.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1671.org.readthedocs.build/en/1671/

@eby0303
Copy link
Author

eby0303 commented Oct 16, 2025

Hey @jessegrabowski ! As suggested in the issue discussion, I’ve opened this draft PR to start working on the BlockDiag rewrite.
I’ll be setting up locally. Please feel free to share any tips or guidance for where to begin
Will update this PR as I make progress.

@jessegrabowski
Copy link
Member

I'd suggest you work in a test-driven way. Add a test to tests/tensor/rewriting/test_linalg.py with a simple nested blockwise and count the number of BlockDiag ops, and assert that there is only 1. Confim that this test fails. Then add a rewrite to tensor/rewriting/linalg.py that looks for a blockwise with a blockwise inside, and if so merges them.

For an example of how to count ops in a graph for the test, look here (BUT the whole class is overkill for your case, just take the pieces from it and write an inline version).

For a good rewrite template to get you started, I think this one is pretty readable. You will need to 1) check that the input is a BlockDiag op, 2) check that at least one of the inputs to the BlockDiag is a BlockDiag, 3) pull out the inputs from the inner BlockDiag, 4) make a new BlockDiag with n_inputs = old_n_inputs + 1 and return it, passing in all 3 inputs.

@eby0303
Copy link
Author

eby0303 commented Oct 16, 2025

@jessegrabowski ! Added a rewrite to fuse nested BlockDiagonal ops and updated test_linalg.py with a test for nested BlockDiagonal fusion.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really great first pass. You're missing some of the boiler plate around rewrites, have a look here (or anywhere in this file really) to see how to register a rewrite, and how to tell it which Op to track (pay attention to the decorators).

You also need to use pytensor.function to compile your block diagonal graph and check that the rewrite was triggered, rather than calling it directly.

from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.graph import Apply

def fuse_blockdiagonal(node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to register the rewrite using one or more of the rewrite registration decorators. I suggest @register_canonicalizeto start. You also need to pass in which Op you are registering. Check the other rewrites to see how it works.

solve_triangular,
)

from pytensor.tensor.slinalg import BlockDiagonal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you have pre-commit and you've done pre-commit install in your dev environment. You have doubled imports and other issues this tool with help you check.

Comment on lines +67 to +68
# Apply the rewrite
fused = fuse_blockdiagonal(outer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't want to actually call the rewrite. Instead, compile the function using pytensor.function, then check that the rewrite was correctly applied by looking at the compiled graph. Check here for a template to follow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling rewrite_graph followed by assert_equal_computations is also a fine test, unless you are too uncertain and want to evaluate against something provably correct

fused = fuse_blockdiagonal(outer)

# Count number of BlockDiagonal ops after fusion
nodes_after = ancestors([fused])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to look at only ancestors. Once you have a compiled function, you can look at all the nodes with fn.maker.fgraph.apply_nodes (see the SVD test I linked above)


def test_nested_blockdiag_fusion():
# Create matrix variables
x = pt.matrix("x")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of pt.matrix use pt.tensor('x', shape=(3, 3)) for example, and give all the variables static shapes. The reason for this is that I want to test that the fused blockwise correctly comes out with the correct static shape

1 for node in nodes_before
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count > 1, "Setup failed: should have nested BlockDiagonal"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You control the setup, so directly assert initial_count == 2.

But on that note, make sure to test a deeper nesting as well.

1 for node in nodes_after
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert fused_count == 1, "Nested BlockDiagonal ops were not fused"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test that the n_inputs property of the new BlockDiagonal is correctly set.

if changed:
# Return a new fused BlockDiagonal with all inputs
return BlockDiagonal(len(new_inputs))(*new_inputs)
return node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return None from a rewrite if it didn't do anything

def fuse_blockdiagonal(node):
# Only process if this node is a BlockDiagonal
if not isinstance(node.owner.op, BlockDiagonal):
return node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return None from the rewrite if it didn't do anything

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add rewrite to fuse nested BlockDiag Ops

3 participants