Skip to content

Commit d47fc8d

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU][NFC] Consolidate getIntConst.
PiperOrigin-RevId: 759728429
1 parent 5cbd2d9 commit d47fc8d

File tree

2 files changed

+25
-36
lines changed

2 files changed

+25
-36
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,18 @@ bool incrementIndex(const MutableArrayRef<int64_t> idx,
209209
return false;
210210
}
211211

212-
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
213-
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
214-
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
215-
return integer_attr.getValue().getSExtValue();
216-
}
217-
}
218-
if (silent) {
219-
return failure();
212+
FailureOr<int64_t> expectIntConst(Value v) {
213+
if (auto cst = getIntConst(v)) {
214+
return cst.value();
220215
}
221216
return emitError(v.getLoc(), "Expected an integer constant");
222217
}
223218

224-
FailureOr<SmallVector<int64_t>> getIntConstsFromOperandRange(
225-
ValueRange vals, bool silent = false) {
219+
FailureOr<SmallVector<int64_t>> expectIntConstsFromOperandRange(
220+
ValueRange vals) {
226221
SmallVector<int64_t> res(vals.size());
227222
for (int i = 0; i < vals.size(); ++i) {
228-
FAILUREOR_ASSIGN_OR_RETURN(res[i], getIntConst(vals[i], silent));
223+
FAILUREOR_ASSIGN_OR_RETURN(res[i], expectIntConst(vals[i]));
229224
}
230225
return res;
231226
}
@@ -265,7 +260,7 @@ FailureOr<std::pair<Value, SmallVector<int64_t>>> sliceRef(
265260
Value c0 = nullptr;
266261
SmallVector<int64_t> indices_within_slice(indices.size() - tiling.size(), 0);
267262
for (auto tiled_idx : indices.take_back(tiling.size())) {
268-
if (auto cst = getIntConst(tiled_idx, /*silent=*/true); succeeded(cst)) {
263+
if (auto cst = getIntConst(tiled_idx)) {
269264
indices_within_slice.push_back(*cst);
270265
if (!c0) {
271266
c0 = builder.create<arith::ConstantOp>(i32,
@@ -1548,7 +1543,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
15481543
}
15491544
FAILUREOR_ASSIGN_OR_RETURN(
15501545
const SmallVector<int64_t> indices,
1551-
getIntConstsFromOperandRange(load_op.getIndices()));
1546+
expectIntConstsFromOperandRange(load_op.getIndices()));
15521547
TPU_ASSERT_EQ_OP(indices.size(), 2);
15531548
if (indices[1] % ctx.target_shape[1] != 0) {
15541549
return op.emitOpError("Not implemented: Lane index is not a multiple of ")
@@ -1606,8 +1601,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
16061601
if (strides[rank - 1] != 1) {
16071602
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
16081603
}
1609-
auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true);
1610-
if (failed(last_idx)) {
1604+
auto last_idx = getIntConst(indices[rank - 1]);
1605+
if (!last_idx.has_value()) {
16111606
return op.emitOpError("Not Implemented: Dynamic index on last dim");
16121607
} else if (last_idx.value() != 0) {
16131608
return op.emitOpError("Not Implemented: Index on last dim is not 0");
@@ -1975,7 +1970,7 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
19751970
tpu::StoreOp store_op = cast<tpu::StoreOp>(op);
19761971
FAILUREOR_ASSIGN_OR_RETURN(
19771972
const SmallVector<int64_t> indices,
1978-
getIntConstsFromOperandRange(store_op.getIndices()));
1973+
expectIntConstsFromOperandRange(store_op.getIndices()));
19791974
TPU_ASSERT_EQ_OP(indices.size(), 2);
19801975
if (indices[1] % ctx.target_shape[1] != 0) {
19811976
return op.emitOpError("Not implemented: Lane index is not a multiple of ")
@@ -2143,15 +2138,14 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
21432138
return op.emitOpError("Not implemented: unsupported layout for input");
21442139
}
21452140
LayoutOffsets expected_offsets_out = layout_in.offsets();
2146-
auto shift = getIntConst(amount, /*silent=*/true);
2147-
const bool has_static_shift = succeeded(shift);
2141+
auto shift = getIntConst(amount);
21482142
int rotated_tiled_dim = op.getDimension() - (op.getType().getRank() - 2);
21492143
bool has_padding_along_rotation =
21502144
(rotated_tiled_dim == 0 || rotated_tiled_dim == 1) &&
21512145
op.getType().getShape()[op.getDimension()] %
21522146
layout.tiling()[rotated_tiled_dim] !=
21532147
0;
2154-
if (has_static_shift && has_padding_along_rotation) {
2148+
if (shift.has_value() && has_padding_along_rotation) {
21552149
// We checked above that there are no implicit dims.
21562150
const int64_t dim_size = op.getType().getShape()[op.getDimension()];
21572151
// TODO(b/337384645): Currently we assume {0, 0} offsets in the input
@@ -2173,7 +2167,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
21732167
// TODO(b/411170715): Allow sublane rotation once the bug is fixed.
21742168
// TODO(b/337384645): Support non-zero stride.
21752169
if (has_padding_along_rotation &&
2176-
(!has_static_shift ||
2170+
(!shift.has_value() ||
21772171
(rotated_tiled_dim == 0 ||
21782172
(rotated_tiled_dim == 1 && op.getStride().value_or(0) != 0)))) {
21792173
return op.emitOpError("Not implemented: unsupported unaligned shape");
@@ -2200,19 +2194,19 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
22002194
builder.getIntegerAttr(builder.getIndexType(), d));
22012195
};
22022196
auto modI = [&](const Value &v, unsigned d) -> Value {
2203-
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
2197+
if (auto cst = getIntConst(v)) {
22042198
return mlirI32Const(cst.value() % d);
22052199
}
22062200
return builder.create<arith::RemUIOp>(v, mlirI32Const(d));
22072201
};
22082202
auto divI = [&](const Value &v, unsigned d) -> Value {
2209-
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
2203+
if (auto cst = getIntConst(v)) {
22102204
return mlirI32Const(cst.value() / d);
22112205
}
22122206
return builder.create<arith::DivUIOp>(v, mlirI32Const(d));
22132207
};
22142208
auto addI = [&](const Value &v, unsigned d) -> Value {
2215-
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
2209+
if (auto cst = getIntConst(v)) {
22162210
return mlirI32Const(cst.value() + d);
22172211
}
22182212
return builder.create<arith::AddIOp>(v, mlirI32Const(d));
@@ -2239,8 +2233,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
22392233
auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0) {
22402234
CHECK(dim == 0 || dim == 1);
22412235
Value padding_vreg;
2242-
if (auto padding_cst = getIntConst(padding, /*silent=*/true);
2243-
succeeded(padding_cst)) {
2236+
if (auto padding_cst = getIntConst(padding)) {
22442237
CHECK_GE(padding_cst.value(), 0);
22452238
CHECK_LE(padding_cst.value(), ctx.target_shape[dim]);
22462239
padding_vreg = builder.create<arith::ConstantOp>(DenseElementsAttr::get(
@@ -2269,8 +2262,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
22692262
// and blend the data from contiguous vregs to emulate circular rotation.
22702263
auto rotateOnTilingDim = [&](const xla::Array<Value> &vregs,
22712264
const Value &shift, int axis, int stride = 0) {
2272-
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
2273-
succeeded(shift_cst)) {
2265+
if (auto shift_cst = getIntConst(shift)) {
22742266
if (shift_cst.value() == 0 && stride == 0) {
22752267
return vregs;
22762268
}
@@ -2395,8 +2387,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
23952387
CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0));
23962388
SmallVector<xla::Array<Value>, 4> chunks;
23972389
// Handle rotation with static shift.
2398-
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
2399-
succeeded(shift_cst)) {
2390+
if (auto shift_cst = getIntConst(shift)) {
24002391
int64_t static_shift = shift_cst.value();
24012392
if (has_padding_along_rotation) {
24022393
return lazyRotate(vregs, static_shift, axis);
@@ -2519,8 +2510,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
25192510
vty.getDimSize(dim));
25202511
// After applying stride, we expect all shifts in a vreg are less or
25212512
// equal to the vreg's lane count for now.
2522-
if (auto base_amount_cst = getIntConst(base_amount, /*silent=*/true);
2523-
succeeded(base_amount_cst)) {
2513+
if (auto base_amount_cst = getIntConst(base_amount)) {
25242514
int64_t static_base_amount = base_amount_cst.value();
25252515
auto max_shift_in_vreg = static_base_amount % ctx.target_shape[1] +
25262516
(ctx.target_shape[0] - 1) * stride;
@@ -3163,7 +3153,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
31633153
bool must_support_unaligned_dynamic_index = false;
31643154
if (load_op.getIndices().size() > 1) {
31653155
auto second_minor_idx = load_op.getIndices().take_back(2)[0];
3166-
if (failed(getIntConst(second_minor_idx, /*silent=*/true)) &&
3156+
if (!getIntConst(second_minor_idx).has_value() &&
31673157
!isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) {
31683158
must_support_unaligned_dynamic_index = true;
31693159
}
@@ -3196,7 +3186,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
31963186
}
31973187

31983188
auto add_idx = [&](const Value &v, int64_t d) -> Value {
3199-
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
3189+
if (auto cst = getIntConst(v)) {
32003190
return IdxConst(cst.value() + d, builder, op.getLoc());
32013191
}
32023192
return builder.create<arith::AddIOp>(v, IdxConst(d, builder, op.getLoc()));
@@ -4476,7 +4466,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op,
44764466
bool must_support_unaligned_dynamic_index = false;
44774467
if (store_op.getIndices().size() > 1) {
44784468
auto second_minor_idx = store_op.getIndices().take_back(2)[0];
4479-
if (failed(getIntConst(second_minor_idx, /*silent=*/true)) &&
4469+
if (!getIntConst(second_minor_idx).has_value() &&
44804470
!isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) {
44814471
must_support_unaligned_dynamic_index = true;
44824472
}
@@ -4507,7 +4497,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op,
45074497
}
45084498

45094499
auto add_idx = [&](const Value &v, int64_t d) -> Value {
4510-
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
4500+
if (auto cst = getIntConst(v)) {
45114501
return IdxConst(cst.value() + d, builder, op.getLoc());
45124502
}
45134503
return builder.create<arith::AddIOp>(v, IdxConst(d, builder, op.getLoc()));

jaxlib/mosaic/dialect/tpu/util.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ inline arith::ConstantOp I32Const(int32_t value, ArrayRef<int64_t> shape,
283283
builder.getIntegerAttr(builder.getI32Type(), value)));
284284
}
285285

286-
// TODO(jevinjiang): consolidate this with getIntConst in apply-vector-layout.
287286
std::optional<int64_t> getIntConst(Value v);
288287
} // namespace mlir::tpu
289288

0 commit comments

Comments
 (0)