Skip to content

Commit

Permalink
[mlir][NFC] Simplify type checks with isa predicates (llvm#87183)
Browse files Browse the repository at this point in the history
For more context on isa predicates, see:
llvm#83753.
  • Loading branch information
kuhar authored Apr 1, 2024
1 parent a7206a6 commit 971b852
Show file tree
Hide file tree
Showing 28 changed files with 83 additions and 118 deletions.
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
if (llvm::none_of(operandTypes,
[](Type type) { return isa<VectorType>(type); })) {
if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
return rewriter.notifyMatchFailure(op, "expected vector operand");
}
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
Expand Down
10 changes: 3 additions & 7 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ template <typename ExtOpTy>
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
return false;
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
return isa<vector::ContractionOp>(user);
});
return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
}

static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
Expand Down Expand Up @@ -345,15 +343,13 @@ getSliceContract(Operation *op,
static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
bool useNvGpu) {
auto hasVectorDest = [](Operation *op) {
return llvm::any_of(op->getResultTypes(),
[](Type t) { return isa<VectorType>(t); });
return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
};
BackwardSliceOptions backwardSliceOptions;
backwardSliceOptions.filter = hasVectorDest;

auto hasVectorSrc = [](Operation *op) {
return llvm::any_of(op->getOperandTypes(),
[](Type t) { return isa<VectorType>(t); });
return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
};
ForwardSliceOptions forwardSliceOptions;
forwardSliceOptions.filter = hasVectorSrc;
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {

bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
// Any memref-typed iteration arguments are treated as serializing.
if (llvm::any_of(forOp.getResultTypes(),
[](Type type) { return isa<BaseMemRefType>(type); }))
if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>))
return false;

// Collect all load and store ops in loop nest rooted at 'forOp'.
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
}

static NestedPattern &vectorTransferPattern() {
static auto pattern = affine::matcher::Op([](Operation &op) {
return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
});
static auto pattern = affine::matcher::Op(
llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>);
return pattern;
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);

// Return common loop depth for loads if there are no store ops.
if (all_of(targetDstOps,
[&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
return loopDepth;

// Check dependences on all pairs of ops in 'targetDstOps' and store the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ struct FuncOpInterface
static bool supportsUnstructuredControlFlow() { return true; }

bool hasTensorSemantics(Operation *op) const {
auto isaTensor = [](Type type) { return isa<TensorType>(type); };
auto isaTensor = llvm::IsaPred<TensorType>;

// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"

using namespace mlir;
Expand Down Expand Up @@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,

/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
return llvm::any_of(funcOp.getFunctionType().getInputs(),
llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getFunctionType().getResults(),
llvm::IsaPred<TensorType>);
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() {
}
}

if (llvm::any_of(getResultTypes(),
[](Type type) { return isa<ArrayType>(type); })) {
if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
return emitOpError() << "cannot return array type";
}

Expand Down
24 changes: 8 additions & 16 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
"scf.forall op requires a mapping attribute");
}

bool hasBlockMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUBlockMappingAttr>(attr);
});
bool hasWarpgroupMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUWarpgroupMappingAttr>(attr);
});
bool hasWarpMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUWarpMappingAttr>(attr);
});
bool hasThreadMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUThreadMappingAttr>(attr);
});
bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<GPUBlockMappingAttr>);
bool hasWarpgroupMapping = llvm::any_of(
forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<GPUWarpMappingAttr>);
bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<GPUThreadMappingAttr>);
int64_t countMappingTypes = 0;
countMappingTypes += hasBlockMapping ? 1 : 0;
countMappingTypes += hasWarpgroupMapping ? 1 : 0;
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// control flow code.
static bool areAllUsersExecuteOrAwait(Value token) {
return !token.use_empty() &&
llvm::all_of(token.getUsers(), [](Operation *user) {
return isa<async::ExecuteOp, async::AwaitOp>(user);
});
llvm::all_of(token.getUsers(),
llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
}

// Add the `asyncToken` as dependency as needed after `op`.
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() {
if (!resultType)
return success();

auto isVector = [](Type type) {
return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
type);
};
auto isVector =
llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;

// Due to bitcast requiring both operands to be of the same size, it is not
// possible for only one of the two to be a pointer of vectors.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
Expand Down Expand Up @@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
assert(llvm::all_of(outputTypes,
[](Type t) { return llvm::isa<ShapedType>(t); }));
assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));

SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
Expand Down Expand Up @@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
resultTensorTypes.value_or(TypeRange());
if (!resultTensorTypes)
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
[](Type type) { return llvm::isa<RankedTensorType>(type); });
llvm::IsaPred<RankedTensorType>);

state.addOperands(inputs);
state.addOperands(outputs);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {

// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(),
[](Type type) { return isa<RankedTensorType>(type); });
return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
}

/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3537,15 +3537,14 @@ struct Conv1DGenerator
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
bool setOperKind(Operation *reduceOp) {
int numBlockArguments = llvm::count_if(
reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// Otherwise, if it can be pooling.
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
return !isa<BlockArgument>(v);
});
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
oper = Pool;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
}

static bool isComputeOperation(Operation *op) {
return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
return isa<acc::ParallelOp, acc::LoopOp>(op);
}

namespace {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
if (getMatrixOperands()) {
Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
typeC.getElementType()};
if (!llvm::all_of(elementTypes,
[](Type ty) { return isa<IntegerType>(ty); })) {
if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
return emitOpError("Matrix Operands require all matrix element types to "
"be Integer Types");
}
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input,
}

static bool isErrorPropagationPossible(TypeRange operandTypes) {
return llvm::any_of(operandTypes, [](Type ty) {
return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
});
return llvm::any_of(operandTypes,
llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
}

static LogicalResult verifySizeOrIndexOp(Operation *op) {
Expand Down
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
/// Returns a tuple corresponding to whether range has tensor or vector type.
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return std::make_tuple(
llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
return {llvm::any_of(types, llvm::IsaPred<TensorType>),
llvm::any_of(types, llvm::IsaPred<VectorType>)};
}

static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
Expand All @@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
};
if (inferred.size() != existing.size())
return false;
for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
if (!isCompatible(inferredDim, existingDim))
return false;
return true;
Expand Down Expand Up @@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");

auto rankedOperands = make_filter_range(
op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
auto rankedOperands =
make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);

// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
Expand All @@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
return op->emitOpError("operands don't have broadcast-compatible shapes");
}

auto rankedResults = make_filter_range(
op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
auto rankedResults =
make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);

// If all of the results are unranked then no further verification.
if (rankedResults.empty())
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(outputs.size() == 1 && "expected one output");
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
[](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
llvm::IsaPred<transform::TransformHandleTypeInterface>);
}

//===----------------------------------------------------------------------===//
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape(

AffineMap resMap = op.getIndexingMapsArray()[2];
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
/*symCount=*/0, extents, ctx);
/*symbolCount=*/0, extents, ctx);
// Compose the resMap with the extentsMap, which is a constant map.
AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
assert(
llvm::all_of(expectedMap.getResults(),
[](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
"expected constant extent along all dimensions.");
assert(llvm::all_of(expectedMap.getResults(),
llvm::IsaPred<AffineConstantExpr>) &&
"expected constant extent along all dimensions.");
// Extract the expected shape and build the type.
auto expectedShape = llvm::to_vector<4>(
llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
Expand Down
Loading

0 comments on commit 971b852

Please sign in to comment.