Skip to content

Reuse LU decomposition in Solve #1396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 8, 2025

Requires #1394

This PR adds the rewrites to reuse an LU decomposition of the same A matrix across multiple solves:

  1. Distinct Solve operations such as in a graph containing both the forward and backward pass of Solve)
  2. Blockwise Solve where A is broadcasted by b
  3. Scan Solve where A is a composed only of non-sequences inputs, but not b

It does not propagate the check_finite flag which I think we should rework out of Solve

Closes #1374


📚 Documentation preview 📚: https://pytensor--1396.org.readthedocs.build/en/1396/

@ricardoV94 ricardoV94 marked this pull request as draft May 8, 2025 17:10
@ricardoV94 ricardoV94 force-pushed the decompose_lu_solve branch 3 times, most recently from 50146f6 to 903a86e Compare May 9, 2025 10:39
@ricardoV94
Copy link
Member Author

ricardoV94 commented May 9, 2025

There's a test failing @jessegrabowski, it seems like our lu_solve gives different precision output compared to getrs (I assume worse, but I didn't check).

Anyway this makes me wonder, don't we want to wrap that lapack Op instead of using our double triangular and pivots op thing? It would reduce the cost of splitting the Op, because all that logic will be inside Lapack. It's also a more clean graph?

@jessegrabowski
Copy link
Member

Yes we can directly wrap the Op. I was just having trouble with gradients when I did it that way. If you recall, we had a call where I compared the two approaches (core ops vs new op)

@ricardoV94
Copy link
Member Author

You couldn't figure out the gradients of just the lapack Op or something else?

@jessegrabowski
Copy link
Member

Yes, I couldn't get correct gradients. What I thought would be the correct, straight-forward answer ended up being wrong. I got frustrated quickly and didn't spend a super long time on it

@ricardoV94
Copy link
Member Author

Okay, I mean I'm okay with this, just need to tweak the tolerance, and perhaps keep an issue to revisit more carefully later if it turns out problematic. It's a float32 thing and the differences are not crazy

@jessegrabowski
Copy link
Member

jessegrabowski commented May 10, 2025

My guess is that there's sequential loss of precision by doing two solves vs one. We do have the fig leaf of "well, this is what jax does" at least!

I'll at least open a branch/PR with the Op version of lu_solve, and we can work on getting the gradient to work when we have some free time (never).

@ricardoV94 ricardoV94 marked this pull request as ready for review May 14, 2025 04:11
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables the reuse of LU decompositions across Solve calls for distinct, blockwise, and scan-based operations. Key changes include updated test tolerances and modes in blockwise and rewriting tests, new tests to verify the LU rewrite behavior, and implementation of several LU-decomposition rewrite functions along with adjustments in the scan rewriting module.

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/tensor/test_blockwise.py Adjusted tolerance settings and updated mode exclusion parameters.
tests/tensor/linalg/test_rewriting.py Added tests to validate the new LU decomposition rewrite logic.
pytensor/tensor/rewriting/linalg.py Minor update to is_matrix_transpose to correctly handle expanded dims.
pytensor/tensor/_linalg/solve/rewriting.py Implemented LU reuse rewrites and helper functions for Solve operations.
pytensor/tensor/_linalg/solve/init.py Registered the new solve rewrites.
pytensor/tensor/_linalg/init.py Registered LU decomposition rewrites.
pytensor/tensor/init.py Imported the updated linalg module.
pytensor/scan/rewriting.py Adjusted positions for scan rewrite registrations.

Copy link

codecov bot commented May 14, 2025

Codecov Report

Attention: Patch coverage is 92.45283% with 8 lines in your changes missing coverage. Please review.

Project coverage is 82.12%. Comparing base (24a2234) to head (7278076).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/_linalg/solve/rewriting.py 94.11% 2 Missing and 4 partials ⚠️
pytensor/tensor/rewriting/linalg.py 0.00% 1 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (92.45%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1396      +/-   ##
==========================================
+ Coverage   82.08%   82.12%   +0.04%     
==========================================
  Files         208      211       +3     
  Lines       49565    49682     +117     
  Branches     8792     8812      +20     
==========================================
+ Hits        40685    40802     +117     
+ Misses       6706     6702       -4     
- Partials     2174     2178       +4     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 84.72% <ø> (ø)
pytensor/scan/rewriting.py 82.89% <ø> (ø)
pytensor/tensor/_linalg/__init__.py 100.00% <100.00%> (ø)
pytensor/tensor/_linalg/solve/__init__.py 100.00% <100.00%> (ø)
pytensor/tensor/rewriting/linalg.py 92.28% <0.00%> (-0.48%) ⬇️
pytensor/tensor/_linalg/solve/rewriting.py 94.11% <94.11%> (ø)

... and 4 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node):
"more_mem",
"scan",
"scan_pushout",
position=5,
position=6,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ordering necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we want the rewrite that splits the LU before the pushout which is the one that actually removes it from the inner graph.

I could have used decimals, but it makes sense to have something whole between the previous rewrite and this

from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you choose to put these rewrites in tensor._linalg.solve.rewriting instead of in tensor.rewriting._linalg.solve ? It breaks the usual pattern.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it does. For instance the random rewrites are in tensor/random/rewriting, not tensor/rewriting/random


assume_a = node.op.core_op.assume_a

if assume_a != "gen":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if assume_a != "gen":
if assume_a not in SUPPORTED_ASSUMPTIONS:

You hard-coded a check for gen in a few places; it might be easier to update in future if we just have a single constant with all the assume_a arguments we're allowing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will actually be a bit trickier than this, because different backends have different support for these decompositions. You can see how it looks like in the last commit of #1382

replace_dict = eager_split_lu_solve_steps.transform(
new_scan_fgraph, inner_node
)
assert isinstance(replace_dict, dict) and len(replace_dict) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please raise an actual error here. I understand it's a sanity check, but hitting these in the future is very frustrating (see pymc-devs/pymc#7780)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an error message to the assert

Copy link
Member

@jessegrabowski jessegrabowski May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why resistant to an actual error? Because the user never sees this (it just causes the rewrite to abort?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a sanity check but an assert raises an error as well: AssertionError ;)

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces optimizations to reuse LU decompositions across multiple Solve operations, reducing redundant factorizations in both blockwise and scanned contexts.

  • Add rewrites (reuse_lu_decomposition_multiple_solves, scan_split_non_sequence_lu_decomposition_solve) to factor A once and reuse it.
  • Add tests in test_rewriting.py covering forward/backward, blockwise, and scan scenarios.
  • Register the new rewrites via imports and adjust scan-rewrite ordering.

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/tensor/test_blockwise.py Update RNG seed logic and exclude the new rewrite in tests
tests/tensor/linalg/test_rewriting.py New tests for LU-reuse rewrites
pytensor/tensor/rewriting/linalg.py Extend is_matrix_transpose to allow left expand_dims
pytensor/tensor/_linalg/solve/rewriting.py Implement LU-decomposition reuse rewrites
pytensor/tensor/_linalg/solve/init.py Register the new rewrite module
pytensor/tensor/_linalg/init.py Import solve package to trigger rewrite registration
pytensor/tensor/init.py Import _linalg submodule for rewrite registration
pytensor/scan/rewriting.py Adjust positions of existing scan-pushout rewrites
Comments suppressed due to low confidence (3)

tests/tensor/test_blockwise.py:331

  • [nitpick] Using the sum of character codes loses ordering information and can lead to seed collisions; consider using a hash (e.g., hashlib) or include order-sensitive data to generate a more robust RNG seed.
seed = sum(map(ord, str(cls.core_op) + cls.signature))

pytensor/tensor/_linalg/solve/rewriting.py:97

  • The zip(..., strict=True) keyword is only available in Python 3.10+ and may break compatibility; remove strict=True or guard its use for older Python versions.
for a_bcast, b_bcast in zip(

pytensor/scan/rewriting.py:2571

  • Changing the registration position of the scan_push_out_non_seq rewrite may affect the overall rewrite ordering; verify that the new ordering doesn't interfere with other scan optimizations.
    position=3,

@ricardoV94 ricardoV94 force-pushed the decompose_lu_solve branch from 249a69a to 8c84fcf Compare May 19, 2025 11:41
@ricardoV94 ricardoV94 force-pushed the decompose_lu_solve branch from 8c84fcf to 2b4132b Compare May 19, 2025 11:49
@ricardoV94 ricardoV94 force-pushed the decompose_lu_solve branch from 2b4132b to 7278076 Compare May 19, 2025 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Solve to Solve LU optimization
2 participants