-
Notifications
You must be signed in to change notification settings - Fork 130
Reuse LU decomposition in Solve #1396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
50146f6
to
903a86e
Compare
There's a test failing @jessegrabowski, it seems like our lu_solve gives different precision output compared to getrs (I assume worse, but I didn't check). Anyway this makes me wonder, don't we want to wrap that lapack Op instead of using our double triangular and pivots op thing? It would reduce the cost of splitting the Op, because all that logic will be inside Lapack. It's also a more clean graph? |
Yes we can directly wrap the Op. I was just having trouble with gradients when I did it that way. If you recall, we had a call where I compared the two approaches (core ops vs new op) |
You couldn't figure out the gradients of just the lapack Op or something else? |
Yes, I couldn't get correct gradients. What I thought would be the correct, straight-forward answer ended up being wrong. I got frustrated quickly and didn't spend a super long time on it |
Okay, I mean I'm okay with this, just need to tweak the tolerance, and perhaps keep an issue to revisit more carefully later if it turns out problematic. It's a float32 thing and the differences are not crazy |
My guess is that there's sequential loss of precision by doing two solves vs one. We do have the fig leaf of "well, this is what jax does" at least! I'll at least open a branch/PR with the Op version of lu_solve, and we can work on getting the gradient to work when we have some free time (never). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enables the reuse of LU decompositions across Solve calls for distinct, blockwise, and scan-based operations. Key changes include updated test tolerances and modes in blockwise and rewriting tests, new tests to verify the LU rewrite behavior, and implementation of several LU-decomposition rewrite functions along with adjustments in the scan rewriting module.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
tests/tensor/test_blockwise.py | Adjusted tolerance settings and updated mode exclusion parameters. |
tests/tensor/linalg/test_rewriting.py | Added tests to validate the new LU decomposition rewrite logic. |
pytensor/tensor/rewriting/linalg.py | Minor update to is_matrix_transpose to correctly handle expanded dims. |
pytensor/tensor/_linalg/solve/rewriting.py | Implemented LU reuse rewrites and helper functions for Solve operations. |
pytensor/tensor/_linalg/solve/init.py | Registered the new solve rewrites. |
pytensor/tensor/_linalg/init.py | Registered LU decomposition rewrites. |
pytensor/tensor/init.py | Imported the updated linalg module. |
pytensor/scan/rewriting.py | Adjusted positions for scan rewrite registrations. |
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (92.45%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1396 +/- ##
==========================================
+ Coverage 82.08% 82.12% +0.04%
==========================================
Files 208 211 +3
Lines 49565 49682 +117
Branches 8792 8812 +20
==========================================
+ Hits 40685 40802 +117
+ Misses 6706 6702 -4
- Partials 2174 2178 +4
🚀 New features to boost your workflow:
|
@@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node): | |||
"more_mem", | |||
"scan", | |||
"scan_pushout", | |||
position=5, | |||
position=6, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ordering necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we want the rewrite that splits the LU before the pushout which is the one that actually removes it from the inner graph.
I could have used decimals, but it makes sense to have something whole between the previous rewrite and this
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.elemwise import DimShuffle | ||
from pytensor.tensor.rewriting.basic import register_specialize | ||
from pytensor.tensor.rewriting.linalg import is_matrix_transpose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you choose to put these rewrites in tensor._linalg.solve.rewriting
instead of in tensor.rewriting._linalg.solve
? It breaks the usual pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it does. For instance the random rewrites are in tensor/random/rewriting, not tensor/rewriting/random
|
||
assume_a = node.op.core_op.assume_a | ||
|
||
if assume_a != "gen": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if assume_a != "gen": | |
if assume_a not in SUPPORTED_ASSUMPTIONS: |
You hard-coded a check for gen
in a few places; it might be easier to update in future if we just have a single constant with all the assume_a arguments we're allowing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will actually be a bit trickier than this, because different backends have different support for these decompositions. You can see how it looks like in the last commit of #1382
replace_dict = eager_split_lu_solve_steps.transform( | ||
new_scan_fgraph, inner_node | ||
) | ||
assert isinstance(replace_dict, dict) and len(replace_dict) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please raise an actual error here. I understand it's a sanity check, but hitting these in the future is very frustrating (see pymc-devs/pymc#7780)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an error message to the assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why resistant to an actual error? Because the user never sees this (it just causes the rewrite to abort?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a sanity check but an assert raises an error as well: AssertionError ;)
1d49175
to
249a69a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces optimizations to reuse LU decompositions across multiple Solve
operations, reducing redundant factorizations in both blockwise and scanned contexts.
- Add rewrites (
reuse_lu_decomposition_multiple_solves
,scan_split_non_sequence_lu_decomposition_solve
) to factorA
once and reuse it. - Add tests in
test_rewriting.py
covering forward/backward, blockwise, and scan scenarios. - Register the new rewrites via imports and adjust scan-rewrite ordering.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/tensor/test_blockwise.py | Update RNG seed logic and exclude the new rewrite in tests |
tests/tensor/linalg/test_rewriting.py | New tests for LU-reuse rewrites |
pytensor/tensor/rewriting/linalg.py | Extend is_matrix_transpose to allow left expand_dims |
pytensor/tensor/_linalg/solve/rewriting.py | Implement LU-decomposition reuse rewrites |
pytensor/tensor/_linalg/solve/init.py | Register the new rewrite module |
pytensor/tensor/_linalg/init.py | Import solve package to trigger rewrite registration |
pytensor/tensor/init.py | Import _linalg submodule for rewrite registration |
pytensor/scan/rewriting.py | Adjust positions of existing scan-pushout rewrites |
Comments suppressed due to low confidence (3)
tests/tensor/test_blockwise.py:331
- [nitpick] Using the sum of character codes loses ordering information and can lead to seed collisions; consider using a hash (e.g., hashlib) or include order-sensitive data to generate a more robust RNG seed.
seed = sum(map(ord, str(cls.core_op) + cls.signature))
pytensor/tensor/_linalg/solve/rewriting.py:97
- The
zip(..., strict=True)
keyword is only available in Python 3.10+ and may break compatibility; removestrict=True
or guard its use for older Python versions.
for a_bcast, b_bcast in zip(
pytensor/scan/rewriting.py:2571
- Changing the registration position of the
scan_push_out_non_seq
rewrite may affect the overall rewrite ordering; verify that the new ordering doesn't interfere with other scan optimizations.
position=3,
249a69a
to
8c84fcf
Compare
8c84fcf
to
2b4132b
Compare
2b4132b
to
7278076
Compare
Requires #1394
This PR adds the rewrites to reuse an LU decomposition of the same A matrix across multiple solves:
It does not propagate the
check_finite
flag which I think we should rework out of SolveCloses #1374
📚 Documentation preview 📚: https://pytensor--1396.org.readthedocs.build/en/1396/