-
Notifications
You must be signed in to change notification settings - Fork 145
WIP: Add rewrite to fuse nested BlockDiag Ops #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hey @jessegrabowski ! As suggested in the issue discussion, I’ve opened this draft PR to start working on the BlockDiag rewrite. |
I'd suggest you work in a test-driven way. Add a test to For an example of how to count ops in a graph for the test, look here (BUT the whole class is overkill for your case, just take the pieces from it and write an inline version). For a good rewrite template to get you started, I think this one is pretty readable. You will need to 1) check that the input is a BlockDiag op, 2) check that at least one of the inputs to the BlockDiag is a BlockDiag, 3) pull out the inputs from the inner BlockDiag, 4) make a new BlockDiag with |
@jessegrabowski ! Added a rewrite to fuse nested BlockDiagonal ops and updated test_linalg.py with a test for nested BlockDiagonal fusion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great first pass. You're missing some of the boiler plate around rewrites, have a look here (or anywhere in this file really) to see how to register a rewrite, and how to tell it which Op to track (pay attention to the decorators).
You also need to use pytensor.function
to compile your block diagonal graph and check that the rewrite was triggered, rather than calling it directly.
from pytensor.tensor.slinalg import BlockDiagonal | ||
from pytensor.graph import Apply | ||
|
||
def fuse_blockdiagonal(node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to register the rewrite using one or more of the rewrite registration decorators. I suggest @register_canonicalize
to start. You also need to pass in which Op you are registering. Check the other rewrites to see how it works.
solve_triangular, | ||
) | ||
|
||
from pytensor.tensor.slinalg import BlockDiagonal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you have pre-commit
and you've done pre-commit install
in your dev environment. You have doubled imports and other issues this tool with help you check.
# Apply the rewrite | ||
fused = fuse_blockdiagonal(outer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't want to actually call the rewrite. Instead, compile the function using pytensor.function
, then check that the rewrite was correctly applied by looking at the compiled graph. Check here for a template to follow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling rewrite_graph
followed by assert_equal_computations
is also a fine test, unless you are too uncertain and want to evaluate against something provably correct
fused = fuse_blockdiagonal(outer) | ||
|
||
# Count number of BlockDiagonal ops after fusion | ||
nodes_after = ancestors([fused]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to look at only ancestors. Once you have a compiled function, you can look at all the nodes with fn.maker.fgraph.apply_nodes
(see the SVD test I linked above)
|
||
def test_nested_blockdiag_fusion(): | ||
# Create matrix variables | ||
x = pt.matrix("x") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of pt.matrix
use pt.tensor('x', shape=(3, 3))
for example, and give all the variables static shapes. The reason for this is that I want to test that the fused blockwise correctly comes out with the correct static shape
1 for node in nodes_before | ||
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
) | ||
assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You control the setup, so directly assert initial_count == 2
.
But on that note, make sure to test a deeper nesting as well.
1 for node in nodes_after | ||
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
) | ||
assert fused_count == 1, "Nested BlockDiagonal ops were not fused" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also test that the n_inputs
property of the new BlockDiagonal is correctly set.
if changed: | ||
# Return a new fused BlockDiagonal with all inputs | ||
return BlockDiagonal(len(new_inputs))(*new_inputs) | ||
return node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return None
from a rewrite if it didn't do anything
def fuse_blockdiagonal(node): | ||
# Only process if this node is a BlockDiagonal | ||
if not isinstance(node.owner.op, BlockDiagonal): | ||
return node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return None
from the rewrite if it didn't do anything
Description
This is a draft PR for issue #1593.
I’m setting up the local environment and exploring how to implement a rewrite that fuses nested BlockDiag Ops into a single one.
I’ll update this PR with code once the setup is complete and I have an initial version of the rewrite.
Related Issue
BlockDiag
Ops #1593Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1671.org.readthedocs.build/en/1671/