Skip to content

Commit 971b852

Browse files
authored
[mlir][NFC] Simplify type checks with isa predicates (llvm#87183)
For more context on isa predicates, see: llvm#83753.
1 parent a7206a6 commit 971b852

File tree

28 files changed

+83
-118
lines changed

28 files changed

+83
-118
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
545545
ConversionPatternRewriter &rewriter,
546546
const LLVMTypeConverter &converter) {
547547
TypeRange operandTypes(operands);
548-
if (llvm::none_of(operandTypes,
549-
[](Type type) { return isa<VectorType>(type); })) {
548+
if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
550549
return rewriter.notifyMatchFailure(op, "expected vector operand");
551550
}
552551
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,7 @@ template <typename ExtOpTy>
202202
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203203
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
204204
return false;
205-
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
206-
return isa<vector::ContractionOp>(user);
207-
});
205+
return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
208206
}
209207

210208
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
@@ -345,15 +343,13 @@ getSliceContract(Operation *op,
345343
static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
346344
bool useNvGpu) {
347345
auto hasVectorDest = [](Operation *op) {
348-
return llvm::any_of(op->getResultTypes(),
349-
[](Type t) { return isa<VectorType>(t); });
346+
return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
350347
};
351348
BackwardSliceOptions backwardSliceOptions;
352349
backwardSliceOptions.filter = hasVectorDest;
353350

354351
auto hasVectorSrc = [](Operation *op) {
355-
return llvm::any_of(op->getOperandTypes(),
356-
[](Type t) { return isa<VectorType>(t); });
352+
return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
357353
};
358354
ForwardSliceOptions forwardSliceOptions;
359355
forwardSliceOptions.filter = hasVectorSrc;

mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {
136136

137137
bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
138138
// Any memref-typed iteration arguments are treated as serializing.
139-
if (llvm::any_of(forOp.getResultTypes(),
140-
[](Type type) { return isa<BaseMemRefType>(type); }))
139+
if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>))
141140
return false;
142141

143142
// Collect all load and store ops in loop nest rooted at 'forOp'.

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
609609
}
610610

611611
static NestedPattern &vectorTransferPattern() {
612-
static auto pattern = affine::matcher::Op([](Operation &op) {
613-
return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
614-
});
612+
static auto pattern = affine::matcher::Op(
613+
llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>);
615614
return pattern;
616615
}
617616

mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
211211
unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
212212

213213
// Return common loop depth for loads if there are no store ops.
214-
if (all_of(targetDstOps,
215-
[&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
214+
if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
216215
return loopDepth;
217216

218217
// Check dependences on all pairs of ops in 'targetDstOps' and store the

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ struct FuncOpInterface
326326
static bool supportsUnstructuredControlFlow() { return true; }
327327

328328
bool hasTensorSemantics(Operation *op) const {
329-
auto isaTensor = [](Type type) { return isa<TensorType>(type); };
329+
auto isaTensor = llvm::IsaPred<TensorType>;
330330

331331
// A function has tensor semantics if it has tensor arguments/results.
332332
auto funcOp = cast<FuncOp>(op);

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
6868
#include "mlir/Dialect/Func/IR/FuncOps.h"
6969
#include "mlir/Dialect/MemRef/IR/MemRef.h"
70+
#include "mlir/IR/BuiltinTypes.h"
7071
#include "mlir/IR/Operation.h"
7172

7273
using namespace mlir;
@@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277278

278279
/// Return "true" if the given function signature has tensor semantics.
279280
static bool hasTensorSignature(func::FuncOp funcOp) {
280-
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
281-
return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
282-
llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
281+
return llvm::any_of(funcOp.getFunctionType().getInputs(),
282+
llvm::IsaPred<TensorType>) ||
283+
llvm::any_of(funcOp.getFunctionType().getResults(),
284+
llvm::IsaPred<TensorType>);
283285
}
284286

285287
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() {
224224
}
225225
}
226226

227-
if (llvm::any_of(getResultTypes(),
228-
[](Type type) { return isa<ArrayType>(type); })) {
227+
if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
229228
return emitOpError() << "cannot return array type";
230229
}
231230

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

+8-16
Original file line numberDiff line numberDiff line change
@@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
296296
"scf.forall op requires a mapping attribute");
297297
}
298298

299-
bool hasBlockMapping =
300-
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
301-
return isa<GPUBlockMappingAttr>(attr);
302-
});
303-
bool hasWarpgroupMapping =
304-
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
305-
return isa<GPUWarpgroupMappingAttr>(attr);
306-
});
307-
bool hasWarpMapping =
308-
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
309-
return isa<GPUWarpMappingAttr>(attr);
310-
});
311-
bool hasThreadMapping =
312-
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
313-
return isa<GPUThreadMappingAttr>(attr);
314-
});
299+
bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
300+
llvm::IsaPred<GPUBlockMappingAttr>);
301+
bool hasWarpgroupMapping = llvm::any_of(
302+
forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
303+
bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
304+
llvm::IsaPred<GPUWarpMappingAttr>);
305+
bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
306+
llvm::IsaPred<GPUThreadMappingAttr>);
315307
int64_t countMappingTypes = 0;
316308
countMappingTypes += hasBlockMapping ? 1 : 0;
317309
countMappingTypes += hasWarpgroupMapping ? 1 : 0;

mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
232232
// control flow code.
233233
static bool areAllUsersExecuteOrAwait(Value token) {
234234
return !token.use_empty() &&
235-
llvm::all_of(token.getUsers(), [](Operation *user) {
236-
return isa<async::ExecuteOp, async::AwaitOp>(user);
237-
});
235+
llvm::all_of(token.getUsers(),
236+
llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
238237
}
239238

240239
// Add the `asyncToken` as dependency as needed after `op`.

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() {
27862786
if (!resultType)
27872787
return success();
27882788

2789-
auto isVector = [](Type type) {
2790-
return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
2791-
type);
2792-
};
2789+
auto isVector =
2790+
llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
27932791

27942792
// Due to bitcast requiring both operands to be of the same size, it is not
27952793
// possible for only one of the two to be a pointer of vectors.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/IR/AffineExprVisitor.h"
2929
#include "mlir/IR/AffineMap.h"
3030
#include "mlir/IR/BuiltinAttributes.h"
31+
#include "mlir/IR/BuiltinTypeInterfaces.h"
3132
#include "mlir/IR/Matchers.h"
3233
#include "mlir/IR/OpImplementation.h"
3334
#include "mlir/IR/OperationSupport.h"
@@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
119120
TypeRange inputTypes, TypeRange outputTypes,
120121
ArrayRef<NamedAttribute> attrs,
121122
RegionBuilderFn regionBuilder) {
122-
assert(llvm::all_of(outputTypes,
123-
[](Type t) { return llvm::isa<ShapedType>(t); }));
123+
assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
124124

125125
SmallVector<Type, 8> argTypes;
126126
SmallVector<Location, 8> argLocs;
@@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
162162
resultTensorTypes.value_or(TypeRange());
163163
if (!resultTensorTypes)
164164
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
165-
[](Type type) { return llvm::isa<RankedTensorType>(type); });
165+
llvm::IsaPred<RankedTensorType>);
166166

167167
state.addOperands(inputs);
168168
state.addOperands(outputs);

mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
2727

2828
// TODO: The conversion pattern can be made to work for `any_of` here, but
2929
// it's more complex as it requires tracking which operands are scalars.
30-
return llvm::all_of(op->getOperandTypes(),
31-
[](Type type) { return isa<RankedTensorType>(type); });
30+
return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
3231
}
3332

3433
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -3537,15 +3537,14 @@ struct Conv1DGenerator
35373537
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
35383538
// must be block arguments or extension of block arguments.
35393539
bool setOperKind(Operation *reduceOp) {
3540-
int numBlockArguments = llvm::count_if(
3541-
reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
3540+
int numBlockArguments =
3541+
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
35423542
switch (numBlockArguments) {
35433543
case 1: {
35443544
// Will be convolution if feeder is a MulOp.
35453545
// Otherwise, if it can be pooling.
3546-
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
3547-
return !isa<BlockArgument>(v);
3548-
});
3546+
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3547+
llvm::IsaPred<BlockArgument>);
35493548
Operation *feedOp = (*feedValIt).getDefiningOp();
35503549
if (isCastOfBlockArgument(feedOp)) {
35513550
oper = Pool;

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
457457
}
458458

459459
static bool isComputeOperation(Operation *op) {
460-
return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
460+
return isa<acc::ParallelOp, acc::LoopOp>(op);
461461
}
462462

463463
namespace {

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
125125
if (getMatrixOperands()) {
126126
Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
127127
typeC.getElementType()};
128-
if (!llvm::all_of(elementTypes,
129-
[](Type ty) { return isa<IntegerType>(ty); })) {
128+
if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
130129
return emitOpError("Matrix Operands require all matrix element types to "
131130
"be Integer Types");
132131
}

mlir/lib/Dialect/Shape/IR/Shape.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input,
6565
}
6666

6767
static bool isErrorPropagationPossible(TypeRange operandTypes) {
68-
return llvm::any_of(operandTypes, [](Type ty) {
69-
return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
70-
});
68+
return llvm::any_of(operandTypes,
69+
llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
7170
}
7271

7372
static LogicalResult verifySizeOrIndexOp(Operation *op) {

mlir/lib/Dialect/Traits.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
188188
/// Returns a tuple corresponding to whether range has tensor or vector type.
189189
template <typename iterator_range>
190190
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
191-
return std::make_tuple(
192-
llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
193-
llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
191+
return {llvm::any_of(types, llvm::IsaPred<TensorType>),
192+
llvm::any_of(types, llvm::IsaPred<VectorType>)};
194193
}
195194

196195
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
@@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
202201
};
203202
if (inferred.size() != existing.size())
204203
return false;
205-
for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
204+
for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
206205
if (!isCompatible(inferredDim, existingDim))
207206
return false;
208207
return true;
@@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
238237
std::get<1>(resultsHasTensorVectorType)))
239238
return op->emitError("cannot broadcast vector with tensor");
240239

241-
auto rankedOperands = make_filter_range(
242-
op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
240+
auto rankedOperands =
241+
make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
243242

244243
// If all operands are unranked, then all result shapes are possible.
245244
if (rankedOperands.empty())
@@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
257256
return op->emitOpError("operands don't have broadcast-compatible shapes");
258257
}
259258

260-
auto rankedResults = make_filter_range(
261-
op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
259+
auto rankedResults =
260+
make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
262261

263262
// If all of the results are unranked then no further verification.
264263
if (rankedResults.empty())

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
819819
assert(outputs.size() == 1 && "expected one output");
820820
return llvm::all_of(
821821
std::initializer_list<Type>{inputs.front(), outputs.front()},
822-
[](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
822+
llvm::IsaPred<transform::TransformHandleTypeInterface>);
823823
}
824824

825825
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape(
898898

899899
AffineMap resMap = op.getIndexingMapsArray()[2];
900900
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
901-
/*symCount=*/0, extents, ctx);
901+
/*symbolCount=*/0, extents, ctx);
902902
// Compose the resMap with the extentsMap, which is a constant map.
903903
AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
904-
assert(
905-
llvm::all_of(expectedMap.getResults(),
906-
[](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
907-
"expected constant extent along all dimensions.");
904+
assert(llvm::all_of(expectedMap.getResults(),
905+
llvm::IsaPred<AffineConstantExpr>) &&
906+
"expected constant extent along all dimensions.");
908907
// Extract the expected shape and build the type.
909908
auto expectedShape = llvm::to_vector<4>(
910909
llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {

0 commit comments

Comments
 (0)