@@ -6729,101 +6729,6 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
67296729)
67306730
67316731
6732- def tile (operand : ArrayLike , reps : Sequence [int ]) -> Array :
6733- """Tiles an array by repeating it along each dimension.
6734-
6735- Args:
6736- operand: an array to tile.
6737- reps: a sequence of integers representing the number of repeats for each
6738- dimension. Must have the same length as ``operand.ndim``.
6739-
6740- Returns:
6741- A tiled array with shape ``(operand.shape[0] * reps[0], ...,
6742- operand.shape[-1] * reps[-1])``.
6743-
6744- Examples:
6745- >>> x = jnp.array([[1, 2], [3, 4]])
6746- >>> lax.tile(x, (2, 3))
6747- Array([[1, 2, 1, 2, 1, 2],
6748- [3, 4, 3, 4, 3, 4],
6749- [1, 2, 1, 2, 1, 2],
6750- [3, 4, 3, 4, 3, 4]], dtype=int32)
6751-
6752- >>> y = jnp.array([1, 2, 3])
6753- >>> lax.tile(y, (2,))
6754- Array([1, 2, 3, 1, 2, 3], dtype=int32)
6755-
6756- >>> z = jnp.array([[1], [2]])
6757- >>> lax.tile(z, (1, 3))
6758- Array([[1, 1, 1],
6759- [2, 2, 2]], dtype=int32)
6760- """
6761- return tile_p .bind (operand , reps = tuple (reps ))
6762-
6763-
6764- def _tile_abstract_eval (operand , * , reps ):
6765- if len (reps ) != operand .ndim :
6766- raise ValueError (
6767- 'tile reps must have length equal to operand.ndim, '
6768- f'got reps={ reps } for operand.ndim={ operand .ndim } '
6769- )
6770- out_shape = tuple (d * r for d , r in zip (operand .shape , reps ))
6771- return operand .update (shape = out_shape )
6772-
6773-
6774- def _tile_impl (operand , * , reps ):
6775- out_shape = tuple (d * r for d , r in zip (operand .shape , reps ))
6776- bcast_shape = []
6777- bcast_dims = []
6778- for d , r in zip (operand .shape , reps ):
6779- if d == 1 or r == 1 :
6780- bcast_dims .append (len (bcast_shape ))
6781- bcast_shape .append (d * r )
6782- else :
6783- bcast_dims .append (len (bcast_shape ) + 1 )
6784- bcast_shape .extend ((r , d ))
6785- bcast = broadcast_in_dim (operand , tuple (bcast_shape ), tuple (bcast_dims ))
6786- return reshape (bcast , out_shape )
6787-
6788-
6789- def _tile_transpose (ct , operand , * , reps ):
6790- assert ad .is_undefined_primal (operand )
6791- if type (ct ) is ad_util .Zero :
6792- return ad_util .Zero (operand .aval )
6793- reshape_shape = []
6794- reduce_dims = []
6795- for d , r in zip (operand .aval .shape , reps ):
6796- if r == 1 :
6797- reshape_shape .append (d )
6798- elif d == 1 :
6799- reduce_dims .append (len (reshape_shape ))
6800- reshape_shape .append (r )
6801- else :
6802- reduce_dims .append (len (reshape_shape ))
6803- reshape_shape .extend ((r , d ))
6804- reshaped_ct = reshape (ct , tuple (reshape_shape ))
6805- return [reduce_sum (reshaped_ct , tuple (reduce_dims ))]
6806-
6807-
6808- def _tile_batching_rule (batched_args , batch_dims , * , reps ):
6809- (operand ,) = batched_args
6810- (bdim ,) = batch_dims
6811- if bdim is None :
6812- return tile (operand , reps ), None
6813- reps = list (reps )
6814- reps .insert (bdim , 1 )
6815- return tile (operand , reps ), bdim
6816-
6817-
6818- tile_p = core .Primitive ('tile' )
6819- tile_p .def_impl (_tile_impl )
6820- tile_p .def_abstract_eval (_tile_abstract_eval )
6821- ad .deflinear2 (tile_p , _tile_transpose )
6822- batching .primitive_batchers [tile_p ] = _tile_batching_rule
6823- mlir .register_lowering (
6824- tile_p , mlir .lower_fun (_tile_impl , multiple_results = False ))
6825-
6826-
68276732def _clamp_shape_rule (min , operand , max ):
68286733 if min .shape and min .shape != operand .shape :
68296734 raise TypeError ("clamp requires min.shape == operand.shape or min.shape == "
0 commit comments