2424 ScalarFromTensor ,
2525 TensorFromScalar ,
2626 alloc ,
27+ arange ,
2728 cast ,
2829 concatenate ,
2930 expand_dims ,
3435 switch ,
3536)
3637from pytensor .tensor .basic import constant as tensor_constant
37- from pytensor .tensor .blockwise import Blockwise
38+ from pytensor .tensor .blockwise import Blockwise , _squeeze_left
3839from pytensor .tensor .elemwise import Elemwise
3940from pytensor .tensor .exceptions import NotScalarConstantError
41+ from pytensor .tensor .extra_ops import broadcast_to
4042from pytensor .tensor .math import (
4143 add ,
4244 and_ ,
5860)
5961from pytensor .tensor .shape import (
6062 shape_padleft ,
63+ shape_padright ,
6164 shape_tuple ,
6265)
6366from pytensor .tensor .sharedvar import TensorSharedVariable
@@ -1578,6 +1581,9 @@ def local_blockwise_of_subtensor(fgraph, node):
15781581 """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
15791582
15801583 Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1584+
1585+ TODO: Handle batched indices like we do with blockwise of inc_subtensor
1586+ TODO: Extend to AdvanceSubtensor
15811587 """
15821588 if not isinstance (node .op .core_op , Subtensor ):
15831589 return
@@ -1598,64 +1604,151 @@ def local_blockwise_of_subtensor(fgraph, node):
15981604@register_stabilize ("shape_unsafe" )
15991605@register_specialize ("shape_unsafe" )
16001606@node_rewriter ([Blockwise ])
1601- def local_blockwise_advanced_inc_subtensor (fgraph , node ):
1602- """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1603- if not isinstance (node .op .core_op , AdvancedIncSubtensor ):
1604- return None
1607+ def local_blockwise_inc_subtensor (fgraph , node ):
1608+ """Rewrite blockwised inc_subtensors.
16051609
1606- x , y , * idxs = node .inputs
1610+ Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1611+ Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
16071612
1608- # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1609- if any (
1610- (
1611- isinstance (idx , SliceType | NoneTypeT )
1612- or (idx .type .dtype == "bool" and idx .type .ndim > 0 )
1613- )
1614- for idx in idxs
1615- ):
1613+ such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1614+ and can be safely rewritten without Blockwise.
1615+ """
1616+ core_op = node .op .core_op
1617+ if not isinstance (core_op , AdvancedIncSubtensor | IncSubtensor ):
16161618 return None
16171619
1618- op : Blockwise = node .op # type: ignore
1619- batch_ndim = op .batch_ndim (node )
1620-
1621- new_idxs = []
1622- for idx in idxs :
1623- if all (idx .type .broadcastable [:batch_ndim ]):
1624- new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1625- else :
1626- # Rewrite does not apply
1620+ x , y , * idxs = node .inputs
1621+ [out ] = node .outputs
1622+ if isinstance (node .op .core_op , AdvancedIncSubtensor ):
1623+ if any (
1624+ (
1625+ # Blockwise requires all inputs to be tensors so it is not possible
1626+ # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1627+ # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1628+ # are separated by basic indices
1629+ isinstance (idx , SliceType | NoneTypeT )
1630+ # Also get out if we have boolean indices as they cross dimension boundaries
1631+ # / can't be safely broadcasted depending on their runtime content
1632+ or (idx .type .dtype == "bool" )
1633+ )
1634+ for idx in idxs
1635+ ):
16271636 return None
16281637
1629- x_batch_bcast = x .type .broadcastable [:batch_ndim ]
1630- y_batch_bcast = y .type .broadcastable [:batch_ndim ]
1631- if any (xb and not yb for xb , yb in zip (x_batch_bcast , y_batch_bcast , strict = True )):
1632- # Need to broadcast batch x dims
1633- batch_shape = tuple (
1634- x_dim if (not xb or yb ) else y_dim
1635- for xb , x_dim , yb , y_dim in zip (
1636- x_batch_bcast ,
1638+ batch_ndim = node .op .batch_ndim (node )
1639+ idxs_core_ndim = [len (inp_sig ) for inp_sig in node .op .inputs_sig [2 :]]
1640+ max_idx_core_ndim = max (idxs_core_ndim , default = 0 )
1641+
1642+ # Step 1. Broadcast buffer to batch_shape
1643+ if x .type .broadcastable != out .type .broadcastable :
1644+ batch_shape = [1 ] * batch_ndim
1645+ for inp in node .inputs :
1646+ for i , (broadcastable , batch_dim ) in enumerate (
1647+ zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1648+ ):
1649+ if broadcastable :
1650+ # This dimension is broadcastable, it doesn't provide shape information
1651+ continue
1652+ if batch_shape [i ] != 1 :
1653+ # We already found a source of shape for this batch dimension
1654+ continue
1655+ batch_shape [i ] = batch_dim
1656+ x = broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1657+ assert x .type .broadcastable == out .type .broadcastable
1658+
1659+ # Step 2. Massage indices so they respect blockwise semantics
1660+ if isinstance (core_op , IncSubtensor ):
1661+ # For basic IncSubtensor there are two cases:
1662+ # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1663+ # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1664+ # in case we can end up with a basic IncSubtensor again
1665+ core_idxs = []
1666+ counter = 0
1667+ for idx in core_op .idx_list :
1668+ if isinstance (idx , slice ):
1669+ # Squeeze away dummy dimensions so we can convert to slice
1670+ new_entries = [None , None , None ]
1671+ for i , entry in enumerate ((idx .start , idx .stop , idx .step )):
1672+ if entry is None :
1673+ continue
1674+ else :
1675+ new_entries [i ] = new_entry = idxs [counter ].squeeze ()
1676+ counter += 1
1677+ if new_entry .ndim > 0 :
1678+ # If the slice entry has dimensions after the squeeze we can't convert it to a slice
1679+ # We could try to convert to equivalent integer indices, but nothing guarantees
1680+ # that the slice is "square".
1681+ return None
1682+ core_idxs .append (slice (* new_entries ))
1683+ else :
1684+ core_idxs .append (_squeeze_left (idxs [counter ]))
1685+ counter += 1
1686+ else :
1687+ # For AdvancedIncSubtensor we have tensor integer indices,
1688+ # We need to expand batch indexes on the right, so they don't interact with core index dimensions
1689+ # We still squeeze on the left in case that allows us to use simpler indices
1690+ core_idxs = [
1691+ _squeeze_left (
1692+ shape_padright (idx , max_idx_core_ndim - idx_core_ndim ),
1693+ stop_at_dim = batch_ndim ,
1694+ )
1695+ for idx , idx_core_ndim in zip (idxs , idxs_core_ndim )
1696+ ]
1697+
1698+ # Step 3. Create new indices for the new batch dimension of x
1699+ if not all (
1700+ all (idx .type .broadcastable [:batch_ndim ])
1701+ for idx in idxs
1702+ if not isinstance (idx , slice )
1703+ ):
1704+ # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1705+ # We build vectorized indexing with new arange indices that do not interact with core indices or each other
1706+ # (i.e., they broadcast)
1707+
1708+ # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1709+ # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1710+ # even if not all batch dimensions have corresponding batch indices.
1711+ batch_slices = [
1712+ shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1713+ for (x_batch_shape , n ) in zip (
16371714 tuple (x .shape )[:batch_ndim ],
1638- y_batch_bcast ,
1639- tuple (y .shape )[:batch_ndim ],
1640- strict = True ,
1715+ reversed (range (max_idx_core_ndim , max_idx_core_ndim + batch_ndim )),
16411716 )
1642- )
1643- core_shape = tuple (x .shape )[batch_ndim :]
1644- x = alloc (x , * batch_shape , * core_shape )
1645-
1646- new_idxs = [slice (None )] * batch_ndim + new_idxs
1647- x_view = x [tuple (new_idxs )]
1648-
1649- # We need to introduce any implicit expand_dims on core dimension of y
1650- y_core_ndim = y .type .ndim - batch_ndim
1651- if (missing_y_core_ndim := x_view .type .ndim - batch_ndim - y_core_ndim ) > 0 :
1652- missing_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1653- y = expand_dims (y , missing_axes )
1654-
1655- symbolic_idxs = x_view .owner .inputs [1 :]
1656- new_out = op .core_op .make_node (x , y , * symbolic_idxs ).outputs
1657- copy_stack_trace (node .outputs , new_out )
1658- return new_out
1717+ ]
1718+ else :
1719+ # In the case we don't have batch indices,
1720+ # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
1721+ batch_slices = [slice (None )] * batch_ndim
1722+
1723+ new_idxs = (* batch_slices , * core_idxs )
1724+ x_view = x [new_idxs ]
1725+
1726+ # Step 4. Introduce any implicit expand_dims on core dimension of y
1727+ missing_y_core_ndim = x_view .type .ndim - y .type .ndim
1728+ implicit_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1729+ y = _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1730+
1731+ if isinstance (core_op , IncSubtensor ):
1732+ # Check if we can still use a basic IncSubtensor
1733+ if isinstance (x_view .owner .op , Subtensor ):
1734+ new_props = core_op ._props_dict ()
1735+ new_props ["idx_list" ] = x_view .owner .op .idx_list
1736+ new_core_op = type (core_op )(** new_props )
1737+ symbolic_idxs = x_view .owner .inputs [1 :]
1738+ new_out = new_core_op (x , y , * symbolic_idxs )
1739+ else :
1740+ # We need to use AdvancedSet/IncSubtensor
1741+ if core_op .set_instead_of_inc :
1742+ new_out = x [new_idxs ].set (y )
1743+ else :
1744+ new_out = x [new_idxs ].inc (y )
1745+ else :
1746+ # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1747+ symbolic_idxs = x_view .owner .inputs [1 :]
1748+ new_out = core_op (x , y , * symbolic_idxs )
1749+
1750+ copy_stack_trace (out , new_out )
1751+ return [new_out ]
16591752
16601753
16611754@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
0 commit comments