From 57b21541a246500a7b54d05abf0d10f48c8d8f82 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 2 Jan 2025 11:49:43 -0800 Subject: [PATCH] [Mosaic] NFC: Pull out vreg related functions to util. These functions are related to vreg manipulation and are used in different rules. PiperOrigin-RevId: 711484002 --- jaxlib/mosaic/BUILD | 15 ++ .../tpu/transforms/apply_vector_layout.cc | 220 +++-------------- jaxlib/mosaic/dialect/tpu/vreg_util.cc | 206 ++++++++++++++++ jaxlib/mosaic/dialect/tpu/vreg_util.h | 82 +++++++ jaxlib/mosaic/dialect/tpu/vreg_util_test.cc | 228 ++++++++++++++++++ 5 files changed, 567 insertions(+), 184 deletions(-) create mode 100644 jaxlib/mosaic/dialect/tpu/vreg_util.cc create mode 100644 jaxlib/mosaic/dialect/tpu/vreg_util.h create mode 100644 jaxlib/mosaic/dialect/tpu/vreg_util_test.cc diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 62cffd26f829..37f9a35596d6 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -43,6 +43,7 @@ cc_library( "dialect/tpu/tpu_dialect.cc", "dialect/tpu/tpu_ops.cc", "dialect/tpu/util.cc", + "dialect/tpu/vreg_util.cc", ":extension_srcs", ] + glob([ "dialect/tpu/transforms/*.cc", @@ -51,6 +52,7 @@ cc_library( "dialect/tpu/layout.h", "dialect/tpu/tpu_dialect.h", "dialect/tpu/util.h", + "dialect/tpu/vreg_util.h", ] + glob([ "dialect/tpu/transforms/*.h", ]), @@ -232,6 +234,19 @@ cc_library( alwayslink = True, ) +cc_test( + name = "vreg_util_test", + srcs = ["dialect/tpu/vreg_util_test.cc"], + deps = [ + ":tpu_dialect", + "//testing/base/public:gunit_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:VectorDialect", + ], +) + filegroup( name = "extension_srcs", srcs = [ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c006950def3b..3a8263573544 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -29,7 +29,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -52,6 +51,7 @@ #include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/include/llvm/ADT/APInt.h" +#include "llvm/include/llvm/Support/LogicalResult.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" @@ -64,6 +64,7 @@ #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include "jaxlib/mosaic/dialect/tpu/util.h" +#include "jaxlib/mosaic/dialect/tpu/vreg_util.h" #include "xla/array.h" #include "xla/layout.h" #include "xla/util.h" @@ -275,16 +276,6 @@ void updateSliceFromRange(xla::Array &arr, Range data, CHECK(data_it == data.end()); } -FailureOr getZeroIntOrFloatAttr(Type ty) { - if (isa(ty)) { - return TypedAttr(FloatAttr::get(ty, 0)); - } - if (isa(ty)) { - return TypedAttr(IntegerAttr::get(ty, 0)); - } - return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty; -} - FailureOr getIntConst(Value v, bool silent = false) { if (auto constant_op = v.getDefiningOp()) { if (auto integer_attr = dyn_cast(constant_op.getValue())) { @@ -479,33 +470,6 @@ FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, return argument; } -VectorType getNativeVregOrVmaskTypeImpl( - Type elem_ty, const int8_t bitwidth, - const std::array target_shape) { - if (bitwidth == 32) { - return VectorType::get(target_shape, elem_ty); - } - return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth}, - elem_ty); -} - -VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth, - const std::array target_shape) { - int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); - if (bitwidth == 1) { - bitwidth = layout_bitwidth; - } else { - CHECK_EQ(bitwidth, layout_bitwidth); - } - return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape); -} - -VectorType getNativeVregType(Type elem_ty, - const std::array target_shape) { - return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), - target_shape); -} - // Masks all values outside of bounds. // // Arguments: @@ -518,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty, // Returns: // An MLIR value of the same type as the value argument, with all entries // outside of bounds replaced by neutral. -FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, +FailureOr maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder, TypedValue value, const VRegDataBounds &bounds, const Attribute neutral) { @@ -542,9 +506,7 @@ FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, value.getLoc(), VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask); } - auto neutral_vec = builder.create( - value.getLoc(), native_vreg_ty, - DenseElementsAttr::get(native_vreg_ty, neutral)); + Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral); return builder .create(value.getLoc(), mask, value, neutral_vec) .getResult(); @@ -1863,126 +1825,28 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]); TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]); - const VectorType i32_vreg_ty = - getNativeVregType(builder.getI32Type(), ctx.target_shape); - auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) { - CHECK(dim == 0 || dim == 1); - CHECK(padding >= 0 && padding <= ctx.target_shape[dim]); - return cast>( - builder - .create( - arith::CmpIPredicate::slt, - builder.create(i32_vreg_ty, - builder.getI32IntegerAttr(dim)), - builder.create(DenseElementsAttr::get( - i32_vreg_ty, builder.getI32IntegerAttr( - ctx.target_shape[dim] - padding)))) - .getResult()); - }; - - // We can also extend this helper function with padding_top and padding_left - // based on the offsets in vregs. - const Value i32_zeros_vreg = builder.create( - op.getLoc(), - DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0))); - const Value i32_max_vreg = builder.create( - op.getLoc(), DenseElementsAttr::get( - i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff))); - auto maskVregs = [&](xla::Array &vregs, int64_t padding_bottom, - int64_t padding_right) { - auto vreg_ty = cast(vregs.begin()->getType()); - int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1; - // Mask out the bottom. - if (padding_bottom > 0) { - // We have limited the row size of LHS and RHS need to be a multiple of - // native tiling at the beginning of this rule. Therefore, it is safe to - // bitcast to x32 vreg for masking. - int sub_padding = padding_bottom % packing; - int x32_padding_bottom = padding_bottom / packing; - auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom); - // Create an int32 vreg which contains subelement masking and then - // logical_and with target vreg to mask out the unaligned paddings. - // Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is - // [8, 128], then the mask will be: - // - // sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff] - // sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff] - // sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff] - // sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff] - // sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff] - // sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff] - // sublane 6: [0 , 0 , ..., 0 ] - // sublane 7: [0 , 0 , ..., 0 ] - // - // Through this way, in order to mask sub-elements, each target vreg only - // needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select - // + packing). - Value partial_sublane_mask = builder.create( - op.getLoc(), - DenseElementsAttr::get( - i32_vreg_ty, - builder.getI32IntegerAttr( - 0xffffffff >> - (sub_padding * vreg_ty.getElementTypeBitWidth())))); - // Insert 0xffffffff above the blended sublane. - Value sublane_mask = builder.create( - getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg, - partial_sublane_mask); - // Insert 0 below the blended sublane. - sublane_mask = builder.create(mask_bottom, sublane_mask, - i32_zeros_vreg); - for (int64_t i = 0; i < vregs.dim(1); ++i) { - Value &vreg = vregs({vregs.dim(0) - 1, i}); - Value i32_vreg = builder.create(i32_vreg_ty, vreg); - if (sub_padding > 0) { - i32_vreg = builder.create(i32_vreg, sublane_mask); - } else { - i32_vreg = builder.create(mask_bottom, i32_vreg, - i32_zeros_vreg); - } - vreg = builder.create(vreg_ty, i32_vreg); - } - } - // Mask out the right. - if (padding_right > 0) { - auto mask_right = getX32VmaskByPaddingEnd(1, padding_right); - for (int64_t i = 0; i < vregs.dim(0); ++i) { - Value &vreg = vregs({i, vregs.dim(1) - 1}); - Value i32_vreg = builder.create(i32_vreg_ty, vreg); - i32_vreg = builder.create(mask_right, i32_vreg, - i32_zeros_vreg); - vreg = builder.create(vreg_ty, i32_vreg); - } - } - }; - - // Create a vreg filled with zeros. - auto getZerosVergLike = - [&](const Value &vreg) -> FailureOr> { - const VectorType vreg_type = cast(vreg.getType()); - FAILUREOR_ASSIGN_OR_RETURN( - const Attribute zero_attr, - getZeroIntOrFloatAttr(vreg_type.getElementType())); - return cast>( - builder - .create( - op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr)) - .getResult()); - }; - - FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg, - getZerosVergLike(*lhs_vregs.begin())); - FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg, - getZerosVergLike(*rhs_vregs.begin())); - FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg, - getZerosVergLike(*acc_vregs.begin())); + auto lhs_zeros_vreg = + getZerosVector(builder, cast(lhs_vregs.begin()->getType())); + auto rhs_zeros_vreg = + getZerosVector(builder, cast(rhs_vregs.begin()->getType())); + auto acc_zeros_vreg = + getZerosVector(builder, cast(acc_vregs.begin()->getType())); // Only mask out the paddings on contracting dim of LHS and RHS. - maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]); + RETURN_IF_FAILED( + maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape, + /*padding_bottom=*/0, + /*padding_right=*/padded_lhs_cols - lhs_shape[1])); if (transpose_rhs) { - maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]); + RETURN_IF_FAILED(maskNativeTilingVregs( + builder, rhs_vregs, ctx.target_shape, + /*padding_bottom=*/0, + /*padding_right=*/padded_rhs_cols - rhs_shape[1])); } else { - maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0); + RETURN_IF_FAILED( + maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape, + /*padding_bottom=*/padded_rhs_rows - rhs_shape[0], + /*padding_right=*/0)); } // At this point, all paddings on vregs are masked out. For now, we @@ -2875,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, native_vreg_ty, /*dimension =*/builder.getI32IntegerAttr(1)); for (int64_t i = 0; i < num_tiles; ++i) { - auto offset = builder.create( - native_vreg_ty, - DenseElementsAttr::get( - native_vreg_ty, - IntegerAttr::get(vty.getElementType(), - i * *(native_vreg_ty.getShape().end() - 1)))); + Value offset = getFullVector( + builder, native_vreg_ty, + IntegerAttr::get(vty.getElementType(), + i * *(native_vreg_ty.getShape().end() - 1))); tiles[i] = builder.create(vreg_iota, offset); } xla::Array broadcasted_tiles(tile_array_shape); @@ -2902,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, native_vreg_ty, /*dimension =*/builder.getI32IntegerAttr(0)); for (int64_t i = 0; i < num_tiles; ++i) { - auto offset = builder.create( - native_vreg_ty, - DenseElementsAttr::get( - native_vreg_ty, - IntegerAttr::get(vty.getElementType(), - i * *(native_vreg_ty.getShape().end() - 2)))); + Value offset = getFullVector( + builder, native_vreg_ty, + IntegerAttr::get(vty.getElementType(), + i * *(native_vreg_ty.getShape().end() - 2))); tiles[i] = builder.create(vreg_iota, offset); } xla::Array broadcasted_tiles(tile_array_shape); @@ -2924,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, SmallVector tiles; tiles.reserve(vty.getDimSize(*dimension)); for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) { - tiles.push_back(builder.create( - native_vreg_ty, - DenseElementsAttr::get(native_vreg_ty, - IntegerAttr::get(vty.getElementType(), i)))); + tiles.push_back(getFullVector(builder, native_vreg_ty, + IntegerAttr::get(vty.getElementType(), i))); } xla::Array out_tiles(tile_array_shape); out_tiles.Each([&](absl::Span idxs, Value *v) { @@ -3516,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, const int64_t offset = *offsets_in[1]; const int64_t lane_offset = offset % ctx.target_shape[1]; const int64_t tile_offset = offset / ctx.target_shape[1]; - const auto idx_ty = - VectorType::get(ctx.target_shape, builder.getI32Type()); - auto lane_offset_cst = builder.create( - broadcast_op.getLoc(), idx_ty, - DenseElementsAttr::get(idx_ty, - builder.getI32IntegerAttr(lane_offset))); + Value lane_offset_cst = getFullVector( + builder, getNativeVregType(builder.getI32Type(), ctx.target_shape), + builder.getI32IntegerAttr(lane_offset)); DenseI32ArrayAttr sublane_pattern; if (num_tiles != 1) { SmallVector pattern; @@ -3581,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, getNativeVregType(src_i32.getType(), ctx.target_shape); auto tile_i32 = builder.create(native_vreg_ty, src_i32); - auto zeros = builder.create( - broadcast_op.getLoc(), tile_i32.getType(), - DenseElementsAttr::get(tile_i32.getType(), - builder.getI32IntegerAttr(0))); + Value zeros = getZerosVector(builder, tile_i32.getType()); auto tile = builder.create(arith::CmpIPredicate::ne, tile_i32, zeros) .getResult(); diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc new file mode 100644 index 000000000000..7dc5c13c073e --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -0,0 +1,206 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/dialect/tpu/vreg_util.h" + +#include +#include + +#include "absl/log/check.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/IR/Attributes.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" +#include "mlir/include/mlir/IR/Diagnostics.h" +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/include/mlir/IR/Types.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" +#include "xla/array.h" + +namespace mlir::tpu { + +namespace { + +VectorType getNativeVregOrVmaskTypeImpl( + Type elem_ty, const int8_t bitwidth, + const std::array target_shape) { + if (bitwidth == 32) { + return VectorType::get(target_shape, elem_ty); + } + return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth}, + elem_ty); +} + +} // namespace + +VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth, + const std::array target_shape) { + int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); + if (bitwidth == 1) { + bitwidth = layout_bitwidth; + } else { + CHECK_EQ(bitwidth, layout_bitwidth); + } + return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape); +} + +VectorType getNativeVregType(Type elem_ty, + const std::array target_shape) { + return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), + target_shape); +} + +TypedValue getFullVector(ImplicitLocOpBuilder &builder, + VectorType vty, Attribute value) { + return cast>( + builder.create(DenseElementsAttr::get(vty, value)) + .getResult()); +} + +TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, + TypedValue vec, + Attribute value) { + return getFullVector(builder, vec.getType(), value); +} + +TypedValue getZerosVector(ImplicitLocOpBuilder &builder, + VectorType vty) { + return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType())); +} + +TypedValue getZerosLikeVector(ImplicitLocOpBuilder &builder, + TypedValue vec) { + return getZerosVector(builder, vec.getType()); +} + +FailureOr> getX32VmaskByPaddingEnd( + ImplicitLocOpBuilder &builder, int64_t padding, + const std::array target_shape, int64_t dim) { + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), target_shape); + if (dim != 0 && dim != 1) { + return builder.emitError() + << "Expected a 2D vector for getX32VmaskByPaddingEnd"; + } + + if (padding < 0 || padding > target_shape[dim]) { + return builder.emitError() + << "Padding must be in [0, target_shape[dim]). Padding: " << padding + << ", target_shape[dim]: " << target_shape[dim]; + } + + Value padding_vreg = + getFullVector(builder, i32_vreg_ty, + builder.getI32IntegerAttr(target_shape[dim] - padding)); + + return cast>( + builder + .create( + arith::CmpIPredicate::slt, + builder.create(i32_vreg_ty, + builder.getI32IntegerAttr(dim)), + padding_vreg) + .getResult()); +} + +LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, + xla::Array &vregs, + std::array target_shape, + int64_t padding_bottom, + int64_t padding_right) { + auto vreg_ty = dyn_cast(vregs.begin()->getType()); + if (!vreg_ty) { + return builder.emitError() << "Expected a vector type"; + } + + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), target_shape); + Value i32_zeros_vreg = getZerosVector(builder, i32_vreg_ty); + Value i32_max_vreg = getFullVector(builder, i32_vreg_ty, + builder.getI32IntegerAttr(0xffffffff)); + + int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1; + // Mask out the bottom. + if (padding_bottom > 0) { + // The function is only called when the vreg has native tiling. Therefore, + // it is safe to bitcast to x32 vreg for masking. + int sub_padding = padding_bottom % packing; + int x32_padding_bottom = padding_bottom / packing; + FAILUREOR_ASSIGN_OR_RETURN( + Value mask_top, getX32VmaskByPaddingEnd(builder, x32_padding_bottom + 1, + target_shape, /*dim=*/0)); + FAILUREOR_ASSIGN_OR_RETURN( + Value mask_bottom, + getX32VmaskByPaddingEnd(builder, x32_padding_bottom, target_shape, + /*dim=*/0)); + // Create an int32 vreg which contains subelement masking and then + // logical_and with target vreg to mask out the unaligned paddings. + // Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is + // [8, 128], then the mask will be: + // + // sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff] + // sublane 6: [0 , 0 , ..., 0 ] + // sublane 7: [0 , 0 , ..., 0 ] + // + // Through this way, in order to mask sub-elements, each target vreg only + // needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + // + packing). + Value partial_sublane_mask = getFullVector( + builder, i32_vreg_ty, + builder.getI32IntegerAttr( + 0xffffffff >> (sub_padding * vreg_ty.getElementTypeBitWidth()))); + // Insert 0xffffffff above the blended sublane. + Value sublane_mask = builder.create(mask_top, i32_max_vreg, + partial_sublane_mask); + // Insert 0 below the blended sublane. + sublane_mask = builder.create(mask_bottom, sublane_mask, + i32_zeros_vreg); + for (int64_t i = 0; i < vregs.dim(1); ++i) { + Value &vreg = vregs({vregs.dim(0) - 1, i}); + Value i32_vreg = builder.create(i32_vreg_ty, vreg); + if (sub_padding > 0) { + i32_vreg = builder.create(i32_vreg, sublane_mask); + } else { + i32_vreg = builder.create(mask_bottom, i32_vreg, + i32_zeros_vreg); + } + vreg = builder.create(vreg_ty, i32_vreg); + } + } + // Mask out the right. + if (padding_right > 0) { + FAILUREOR_ASSIGN_OR_RETURN( + Value mask_right, getX32VmaskByPaddingEnd(builder, padding_right, + target_shape, /*dim=*/1)); + for (int64_t i = 0; i < vregs.dim(0); ++i) { + Value &vreg = vregs({i, vregs.dim(1) - 1}); + Value i32_vreg = builder.create(i32_vreg_ty, vreg); + i32_vreg = + builder.create(mask_right, i32_vreg, i32_zeros_vreg); + vreg = builder.create(vreg_ty, i32_vreg); + } + } + return success(); +} + +} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h new file mode 100644 index 000000000000..5892582a9f4a --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -0,0 +1,82 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_ + +#include +#include + +#include "mlir/include/mlir/IR/Attributes.h" +#include "mlir/include/mlir/IR/Builders.h" +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/include/mlir/IR/Types.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "xla/array.h" + +namespace mlir::tpu { + +// Returns the native vreg or vmask type for the given element type and target +// shape. The layout bitwidth is used for i1 (vmask) elements. +VectorType getNativeVregOrVmaskType(Type elem_ty, int8_t layout_bitwidth, + std::array target_shape); +VectorType getNativeVregType(Type elem_ty, std::array target_shape); + +// Returns a zero constant of the same type as `vty`. +TypedValue getZerosVector(ImplicitLocOpBuilder &builder, + VectorType vty); +// Same as above, but takes a `vec` as input. +TypedValue getZerosLikeVector(ImplicitLocOpBuilder &builder, + TypedValue vec); + +// Returns a constant of the same type as `vty` with the given `value`. +TypedValue getFullVector(ImplicitLocOpBuilder &builder, + VectorType vty, Attribute value); +// Same as above, but takes a `vec` as input. +TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, + TypedValue vec, + Attribute value); + +// Creates a vmask with false flags to bottom (dim = 0) +// or right (dim = 1) where the flag count corresponds to the (dim_size - +// padding). +// +// For example, assume vmask shape is (4, 8) +// +// getX32VmaskByPaddingEnd(padding=3, dim=1) creates: +// [T, T, T, T, T, F, F, F] +// [T, T, T, T, T, F, F, F] +// [T, T, T, T, T, F, F, F] +// [T, T, T, T, T, F, F, F] +// TODO(b/385204135): Unify with getVmaskByPaddingEnd in tpu_rotate_rule, and +// improve the codegen. +FailureOr> getX32VmaskByPaddingEnd( + ImplicitLocOpBuilder &builder, int64_t padding, + std::array target_shape, int64_t dim); + +// Masks out the padding in the bottom and right of the vregs. vregs are +// expected to have native tiling, and the masked vregs are mutated in +// `vregs`. `padding_bottom` and `padding_right` is the number of elements to +// pad in the bottom and right. +LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, + xla::Array &vregs, + std::array target_shape, + int64_t padding_bottom, + int64_t padding_right); + +} // namespace mlir::tpu + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_ diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc new file mode 100644 index 000000000000..dadbac133fbf --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/dialect/tpu/vreg_util.h" + +#include +#include +#include + +#include +#include +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/include/mlir/IR/Builders.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/include/mlir/IR/MLIRContext.h" +#include "mlir/include/mlir/IR/OwningOpRef.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/Support/DebugStringHelper.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" + +namespace mlir::tpu { + +namespace { + +using ::testing::Eq; +using ::testing::Optional; + +MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") { + auto constant_op = dyn_cast(arg.getDefiningOp()); + if (constant_op == nullptr) { + *result_listener << "Expected a constant op, got " << debugString(arg); + return false; + } + auto dense_attr = dyn_cast(constant_op.getValue()); + if (dense_attr == nullptr) { + *result_listener << "Expected a dense elements attr, got " + << debugString(arg); + return false; + } + if (dense_attr.getType() != type) { + *result_listener << "Expected a dense elements attr with type " + << debugString(type) << ", got " + << debugString(dense_attr.getType()); + return false; + } + if (!dense_attr.isSplat()) { + *result_listener << "Expected a splat dense elements attr, got " + << debugString(dense_attr); + return false; + } + if (auto s = dense_attr.template getSplatValue(); + s != splat_value) { + *result_listener << "Expected a splat dense elements attr with value " + << splat_value << ", got " << s; + return false; + } + return true; +} + +MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") { + auto vty = dyn_cast(arg); + if (vty == nullptr) { + *result_listener << "Expected a vector type, got " << debugString(arg); + return false; + } + if (vty.getShape() != ArrayRef(shape)) { + *result_listener << "Expected a vector type with shape " + << absl::StrJoin(shape, ",") << ", got " + << absl::StrJoin(vty.getShape(), ","); + return false; + } + if (vty.getElementType() != elem_ty) { + *result_listener << "Expected a vector type with element type " + << debugString(elem_ty) << ", got " + << debugString(vty.getElementType()); + return false; + } + return true; +} + +class VregUtilTest : public ::testing::Test { + protected: + void SetUp() override { + context_.loadDialect(); + mlir::Location loc = mlir::UnknownLoc::get(&context_); + mlir::OpBuilder b(&context_); + module_ = b.create(loc); + builder_ = std::make_unique( + module_->getLoc(), module_->getBodyRegion()); + } + + void TearDown() override { + builder_.reset(); + // Reset the module to prevent memory leaks. + module_ = nullptr; + } + + mlir::ImplicitLocOpBuilder& Builder() { return *builder_; } + + private: + MLIRContext context_; + std::unique_ptr builder_; + OwningOpRef module_; +}; + +TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeBitwidthMismatch) { + EXPECT_DEATH(getNativeVregOrVmaskType(Builder().getI16Type(), + /*layout_bitwidth=*/8, {2, 4}), + ""); +} + +TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeI1) { + EXPECT_THAT(getNativeVregOrVmaskType(Builder().getI1Type(), + /*layout_bitwidth=*/8, {2, 4}), + IsVectorTypeWithShape(std::array{2, 4, 4}, + Builder().getI1Type())); +} + +TEST_F(VregUtilTest, GetNativeVregF32) { + EXPECT_THAT(getNativeVregType(Builder().getF32Type(), {2, 4}), + IsVectorTypeWithShape(std::array{2, 4}, + Builder().getF32Type())); +} + +TEST_F(VregUtilTest, GetNativeVregBf16) { + EXPECT_THAT(getNativeVregType(Builder().getBF16Type(), {2, 4}), + IsVectorTypeWithShape(std::array{2, 4, 2}, + Builder().getBF16Type())); +} + +TEST_F(VregUtilTest, GetFullVector) { + VectorType vty = VectorType::get({2, 4}, Builder().getI32Type()); + TypedValue vec = + getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1)); + + EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1})); +} + +TEST_F(VregUtilTest, GetFullLikeVector) { + VectorType vty = VectorType::get({2, 4}, Builder().getF32Type()); + TypedValue in_vec = Builder().create( + vty, Builder().create( + vty.getElementType(), Builder().getF32FloatAttr(1.0f))); + TypedValue vec = + getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f)); + + EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f})); +} + +TEST_F(VregUtilTest, GetZerosVector) { + VectorType vty = VectorType::get({2, 4}, Builder().getI32Type()); + TypedValue vec = getZerosVector(Builder(), vty); + + EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0})); +} + +TEST_F(VregUtilTest, GetZerosLikeVector) { + VectorType vty = VectorType::get({2, 4}, Builder().getF32Type()); + TypedValue in_vec = Builder().create( + vty, Builder().create( + vty.getElementType(), Builder().getF32FloatAttr(1.0f))); + TypedValue vec = getZerosLikeVector(Builder(), in_vec); + + EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f})); +} + +TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) { + constexpr std::array kTargetShape = {4, 8}; + FailureOr> vec = getX32VmaskByPaddingEnd( + Builder(), /*padding=*/1, /*target_shape=*/kTargetShape, + /*dim=*/0); + ASSERT_TRUE(succeeded(vec)); + + auto cmp_op = dyn_cast(vec.value().getDefiningOp()); + ASSERT_TRUE(cmp_op != nullptr); + EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt); + + auto iota_op = dyn_cast(cmp_op.getLhs().getDefiningOp()); + ASSERT_TRUE(iota_op != nullptr); + EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0))); + + EXPECT_THAT( + cmp_op.getRhs(), + IsConstantOpWithSplatValue( + VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3})); +} + +TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) { + constexpr std::array kTargetShape = {4, 8}; + FailureOr> vec = getX32VmaskByPaddingEnd( + Builder(), /*padding=*/3, /*target_shape=*/kTargetShape, + /*dim=*/1); + ASSERT_TRUE(succeeded(vec)); + + auto cmp_op = dyn_cast(vec.value().getDefiningOp()); + ASSERT_TRUE(cmp_op != nullptr); + EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt); + + auto iota_op = dyn_cast(cmp_op.getLhs().getDefiningOp()); + ASSERT_TRUE(iota_op != nullptr); + EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1))); + + EXPECT_THAT( + cmp_op.getRhs(), + IsConstantOpWithSplatValue( + VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5})); +} + +} // namespace + +} // namespace mlir::tpu