|
| 1 | +from copy import copy |
| 2 | + |
| 3 | +from pytensor.graph import Constant, graph_inputs |
| 4 | +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter |
| 5 | +from pytensor.scan.op import Scan |
| 6 | +from pytensor.scan.rewriting import scan_seqopt1 |
| 7 | +from pytensor.tensor.basic import atleast_Nd |
| 8 | +from pytensor.tensor.blockwise import Blockwise |
| 9 | +from pytensor.tensor.elemwise import DimShuffle |
| 10 | +from pytensor.tensor.rewriting.basic import register_specialize |
| 11 | +from pytensor.tensor.rewriting.linalg import is_matrix_transpose |
| 12 | +from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve |
| 13 | +from pytensor.tensor.variable import TensorVariable |
| 14 | + |
| 15 | + |
| 16 | +def decompose_A(A, assume_a): |
| 17 | + if assume_a == "gen": |
| 18 | + return lu_factor(A, check_finite=False) |
| 19 | + else: |
| 20 | + raise NotImplementedError |
| 21 | + |
| 22 | + |
| 23 | +def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False): |
| 24 | + if assume_a == "gen": |
| 25 | + return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed) |
| 26 | + else: |
| 27 | + raise NotImplementedError |
| 28 | + |
| 29 | + |
| 30 | +_SPLITTABLE_SOLVE_ASSUME_A = {"gen"} |
| 31 | + |
| 32 | + |
| 33 | +def _split_lu_solve_steps(fgraph, node, *, eager: bool): |
| 34 | + if not isinstance(node.op.core_op, Solve): |
| 35 | + return None |
| 36 | + |
| 37 | + def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]: |
| 38 | + # Find the root variable of the first input to Solve |
| 39 | + # If `a` is a left expand_dims or matrix transpose (DimShuffle variants), |
| 40 | + # the root variable is the pre-DimShuffled input. |
| 41 | + # Otherwise, `a` is considered the root variable. |
| 42 | + # We also return whether the root `a` is transposed. |
| 43 | + transposed = False |
| 44 | + if a.owner is not None and isinstance(a.owner.op, DimShuffle): |
| 45 | + if a.owner.op.is_left_expand_dims: |
| 46 | + [a] = a.owner.inputs |
| 47 | + elif is_matrix_transpose(a): |
| 48 | + [a] = a.owner.inputs |
| 49 | + transposed = True |
| 50 | + return a, transposed |
| 51 | + |
| 52 | + def find_solve_clients(var, assume_a): |
| 53 | + clients = [] |
| 54 | + for cl, idx in fgraph.clients[var]: |
| 55 | + if ( |
| 56 | + idx == 0 |
| 57 | + and isinstance(cl.op, Blockwise) |
| 58 | + and isinstance(cl.op.core_op, Solve) |
| 59 | + and (cl.op.core_op.assume_a == assume_a) |
| 60 | + ): |
| 61 | + clients.append(cl) |
| 62 | + elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims: |
| 63 | + # If it's a left expand_dims, recurse on the output |
| 64 | + clients.extend(find_solve_clients(cl.outputs[0], assume_a)) |
| 65 | + return clients |
| 66 | + |
| 67 | + assume_a = node.op.core_op.assume_a |
| 68 | + |
| 69 | + if assume_a not in _SPLITTABLE_SOLVE_ASSUME_A: |
| 70 | + return None |
| 71 | + |
| 72 | + A, _ = get_root_A(node.inputs[0]) |
| 73 | + |
| 74 | + # Find Solve using A (or left expand_dims of A) |
| 75 | + # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate |
| 76 | + # that to the A_decomp outputs |
| 77 | + A_solve_clients_and_transpose = [ |
| 78 | + (client, False) for client in find_solve_clients(A, assume_a) |
| 79 | + ] |
| 80 | + |
| 81 | + # Find Solves using A.T |
| 82 | + for cl, _ in fgraph.clients[A]: |
| 83 | + if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out): |
| 84 | + A_T = cl.out |
| 85 | + A_solve_clients_and_transpose.extend( |
| 86 | + (client, True) for client in find_solve_clients(A_T, assume_a) |
| 87 | + ) |
| 88 | + |
| 89 | + if not eager and len(A_solve_clients_and_transpose) == 1: |
| 90 | + # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager) |
| 91 | + # That's a "reuse" inside the inner vectorized loop |
| 92 | + batch_ndim = node.op.batch_ndim(node) |
| 93 | + (client, _) = A_solve_clients_and_transpose[0] |
| 94 | + original_A, b = client.inputs |
| 95 | + if not any( |
| 96 | + a_bcast and not b_bcast |
| 97 | + for a_bcast, b_bcast in zip( |
| 98 | + original_A.type.broadcastable[:batch_ndim], |
| 99 | + b.type.broadcastable[:batch_ndim], |
| 100 | + strict=True, |
| 101 | + ) |
| 102 | + ): |
| 103 | + return None |
| 104 | + |
| 105 | + A_decomp = decompose_A(A, assume_a=assume_a) |
| 106 | + |
| 107 | + replacements = {} |
| 108 | + for client, transposed in A_solve_clients_and_transpose: |
| 109 | + _, b = client.inputs |
| 110 | + b_ndim = client.op.core_op.b_ndim |
| 111 | + new_x = solve_lu_decomposed_system( |
| 112 | + A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed |
| 113 | + ) |
| 114 | + [old_x] = client.outputs |
| 115 | + new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) |
| 116 | + copy_stack_trace(old_x, new_x) |
| 117 | + replacements[old_x] = new_x |
| 118 | + |
| 119 | + return replacements |
| 120 | + |
| 121 | + |
| 122 | +@register_specialize |
| 123 | +@node_rewriter([Blockwise]) |
| 124 | +def reuse_lu_decomposition_multiple_solves(fgraph, node): |
| 125 | + return _split_lu_solve_steps(fgraph, node, eager=False) |
| 126 | + |
| 127 | + |
| 128 | +@node_rewriter([Blockwise]) |
| 129 | +def eager_split_lu_solve_steps(fgraph, node): |
| 130 | + return _split_lu_solve_steps(fgraph, node, eager=True) |
| 131 | + |
| 132 | + |
| 133 | +@node_rewriter([Scan]) |
| 134 | +def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): |
| 135 | + """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step. |
| 136 | +
|
| 137 | + The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite. |
| 138 | + """ |
| 139 | + scan_op: Scan = node.op |
| 140 | + non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs)) |
| 141 | + new_scan_fgraph = scan_op.fgraph |
| 142 | + |
| 143 | + changed = False |
| 144 | + while True: |
| 145 | + for inner_node in new_scan_fgraph.toposort(): |
| 146 | + if ( |
| 147 | + isinstance(inner_node.op, Blockwise) |
| 148 | + and isinstance(inner_node.op.core_op, Solve) |
| 149 | + and inner_node.op.core_op.assume_a in _SPLITTABLE_SOLVE_ASSUME_A |
| 150 | + ): |
| 151 | + A, b = inner_node.inputs |
| 152 | + if all( |
| 153 | + (isinstance(root_inp, Constant) or (root_inp in non_sequences)) |
| 154 | + for root_inp in graph_inputs([A]) |
| 155 | + ): |
| 156 | + if new_scan_fgraph is scan_op.fgraph: |
| 157 | + # Clone the first time to avoid mutating the original fgraph |
| 158 | + new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv() |
| 159 | + non_sequences = {equiv[non_seq] for non_seq in non_sequences} |
| 160 | + inner_node = equiv[inner_node] |
| 161 | + |
| 162 | + replace_dict = eager_split_lu_solve_steps.transform( |
| 163 | + new_scan_fgraph, inner_node |
| 164 | + ) |
| 165 | + assert ( |
| 166 | + isinstance(replace_dict, dict) and len(replace_dict) > 0 |
| 167 | + ), "Rewrite failed" |
| 168 | + new_scan_fgraph.replace_all(replace_dict.items()) |
| 169 | + changed = True |
| 170 | + break # Break to start over with a fresh toposort |
| 171 | + else: # no_break |
| 172 | + break # Nothing else changed |
| 173 | + |
| 174 | + if not changed: |
| 175 | + return |
| 176 | + |
| 177 | + # Return a new scan to indicate that a rewrite was done |
| 178 | + new_scan_op = copy(scan_op) |
| 179 | + new_scan_op.fgraph = new_scan_fgraph |
| 180 | + new_outs = new_scan_op.make_node(*node.inputs).outputs |
| 181 | + copy_stack_trace(node.outputs, new_outs) |
| 182 | + return new_outs |
| 183 | + |
| 184 | + |
| 185 | +scan_seqopt1.register( |
| 186 | + scan_split_non_sequence_lu_decomposition_solve.__name__, |
| 187 | + in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True), |
| 188 | + "fast_run", |
| 189 | + "scan", |
| 190 | + "scan_pushout", |
| 191 | + position=2, |
| 192 | +) |
0 commit comments