Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Add support for second minor broadcasts with packed types #25636

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3334,21 +3334,38 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape);
if (needs_physical_broadcast ==
std::array{true, false}) { // Sublane broadcast
if (layout_in.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only 32-bit supported for sublane broadcast");
}
const int bitwidth = layout_in.bitwidth();
const int packing = layout_in.packing();
if (num_tiles != 1) {
return op.emitOpError(
"Not implemented: Only native tiling supported");
}
TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1);
TPU_ASSERT_OP(offsets_in[0].has_value());
const int64_t offset = *offsets_in[0];
const int64_t sublane_offset = *offsets_in[0] / packing;
const int64_t subelement_offset = *offsets_in[0] % packing;
const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr(
SmallVector<int32_t>(ctx.target_shape[0], offset));
SmallVector<int32_t>(ctx.target_shape[0], sublane_offset));
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
Value *const src_tile) {
Value *const src_vreg) {
Value dst_vreg = *src_vreg;
// Replicate the value within each sublane.
if (packing != 1) {
auto vreg_int_ty = getNativeVregType(
builder.getIntegerType(bitwidth), ctx.target_shape);
auto src_vreg_int =
builder.create<tpu::BitcastVregOp>(vreg_int_ty, dst_vreg);
auto unpack_elem = builder.create<tpu::UnpackSubelementsOp>(
getNativeVregType(builder.getI32Type(), ctx.target_shape),
src_vreg_int, subelement_offset, tpu::PackFormat::kInterleaved);
SmallVector<Value> packed_vregs(packing, unpack_elem);
auto vreg_int = builder.create<tpu::PackSubelementsOp>(
vreg_int_ty, packed_vregs, tpu::PackFormat::kInterleaved);
dst_vreg = builder.create<tpu::BitcastVregOp>(dst_vreg.getType(),
vreg_int);
}
dst_vreg = builder.create<tpu::GatherOp>(dst_vreg.getType(), dst_vreg,
indices, 0);
SmallVector<int64_t> dst_starts(dst_tiles_implicit_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_implicit_shape.size());
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
Expand All @@ -3360,10 +3377,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
dst_limits[i] = dst_starts[i] + 1;
}
}
updateSlice<Value>(dst_tiles,
builder.create<tpu::GatherOp>(
src_tile->getType(), *src_tile, indices, 0),
dst_starts, dst_limits);
updateSlice<Value>(dst_tiles, dst_vreg, dst_starts, dst_limits);
});
} else if (needs_physical_broadcast ==
std::array{false, true}) { // Lane broadcast
Expand Down
8 changes: 3 additions & 5 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1062,16 +1062,14 @@ class VectorLayoutInferer {
// should always use that when sublane broadcasting is required.
if (src_tiled_ishape[0] != dst_tiled_ishape[0] &&
layout.offsets()[0] != std::nullopt) {
if (layout.bitwidth() != kNativeBitwidth) {
NYI("Only 32-bit broadcasts supported");
}
LayoutOffsets offsets = layout.offsets();
// At the moment relayout can only produce replicated sublanes when
// converting to (8, 128) if the input was in (1, 128) tiling
if (layout.tiling()[0] == 1) {
if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) {
offsets[0] = std::nullopt;
}
layout = VectorLayout(layout.bitwidth(), offsets, default_tiling_,
layout = VectorLayout(layout.bitwidth(), offsets,
nativeTiling(layout.bitwidth()),
layout.implicit_dim());
}
LayoutOffsets offsets = layout.offsets();
Expand Down
17 changes: 17 additions & 0 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ def kernel(x_ref, y_ref, out_ref):
)(x, y)
np.testing.assert_array_equal(out, inp.reshape(m * 2, n))

@parameterized.parameters([jnp.int32, jnp.int16, jnp.int8])
def test_row_broadcast(self, dtype):
if not jtu.if_cloud_tpu_at_least(2024, 1, 9):
self.skipTest("Requires libtpu built after 2024-01-09")
if not self.INTERPRET and jtu.get_tpu_version() < 5:
self.skipTest("Requires TPUv5+")
def kernel(x_ref, y_ref):
y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape)
m, n = 4, 1024
x = jax.random.randint(
jax.random.key(12), (m, n), minval=-1000, maxval=1000, dtype=jnp.int32
).astype(dtype)
y = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((m, n), dtype)
)(x)
np.testing.assert_array_equal(y, jnp.broadcast_to(x[3:4], y.shape))

def test_tpu_unsigned_int(self):
def body(x_ref, o_ref):
# Test cast from uint16 -> uint32
Expand Down
Loading