From 41564b804ff10bcfd0adb0118329901162bb22ec Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 11:36:01 +0200 Subject: [PATCH 01/11] Don't do symbolic upcasting in `local_upcast_elemwise_constants` This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc --- pytensor/tensor/rewriting/elemwise.py | 93 +++++++++------------------ 1 file changed, 31 insertions(+), 62 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index afe69a198b..5dd3b59096 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -30,13 +30,9 @@ from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( MakeVector, - alloc, - cast, constant, - get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import add, exp, mul from pytensor.tensor.rewriting.basic import ( alloc_like, @@ -44,7 +40,6 @@ register_canonicalize, register_specialize, ) -from pytensor.tensor.shape import shape_padleft from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): """ if len(node.outputs) > 1: - return - try: - shape_i = fgraph.shape_feature.shape_i - except AttributeError: - shape_i = None - if isinstance(node.op, Elemwise): - scalar_op = node.op.scalar_op - # print "aa", scalar_op.output_types_preference - if getattr(scalar_op, "output_types_preference", None) in ( - ps.upgrade_to_float, - ps.upcast_out, - ): - # this is the kind of op that we can screw with the input - # dtypes by upcasting explicitly - output_dtype = node.outputs[0].type.dtype - new_inputs = [] - for i in node.inputs: - if i.type.dtype == output_dtype: - new_inputs.append(i) - else: - try: - cval_i = get_underlying_scalar_constant_value( - i, only_process_constants=True - ) - if all(i.broadcastable): - new_inputs.append( - shape_padleft(cast(cval_i, output_dtype), i.ndim) - ) - else: - if shape_i is None: - return - new_inputs.append( - alloc( - cast(cval_i, output_dtype), - *[shape_i(d)(i) for d in range(i.ndim)], - ) - ) - # print >> sys.stderr, "AAA", - # *[Shape_i(d)(i) for d in range(i.ndim)] - except NotScalarConstantError: - # for the case of a non-scalar - if isinstance(i, TensorConstant): - new_inputs.append(cast(i, output_dtype)) - else: - new_inputs.append(i) + return None + + if getattr(node.op.scalar_op, "output_types_preference", None) not in ( + ps.upgrade_to_float, + ps.upcast_out, + ): + return None - if new_inputs != node.inputs: - rval = [node.op(*new_inputs)] - if not node.outputs[0].type.is_super(rval[0].type): - # This can happen for example when floatX=float32 - # and we do the true division between and int64 - # and a constant that will get typed as int8. + # this is the kind of op that we can screw with the input + # dtypes by upcasting explicitly + [old_out] = node.outputs + output_dtype = old_out.type.dtype + new_inputs = list(node.inputs) + changed = False + for i, inp in enumerate(node.inputs): + if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): + new_inputs[i] = constant(inp.data.astype(output_dtype)) + changed = True + + if not changed: + return None - # As this is just to allow merging more case, if - # the upcast don't work, we can just skip it. - return + rval = node.op(*new_inputs) + if not old_out.type.is_super(rval.type): + # This can happen for example when floatX=float32 + # and we do the true division between and int64 + # and a constant that will get typed as int8. + # As this is just to allow merging more case, if + # the upcast don't work, we can just skip it. + return None - # Copy over output stacktrace from before upcasting - copy_stack_trace(node.outputs[0], rval) - return rval + # Copy over output stacktrace from before upcasting + copy_stack_trace(old_out, rval) + return [rval] @node_rewriter([add, mul]) From 3c4bf023718db32f7dc32db4027001273e6ecf02 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 11:36:36 +0200 Subject: [PATCH 02/11] Don't return useless subtensor for `local_useless_slice` --- pytensor/tensor/rewriting/subtensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index be16c4fb61..030c5db905 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -351,7 +351,8 @@ def local_useless_slice(fgraph, node): new_idxs[dim] = slice(start, stop, step) if change_flag or ((last_useful_idx + 1) < len(idxs)): - out = x[tuple(new_idxs[: last_useful_idx + 1])] + new_idxs = tuple(new_idxs[: last_useful_idx + 1]) + out = x[new_idxs] if new_idxs else x # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) return [out] From 7b6f8642929304e421a49210c4db5e708f6f1f54 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 9 Jul 2025 10:30:09 +0200 Subject: [PATCH 03/11] Don't try to create invalid `BatchedDot` in `specialize_matmul_to_batched_dot` rewrite --- pytensor/tensor/rewriting/blas.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index e626b0720b..9fc3603db4 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node): """ x, y = node.inputs + if x.type.ndim < 3: + # This doesn't actually have a batch dimension + return None + # BatchedDot does not allow implicit broadcasting of the batch dimensions # We do not want to explicitly broadcast as it may result in huge arrays if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]: @@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node): if len(x_shape) > 3: # If we have more than one batch dim, ravel it x = x.reshape((-1, x_shape[-2], x_shape[-1])) + if len(y_shape) > 3: y = y.reshape((-1, y_shape[-2], y_shape[-1])) new_out = _batched_dot(x, y) From d30f84fedf1ebe31ee594ea012b0a0ff7f9b7b2a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 9 Jul 2025 11:24:19 +0200 Subject: [PATCH 04/11] Add error message for incompatible static shape in Dot Op --- pytensor/tensor/math.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 714f597b32..6c27246889 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -3025,6 +3025,11 @@ def make_node(self, *inputs): ) sx, sy = (input.type.shape for input in inputs) + if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]: + raise ValueError( + f"Incompatible shared dimension for dot product: {sx}, {sy}" + ) + if len(sy) == 2: sz = sx[:-1] + sy[-1:] elif len(sy) == 1: From fa7ac97b32b39d65028644313ea8c7c1288b759d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Jun 2025 10:29:46 +0200 Subject: [PATCH 05/11] Benchmark partial jacobian --- tests/test_gradient.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 89712c19dd..8de9c24b18 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -32,7 +32,7 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.scan.op import Scan -from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, tanh +from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.random import RandomStream from pytensor.tensor.type import ( @@ -1143,6 +1143,24 @@ def test_benchmark(self, vectorize, benchmark): fn = function([x], jac_y, trust_input=True) benchmark(fn, np.array([0, 1, 2], dtype=x.type.dtype)) + def test_benchmark_partial_jacobian(self, vectorize, benchmark): + # Example from https://github.com/jax-ml/jax/discussions/5904#discussioncomment-422956 + N = 1000 + rng = np.random.default_rng(2025) + x_test = rng.random((N,)) + + f_mat = rng.random((N, N)) + x = vector("x", dtype="float64") + + def f(x): + return sqrt(f_mat @ x / N) + + full_jacobian = jacobian(f(x), x, vectorize=vectorize) + partial_jacobian = full_jacobian[:5, :5] + + f = pytensor.function([x], partial_jacobian, trust_input=True) + benchmark(f, x_test) + def test_hessian(): x = vector() From 9fa3885dd482b1f6a01aea64ffd466579c377934 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 11:37:22 +0200 Subject: [PATCH 06/11] Avoid canonicalization of slices when merging non-overlapping slices in `local_subtensor_merge` --- pytensor/tensor/rewriting/subtensor.py | 131 +++++++++++++------------ 1 file changed, 68 insertions(+), 63 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 030c5db905..8e9fba22e4 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node): """ from pytensor.scan.op import Scan - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if u.owner and isinstance(u.owner.op, Subtensor): - # We can merge :) - # x actual tensor on which we are picking slices - x = u.owner.inputs[0] - # slices of the first applied subtensor - slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) - slices2 = get_idx_list(node.inputs, node.op.idx_list) - - # Don't try to do the optimization on do-while scan outputs, - # as it will create a dependency on the shape of the outputs - if ( - x.owner is not None - and isinstance(x.owner.op, Scan) - and x.owner.op.info.as_while - ): - return None + u = node.inputs[0] + if not (u.owner is not None and isinstance(u.owner.op, Subtensor)): + return None - # Get the shapes of the vectors ! - try: - # try not to introduce new shape into the graph - xshape = fgraph.shape_feature.shape_of[x] - ushape = fgraph.shape_feature.shape_of[u] - except AttributeError: - # Following the suggested use of shape_feature which should - # consider the case when the compilation mode doesn't - # include the ShapeFeature - xshape = x.shape - ushape = u.shape - - merged_slices = [] - pos_2 = 0 - pos_1 = 0 - while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): - slice1 = slices1[pos_1] - if isinstance(slice1, slice): - merged_slices.append( - merge_two_slices( - fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] - ) - ) - pos_2 += 1 - else: - merged_slices.append(slice1) - pos_1 += 1 - - if pos_2 < len(slices2): - merged_slices += slices2[pos_2:] - else: - merged_slices += slices1[pos_1:] + # We can merge :) + # x actual tensor on which we are picking slices + x = u.owner.inputs[0] + # slices of the first applied subtensor + slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) + slices2 = get_idx_list(node.inputs, node.op.idx_list) - merged_slices = tuple(as_index_constant(s) for s in merged_slices) - subtens = Subtensor(merged_slices) + # Don't try to do the optimization on do-while scan outputs, + # as it will create a dependency on the shape of the outputs + if ( + x.owner is not None + and isinstance(x.owner.op, Scan) + and x.owner.op.info.as_while + ): + return None - sl_ins = get_slice_elements( - merged_slices, lambda x: isinstance(x, Variable) + # Get the shapes of the vectors ! + try: + # try not to introduce new shape into the graph + xshape = fgraph.shape_feature.shape_of[x] + ushape = fgraph.shape_feature.shape_of[u] + except AttributeError: + # Following the suggested use of shape_feature which should + # consider the case when the compilation mode doesn't + # include the ShapeFeature + xshape = x.shape + ushape = u.shape + + merged_slices = [] + pos_2 = 0 + pos_1 = 0 + while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): + slice1 = slices1[pos_1] + if isinstance(slice1, slice): + merged_slices.append( + merge_two_slices( + fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] + ) ) - # Do not call make_node for test_value - out = subtens(x, *sl_ins) + pos_2 += 1 + else: + merged_slices.append(slice1) + pos_1 += 1 - # Copy over previous output stacktrace - # and stacktrace from previous slicing operation. - # Why? Because, the merged slicing operation could have failed - # because of either of the two original slicing operations - orig_out = node.outputs[0] - copy_stack_trace([orig_out, node.inputs[0]], out) - return [out] + if pos_2 < len(slices2): + merged_slices += slices2[pos_2:] + else: + merged_slices += slices1[pos_1:] + + merged_slices = tuple(as_index_constant(s) for s in merged_slices) + subtens = Subtensor(merged_slices) + + sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable)) + # Do not call make_node for test_value + out = subtens(x, *sl_ins) + + # Copy over previous output stacktrace + # and stacktrace from previous slicing operation. + # Why? Because, the merged slicing operation could have failed + # because of either of the two original slicing operations + orig_out = node.outputs[0] + copy_stack_trace([orig_out, node.inputs[0]], out) + return [out] @register_specialize @@ -788,6 +787,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): if not isinstance(slice1, slice): raise ValueError("slice1 should be of type `slice`") + # Simple case where one of the slices is useless + if is_full_slice(slice1): + return slice2 + elif is_full_slice(slice2): + return slice1 + sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl2, reverse2 = get_canonical_form_slice(slice2, len2) From 62d2ab291ebf0ebcd94c608e083087aa8ad56bf3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 11:38:46 +0200 Subject: [PATCH 07/11] Generalize `local_subtensor_of_elemwise` to Blockwise --- pytensor/tensor/rewriting/subtensor_lift.py | 38 +++++++++++--- tests/tensor/rewriting/test_subtensor_lift.py | 52 ++++++++++++++++--- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 5a367a302a..0dd907473a 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -20,6 +20,7 @@ join, register_infer_shape, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import squeeze @@ -169,8 +170,8 @@ def local_subtensor_of_dot(fgraph, node): @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Subtensor]) -def local_subtensor_of_elemwise(fgraph, node): - """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior. +def local_subtensor_of_batch_dims(fgraph, node): + """Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior. exp(x)[:, 0] -> exp(x[:, 0]) add(x, y)[0] -> add(x[0], y[0]) @@ -178,7 +179,7 @@ def local_subtensor_of_elemwise(fgraph, node): """ elem, *idx = node.inputs - if not (elem.owner and isinstance(elem.owner.op, Elemwise)): + if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)): return None if len(fgraph.clients[elem]) > 1: @@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node): idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + batch_ndim = ( + elem.owner.op.batch_ndim(elem.owner) + if isinstance(elem.owner.op, Blockwise) + else elem.ndim + ) + + if len(idx_tuple) > batch_ndim: + # Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only + batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:] + if all(is_full_slice(idx) for idx in batch_indices): + # No batch indices, nothing to do + return None + elem_with_batch_indices = elem[batch_indices] + [elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform( + fgraph, elem_with_batch_indices.owner + ) + # Reapply the core_indices + core_ndim = elem.type.ndim - batch_ndim + # Number of batch dims may have changed with the lifting of indices, so we recompute + new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim + new_indices = (*(slice(None),) * new_batch_ndim, *core_indices) + new_elem = elem_with_batch_indices_lifted[new_indices] + copy_stack_trace(node.outputs[0], new_elem) + return [new_elem] + elem_inputs = elem.owner.inputs - elem_bcast = elem.type.broadcastable - if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs): + elem_bcast = elem.type.broadcastable[:batch_ndim] + if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs): # No need to worry about implicit broadcasting. indexed_inputs = [inp[idx_tuple] for inp in elem_inputs] @@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node): zip( idx_tuple, elem_bcast, - *(inp.type.broadcastable for inp in elem_inputs), + *(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs), # Indices can be shorter than input ndims strict=False, ) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 933d1a1577..1d7418adae 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -14,6 +14,7 @@ from pytensor.graph import ( Constant, FunctionGraph, + Op, RewriteDatabaseQuery, Type, rewrite_graph, @@ -23,6 +24,7 @@ from pytensor.printing import debugprint from pytensor.tensor import ( add, + dvector, exp, iscalar, iscalars, @@ -37,11 +39,12 @@ vector, ) from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.subtensor_lift import ( local_subtensor_make_vector, - local_subtensor_of_elemwise, + local_subtensor_of_batch_dims, local_subtensor_shape_constant, ) from pytensor.tensor.shape import SpecifyShape, _shape @@ -58,7 +61,7 @@ NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) -class TestLocalSubtensorOfElemwise: +class TestLocalSubtensorOfBatchDims: def test_unary_multiple_clients(self): # as test0, but we reuse the output of the elemwise # So we should not lift the subtensor @@ -144,7 +147,7 @@ def test_multinary_multiple_clients(self): ), ], ) - def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): + def test_elemwise(self, original_fn, expected_fn): rng = np.random.default_rng(257) x = pt.matrix("x", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3)) @@ -163,7 +166,7 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): out.eval({x: x_test, y: y_test}, **eval_kwargs), ) - def test_local_subtensor_of_elemwise_multiple_clients(self): + def test_elemwise_multiple_clients(self): x = pt.matrix("x", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3)) out1 = add(x, y) @@ -171,11 +174,48 @@ def test_local_subtensor_of_elemwise_multiple_clients(self): # Rewrite should fail when another node uses out1 directly (in this case it's an extra output) fgraph = FunctionGraph([x, y], [out1, out2], clone=False) - assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None + assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None # Otherwise it should work fgraph.remove_output(0) - assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None + assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None + + def test_blockwise(self): + class CoreTestOp(Op): + itypes = [dvector, dvector] + otypes = [dvector] + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.convolve(*inputs, mode="valid") + + core_test_op = CoreTestOp() + block_test_op = Blockwise(core_test_op, signature="(a),(b)->(c)") + + x = tensor3("x", shape=(7, 5, 11), dtype="float64") + y = tensor("y", shape=(7, 33), dtype="float64") + out = block_test_op(x, y[:, None, :]) + assert isinstance(out.owner.op, Blockwise) + + out_sliced = out[2:][:, 3:] + rewritten_out_sliced = rewrite_graph(out_sliced) + expected_out_sliced = block_test_op(x[2:, 3:], y[2:][:, None, :]) + assert equal_computations([rewritten_out_sliced], [expected_out_sliced]) + + rng = np.random.default_rng(191) + x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) + y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) + np.testing.assert_allclose( + rewritten_out_sliced.eval( + {x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE + ), + out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE), + ) + + # Check slice on core dims + out_sliced = out[2:][:, 0][:, 4:] + rewritten_out_sliced = rewrite_graph(out_sliced) + expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:] + assert equal_computations([rewritten_out_sliced], [expected_out_sliced]) @pytest.mark.parametrize( From 4c9be3d87b5bbec36dd61bd54b3e0208baa3209d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Jun 2025 11:44:36 +0200 Subject: [PATCH 08/11] Lift subtensor through squeeze --- pytensor/tensor/rewriting/subtensor_lift.py | 35 +++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 0dd907473a..87d838659c 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -461,6 +461,41 @@ def local_subtensor_of_expand_dims(fgraph, node): return [out] +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_squeeze(fgraph, node): + """Lift subtensor through a squeeze operation""" + x, *idxs_vars = node.inputs + if not ( + x.owner is not None + and isinstance(x.owner.op, DimShuffle) + and x.owner.op.is_squeeze + ): + return None + + [x_before_squeeze] = x.owner.inputs + idxs = indices_from_subtensor(idxs_vars, node.op.idx_list) + dropped_dims = x.owner.op.drop + + # Apply indices directly on x + # Add empty slices on the axis that squeeze would have removed + new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None)) + x_indexed = x_before_squeeze[tuple(new_idxs)] + + # Reapply squeeze + # Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims + new_dropped_dims = np.array(dropped_dims) + for i, new_idx in reversed(tuple(enumerate(new_idxs))): + if not isinstance(new_idx, slice): + # If it's not a slice, it's an integer which drops the dimension + new_dropped_dims[new_dropped_dims > i] -= 1 + new_x = x_indexed.squeeze(tuple(new_dropped_dims)) + + copy_stack_trace(x, new_x) + return [new_x] + + @register_canonicalize @register_specialize @node_rewriter([Subtensor]) From b0bc867cf02235cf8c344a4c2dd4c2f56bbfda38 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 12:41:32 +0200 Subject: [PATCH 09/11] Generalize dot rewrites to work with Blockwise --- pytensor/tensor/rewriting/math.py | 21 +++---- pytensor/tensor/rewriting/subtensor_lift.py | 64 +++++++++++++-------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index d126502bde..55ed99fa59 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -44,6 +44,7 @@ Prod, Sum, _conj, + _dot, _inner_prod, _matrix_matrix_matmul, _matrix_vec_prod, @@ -98,6 +99,7 @@ register_useless, ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift +from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -175,21 +177,20 @@ def local_lift_transpose_through_dot(fgraph, node): These rewrites "lift" (propagate towards the inputs) `DimShuffle` through dot product. It allows to put the graph in a more standard shape, and to later merge consecutive `DimShuffle`\s. - - The transformation should be apply whether or not the transpose is - inplace. The newly-introduced transpositions are not inplace, this will - be taken care of in a later rewrite phase. - """ - if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)): - return False - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + + if not ( + is_matrix_transpose(node.out) + and node.inputs[0].owner + and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul)) + ): return False + x, y = node.inputs[0].owner.inputs - if x.ndim == y.ndim == 2: + if x.ndim >= y.ndim >= 2: # Output is dot product of transposed inputs in reverse order - ret = [dot(y.T, x.T)] + ret = [dot_op(y.mT, x.mT)] # Copy over stack trace to output from result of dot-product copy_stack_trace(node.inputs[0], ret) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 87d838659c..7ae234151d 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -5,7 +5,7 @@ from pytensor import Variable from pytensor.compile import optdb -from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar import basic as ps @@ -119,21 +119,43 @@ def local_subtensor_of_dot(fgraph, node): the remaining entries of ``idxs`` (if any), modified to skip the second-to-last dimension of ``B`` (because dot sums over this dimension). """ - if not isinstance(node.op, Subtensor): - return - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + x, *idx_vars = node.inputs + if not ( + x.owner is not None + and ( + isinstance(x.owner.op, Dot) + or ( + isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, Dot) + ) + ) + ): return # If there is other node that use the outputs of the dot # We don't want to compute twice the sub part. - if len(fgraph.clients[node.inputs[0]]) > 1: + if len(fgraph.clients[x]) > 1: return - a = node.inputs[0].owner.inputs[0] - b = node.inputs[0].owner.inputs[1] + a = x.owner.inputs[0] + b = x.owner.inputs[1] + idx_list = indices_from_subtensor(idx_vars, node.op.idx_list) - idx_list = get_idx_list(node.inputs, node.op.idx_list) + batch_ndim = ( + x.owner.op.batch_ndim(x.owner) if isinstance(x.owner.op, Blockwise) else 0 + ) + + if batch_ndim: + batch_idx_list, idx_list = idx_list[:batch_ndim], idx_list[batch_ndim:] + if not idx_list: + # Indexing only over batch dimensions of Blockwise, that can be handled by another rewrite + return None + # We perform the rest of the rewrite on dummy a, b that correspond to the core case + a = a.type.clone(shape=a.type.shape[batch_ndim:])() + b = b.type.clone(shape=b.type.shape[batch_ndim:])() - num_a_indices = min(a.ndim - 1, len(idx_list)) + a_ndim = a.ndim + b_ndim = b.ndim + num_a_indices = min(a_ndim - 1, len(idx_list)) a_indices = idx_list[:num_a_indices] b_indices = idx_list[num_a_indices:] @@ -142,26 +164,22 @@ def local_subtensor_of_dot(fgraph, node): # This wasn't necessary for a, because we just omitted the last index. # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] # (dot also handles b.ndim < 2 as a special case) - if b.ndim > 1 and len(b_indices) >= b.ndim - 1: + if b_ndim > 1 and len(b_indices) >= b_ndim - 1: b_indices = ( - b_indices[: b.ndim - 2] + b_indices[: b_ndim - 2] + (slice(None, None, None),) - + b_indices[b.ndim - 2 :] + + b_indices[b_ndim - 2 :] ) - a_sub = a.__getitem__(tuple(a_indices)) - b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b + a_sub = a[tuple(a_indices)] + b_sub = b[tuple(b_indices)] if b_indices else b + r = dot(a_sub, b_sub) - # Copy over previous output stacktrace to a_sub and b_sub, - # because an error in the subtensor operation (e.g. an index error) - # on either a or b must correspond to an error in the - # subtensor operation on their dot product. - copy_stack_trace(node.outputs[0], [a_sub, b_sub]) + if batch_ndim: + # Replace dummy inputs by the original batch ones + r = vectorize_graph(r, replace={a: x.owner.inputs[0], b: x.owner.inputs[1]}) + r = r[tuple(batch_idx_list)] - # Copy over previous output stacktrace and previous dot product stacktrace, - # because an error here may correspond to an either in either the original - # dot product, or in the dot product after the subtensor operation. - r = dot(a_sub, b_sub) copy_stack_trace([node.outputs[0], node.inputs[0]], r) return [r] From 58831af53d4b095d67d275c6f08f8e2400d1902d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jun 2025 12:03:13 +0200 Subject: [PATCH 10/11] Define all batched dot operations as matmul New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products. --- pytensor/tensor/math.py | 48 +++---- pytensor/tensor/rewriting/blas.py | 4 +- pytensor/tensor/rewriting/elemwise.py | 2 + pytensor/tensor/rewriting/linalg.py | 4 +- pytensor/tensor/rewriting/math.py | 188 +++++++++++++++++++------- tests/tensor/rewriting/test_blas.py | 41 ++++-- tests/tensor/rewriting/test_math.py | 87 +++++++++++- tests/tensor/test_math.py | 11 +- 8 files changed, 281 insertions(+), 104 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 6c27246889..743da35c84 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -# Predefine all batched variations of Dot -_inner_prod = Blockwise( - _dot, - signature="(n),(n)->()", -) - -_matrix_vec_prod = Blockwise( - _dot, - signature="(m,k),(k)->(m)", -) - -_vec_matrix_prod = Blockwise( - _dot, - signature="(k),(k,n)->(n)", -) - -_matrix_matrix_matmul = Blockwise( +_matmul = Blockwise( _dot, signature="(m,k),(k,n)->(m,n)", gufunc_spec=("numpy.matmul", 2, 1), @@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None if x1.type.ndim == 1 and x2.type.ndim == 1: out = _dot(x1, x2) elif x1.type.ndim == 1: - out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) + out = vecmat(x1, x2) elif x2.type.ndim == 1: - out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) + out = matvec(x1, x2) else: - out = _matrix_matrix_matmul(x1, x2) + out = _matmul(x1, x2) if dtype is not None: out = out.astype(dtype) @@ -4047,7 +4031,7 @@ def vecdot( >>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,) >>> # Equivalent to numpy.vecdot(x_batch, y_batch) """ - out = _inner_prod(x1, x2) + out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1)) if dtype is not None: out = out.astype(dtype) @@ -4096,7 +4080,7 @@ def matvec( >>> result = pt.matvec(batched_A, batched_v) # shape (2, 3) >>> # Equivalent to numpy.matvec(batched_A, batched_v) """ - out = _matrix_vec_prod(x1, x2) + out = matmul(x1, x2[..., None]).squeeze(-1) if dtype is not None: out = out.astype(dtype) @@ -4134,18 +4118,18 @@ def vecmat( -------- >>> import pytensor.tensor as pt >>> # Vector-matrix product - >>> v = pt.vector("v", shape=(3,)) # shape (3,) - >>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4) + >>> v = pt.vector("v", shape=(3,)) + >>> A = pt.matrix("A", shape=(3, 4)) >>> result = pt.vecmat(v, A) # shape (4,) >>> # Equivalent to numpy.vecmat(v, A) >>> >>> # Batched vector-matrix product - >>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3) - >>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4) + >>> batched_v = pt.matrix("v", shape=(2, 3)) + >>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) >>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4) >>> # Equivalent to numpy.vecmat(batched_v, batched_A) """ - out = _vec_matrix_prod(x1, x2) + out = matmul(x2.mT, x1[..., None]).squeeze(-1) if dtype is not None: out = out.astype(dtype) @@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y): old_y_ndim = old_y.type.ndim match (old_x_ndim, old_y_ndim): case (1, 1): - batch_op = _inner_prod + batch_fn = vecdot case (2, 1): - batch_op = _matrix_vec_prod + batch_fn = matvec case (1, 2): - batch_op = _vec_matrix_prod + batch_fn = vecmat case (2, 2): - batch_op = _matrix_matrix_matmul + batch_fn = matmul case _: raise ValueError( f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." ) - return batch_op(batched_x, batched_y).owner + return batch_fn(batched_x, batched_y).owner def nan_to_num(x, nan=0.0, posinf=None, neginf=None): diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 9fc3603db4..6fd94f9b33 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -98,7 +98,7 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( Dot, - _matrix_matrix_matmul, + _matmul, add, mul, neg, @@ -908,7 +908,7 @@ def local_dot22_to_dot22scalar(fgraph, node): @register_specialize -@node_rewriter([_matrix_matrix_matmul]) +@node_rewriter([_matmul]) def specialize_matmul_to_batched_dot(fgraph, node): """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 5dd3b59096..f08f19f06c 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -39,6 +39,7 @@ broadcasted_by, register_canonicalize, register_specialize, + register_stabilize, ) from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input): @register_canonicalize +@register_stabilize @register_specialize @node_rewriter([DimShuffle]) def local_dimshuffle_lift(fgraph, node): diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2a1a71ae40..45ce2a4605 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -26,7 +26,7 @@ from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod +from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod from pytensor.tensor.nlinalg import ( SVD, KroneckerProduct, @@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node): # This rewrite only applies to matrix Dot and A.owner.inputs[0].type.ndim == 2 ) - or (A.owner.op == _matrix_matrix_matmul) + or (A.owner.op == _matmul) ) ): return diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 55ed99fa59..7fbf01d939 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -28,6 +28,7 @@ as_tensor_variable, cast, constant, + expand_dims, get_underlying_scalar_constant_value, moveaxis, ones_like, @@ -35,7 +36,6 @@ switch, zeros_like, ) -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays @@ -45,10 +45,7 @@ Sum, _conj, _dot, - _inner_prod, - _matrix_matrix_matmul, - _matrix_vec_prod, - _vec_matrix_prod, + _matmul, add, digamma, dot, @@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node): return ret -@register_stabilize -@register_specialize -@node_rewriter(tracks=[Blockwise]) -def local_batched_matmul_to_core_matmul(fgraph, node): - """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul. +def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool): + """Move batch dimensions of matmul operands to core matmul - Example, if x has batch dimensions, but y not: + Example, if x has batch dimensions that don't overlap with batch dimensions of y x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1]) - It also works when y has batch dimensions, but x not. - """ + It also works for batch dimensions of y that don't overlap with batch dimensions of x - # Check whether we have a matmul operation in this node - if not ( - isinstance(node.op.core_op, Dot) - and len(node.op.inputs_sig[0]) == 2 - and len(node.op.inputs_sig[1]) == 2 - ): - return None + The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False` + """ x, y = node.inputs batch_ndim = node.op.batch_ndim(node) - # Check if x has batch dimensions, but y not (or only broadcastable dimensions) - if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all( - y.type.broadcastable[:-2] - ): - x_stacked = x.reshape((-1, x.shape[-1])) - out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim))) - out = out_stacked.reshape((*x.shape[:-1], y.shape[-1])) - return [out] - - # Otherwise, check if y has batch dimension, but x not - elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all( - x.type.broadcastable[:-2] - ): - # For the y batch case we need to first move the batch axes and then reshape - # y.shape == (*b, k, n) - y_tr = moveaxis(y, -2, 0) # (k, *b, n) - y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n) - out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n) - out_stacked_tr = out_stacked.reshape( - (x.shape[-2], *y.shape[:-2], y.shape[-1]) - ) # (m, *b, n) - out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n) - return [out] - - # Both x and y have batch dimensions, nothing to do here - return None + x_axis_to_merge = [ + i + for i, (bcast_x, bcast_y) in enumerate( + zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2]) + ) + if bcast_y and not bcast_x + ] + + y_axis_to_merge = [ + i + for i, (bcast_x, bcast_y) in enumerate( + zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2]) + ) + if bcast_x and not bcast_y + ] + + if not (x_axis_to_merge or y_axis_to_merge): + return None + + x_shape = tuple(x.shape) + y_shape = tuple(y.shape) + x_is_row = x.type.broadcastable[-2] + y_is_col = y.type.broadcastable[-1] + n_x_axis_to_merge = len(x_axis_to_merge) + n_y_axis_to_merge = len(y_axis_to_merge) + n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge + + x_stacked, y_stacked = x, y + dims_were_merged = False + + if n_x_axis_to_merge: + # ravel batch dimensions of x on the core (m) axis + x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2)) + x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination) + if x_is_row: + # x was a row matrix, squeeze it to clean up the graph + x_stacked = x_stacked.squeeze(-2) + if n_x_axis_to_merge > 1 or not x_is_row: + if not allow_reshape: + # TODO: We could allow the y rewrite to go on + # Or just move one axis (the largest) if x is row + return None + + # Ravel moved batch dims together with (m) if needed + x_stacked_shape = tuple(x_stacked.shape) + x_stacked = x_stacked.reshape( + (*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1]) + ) + dims_were_merged = True + + if n_y_axis_to_merge: + # ravel batch dimensions of y on the core (n) axis + y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1)) + y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination) + if y_is_col: + # y was a column matrix, squeeze it to clean up the graph + y_stacked = y_stacked.squeeze(-1) + if n_y_axis_to_merge > 1 or not y_is_col: + if not allow_reshape: + # TODO: We could allow the x rewrite to go on + # Or just move one axis (the largest) if y is col + return False + # Ravel moved batch dims together with (n) if needed + y_stacked_shape = tuple(y_stacked.shape) + y_stacked = y_stacked.reshape( + (*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1) + ) + dims_were_merged = True + + # Squeeze x_dims corresponding to merged dimensions of y + x_axis_to_squeeze = np.array(y_axis_to_merge) + for i in reversed(x_axis_to_merge): + # The corresponding dimensions of y may have shifted when we merged dimensions of x + x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1 + x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze)) + + # Same for y + y_axis_to_squeeze = np.array(x_axis_to_merge) + for i in reversed(y_axis_to_merge): + y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1 + y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze)) + + out_stacked = x_stacked @ y_stacked + + # Split back any merged dimensions + if dims_were_merged: + x_merged_shapes = [x_shape[i] for i in x_axis_to_merge] + if not x_is_row: + # Otherwise we handle that later with expand_dims, which is cleaner + x_merged_shapes.append(x_shape[-2]) + y_merged_shapes = [y_shape[i] for i in y_axis_to_merge] + if not y_is_col: + # Otherwise we handle that later with expand_dims, which is cleaner + y_merged_shapes.append(y_shape[-1]) + out_stacked_shape = tuple(out_stacked.shape) + out_unstacked = out_stacked.reshape( + ( + *out_stacked_shape[: batch_ndim - n_axis_to_merge], + *x_merged_shapes, + *y_merged_shapes, + ) + ) + else: + out_unstacked = out_stacked + + # Add back dummy row, col axis + # We do this separately to avoid the reshape as much as we can + if y_is_col and (n_y_axis_to_merge or dims_were_merged): + out_unstacked = expand_dims(out_unstacked, -1) + if x_is_row and (n_x_axis_to_merge or dims_were_merged): + out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2) + + # Move batch axis back to their original location + source = range(-n_axis_to_merge - 2, 0) + destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1) + out = moveaxis(out_unstacked, source, destination) + return [out] + + +@register_canonicalize +@node_rewriter(tracks=[_matmul]) +def local_batched_matmul_to_core_matmul(fgraph, node): + # Allow passing batch dimensions of matmul to core vector / column matrices + return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=False) + + +@register_specialize +@node_rewriter(tracks=[_matmul]) +def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node): + # Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation + # We only apply this in specialize, because grahs with reshape are hard to work with + return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=True) @register_canonicalize @register_specialize -@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]) +@node_rewriter([_matmul]) def local_blockwise_dot_to_mul(fgraph, node): """Rewrite blockwise dots that correspond to multiplication without summation. diff --git a/tests/tensor/rewriting/test_blas.py b/tests/tensor/rewriting/test_blas.py index d939ceedce..10e040367c 100644 --- a/tests/tensor/rewriting/test_blas.py +++ b/tests/tensor/rewriting/test_blas.py @@ -1,10 +1,10 @@ import numpy as np import pytest -from pytensor import function +from pytensor import config, function from pytensor import tensor as pt from pytensor.compile import get_default_mode -from pytensor.graph import FunctionGraph +from pytensor.graph import FunctionGraph, ancestors from pytensor.tensor import ( col, dscalar, @@ -21,7 +21,6 @@ vectorize, ) from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.rewriting.blas import ( _as_scalar, @@ -37,8 +36,11 @@ def XYZab(): return matrix(), matrix(), matrix(), scalar(), scalar() -@pytest.mark.parametrize("valid_case", (True, False)) -def test_specialize_matmul_to_batched_dot(valid_case): +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", reason="Test requires specialization rewrites" +) +@pytest.mark.parametrize("aligned", (True, False)) +def test_specialize_matmul_to_batched_dot(aligned): signature = BatchedDot.gufunc_signature rewrite = specialize_matmul_to_batched_dot.__name__ @@ -49,23 +51,36 @@ def core_np(x, y): return np.matmul(x, y) x = tensor(shape=(7, 5, 3, 3)) - if valid_case: + if aligned: y = tensor(shape=(7, 5, 3, 3)) else: y = tensor(shape=(5, 3, 3)) + out = vectorize(core_pt, signature=signature)(x, y) + + assert ( + sum( + isinstance(var.owner.op, BatchedDot) + for var in ancestors([out]) + if var.owner + ) + == 0 + ) + vectorize_pt = function( [x, y], - vectorize(core_pt, signature=signature)(x, y), + out, mode=get_default_mode().including(rewrite), ) - blocwkise_node = any( - isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + + assert ( + sum( + isinstance(var.owner.op, BatchedDot) + for var in ancestors(vectorize_pt.maker.fgraph.outputs) + if var.owner + ) + == 1 ) - if valid_case: - assert not blocwkise_node - else: - assert blocwkise_node x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3699a3fcff..f82353dd1f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -42,6 +42,7 @@ Prod, Sum, _conj, + _matmul, add, arccosh, arcsinh, @@ -4612,6 +4613,88 @@ def test_local_batched_matmul_to_core_matmul(): np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) +@pytest.mark.parametrize( + "mat_shape, vec_shape", + [ + [(1, 2, 2), (5, 2)], + [(5, 2, 2), (1, 2)], + [(1, 1, 2, 2), (7, 5, 2)], + [(7, 5, 2, 2), (1, 1, 5, 2)], + [(1, 5, 1, 2, 2), (7, 5, 7, 2)], + [(7, 5, 7, 2, 2), (1, 5, 1, 2)], + [(5, 1, 3, 1, 2, 2), (1, 7, 3, 7, 2)], + [(1, 7, 3, 7, 2, 2), (5, 1, 3, 1, 2)], + ], + ids=str, +) +@pytest.mark.parametrize("func", ("matvec", "vecmat", "vecdot")) +def test_batch_matvec_to_matmul(func, mat_shape, vec_shape): + def count_matvec_nodes(graph): + # Counts how many matmul nodes actually correspond to matvec or vecmat + return len( + [ + var + for var in ancestors([graph]) + if ( + var.owner is not None + and var.owner.op == _matmul + and ( + (var.owner.inputs[0].type.shape[-2] == 1) + or (var.owner.inputs[1].type.shape[-1] == 1) + ) + ) + ] + ) + + mat = pt.tensor("mat", shape=mat_shape, dtype="float64") + vec = pt.tensor("vec", shape=vec_shape, dtype="float64") + + if func == "matvec": + out = pt.matvec(mat, vec) + elif func == "vecmat": + out = pt.vecmat(vec, mat) + elif func == "vecdot": + out = pt.vecdot(mat[..., 0], vec) + else: + raise NotImplementedError(func) + + assert count_matvec_nodes(out) == 1 + + rewritten_out = rewrite_graph( + out, + include=( + "canonicalize", + "specialize", + ), + exclude=( + "local_eager_useless_unbatched_blockwise", + "specialize_matmul_to_batched_dot", + ), + ) + # No `matvec` in the rewritten out if one of the vector can be treated as a matrix + expected = not any( + mat_dim == 1 and vec_dim != 1 + for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2]) + ) + if not expected and func == "vecdot": + # In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix + expected = not any( + mat_dim != 1 and vec_dim == 1 + for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2]) + ) + + assert count_matvec_nodes(rewritten_out) == expected + + rng = np.random.default_rng(mat_shape + vec_shape) + eval_dict = {mat: rng.random(mat.type.shape), vec: rng.random(vec.type.shape)} + # Evaluate results are correct without further rewrites + no_optimization = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + rewritten_out.eval(eval_dict, mode=no_optimization), + out.eval(eval_dict, mode=no_optimization), + ) + + def test_log_kv_stabilization(): x = pt.scalar("x") out = log(kv(4.5, x)) @@ -4662,8 +4745,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out = dot(a, b) if batched: - batch_a = tensor("batch_a", shape=(1, 5, *a_shape)) - batch_b = tensor("batch_b", shape=(7, 1, *b_shape)) + batch_a = tensor("batch_a", shape=(2, 1, 5, *a_shape)) + batch_b = tensor("batch_b", shape=(2, 7, 1, *b_shape)) out = vectorize_graph(out, {a: batch_a, b: batch_b}) a = batch_a b = batch_b diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 38207d0f5d..af6f87d79f 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -2081,9 +2081,9 @@ def is_super_shape(var1, var2): def test_matrix_vector_ops(): """Test vecdot, matvec, and vecmat helper functions.""" - rng = np.random.default_rng(seed=utt.fetch_seed()) + rng = np.random.default_rng(2089) - # Create test data with batch dimension (2) + atol = 1e-7 if config.floatX == "float32" else 1e-15 batch_size = 2 dim_k = 4 # Common dimension dim_m = 3 # Matrix rows @@ -2098,7 +2098,6 @@ def test_matrix_vector_ops(): mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX) vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX) - # Create tensor variables with matching dtype mat_mk = tensor( name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX ) @@ -2119,7 +2118,7 @@ def test_matrix_vector_ops(): expected_vecdot = np.zeros((batch_size,), dtype=np.int32) for i in range(batch_size): expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i]) - np.testing.assert_allclose(result, expected_vecdot) + np.testing.assert_allclose(result, expected_vecdot, atol=atol) # Test 2: matvec - matrix-vector product matvec_out = matvec(mat_mk, vec_k) @@ -2130,7 +2129,7 @@ def test_matrix_vector_ops(): expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX) for i in range(batch_size): expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i]) - np.testing.assert_allclose(result_matvec, expected_matvec) + np.testing.assert_allclose(result_matvec, expected_matvec, atol=atol) # Test 3: vecmat - vector-matrix product vecmat_out = vecmat(vec_k, mat_kn) @@ -2141,7 +2140,7 @@ def test_matrix_vector_ops(): expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX) for i in range(batch_size): expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i]) - np.testing.assert_allclose(result_vecmat, expected_vecmat) + np.testing.assert_allclose(result_vecmat, expected_vecmat, atol=atol) class TestTensordot: From 425859b5ff5dec1dcca23b6695a672fba998608b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Jun 2025 11:42:22 +0200 Subject: [PATCH 11/11] Prioritize gemv/gerc over dot22scalar The marked xfail test was failing because Ger wasn't introduced, not because of the complex dtype. --- pytensor/tensor/rewriting/blas.py | 4 ++-- tests/tensor/test_blas.py | 8 -------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 6fd94f9b33..74b4d235dc 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -758,7 +758,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ignore_newtrees=False, ), "fast_run", - position=15, + position=11, ) @@ -903,7 +903,7 @@ def local_dot22_to_dot22scalar(fgraph, node): "local_dot22_to_dot22scalar", in2out(local_dot22_to_dot22scalar), "fast_run", - position=11, + position=12, ) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index f3fcf72cc5..1332266e3d 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -1903,17 +1903,9 @@ def test_f32_1_2(self): def test_f64_4_5(self): return self.given_dtype("float64", 4, 5, destructive=False) - @pytest.mark.xfail( - condition=config.floatX == "float32", - reason="GER from complex64 is not introduced in float32 mode", - ) def test_c64_7_1(self): return self.given_dtype("complex64", 7, 1) - @pytest.mark.xfail( - raises=AssertionError, - reason="Unclear how this test was supposed to work with complex128", - ) def test_c128_1_9(self): return self.given_dtype("complex128", 1, 9)