Skip to content

Commit f967e4c

Browse files
Reverts f6101c5
PiperOrigin-RevId: 825412424
1 parent e6a1918 commit f967e4c

File tree

6 files changed

+5
-180
lines changed

6 files changed

+5
-180
lines changed

jax/_src/lax/lax.py

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6729,101 +6729,6 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
67296729
)
67306730

67316731

6732-
def tile(operand: ArrayLike, reps: Sequence[int]) -> Array:
6733-
"""Tiles an array by repeating it along each dimension.
6734-
6735-
Args:
6736-
operand: an array to tile.
6737-
reps: a sequence of integers representing the number of repeats for each
6738-
dimension. Must have the same length as ``operand.ndim``.
6739-
6740-
Returns:
6741-
A tiled array with shape ``(operand.shape[0] * reps[0], ...,
6742-
operand.shape[-1] * reps[-1])``.
6743-
6744-
Examples:
6745-
>>> x = jnp.array([[1, 2], [3, 4]])
6746-
>>> lax.tile(x, (2, 3))
6747-
Array([[1, 2, 1, 2, 1, 2],
6748-
[3, 4, 3, 4, 3, 4],
6749-
[1, 2, 1, 2, 1, 2],
6750-
[3, 4, 3, 4, 3, 4]], dtype=int32)
6751-
6752-
>>> y = jnp.array([1, 2, 3])
6753-
>>> lax.tile(y, (2,))
6754-
Array([1, 2, 3, 1, 2, 3], dtype=int32)
6755-
6756-
>>> z = jnp.array([[1], [2]])
6757-
>>> lax.tile(z, (1, 3))
6758-
Array([[1, 1, 1],
6759-
[2, 2, 2]], dtype=int32)
6760-
"""
6761-
return tile_p.bind(operand, reps=tuple(reps))
6762-
6763-
6764-
def _tile_abstract_eval(operand, *, reps):
6765-
if len(reps) != operand.ndim:
6766-
raise ValueError(
6767-
'tile reps must have length equal to operand.ndim, '
6768-
f'got reps={reps} for operand.ndim={operand.ndim}'
6769-
)
6770-
out_shape = tuple(d * r for d, r in zip(operand.shape, reps))
6771-
return operand.update(shape=out_shape)
6772-
6773-
6774-
def _tile_impl(operand, *, reps):
6775-
out_shape = tuple(d * r for d, r in zip(operand.shape, reps))
6776-
bcast_shape = []
6777-
bcast_dims = []
6778-
for d, r in zip(operand.shape, reps):
6779-
if d == 1 or r == 1:
6780-
bcast_dims.append(len(bcast_shape))
6781-
bcast_shape.append(d * r)
6782-
else:
6783-
bcast_dims.append(len(bcast_shape) + 1)
6784-
bcast_shape.extend((r, d))
6785-
bcast = broadcast_in_dim(operand, tuple(bcast_shape), tuple(bcast_dims))
6786-
return reshape(bcast, out_shape)
6787-
6788-
6789-
def _tile_transpose(ct, operand, *, reps):
6790-
assert ad.is_undefined_primal(operand)
6791-
if type(ct) is ad_util.Zero:
6792-
return ad_util.Zero(operand.aval)
6793-
reshape_shape = []
6794-
reduce_dims = []
6795-
for d, r in zip(operand.aval.shape, reps):
6796-
if r == 1:
6797-
reshape_shape.append(d)
6798-
elif d == 1:
6799-
reduce_dims.append(len(reshape_shape))
6800-
reshape_shape.append(r)
6801-
else:
6802-
reduce_dims.append(len(reshape_shape))
6803-
reshape_shape.extend((r, d))
6804-
reshaped_ct = reshape(ct, tuple(reshape_shape))
6805-
return [reduce_sum(reshaped_ct, tuple(reduce_dims))]
6806-
6807-
6808-
def _tile_batching_rule(batched_args, batch_dims, *, reps):
6809-
(operand,) = batched_args
6810-
(bdim,) = batch_dims
6811-
if bdim is None:
6812-
return tile(operand, reps), None
6813-
reps = list(reps)
6814-
reps.insert(bdim, 1)
6815-
return tile(operand, reps), bdim
6816-
6817-
6818-
tile_p = core.Primitive('tile')
6819-
tile_p.def_impl(_tile_impl)
6820-
tile_p.def_abstract_eval(_tile_abstract_eval)
6821-
ad.deflinear2(tile_p, _tile_transpose)
6822-
batching.primitive_batchers[tile_p] = _tile_batching_rule
6823-
mlir.register_lowering(
6824-
tile_p, mlir.lower_fun(_tile_impl, multiple_results=False))
6825-
6826-
68276732
def _clamp_shape_rule(min, operand, max):
68286733
if min.shape and min.shape != operand.shape:
68296734
raise TypeError("clamp requires min.shape == operand.shape or min.shape == "

jax/_src/numpy/lax_numpy.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4519,13 +4519,11 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
45194519
reps_tup = tuple(reps) # type: ignore[arg-type]
45204520
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
45214521
for rep in reps_tup)
4522-
4523-
# Prepend 1s to reps to match A.ndim
4524-
if len(reps_tup) < A.ndim:
4525-
reps_tup = (1,) * (A.ndim - len(reps_tup)) + reps_tup
4526-
if len(reps_tup) > A.ndim:
4527-
A = lax.expand_dims(A, list(range(len(reps_tup) - A.ndim)))
4528-
return lax.tile(A, reps_tup)
4522+
A_shape = (1,) * (len(reps_tup) - np.ndim(A)) + np.shape(A)
4523+
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
4524+
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
4525+
[k for pair in zip(reps_tup, A_shape) for k in pair])
4526+
return reshape(result, tuple(np.multiply(A_shape, reps_tup)))
45294527

45304528
def _concatenate_array(arr: ArrayLike, axis: int | None,
45314529
dtype: DTypeLike | None = None) -> Array:

jax/lax/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@
229229
tan_p as tan_p,
230230
tanh as tanh,
231231
tanh_p as tanh_p,
232-
tile as tile,
233-
tile_p as tile_p,
234232
top_k as top_k,
235233
top_k_p as top_k_p,
236234
transpose as transpose,

tests/lax_autodiff_test.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,21 +1180,5 @@ def testPowShapeMismatch(self):
11801180
self.assertArraysEqual(actual, expected)
11811181

11821182

1183-
@jtu.sample_product(
1184-
[
1185-
dict(arg_shape=arg_shape, reps=reps)
1186-
for arg_shape, reps in [
1187-
[(3,), (2,)],
1188-
[(2, 3), (1, 2)],
1189-
]
1190-
],
1191-
dtype=grad_float_dtypes,
1192-
)
1193-
def testTileAutodiff(self, arg_shape, reps, dtype):
1194-
rng = jtu.rand_default(self.rng())
1195-
args_maker = lambda: [rng(arg_shape, dtype)]
1196-
op = lambda x: lax.tile(x, reps)
1197-
check_grads(op, args_maker(), order=3, modes=["fwd", "rev"], eps=1.)
1198-
11991183
if __name__ == '__main__':
12001184
absltest.main(testLoader=jtu.JaxTestLoader())

tests/lax_test.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,26 +1526,6 @@ def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape, dimensions):
15261526
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
15271527
self._CheckAgainstNumpy(numpy_op, op, args_maker)
15281528

1529-
@jtu.sample_product(
1530-
[
1531-
dict(arg_shape=arg_shape, reps=reps)
1532-
for arg_shape, reps in [
1533-
[(3,), (2,)],
1534-
[(2, 3), (1, 2)],
1535-
[(2, 3), (2, 1)],
1536-
[(2, 1, 3), (1, 2, 3)],
1537-
]
1538-
],
1539-
dtype=lax_test_util.default_dtypes,
1540-
)
1541-
def testTile(self, arg_shape, reps, dtype):
1542-
rng = jtu.rand_default(self.rng())
1543-
args_maker = lambda: [rng(arg_shape, dtype)]
1544-
op = lambda x: lax.tile(x, reps)
1545-
numpy_op = lambda x: np.tile(x, reps)
1546-
self._CompileAndCheck(op, args_maker)
1547-
self._CheckAgainstNumpy(numpy_op, op, args_maker)
1548-
15491529
@parameterized.parameters(
15501530
{"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
15511531
"err_msg": err_msg}

tests/lax_vmap_test.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -791,45 +791,5 @@ def g(a, b):
791791
self.assertAllClose(output, expected, check_dtypes=False)
792792

793793

794-
@jtu.sample_product(
795-
[
796-
dict(arg_shape=arg_shape, reps=reps)
797-
for arg_shape, reps in [
798-
[(3,), (2,)],
799-
[(2, 3), (1, 2)],
800-
[(2, 3), (2, 1)],
801-
[(2, 1, 3), (1, 2, 3)],
802-
]
803-
],
804-
in_axes=[0, 1, -1],
805-
out_axes=[0, 1, -1],
806-
)
807-
def testTileBatching(self, arg_shape, reps, in_axes, out_axes):
808-
rng = jtu.rand_default(self.rng())
809-
dtype = np.float32
810-
args_maker = lambda: [rng(arg_shape, dtype)]
811-
op = lambda x: lax.tile(x, reps)
812-
args = args_maker()
813-
814-
# Construct batched arguments based on in_axes
815-
if in_axes == 0:
816-
batched_args = [jnp.stack([arg, arg], axis=0) for arg in args]
817-
elif in_axes == 1:
818-
batched_args = [jnp.stack([arg, arg], axis=1) for arg in args]
819-
else: # in_axes == -1
820-
batched_args = [jnp.stack([arg, arg], axis=-1) for arg in args]
821-
822-
# Compute expected output
823-
out = op(*args)
824-
if out_axes == 0:
825-
expected = jnp.stack([out, out], axis=0)
826-
elif out_axes == 1:
827-
expected = jnp.stack([out, out], axis=1)
828-
else: # out_axes == -1
829-
expected = jnp.stack([out, out], axis=-1)
830-
831-
actual = jax.vmap(op, in_axes=in_axes, out_axes=out_axes)(*batched_args)
832-
self.assertAllClose(expected, actual)
833-
834794
if __name__ == '__main__':
835795
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)