diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 656eba1a92..fd8af3acba 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -8,6 +8,7 @@ ) from pytensor import Variable +from pytensor.compile import optdb from pytensor.graph import Constant, FunctionGraph, node_rewriter from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.scalar import basic as ps @@ -43,8 +44,10 @@ ) from pytensor.tensor.special import Softmax, softmax from pytensor.tensor.subtensor import ( + AdvancedSubtensor, AdvancedSubtensor1, Subtensor, + _non_contiguous_adv_indexing, as_index_literal, get_canonical_form_slice, get_constant_idx, @@ -52,7 +55,7 @@ indices_from_subtensor, ) from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import SliceType +from pytensor.tensor.type_other import NoneTypeT, SliceType def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]: @@ -818,3 +821,79 @@ def local_subtensor_shape_constant(fgraph, node): return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] elif shape_parts: return [as_tensor(1, dtype=np.int64)] + + +@node_rewriter([Subtensor]) +def local_subtensor_of_adv_subtensor(fgraph, node): + """Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. + + x[:, :, vec_idx][i, j] -> x[i, j][vec_idx] + x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k] + + Restricted to a single advanced indexing dimension. + + An alternative approach could have fused the basic and advanced indices, + so it is not clear this rewrite should be canonical or a specialization. + Users must include it manually if it fits their use case. + """ + adv_subtensor, *idxs = node.inputs + + if not ( + adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor) + ): + return None + + if len(fgraph.clients[adv_subtensor]) > 1: + # AdvancedSubtensor involves a full_copy, so we don't want to do it twice + return None + + x, *adv_idxs = adv_subtensor.owner.inputs + + # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices + if any( + ( + isinstance(adv_idx.type, NoneTypeT) + or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") + or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) + ) + for adv_idx in adv_idxs + ) or _non_contiguous_adv_indexing(adv_idxs): + return None + + for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): + # We already made sure there were only None slices besides integer indexes + if isinstance(adv_idx.type, TensorType): + break + else: # no-break + # Not sure if this should ever happen, but better safe than sorry + return None + + basic_idxs = indices_from_subtensor(idxs, node.op.idx_list) + basic_idxs_lifted = basic_idxs[:first_adv_idx_dim] + basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[ + first_adv_idx_dim: + ] + + if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted): + # All basic indices happen to the right of the advanced indices + return None + + [basic_subtensor] = node.outputs + dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted) + + x_indexed = x[basic_idxs_lifted] + copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) + + x_after_index_lift = expand_dims(x_indexed, dropped_dims) + x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) + copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) + + new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) + return [new_out] + + +# Rewrite will only be included if tagged by name +r = local_subtensor_of_adv_subtensor +optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False) +optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False) +del r diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 78e529178e..e02fdc1083 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -52,7 +52,7 @@ ) from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape from pytensor.tensor.special import softmax -from pytensor.tensor.subtensor import Subtensor +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) @@ -756,3 +756,47 @@ def __eq__(self, other): x = shape(Variable(MyType(), None, None))[0] assert not local_subtensor_shape_constant.transform(None, x.owner) + + +@pytest.mark.parametrize( + "original_fn, supported", + [ + (lambda x: x[:, [0, 1]][0], True), + (lambda x: x[:, [0, 1], [0, 0]][1:], True), + (lambda x: x[:, [[0, 1], [0, 0]]][1:], True), + # Not supported, basic indexing on advanced indexing dim + (lambda x: x[[0, 1]][0], False), + # Not implemented, basic indexing on the right of advanced indexing + (lambda x: x[[0, 1]][:, 0], False), + # Not implemented, complex flavors of advanced indexing + (lambda x: x[:, None, [0, 1]][0], False), + (lambda x: x[:, 5:, [0, 1]][0], False), + (lambda x: x[:, :, np.array([True, False, False])][0], False), + (lambda x: x[[0, 1], :, [0, 1]][:, 0], False), + ], +) +def test_local_subtensor_of_adv_subtensor(original_fn, supported): + rng = np.random.default_rng(257) + x = pt.tensor3("x", shape=(7, 5, 3)) + x_test = rng.normal(size=x.type.shape) + + out = original_fn(x) + opt_out = rewrite_graph( + out, include=("canonicalize", "local_subtensor_of_adv_subtensor") + ) + # The graphs generated are too complicated to assert + # We simply check that the happens before the advanced subtensor + toposort = FunctionGraph(outputs=[opt_out], clone=False).toposort() + [idx_subtensor] = [ + i for i, node in enumerate(toposort) if isinstance(node.op, Subtensor) + ] + [idx_adv_subtensor] = [ + i for i, node in enumerate(toposort) if isinstance(node.op, AdvancedSubtensor) + ] + swapped = idx_subtensor < idx_adv_subtensor + correct = swapped if supported else not swapped + assert correct, debugprint(opt_out, print_type=True) + np.testing.assert_allclose( + opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + )