Skip to content

Commit 8c84fcf

Browse files
committed
Reuse LU decomposition in Solve
1 parent cff7587 commit 8c84fcf

File tree

10 files changed

+377
-7
lines changed

10 files changed

+377
-7
lines changed

pytensor/compile/mode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
490490
"fusion",
491491
"inplace",
492492
"scan_save_mem_prealloc",
493+
"reuse_lu_decomposition_multiple_solves",
494+
"scan_split_non_sequence_lu_decomposition_solve",
493495
],
494496
),
495497
)

pytensor/scan/rewriting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,26 +2561,24 @@ def scan_push_out_dot1(fgraph, node):
25612561
position=1,
25622562
)
25632563

2564-
25652564
scan_seqopt1.register(
25662565
"scan_push_out_non_seq",
25672566
in2out(scan_push_out_non_seq, ignore_newtrees=True),
25682567
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
25692568
"fast_run",
25702569
"scan",
25712570
"scan_pushout",
2572-
position=2,
2571+
position=3,
25732572
)
25742573

2575-
25762574
scan_seqopt1.register(
25772575
"scan_push_out_seq",
25782576
in2out(scan_push_out_seq, ignore_newtrees=True),
25792577
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
25802578
"fast_run",
25812579
"scan",
25822580
"scan_pushout",
2583-
position=3,
2581+
position=4,
25842582
)
25852583

25862584

@@ -2592,7 +2590,7 @@ def scan_push_out_dot1(fgraph, node):
25922590
"more_mem",
25932591
"scan",
25942592
"scan_pushout",
2595-
position=4,
2593+
position=5,
25962594
)
25972595

25982596

@@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node):
26052603
"more_mem",
26062604
"scan",
26072605
"scan_pushout",
2608-
position=5,
2606+
position=6,
26092607
)
26102608

26112609
scan_eqopt2.register(

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
114114

115115

116116
# isort: off
117+
import pytensor.tensor._linalg
117118
from pytensor.tensor import linalg
118119
from pytensor.tensor import special
119120
from pytensor.tensor import signal

pytensor/tensor/_linalg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Register rewrites
2+
import pytensor.tensor._linalg.solve
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Register rewrites in the database
2+
import pytensor.tensor._linalg.solve.rewriting
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from copy import copy
2+
3+
from pytensor.graph import Constant, graph_inputs
4+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
5+
from pytensor.scan.op import Scan
6+
from pytensor.scan.rewriting import scan_seqopt1
7+
from pytensor.tensor.basic import atleast_Nd
8+
from pytensor.tensor.blockwise import Blockwise
9+
from pytensor.tensor.elemwise import DimShuffle
10+
from pytensor.tensor.rewriting.basic import register_specialize
11+
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
12+
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
13+
from pytensor.tensor.variable import TensorVariable
14+
15+
16+
def decompose_A(A, assume_a):
17+
if assume_a == "gen":
18+
return lu_factor(A, check_finite=False)
19+
else:
20+
raise NotImplementedError
21+
22+
23+
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
24+
if assume_a == "gen":
25+
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
26+
else:
27+
raise NotImplementedError
28+
29+
30+
_SPLITTABLE_SOLVE_ASSUME_A = {"gen"}
31+
32+
33+
def _split_lu_solve_steps(fgraph, node, *, eager: bool):
34+
if not isinstance(node.op.core_op, Solve):
35+
return None
36+
37+
def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
38+
# Find the root variable of the first input to Solve
39+
# If `a` is a left expand_dims or matrix transpose (DimShuffle variants),
40+
# the root variable is the pre-DimShuffled input.
41+
# Otherwise, `a` is considered the root variable.
42+
# We also return whether the root `a` is transposed.
43+
transposed = False
44+
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
45+
if a.owner.op.is_left_expand_dims:
46+
[a] = a.owner.inputs
47+
elif is_matrix_transpose(a):
48+
[a] = a.owner.inputs
49+
transposed = True
50+
return a, transposed
51+
52+
def find_solve_clients(var, assume_a):
53+
clients = []
54+
for cl, idx in fgraph.clients[var]:
55+
if (
56+
idx == 0
57+
and isinstance(cl.op, Blockwise)
58+
and isinstance(cl.op.core_op, Solve)
59+
and (cl.op.core_op.assume_a == assume_a)
60+
):
61+
clients.append(cl)
62+
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
63+
# If it's a left expand_dims, recurse on the output
64+
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
65+
return clients
66+
67+
assume_a = node.op.core_op.assume_a
68+
69+
if assume_a not in _SPLITTABLE_SOLVE_ASSUME_A:
70+
return None
71+
72+
A, _ = get_root_A(node.inputs[0])
73+
74+
# Find Solve using A (or left expand_dims of A)
75+
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
76+
# that to the A_decomp outputs
77+
A_solve_clients_and_transpose = [
78+
(client, False) for client in find_solve_clients(A, assume_a)
79+
]
80+
81+
# Find Solves using A.T
82+
for cl, _ in fgraph.clients[A]:
83+
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
84+
A_T = cl.out
85+
A_solve_clients_and_transpose.extend(
86+
(client, True) for client in find_solve_clients(A_T, assume_a)
87+
)
88+
89+
if not eager and len(A_solve_clients_and_transpose) == 1:
90+
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
91+
# That's a "reuse" inside the inner vectorized loop
92+
batch_ndim = node.op.batch_ndim(node)
93+
(client, _) = A_solve_clients_and_transpose[0]
94+
original_A, b = client.inputs
95+
if not any(
96+
a_bcast and not b_bcast
97+
for a_bcast, b_bcast in zip(
98+
original_A.type.broadcastable[:batch_ndim],
99+
b.type.broadcastable[:batch_ndim],
100+
strict=True,
101+
)
102+
):
103+
return None
104+
105+
A_decomp = decompose_A(A, assume_a=assume_a)
106+
107+
replacements = {}
108+
for client, transposed in A_solve_clients_and_transpose:
109+
_, b = client.inputs
110+
b_ndim = client.op.core_op.b_ndim
111+
new_x = solve_lu_decomposed_system(
112+
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
113+
)
114+
[old_x] = client.outputs
115+
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
116+
copy_stack_trace(old_x, new_x)
117+
replacements[old_x] = new_x
118+
119+
return replacements
120+
121+
122+
@register_specialize
123+
@node_rewriter([Blockwise])
124+
def reuse_lu_decomposition_multiple_solves(fgraph, node):
125+
return _split_lu_solve_steps(fgraph, node, eager=False)
126+
127+
128+
@node_rewriter([Blockwise])
129+
def eager_split_lu_solve_steps(fgraph, node):
130+
return _split_lu_solve_steps(fgraph, node, eager=True)
131+
132+
133+
@node_rewriter([Scan])
134+
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
135+
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
136+
137+
The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
138+
"""
139+
scan_op: Scan = node.op
140+
non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs))
141+
new_scan_fgraph = scan_op.fgraph
142+
143+
changed = False
144+
while True:
145+
for inner_node in new_scan_fgraph.toposort():
146+
if (
147+
isinstance(inner_node.op, Blockwise)
148+
and isinstance(inner_node.op.core_op, Solve)
149+
and inner_node.op.core_op.assume_a in _SPLITTABLE_SOLVE_ASSUME_A
150+
):
151+
A, b = inner_node.inputs
152+
if all(
153+
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
154+
for root_inp in graph_inputs([A])
155+
):
156+
if new_scan_fgraph is scan_op.fgraph:
157+
# Clone the first time to avoid mutating the original fgraph
158+
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
159+
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
160+
inner_node = equiv[inner_node]
161+
162+
replace_dict = eager_split_lu_solve_steps.transform(
163+
new_scan_fgraph, inner_node
164+
)
165+
assert (
166+
isinstance(replace_dict, dict) and len(replace_dict) > 0
167+
), "Rewrite failed"
168+
new_scan_fgraph.replace_all(replace_dict.items())
169+
changed = True
170+
break # Break to start over with a fresh toposort
171+
else: # no_break
172+
break # Nothing else changed
173+
174+
if not changed:
175+
return
176+
177+
# Return a new scan to indicate that a rewrite was done
178+
new_scan_op = copy(scan_op)
179+
new_scan_op.fgraph = new_scan_fgraph
180+
new_outs = new_scan_op.make_node(*node.inputs).outputs
181+
copy_stack_trace(node.outputs, new_outs)
182+
return new_outs
183+
184+
185+
scan_seqopt1.register(
186+
scan_split_non_sequence_lu_decomposition_solve.__name__,
187+
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
188+
"fast_run",
189+
"scan",
190+
"scan_pushout",
191+
position=2,
192+
)

pytensor/tensor/rewriting/linalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
7575
if ndims < 2:
7676
return False
7777
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
78+
79+
# Allow expand_dims on the left of the transpose
80+
if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
81+
transpose_order = (
82+
*(["x"] * diff),
83+
*transpose_order,
84+
)
7885
return node.op.new_order == transpose_order
7986
return False
8087

tests/tensor/linalg/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)