Skip to content

Commit

Permalink
[Mosaic] NFC: Pull out vreg related functions to util.
Browse files Browse the repository at this point in the history
These functions are related to vreg manipulation and are used in different rules.

PiperOrigin-RevId: 711484002
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Jan 2, 2025
1 parent df36c29 commit 57b2154
Show file tree
Hide file tree
Showing 5 changed files with 567 additions and 184 deletions.
15 changes: 15 additions & 0 deletions jaxlib/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
"dialect/tpu/util.cc",
"dialect/tpu/vreg_util.cc",
":extension_srcs",
] + glob([
"dialect/tpu/transforms/*.cc",
Expand All @@ -51,6 +52,7 @@ cc_library(
"dialect/tpu/layout.h",
"dialect/tpu/tpu_dialect.h",
"dialect/tpu/util.h",
"dialect/tpu/vreg_util.h",
] + glob([
"dialect/tpu/transforms/*.h",
]),
Expand Down Expand Up @@ -232,6 +234,19 @@ cc_library(
alwayslink = True,
)

cc_test(
name = "vreg_util_test",
srcs = ["dialect/tpu/vreg_util_test.cc"],
deps = [
":tpu_dialect",
"//testing/base/public:gunit_main",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:VectorDialect",
],
)

filegroup(
name = "extension_srcs",
srcs = [
Expand Down
220 changes: 36 additions & 184 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -52,6 +51,7 @@
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/ADT/APInt.h"
#include "llvm/include/llvm/Support/LogicalResult.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
Expand All @@ -64,6 +64,7 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/util.h"
Expand Down Expand Up @@ -275,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
CHECK(data_it == data.end());
}

FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
}
if (isa<IntegerType>(ty)) {
return TypedAttr(IntegerAttr::get(ty, 0));
}
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
}

FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
Expand Down Expand Up @@ -479,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
return argument;
}

VectorType getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}

VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
const std::array<int64_t, 2> target_shape) {
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
if (bitwidth == 1) {
bitwidth = layout_bitwidth;
} else {
CHECK_EQ(bitwidth, layout_bitwidth);
}
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
}

VectorType getNativeVregType(Type elem_ty,
const std::array<int64_t, 2> target_shape) {
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
target_shape);
}

// Masks all values outside of bounds.
//
// Arguments:
Expand All @@ -518,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
// Returns:
// An MLIR value of the same type as the value argument, with all entries
// outside of bounds replaced by neutral.
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
FailureOr<Value> maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder,
TypedValue<VectorType> value,
const VRegDataBounds &bounds,
const Attribute neutral) {
Expand All @@ -542,9 +506,7 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
value.getLoc(),
VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask);
}
auto neutral_vec = builder.create<arith::ConstantOp>(
value.getLoc(), native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty, neutral));
Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral);
return builder
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
.getResult();
Expand Down Expand Up @@ -1863,126 +1825,28 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);

const VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), ctx.target_shape);
auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
CHECK(dim == 0 || dim == 1);
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
return cast<TypedValue<VectorType>>(
builder
.create<arith::CmpIOp>(
arith::CmpIPredicate::slt,
builder.create<tpu::IotaOp>(i32_vreg_ty,
builder.getI32IntegerAttr(dim)),
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
i32_vreg_ty, builder.getI32IntegerAttr(
ctx.target_shape[dim] - padding))))
.getResult());
};

// We can also extend this helper function with padding_top and padding_left
// based on the offsets in vregs.
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
op.getLoc(), DenseElementsAttr::get(
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
int64_t padding_right) {
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
// Mask out the bottom.
if (padding_bottom > 0) {
// We have limited the row size of LHS and RHS need to be a multiple of
// native tiling at the beginning of this rule. Therefore, it is safe to
// bitcast to x32 vreg for masking.
int sub_padding = padding_bottom % packing;
int x32_padding_bottom = padding_bottom / packing;
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
// Create an int32 vreg which contains subelement masking and then
// logical_and with target vreg to mask out the unaligned paddings.
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
// [8, 128], then the mask will be:
//
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
// sublane 6: [0 , 0 , ..., 0 ]
// sublane 7: [0 , 0 , ..., 0 ]
//
// Through this way, in order to mask sub-elements, each target vreg only
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
// + packing).
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
i32_vreg_ty,
builder.getI32IntegerAttr(
0xffffffff >>
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
// Insert 0xffffffff above the blended sublane.
Value sublane_mask = builder.create<arith::SelectOp>(
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
partial_sublane_mask);
// Insert 0 below the blended sublane.
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
i32_zeros_vreg);
for (int64_t i = 0; i < vregs.dim(1); ++i) {
Value &vreg = vregs({vregs.dim(0) - 1, i});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
if (sub_padding > 0) {
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
} else {
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
i32_zeros_vreg);
}
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
// Mask out the right.
if (padding_right > 0) {
auto mask_right = getX32VmaskByPaddingEnd(1, padding_right);
for (int64_t i = 0; i < vregs.dim(0); ++i) {
Value &vreg = vregs({i, vregs.dim(1) - 1});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
i32_vreg = builder.create<arith::SelectOp>(mask_right, i32_vreg,
i32_zeros_vreg);
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
};

// Create a vreg filled with zeros.
auto getZerosVergLike =
[&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
const VectorType vreg_type = cast<VectorType>(vreg.getType());
FAILUREOR_ASSIGN_OR_RETURN(
const Attribute zero_attr,
getZeroIntOrFloatAttr(vreg_type.getElementType()));
return cast<TypedValue<VectorType>>(
builder
.create<arith::ConstantOp>(
op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr))
.getResult());
};

FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg,
getZerosVergLike(*lhs_vregs.begin()));
FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg,
getZerosVergLike(*rhs_vregs.begin()));
FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg,
getZerosVergLike(*acc_vregs.begin()));
auto lhs_zeros_vreg =
getZerosVector(builder, cast<VectorType>(lhs_vregs.begin()->getType()));
auto rhs_zeros_vreg =
getZerosVector(builder, cast<VectorType>(rhs_vregs.begin()->getType()));
auto acc_zeros_vreg =
getZerosVector(builder, cast<VectorType>(acc_vregs.begin()->getType()));

// Only mask out the paddings on contracting dim of LHS and RHS.
maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]);
RETURN_IF_FAILED(
maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape,
/*padding_bottom=*/0,
/*padding_right=*/padded_lhs_cols - lhs_shape[1]));
if (transpose_rhs) {
maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]);
RETURN_IF_FAILED(maskNativeTilingVregs(
builder, rhs_vregs, ctx.target_shape,
/*padding_bottom=*/0,
/*padding_right=*/padded_rhs_cols - rhs_shape[1]));
} else {
maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0);
RETURN_IF_FAILED(
maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape,
/*padding_bottom=*/padded_rhs_rows - rhs_shape[0],
/*padding_right=*/0));
}

// At this point, all paddings on vregs are masked out. For now, we
Expand Down Expand Up @@ -2875,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
native_vreg_ty,
/*dimension =*/builder.getI32IntegerAttr(1));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1))));
Value offset = getFullVector(
builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1)));
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
Expand All @@ -2902,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
native_vreg_ty,
/*dimension =*/builder.getI32IntegerAttr(0));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2))));
Value offset = getFullVector(
builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2)));
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
Expand All @@ -2924,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
SmallVector<Value> tiles;
tiles.reserve(vty.getDimSize(*dimension));
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
tiles.push_back(builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i))));
tiles.push_back(getFullVector(builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i)));
}
xla::Array<Value> out_tiles(tile_array_shape);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
Expand Down Expand Up @@ -3516,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const int64_t offset = *offsets_in[1];
const int64_t lane_offset = offset % ctx.target_shape[1];
const int64_t tile_offset = offset / ctx.target_shape[1];
const auto idx_ty =
VectorType::get(ctx.target_shape, builder.getI32Type());
auto lane_offset_cst = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), idx_ty,
DenseElementsAttr::get(idx_ty,
builder.getI32IntegerAttr(lane_offset)));
Value lane_offset_cst = getFullVector(
builder, getNativeVregType(builder.getI32Type(), ctx.target_shape),
builder.getI32IntegerAttr(lane_offset));
DenseI32ArrayAttr sublane_pattern;
if (num_tiles != 1) {
SmallVector<int32_t> pattern;
Expand Down Expand Up @@ -3581,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
getNativeVregType(src_i32.getType(), ctx.target_shape);
auto tile_i32 =
builder.create<vector::BroadcastOp>(native_vreg_ty, src_i32);
auto zeros = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), tile_i32.getType(),
DenseElementsAttr::get(tile_i32.getType(),
builder.getI32IntegerAttr(0)));
Value zeros = getZerosVector(builder, tile_i32.getType());
auto tile =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
.getResult();
Expand Down
Loading

0 comments on commit 57b2154

Please sign in to comment.