@@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
188
188
// / Returns a tuple corresponding to whether range has tensor or vector type.
189
189
template <typename iterator_range>
190
190
static std::tuple<bool , bool > hasTensorOrVectorType (iterator_range types) {
191
- return std::make_tuple (
192
- llvm::any_of (types, [](Type t) { return isa<TensorType>(t); }),
193
- llvm::any_of (types, [](Type t) { return isa<VectorType>(t); }));
191
+ return {llvm::any_of (types, llvm::IsaPred<TensorType>),
192
+ llvm::any_of (types, llvm::IsaPred<VectorType>)};
194
193
}
195
194
196
195
static bool isCompatibleInferredReturnShape (ArrayRef<int64_t > inferred,
@@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
202
201
};
203
202
if (inferred.size () != existing.size ())
204
203
return false ;
205
- for (auto [inferredDim, existingDim] : llvm::zip (inferred, existing))
204
+ for (auto [inferredDim, existingDim] : llvm::zip_equal (inferred, existing))
206
205
if (!isCompatible (inferredDim, existingDim))
207
206
return false ;
208
207
return true ;
@@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
238
237
std::get<1 >(resultsHasTensorVectorType)))
239
238
return op->emitError (" cannot broadcast vector with tensor" );
240
239
241
- auto rankedOperands = make_filter_range (
242
- op->getOperandTypes (), [](Type t) { return isa <RankedTensorType>(t); } );
240
+ auto rankedOperands =
241
+ make_filter_range ( op->getOperandTypes (), llvm::IsaPred <RankedTensorType>);
243
242
244
243
// If all operands are unranked, then all result shapes are possible.
245
244
if (rankedOperands.empty ())
@@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
257
256
return op->emitOpError (" operands don't have broadcast-compatible shapes" );
258
257
}
259
258
260
- auto rankedResults = make_filter_range (
261
- op->getResultTypes (), [](Type t) { return isa <RankedTensorType>(t); } );
259
+ auto rankedResults =
260
+ make_filter_range ( op->getResultTypes (), llvm::IsaPred <RankedTensorType>);
262
261
263
262
// If all of the results are unranked then no further verification.
264
263
if (rankedResults.empty ())
0 commit comments