@@ -209,23 +209,18 @@ bool incrementIndex(const MutableArrayRef<int64_t> idx,
209
209
return false ;
210
210
}
211
211
212
- FailureOr<int64_t > getIntConst (Value v, bool silent = false ) {
213
- if (auto constant_op = v.getDefiningOp <arith::ConstantOp>()) {
214
- if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue ())) {
215
- return integer_attr.getValue ().getSExtValue ();
216
- }
217
- }
218
- if (silent) {
219
- return failure ();
212
+ FailureOr<int64_t > expectIntConst (Value v) {
213
+ if (auto cst = getIntConst (v)) {
214
+ return cst.value ();
220
215
}
221
216
return emitError (v.getLoc (), " Expected an integer constant" );
222
217
}
223
218
224
- FailureOr<SmallVector<int64_t >> getIntConstsFromOperandRange (
225
- ValueRange vals, bool silent = false ) {
219
+ FailureOr<SmallVector<int64_t >> expectIntConstsFromOperandRange (
220
+ ValueRange vals) {
226
221
SmallVector<int64_t > res (vals.size ());
227
222
for (int i = 0 ; i < vals.size (); ++i) {
228
- FAILUREOR_ASSIGN_OR_RETURN (res[i], getIntConst (vals[i], silent ));
223
+ FAILUREOR_ASSIGN_OR_RETURN (res[i], expectIntConst (vals[i]));
229
224
}
230
225
return res;
231
226
}
@@ -265,7 +260,7 @@ FailureOr<std::pair<Value, SmallVector<int64_t>>> sliceRef(
265
260
Value c0 = nullptr ;
266
261
SmallVector<int64_t > indices_within_slice (indices.size () - tiling.size (), 0 );
267
262
for (auto tiled_idx : indices.take_back (tiling.size ())) {
268
- if (auto cst = getIntConst (tiled_idx, /* silent= */ true ); succeeded (cst )) {
263
+ if (auto cst = getIntConst (tiled_idx)) {
269
264
indices_within_slice.push_back (*cst);
270
265
if (!c0) {
271
266
c0 = builder.create <arith::ConstantOp>(i32,
@@ -1548,7 +1543,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
1548
1543
}
1549
1544
FAILUREOR_ASSIGN_OR_RETURN (
1550
1545
const SmallVector<int64_t > indices,
1551
- getIntConstsFromOperandRange (load_op.getIndices ()));
1546
+ expectIntConstsFromOperandRange (load_op.getIndices ()));
1552
1547
TPU_ASSERT_EQ_OP (indices.size (), 2 );
1553
1548
if (indices[1 ] % ctx.target_shape [1 ] != 0 ) {
1554
1549
return op.emitOpError (" Not implemented: Lane index is not a multiple of " )
@@ -1606,8 +1601,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
1606
1601
if (strides[rank - 1 ] != 1 ) {
1607
1602
return op.emitOpError (" Not Implemented: Stride on last dim is not 1" );
1608
1603
}
1609
- auto last_idx = getIntConst (indices[rank - 1 ], /* silent= */ true );
1610
- if (failed ( last_idx)) {
1604
+ auto last_idx = getIntConst (indices[rank - 1 ]);
1605
+ if (! last_idx. has_value ( )) {
1611
1606
return op.emitOpError (" Not Implemented: Dynamic index on last dim" );
1612
1607
} else if (last_idx.value () != 0 ) {
1613
1608
return op.emitOpError (" Not Implemented: Index on last dim is not 0" );
@@ -1975,7 +1970,7 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
1975
1970
tpu::StoreOp store_op = cast<tpu::StoreOp>(op);
1976
1971
FAILUREOR_ASSIGN_OR_RETURN (
1977
1972
const SmallVector<int64_t > indices,
1978
- getIntConstsFromOperandRange (store_op.getIndices ()));
1973
+ expectIntConstsFromOperandRange (store_op.getIndices ()));
1979
1974
TPU_ASSERT_EQ_OP (indices.size (), 2 );
1980
1975
if (indices[1 ] % ctx.target_shape [1 ] != 0 ) {
1981
1976
return op.emitOpError (" Not implemented: Lane index is not a multiple of " )
@@ -2143,15 +2138,14 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2143
2138
return op.emitOpError (" Not implemented: unsupported layout for input" );
2144
2139
}
2145
2140
LayoutOffsets expected_offsets_out = layout_in.offsets ();
2146
- auto shift = getIntConst (amount, /* silent=*/ true );
2147
- const bool has_static_shift = succeeded (shift);
2141
+ auto shift = getIntConst (amount);
2148
2142
int rotated_tiled_dim = op.getDimension () - (op.getType ().getRank () - 2 );
2149
2143
bool has_padding_along_rotation =
2150
2144
(rotated_tiled_dim == 0 || rotated_tiled_dim == 1 ) &&
2151
2145
op.getType ().getShape ()[op.getDimension ()] %
2152
2146
layout.tiling ()[rotated_tiled_dim] !=
2153
2147
0 ;
2154
- if (has_static_shift && has_padding_along_rotation) {
2148
+ if (shift. has_value () && has_padding_along_rotation) {
2155
2149
// We checked above that there are no implicit dims.
2156
2150
const int64_t dim_size = op.getType ().getShape ()[op.getDimension ()];
2157
2151
// TODO(b/337384645): Currently we assume {0, 0} offsets in the input
@@ -2173,7 +2167,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2173
2167
// TODO(b/411170715): Allow sublane rotation once the bug is fixed.
2174
2168
// TODO(b/337384645): Support non-zero stride.
2175
2169
if (has_padding_along_rotation &&
2176
- (!has_static_shift ||
2170
+ (!shift. has_value () ||
2177
2171
(rotated_tiled_dim == 0 ||
2178
2172
(rotated_tiled_dim == 1 && op.getStride ().value_or (0 ) != 0 )))) {
2179
2173
return op.emitOpError (" Not implemented: unsupported unaligned shape" );
@@ -2200,19 +2194,19 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2200
2194
builder.getIntegerAttr (builder.getIndexType (), d));
2201
2195
};
2202
2196
auto modI = [&](const Value &v, unsigned d) -> Value {
2203
- if (auto cst = getIntConst (v, /* silent= */ true ); succeeded (cst )) {
2197
+ if (auto cst = getIntConst (v)) {
2204
2198
return mlirI32Const (cst.value () % d);
2205
2199
}
2206
2200
return builder.create <arith::RemUIOp>(v, mlirI32Const (d));
2207
2201
};
2208
2202
auto divI = [&](const Value &v, unsigned d) -> Value {
2209
- if (auto cst = getIntConst (v, /* silent= */ true ); succeeded (cst )) {
2203
+ if (auto cst = getIntConst (v)) {
2210
2204
return mlirI32Const (cst.value () / d);
2211
2205
}
2212
2206
return builder.create <arith::DivUIOp>(v, mlirI32Const (d));
2213
2207
};
2214
2208
auto addI = [&](const Value &v, unsigned d) -> Value {
2215
- if (auto cst = getIntConst (v, /* silent= */ true ); succeeded (cst )) {
2209
+ if (auto cst = getIntConst (v)) {
2216
2210
return mlirI32Const (cst.value () + d);
2217
2211
}
2218
2212
return builder.create <arith::AddIOp>(v, mlirI32Const (d));
@@ -2239,8 +2233,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2239
2233
auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0 ) {
2240
2234
CHECK (dim == 0 || dim == 1 );
2241
2235
Value padding_vreg;
2242
- if (auto padding_cst = getIntConst (padding, /* silent=*/ true );
2243
- succeeded (padding_cst)) {
2236
+ if (auto padding_cst = getIntConst (padding)) {
2244
2237
CHECK_GE (padding_cst.value (), 0 );
2245
2238
CHECK_LE (padding_cst.value (), ctx.target_shape [dim]);
2246
2239
padding_vreg = builder.create <arith::ConstantOp>(DenseElementsAttr::get (
@@ -2269,8 +2262,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2269
2262
// and blend the data from contiguous vregs to emulate circular rotation.
2270
2263
auto rotateOnTilingDim = [&](const xla::Array<Value> &vregs,
2271
2264
const Value &shift, int axis, int stride = 0 ) {
2272
- if (auto shift_cst = getIntConst (shift, /* silent=*/ true );
2273
- succeeded (shift_cst)) {
2265
+ if (auto shift_cst = getIntConst (shift)) {
2274
2266
if (shift_cst.value () == 0 && stride == 0 ) {
2275
2267
return vregs;
2276
2268
}
@@ -2395,8 +2387,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2395
2387
CHECK ((tiling_dim != 1 && stride == 0 ) || (tiling_dim == 1 && stride >= 0 ));
2396
2388
SmallVector<xla::Array<Value>, 4 > chunks;
2397
2389
// Handle rotation with static shift.
2398
- if (auto shift_cst = getIntConst (shift, /* silent=*/ true );
2399
- succeeded (shift_cst)) {
2390
+ if (auto shift_cst = getIntConst (shift)) {
2400
2391
int64_t static_shift = shift_cst.value ();
2401
2392
if (has_padding_along_rotation) {
2402
2393
return lazyRotate (vregs, static_shift, axis);
@@ -2519,8 +2510,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2519
2510
vty.getDimSize (dim));
2520
2511
// After applying stride, we expect all shifts in a vreg are less or
2521
2512
// equal to the vreg's lane count for now.
2522
- if (auto base_amount_cst = getIntConst (base_amount, /* silent=*/ true );
2523
- succeeded (base_amount_cst)) {
2513
+ if (auto base_amount_cst = getIntConst (base_amount)) {
2524
2514
int64_t static_base_amount = base_amount_cst.value ();
2525
2515
auto max_shift_in_vreg = static_base_amount % ctx.target_shape [1 ] +
2526
2516
(ctx.target_shape [0 ] - 1 ) * stride;
@@ -3163,7 +3153,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
3163
3153
bool must_support_unaligned_dynamic_index = false ;
3164
3154
if (load_op.getIndices ().size () > 1 ) {
3165
3155
auto second_minor_idx = load_op.getIndices ().take_back (2 )[0 ];
3166
- if (failed ( getIntConst (second_minor_idx, /* silent= */ true ) ) &&
3156
+ if (! getIntConst (second_minor_idx). has_value ( ) &&
3167
3157
!isGuaranteedDivisible (second_minor_idx, memref_tiling[0 ])) {
3168
3158
must_support_unaligned_dynamic_index = true ;
3169
3159
}
@@ -3196,7 +3186,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
3196
3186
}
3197
3187
3198
3188
auto add_idx = [&](const Value &v, int64_t d) -> Value {
3199
- if (auto cst = getIntConst (v, /* silent= */ true ); succeeded (cst )) {
3189
+ if (auto cst = getIntConst (v)) {
3200
3190
return IdxConst (cst.value () + d, builder, op.getLoc ());
3201
3191
}
3202
3192
return builder.create <arith::AddIOp>(v, IdxConst (d, builder, op.getLoc ()));
@@ -4476,7 +4466,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op,
4476
4466
bool must_support_unaligned_dynamic_index = false ;
4477
4467
if (store_op.getIndices ().size () > 1 ) {
4478
4468
auto second_minor_idx = store_op.getIndices ().take_back (2 )[0 ];
4479
- if (failed ( getIntConst (second_minor_idx, /* silent= */ true ) ) &&
4469
+ if (! getIntConst (second_minor_idx). has_value ( ) &&
4480
4470
!isGuaranteedDivisible (second_minor_idx, memref_tiling[0 ])) {
4481
4471
must_support_unaligned_dynamic_index = true ;
4482
4472
}
@@ -4507,7 +4497,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op,
4507
4497
}
4508
4498
4509
4499
auto add_idx = [&](const Value &v, int64_t d) -> Value {
4510
- if (auto cst = getIntConst (v, /* silent= */ true ); succeeded (cst )) {
4500
+ if (auto cst = getIntConst (v)) {
4511
4501
return IdxConst (cst.value () + d, builder, op.getLoc ());
4512
4502
}
4513
4503
return builder.create <arith::AddIOp>(v, IdxConst (d, builder, op.getLoc ()));
0 commit comments