From eb3d5212d20fde55791af81971d9041269cd133e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 21:50:45 +0200 Subject: [PATCH 1/3] Add test of lower/upper flags --- tests/tensor/linalg/test_rewriting.py | 50 +++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index f1ea2e1af3..c6f5fae851 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -251,3 +251,53 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter): assert fn_opt(A_valid, b1_valid * np.nan, b2_valid) with pytest.raises(ValueError, match="array must not contain infs or NaNs"): assert fn_opt(A_valid * np.nan, b1_valid, b2_valid) + + +@pytest.mark.parametrize( + "lower_first", [True, False], ids=["lower_first", "upper_first"] +) +def test_cho_solve_handles_lower_flags(lower_first): + rewrite_name = reuse_decomposition_multiple_solves.__name__ + A = tensor("A", shape=(5, None)) + b = tensor("b", shape=(5,)) + + x1 = solve(A, b, assume_a="pos", lower=lower_first, check_finite=False) + x2 = solve(A.mT, b, assume_a="pos", lower=not lower_first, check_finite=False) + + dx1_dA = grad(x1.sum(), A) + dx2_dA = grad(x2.sum(), A) + + fn = function([A, b], [x1, dx1_dA, x2, dx2_dA]) + fn_no_rewrite = function( + [A, b], + [x1, dx1_dA, x2, dx2_dA], + mode=get_default_mode().excluding(rewrite_name), + ) + + rng = np.random.default_rng() + L_values = rng.normal(size=(5, 5)).astype(config.floatX) + A_values = L_values @ L_values.T # Ensure A is positive definite + + if lower_first: + A_values[np.triu_indices(5, k=1)] = np.nan + else: + A_values[np.tril_indices(5, k=-1)] = np.nan + + b_values = rng.normal(size=(5,)).astype(config.floatX) + + # This computation should not raise an error, and none of them should be NaN + res = fn(A_values, b_values) + expected_res = fn_no_rewrite(A_values, b_values) + + for x, expected_x in zip(res, expected_res): + assert np.isfinite(x).all() + np.testing.assert_allclose( + x, + expected_x, + atol=1e-6 if config.floatX == "float64" else 1e-3, + rtol=1e-6 if config.floatX == "float64" else 1e-3, + ) + + # If we put the NaN in the wrong place, it should raise an error + with pytest.raises(np.linalg.LinAlgError): + fn(A_values.T, b_values) From 8a87fc4a6eb820999d3163c82cff48d5bf55c0e2 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 23:02:11 +0200 Subject: [PATCH 2/3] More carefully handle `lower` flag in `Solve` --- pytensor/tensor/_linalg/solve/rewriting.py | 37 ++++++++++++++-------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index c0a1c5cce8..eca3c47ff4 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -100,6 +100,7 @@ def find_solve_clients(var, assume_a): elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims: # If it's a left expand_dims, recurse on the output clients.extend(find_solve_clients(cl.outputs[0], assume_a)) + return clients assume_a = node.op.core_op.assume_a @@ -107,33 +108,35 @@ def find_solve_clients(var, assume_a): if assume_a not in allowed_assume_a: return None - A, _ = get_root_A(node.inputs[0]) + root_A, root_A_transposed = get_root_A(node.inputs[0]) # Find Solve using A (or left expand_dims of A) # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate # that to the A_decomp outputs - A_solve_clients_and_transpose = [ - (client, False) for client in find_solve_clients(A, assume_a) + root_A_solve_clients_and_transpose = [ + (client, False) for client in find_solve_clients(root_A, assume_a) ] # Find Solves using A.T - for cl, _ in fgraph.clients[A]: + for cl, _ in fgraph.clients[root_A]: if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out): A_T = cl.out - A_solve_clients_and_transpose.extend( + root_A_solve_clients_and_transpose.extend( (client, True) for client in find_solve_clients(A_T, assume_a) ) - if not eager and len(A_solve_clients_and_transpose) == 1: + if not eager and len(root_A_solve_clients_and_transpose) == 1: # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager) # That's a "reuse" inside the inner vectorized loop batch_ndim = node.op.batch_ndim(node) - (client, _) = A_solve_clients_and_transpose[0] - original_A, b = client.inputs + (client, _) = root_A_solve_clients_and_transpose[0] + + A, b = client.inputs + if not any( a_bcast and not b_bcast for a_bcast, b_bcast in zip( - original_A.type.broadcastable[:batch_ndim], + A.type.broadcastable[:batch_ndim], b.type.broadcastable[:batch_ndim], strict=True, ) @@ -142,19 +145,27 @@ def find_solve_clients(var, assume_a): # If any Op had check_finite=True, we also do it for the LU decomposition check_finite_decomp = False - for client, _ in A_solve_clients_and_transpose: + for client, _ in root_A_solve_clients_and_transpose: if client.op.core_op.check_finite: check_finite_decomp = True break - lower = node.op.core_op.lower + (first_solve, transposed) = root_A_solve_clients_and_transpose[0] + lower = first_solve.op.core_op.lower + if transposed: + lower = not lower + A_decomp = decompose_A( - A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower + root_A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower ) replacements = {} - for client, transposed in A_solve_clients_and_transpose: + for client, transposed in root_A_solve_clients_and_transpose: _, b = client.inputs + lower = client.op.core_op.lower + if transposed: + lower = not lower + new_x = solve_decomposed_system( A_decomp, b, From 025c0258aed06a33801b043b63eff68a4134ebb4 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 21 Jun 2025 23:02:19 +0200 Subject: [PATCH 3/3] Revert "More carefully handle `lower` flag in `Solve`" This reverts commit 388e93e5d0c01694461c75b36b07c1080b476800. --- pytensor/tensor/_linalg/solve/rewriting.py | 37 ++++++++-------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index eca3c47ff4..c0a1c5cce8 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -100,7 +100,6 @@ def find_solve_clients(var, assume_a): elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims: # If it's a left expand_dims, recurse on the output clients.extend(find_solve_clients(cl.outputs[0], assume_a)) - return clients assume_a = node.op.core_op.assume_a @@ -108,35 +107,33 @@ def find_solve_clients(var, assume_a): if assume_a not in allowed_assume_a: return None - root_A, root_A_transposed = get_root_A(node.inputs[0]) + A, _ = get_root_A(node.inputs[0]) # Find Solve using A (or left expand_dims of A) # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate # that to the A_decomp outputs - root_A_solve_clients_and_transpose = [ - (client, False) for client in find_solve_clients(root_A, assume_a) + A_solve_clients_and_transpose = [ + (client, False) for client in find_solve_clients(A, assume_a) ] # Find Solves using A.T - for cl, _ in fgraph.clients[root_A]: + for cl, _ in fgraph.clients[A]: if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out): A_T = cl.out - root_A_solve_clients_and_transpose.extend( + A_solve_clients_and_transpose.extend( (client, True) for client in find_solve_clients(A_T, assume_a) ) - if not eager and len(root_A_solve_clients_and_transpose) == 1: + if not eager and len(A_solve_clients_and_transpose) == 1: # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager) # That's a "reuse" inside the inner vectorized loop batch_ndim = node.op.batch_ndim(node) - (client, _) = root_A_solve_clients_and_transpose[0] - - A, b = client.inputs - + (client, _) = A_solve_clients_and_transpose[0] + original_A, b = client.inputs if not any( a_bcast and not b_bcast for a_bcast, b_bcast in zip( - A.type.broadcastable[:batch_ndim], + original_A.type.broadcastable[:batch_ndim], b.type.broadcastable[:batch_ndim], strict=True, ) @@ -145,27 +142,19 @@ def find_solve_clients(var, assume_a): # If any Op had check_finite=True, we also do it for the LU decomposition check_finite_decomp = False - for client, _ in root_A_solve_clients_and_transpose: + for client, _ in A_solve_clients_and_transpose: if client.op.core_op.check_finite: check_finite_decomp = True break - (first_solve, transposed) = root_A_solve_clients_and_transpose[0] - lower = first_solve.op.core_op.lower - if transposed: - lower = not lower - + lower = node.op.core_op.lower A_decomp = decompose_A( - root_A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower + A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower ) replacements = {} - for client, transposed in root_A_solve_clients_and_transpose: + for client, transposed in A_solve_clients_and_transpose: _, b = client.inputs - lower = client.op.core_op.lower - if transposed: - lower = not lower - new_x = solve_decomposed_system( A_decomp, b,